使用Java客户端
简介
Java应用可以直接访问TensorFlow serving加载模型提供的服务,我们需要编写Java的gRPC客户端代码。
完整例子
这里有一个导出模型使用Java来访问模型的例子 https://github.com/tobegit3hub/deep_recommend_system/tree/master/java_predict_client 。
使用时通过Maven编译即可,不同模型只需要修改一个Java文件,其他外部依赖已经管理好,建议在此项目中修改使用。
Java客户端实现原理
Java无论是服务端还是客户端都是在独立于grpc的项目中实现,代码在 https://github.com/grpc/grpc-java 。使用时需要引入grpc实现的类,建议使用maven管理依赖,在pom.xml中加入下面的依赖。
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
<version>1.0.0</version>
</dependency>
由于使用grpc还需要用到protobuf生成的Java代码,如果通过命令生成再拷贝jar文件不好管理,可以使用maven插件,把proto文件拷贝到指定目录,在编译时就会自动生成java文件放到target目录。
<build>
<extensions>
<extension>
<groupId>kr.motd.maven</groupId>
<artifactId>os-maven-plugin</artifactId>
<version>1.4.1.Final</version>
</extension>
</extensions>
<plugins>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.5.0</version>
<configuration>
<!--
The version of protoc must match protobuf-java. If you don't depend on
protobuf-java directly, you will be transitively depending on the
protobuf-java version that grpc depends on.
-->
<protocArtifact>com.google.protobuf:protoc:3.0.0:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:1.0.0:exe:${os.detected.classifier}</pluginArtifact>
</configuration>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>compile-custom</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
注意我们需要加入TensorFlow serving和TensorFlow项目的proto文件,由于我们不使用bazel编译,因此proto文件的依赖路径需要修改,建议参考上面的完整项目。
构造TensorProto对象
使用protobuf定义了请求的接口,但我们还需要构建protobuf生成代码中的TensorProto对象,本质上是一个多维数据,在C++和Python中都有函数可以直接生成。
Java可以定义多维数据,然后参考这个Stackoverflow答案来构建 http://stackoverflow.com/questions/39443019/how-can-i-create-tensorproto-for-tensorflow-in-java ,下面是一个构建二位TensorProto的代码。
// Generate features TensorProto
float[][] featuresTensorData = new float[][]{
{10f, 10f, 10f, 8f, 6f, 1f, 8f, 9f, 1f},
{10f, 10f, 10f, 8f, 6f, 1f, 8f, 9f, 1f},
};
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT);
for (int i = 0; i < featuresTensorData.length; ++i) {
for (int j = 0; j < featuresTensorData[i].length; ++j) {
featuresTensorBuilder.addFloatVal(featuresTensorData[i][j]);
}
}
TensorShapeProto.Dim dim1 = TensorShapeProto.Dim.newBuilder().setSize(2).build();
TensorShapeProto.Dim dim2 = TensorShapeProto.Dim.newBuilder().setSize(9).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim1).addDim(dim2).build();
featuresTensorBuilder.setTensorShape(shape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();
注意除了设置data,shape和dtype都需要我们手动设置,否则服务端无法解析TensorProto成tensor对象。
读取图片文件生成TensorProto
在图像分类等场景中,我们需要读取图片文件生成TensorProto对象,才可以通过gRPC请求TensorFlow serving服务,这里提供一个Java例子,测试支持jpg和png图片格式。
这里有完整的使用CNN训练模型和inference的例子,Java客户端可以直接读取本地文件来请求服务进行预测和分类 https://github.com/tobegit3hub/deep_cnn/tree/master/java_predict_client 。
// Generate image file to array
int[][][][] featuresTensorData = new int[2][32][32][3];
String[] imageFilenames = new String[]{"../data/inference/Mew.png", "../data/inference/Pikachu.png"};
for (int i = 0; i < imageFilenames.length; i++) {
// Convert image file to multi-dimension array
File imageFile = new File(imageFilenames[i]);
try {
BufferedImage image = ImageIO.read(imageFile);
logger.info("Start to convert the image: " + imageFile.getPath());
int imageWidth = 32;
int imageHeight = 32;
int[][] imageArray = new int[imageHeight][imageWidth];
for (int row = 0; row < imageHeight; row++) {
for (int column = 0; column < imageWidth; column++) {
imageArray[row][column] = image.getRGB(column, row);
int pixel = image.getRGB(column, row);
int red = (pixel >> 16) & 0xff;
int green = (pixel >> 8) & 0xff;
int blue = pixel & 0xff;
featuresTensorData[i][row][column][0] = red;
featuresTensorData[i][row][column][1] = green;
featuresTensorData[i][row][column][2] = blue;
}
}
} catch (IOException e) {
logger.log(Level.WARNING, e.getMessage());
System.exit(1);
}
}
// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
for (int i = 0; i < featuresTensorData.length; ++i) {
for (int j = 0; j < featuresTensorData[i].length; ++j) {
for (int k = 0; k < featuresTensorData[i][j].length; ++k) {
for (int l = 0; l < featuresTensorData[i][j][k].length; ++l) {
featuresTensorBuilder.addFloatVal(featuresTensorData[i][j][k][l]);
}
}
}
}
TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(2).build();
TensorShapeProto.Dim featuresDim2 = TensorShapeProto.Dim.newBuilder().setSize(32).build();
TensorShapeProto.Dim featuresDim3 = TensorShapeProto.Dim.newBuilder().setSize(32).build();
TensorShapeProto.Dim featuresDim4 = TensorShapeProto.Dim.newBuilder().setSize(3).build();
TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).addDim(featuresDim2).addDim(featuresDim3).addDim(featuresDim4).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();
原文: http://docs.api.xiaomi.com/cloud-ml/modelservice/0903_use_java_client.html