前景提要
原本的目的是移植一个模型到安卓,遇到问题后,重新做了个简单的模型验证,出现同样的问题。
python 训练的代码 model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])]) model.compile(optimizer='sgd', loss='mean_squared_error') xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=np.float32) ys = np.array([-3.0, -1.0, 0.0, 3.0, 5.0, 7.0], dtype=np.float32) model.fit(xs, ys, epochs=500) keras_file = 'linear.h5' keras.models.save_model(model, keras_file)
转换成 .tflite 后,在安卓使用 Interpreter interpreter = new Interpreter(FileUtil.loadMappedFile(activity, "linear.tflite")); interpreter.allocateTensors(); int probabilityTensorIndex = 0; int[] probabilityShape = interpreter.getOutputTensor(probabilityTensorIndex).shape(); // DataType probabilityDataType = interpreter.getOutputTensor(probabilityTensorIndex).dataType(); TensorBuffer outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); int inputTensorIndex = 0; DataType inputDataType = interpreter.getInputTensor(inputTensorIndex).dataType(); int[] inputShape = interpreter.getInputTensor(inputTensorIndex).shape(); TensorBuffer inputBuffer = TensorBuffer.createFixedSize(inputShape, inputDataType); final float[] input = {10}; inputBuffer.loadArray(input); interpreter.run(inputBuffer, outputProbabilityBuffer);
报错是 I/tflite: Initialized TensorFlow Lite runtime. E/AndroidRuntime: FATAL EXCEPTION: inference Process: com.example.my1application, PID: 26839 java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat at org.tensorflow.lite.Tensor.dataTypeOf(Tensor.java:344) at org.tensorflow.lite.Tensor.throwIfTypeIsIncompatible(Tensor.java:397) at org.tensorflow.lite.Tensor.getInputShapeIfDifferent(Tensor.java:287) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:137) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:316) at org.tensorflow.lite.Interpreter.run(Interpreter.java:277) at com.example.my1application.DisplayMessageActivity$1.run(DisplayMessageActivity.java:114) at android.os.Handler.handleCallback(Handler.java:815) at android.os.Handler.dispatchMessage(Handler.java:104) at android.os.Looper.loop(Looper.java:207) at android.os.HandlerThread.run(HandlerThread.java:61)