diff --git a/.gitignore b/.gitignore index 0b1f0bf..7f8f4ce 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ tmp/ .project dist/ .DS_Store +models/ +outputs/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e59eb58 --- /dev/null +++ b/README.md @@ -0,0 +1,68 @@ +# TensorFlow Java Examples + +This repository contains examples for [TensorFlow-Java](https://github.com/tensorflow/java). + +## Example Models + +There are five example models: a LeNet CNN, a VGG CNN, inference using Faster-RCNN, a linear regression and a logistic regression. + +### Faster-RCNN + +The Faster-RCNN inference example is in `org.tensorflow.model.examples.cnn.fastrcnn`. + +Download the model from https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_1024x1024/1 + +Unzip then untar the model to a local folder - I've used models/faster_rcnn_inception_resnet_v2_1024x1024. + +Create a testimages folder then add some test images into a testimages folder + +To run the example add the input image and output image as parameters: + +```shell +java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-jar-with-dependencies.jar org.tensorflow.model.examples.cnn.fastrcnn.FasterRcnnInception testimages/image2.jpg image2rcnn.jpg +``` + +### LeNet CNN + +The LeNet example runs on MNIST which is stored in the project's resource directory. It is found in +`org.tensorflow.model.examples.cnn.lenet`, and can be run with: + +```shell +java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.cnn.lenet.CnnMnist +``` + +### VGG + +The VGG11 example runs on FashionMNIST, stored in the project's resource directory. It is found in +`org.tensorflow.model.examples.cnn.vgg`, and can be run with: + +```shell +java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.cnn.vgg.VGG11OnFashionMnist +``` + +### Linear Regression + +The linear regression example runs on hard coded data. It is found in `org.tensorflow.model.examples.regression.linear` +and can be run with: + +```shell +java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.regression.linear.LinearRegressionExample +``` + +### Logistic Regression + +The logistic regression example runs on MNIST, stored in the project's resource directory. It is found in +`org.tensorflow.model.examples.dense.SimpleMnist`, and can be run with: + +```shell +java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.dense.SimpleMnist +``` + +## Contributions + +Contributions of other example models are welcome, for instructions please see the +[Contributor guidelines](https://github.com/tensorflow/java/blob/master/CONTRIBUTING.md) in TensorFlow-Java. + +## Development + +This repository tracks TensorFlow-Java and the head will be updated with new releases of TensorFlow-Java. diff --git a/tensorflow-examples/pom.xml b/pom.xml similarity index 64% rename from tensorflow-examples/pom.xml rename to pom.xml index 46a2488..08df035 100644 --- a/tensorflow-examples/pom.xml +++ b/pom.xml @@ -1,8 +1,13 @@ - + + 4.0.0 - org.tensorflow.model + + org.tensorflow tensorflow-examples - 0.1.0-SNAPSHOT + 1.0.0-tfj-1.0.0-rc.2 TensorFlow Examples A suite of executable examples using TensorFlow Java @@ -10,21 +15,25 @@ - 1.8 - 1.8 + 17 + 17 + 17 + 1.0.0-rc.2 + org.tensorflow tensorflow-core-platform - 0.1.0-SNAPSHOT + ${tensorflow.version} org.tensorflow - tensorflow-training - 0.1.0-SNAPSHOT + tensorflow-framework + ${tensorflow.version} + @@ -40,7 +49,7 @@ - org.tensorflow.model.examples.mnist.SimpleMnist + org.tensorflow.model.examples.dense.SimpleMnist diff --git a/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java b/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java new file mode 100644 index 0000000..9e2a4c9 --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java @@ -0,0 +1,359 @@ +/* + * Copyright 2021, 2024 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.model.examples.cnn.fastrcnn; +/* + +From the web page this is the output dictionary + +num_detections: a tf.int tensor with only one value, the number of detections [N]. +detection_boxes: a tf.float32 tensor of shape [N, 4] containing bounding box coordinates in the following order: [ymin, xmin, ymax, xmax]. +detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file. +detection_scores: a tf.float32 tensor of shape [N] containing detection scores. +raw_detection_boxes: a tf.float32 tensor of shape [1, M, 4] containing decoded detection boxes without Non-Max suppression. M is the number of raw detections. +raw_detection_scores: a tf.float32 tensor of shape [1, M, 90] and contains class score logits for raw detection boxes. M is the number of raw detections. +detection_anchor_indices: a tf.float32 tensor of shape [N] and contains the anchor indices of the detections after NMS. +detection_multiclass_scores: a tf.float32 tensor of shape [1, N, 90] and contains class score distribution (including background) for detection boxes in the image including background class. + +However using +venv\Scripts\python.exe venv\Lib\site-packages\tensorflow\python\tools\saved_model_cli.py show --dir models\faster_rcnn_inception_resnet_v2_1024x1024 --all +2021-03-19 12:25:37.000143: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudart64_110.dll + +MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: + +signature_def['__saved_model_init_op']: + The given SavedModel SignatureDef contains the following input(s): + The given SavedModel SignatureDef contains the following output(s): + outputs['__saved_model_init_op'] tensor_info: + dtype: DT_INVALID + shape: unknown_rank + name: NoOp + Method name is: + +signature_def['serving_default']: + The given SavedModel SignatureDef contains the following input(s): + inputs['input_tensor'] tensor_info: + dtype: DT_UINT8 + shape: (1, -1, -1, 3) + name: serving_default_input_tensor:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['detection_anchor_indices'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300) + name: StatefulPartitionedCall:0 + outputs['detection_boxes'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300, 4) + name: StatefulPartitionedCall:1 + outputs['detection_classes'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300) + name: StatefulPartitionedCall:2 + outputs['detection_multiclass_scores'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300, 91) + name: StatefulPartitionedCall:3 + outputs['detection_scores'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300) + name: StatefulPartitionedCall:4 + outputs['num_detections'] tensor_info: + dtype: DT_FLOAT + shape: (1) + name: StatefulPartitionedCall:5 + outputs['raw_detection_boxes'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300, 4) + name: StatefulPartitionedCall:6 + outputs['raw_detection_scores'] tensor_info: + dtype: DT_FLOAT + shape: (1, 300, 91) + name: StatefulPartitionedCall:7 + Method name is: tensorflow/serving/predict + +Defined Functions: + Function Name: '__call__' + Option #1 + Callable with: + Argument #1 + input_tensor: TensorSpec(shape=(1, None, None, 3), dtype=tf.uint8, name='input_tensor') + +So it appears there's a discrepancy between the web page and running saved_model_cli as +num_detections: a tf.int tensor with only one value, the number of detections [N]. +but the actual tensor is DT_FLOAT according to saved_model_cli +also the web page states +detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file. +but again the actual tensor is DT_FLOAT according to saved_model_cli. +*/ + + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Result; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.image.DecodeJpeg; +import org.tensorflow.op.image.EncodeJpeg; +import org.tensorflow.op.io.ReadFile; +import org.tensorflow.op.io.WriteFile; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; + + +/** + * Loads an image using ReadFile and DecodeJpeg and then uses the saved model + * faster_rcnn/inception_resnet_v2_1024x1024/1 to detect objects with a detection score greater than 0.3 + * Uses the DrawBounding boxes + */ +public class FasterRcnnInception { + + private final static String[] cocoLabels = new String[]{ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "street sign", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "hat", + "backpack", + "umbrella", + "shoe", + "eye glasses", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "plate", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "mirror", + "dining table", + "window", + "desk", + "toilet", + "door", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "blender", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + "hair brush" + }; + + public static void main(String[] params) { + String outputImagePath; + String imagePath; + + if (params.length == 0) { + imagePath = "src/main/resources/fasterrcnninception/image2.jpg"; + outputImagePath = "outputs/image2rcnn.jpg"; + + } else if (params.length == 2) { + imagePath = params[0]; + outputImagePath = params[1]; + + } else { + throw new IllegalArgumentException("Exactly 0 or 2 parameters required: java FasterRcnnInception [ ]"); + } + // get path to model folder + String modelPath = "models/faster_rcnn_inception_resnet_v2_1024x1024"; + // load saved model + SavedModelBundle model = SavedModelBundle.load(modelPath, "serve"); + //create a map of the COCO 2017 labels + TreeMap cocoTreeMap = new TreeMap<>(); + float cocoCount = 0; + for (String cocoLabel : cocoLabels) { + cocoTreeMap.put(cocoCount, cocoLabel); + cocoCount++; + } + try (Graph g = new Graph(); Session s = new Session(g)) { + Ops tf = Ops.create(g); + Constant fileName = tf.constant(imagePath); + ReadFile readFile = tf.io.readFile(fileName); + Session.Runner runner = s.runner(); + DecodeJpeg.Options options = DecodeJpeg.channels(3L); + DecodeJpeg decodeImage = tf.image.decodeJpeg(readFile.contents(), options); + //fetch image from file + Shape imageShape; + try (var shapeResult = runner.fetch(decodeImage).run()) { + imageShape = shapeResult.get(0).shape(); + } + //reshape the tensor to 4D for input to model + Reshape reshape = tf.reshape(decodeImage, + tf.array(1, + imageShape.asArray()[0], + imageShape.asArray()[1], + imageShape.asArray()[2] + ) + ); + try (var reshapeResult = s.runner().fetch(reshape).run()) { + TUint8 reshapeTensor = (TUint8) reshapeResult.get(0); + Map feedDict = new HashMap<>(); + //The given SavedModel SignatureDef input + feedDict.put("input_tensor", reshapeTensor); + //The given SavedModel MetaGraphDef key + //detection_classes, detectionBoxes etc. are model output names + try (Result outputTensorMap = model.function("serving_default").call(feedDict)) { + TFloat32 numDetections = (TFloat32) outputTensorMap.get("num_detections").get(); + int numDetects = (int) numDetections.getFloat(0); + if (numDetects > 0) { + TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes").get(); + TFloat32 detectionScores = (TFloat32) outputTensorMap.get("detection_scores").get(); + ArrayList boxArray = new ArrayList<>(); + //TODO tf.image.combinedNonMaxSuppression + for (int n = 0; n < numDetects; n++) { + //put probability and position in outputMap + float detectionScore = detectionScores.getFloat(0, n); + //only include those classes with detection score greater than 0.3f + if (detectionScore > 0.3f) { + boxArray.add(detectionBoxes.get(0, n)); + } + } + /* These values are also returned by the FasterRCNN, but we don't use them in this example. + * TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes").get(); + * TFloat32 rawDetectionBoxes = (TFloat32) outputTensorMap.get("raw_detection_boxes").get(); + * TFloat32 rawDetectionScores = (TFloat32) outputTensorMap.get("raw_detection_scores").get(); + * TFloat32 detectionAnchorIndices = (TFloat32) outputTensorMap.get("detection_anchor_indices").get(); + * TFloat32 detectionMulticlassScores = (TFloat32) outputTensorMap.get("detection_multiclass_scores").get(); + */ + //2-D. A list of RGBA colors to cycle through for the boxes. + Operand colors = tf.constant(new float[][]{ + {0.9f, 0.3f, 0.3f, 0.0f}, + {0.3f, 0.3f, 0.9f, 0.0f}, + {0.3f, 0.9f, 0.3f, 0.0f} + }); + Shape boxesShape = Shape.of(1, boxArray.size(), 4); + int boxCount = 0; + //3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding boxes + try (TFloat32 boxes = TFloat32.tensorOf(boxesShape)) { + //batch size of 1 + boxes.setFloat(1, 0, 0, 0); + for (FloatNdArray floatNdArray : boxArray) { + boxes.set(floatNdArray, 0, boxCount); + boxCount++; + } + //Placeholders for boxes and path to outputimage + Placeholder boxesPlaceHolder = tf.placeholder(TFloat32.class, Placeholder.shape(boxesShape)); + Placeholder outImagePathPlaceholder = tf.placeholder(TString.class); + //Create JPEG from the Tensor with quality of 100% + EncodeJpeg.Options jpgOptions = EncodeJpeg.quality(100L); + //convert the 4D input image to normalised 0.0f - 1.0f + //Draw bounding boxes using boxes tensor and list of colors + //multiply by 255 then reshape and recast to TUint8 3D tensor + WriteFile writeFile = tf.io.writeFile(outImagePathPlaceholder, + tf.image.encodeJpeg( + tf.dtypes.cast(tf.reshape( + tf.math.mul( + tf.image.drawBoundingBoxes(tf.math.div( + tf.dtypes.cast(tf.constant(reshapeTensor), + TFloat32.class), + tf.constant(255.0f) + ), + boxesPlaceHolder, colors), + tf.constant(255.0f) + ), + tf.array( + imageShape.asArray()[0], + imageShape.asArray()[1], + imageShape.asArray()[2] + ) + ), TUint8.class), + jpgOptions)); + //output the JPEG to file + s.runner().feed(outImagePathPlaceholder, TString.scalarOf(outputImagePath)) + .feed(boxesPlaceHolder, boxes) + .addTarget(writeFile).run(); + } + } + } + } + } + } +} diff --git a/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/Readme.md b/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/Readme.md new file mode 100644 index 0000000..8810c01 --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/Readme.md @@ -0,0 +1,16 @@ +# FasterRcnnInception + +Download the model from https://www.kaggle.com/models/tensorflow/faster-rcnn-inception-resnet-v2/tensorFlow2/1024x1024/1 + +Unzip then untar the model to a local folder - I've used models/faster_rcnn_inception_resnet_v2_1024x1024. + +Create a testimages folder then add some test images into a testimages folder + +To run the example add the input image and output image as parameters: + +FasterRcnnInception testimages/image2.jpg image2rcnn.jpg + +### Example output +Using the image2.jpg image from https://github.com/tensorflow/models/tree/master/research/object_detection/test_images +![image2rcnn.jpg.](image2rcnn.jpg "Beach") + diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java b/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java similarity index 65% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java rename to src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java index 7819d37..12a0053 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/CnnMnist.java +++ b/src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java @@ -1,19 +1,20 @@ /* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= */ -package org.tensorflow.model.examples.mnist; +package org.tensorflow.model.examples.cnn.lenet; import java.util.Arrays; import java.util.logging.Level; @@ -21,9 +22,20 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.model.examples.mnist.data.ImageBatch; -import org.tensorflow.model.examples.mnist.data.MnistDataset; +import org.tensorflow.framework.optimizers.AdaDelta; +import org.tensorflow.framework.optimizers.AdaGrad; +import org.tensorflow.framework.optimizers.AdaGradDA; +import org.tensorflow.framework.optimizers.Adam; +import org.tensorflow.framework.optimizers.GradientDescent; +import org.tensorflow.framework.optimizers.Momentum; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.framework.optimizers.RMSProp; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; @@ -39,18 +51,6 @@ import org.tensorflow.op.nn.Softmax; import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.random.TruncatedNormal; -import org.tensorflow.tools.Shape; -import org.tensorflow.tools.ndarray.ByteNdArray; -import org.tensorflow.tools.ndarray.FloatNdArray; -import org.tensorflow.tools.ndarray.index.Indices; -import org.tensorflow.training.optimizers.AdaDelta; -import org.tensorflow.training.optimizers.AdaGrad; -import org.tensorflow.training.optimizers.AdaGradDA; -import org.tensorflow.training.optimizers.Adam; -import org.tensorflow.training.optimizers.GradientDescent; -import org.tensorflow.training.optimizers.Momentum; -import org.tensorflow.training.optimizers.Optimizer; -import org.tensorflow.training.optimizers.RMSProp; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TUint8; @@ -74,7 +74,11 @@ public class CnnMnist { public static final String TARGET = "target"; public static final String TRAIN = "train"; public static final String TRAINING_LOSS = "training_loss"; - public static final String INIT = "init"; + + private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; + private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; + private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; + private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; public static Graph build(String optimizerName) { Graph graph = new Graph(); @@ -82,22 +86,22 @@ public static Graph build(String optimizerName) { Ops tf = Ops.create(graph); // Inputs - Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE, + Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.class, Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); Reshape input_reshaped = tf .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); - Placeholder labels = tf.withName(TARGET).placeholder(TUint8.DTYPE); + Placeholder labels = tf.withName(TARGET).placeholder(TUint8.class); // Scaling the features Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); Operand scaledInput = tf.math - .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor), + .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor), scalingFactor); // First conv layer Variable conv1Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE, + .truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.class, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); Conv2d conv1 = tf.nn .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); @@ -112,7 +116,7 @@ public static Graph build(String optimizerName) { // Second conv layer Variable conv2Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE, + .truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.class, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); Conv2d conv2 = tf.nn .conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); @@ -132,7 +136,7 @@ public static Graph build(String optimizerName) { // Fully connected layer Variable fc1Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE, + .truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.class, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); Variable fc1Biases = tf .variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f))); @@ -141,7 +145,7 @@ public static Graph build(String optimizerName) { // Softmax layer Variable fc2Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE, + .truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.class, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); Variable fc2Biases = tf .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); @@ -154,8 +158,7 @@ public static Graph build(String optimizerName) { // Loss function & regularization OneHot oneHot = tf .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn - .softmaxCrossEntropyWithLogits(logits, oneHot); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math .add(tf.nn.l2Loss(fc1Biases), @@ -165,60 +168,39 @@ public static Graph build(String optimizerName) { String lcOptimizerName = optimizerName.toLowerCase(); // Optimizer - Optimizer optimizer; - switch (lcOptimizerName) { - case "adadelta": - optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f); - break; - case "adagradda": - optimizer = new AdaGradDA(graph, 0.01f); - break; - case "adagrad": - optimizer = new AdaGrad(graph, 0.01f); - break; - case "adam": - optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); - break; - case "sgd": - optimizer = new GradientDescent(graph, 0.01f); - break; - case "momentum": - optimizer = new Momentum(graph, 0.01f, 0.9f, false); - break; - case "rmsprop": - optimizer = new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false); - break; - default: - throw new IllegalArgumentException("Unknown optimizer " + optimizerName); - } - logger.info("Optimizer = " + optimizer.toString()); + Optimizer optimizer = switch (lcOptimizerName) { + case "adadelta" -> new AdaDelta(graph, 1f, 0.95f, 1e-8f); + case "adagradda" -> new AdaGradDA(graph, 0.01f); + case "adagrad" -> new AdaGrad(graph, 0.01f); + case "adam" -> new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); + case "sgd" -> new GradientDescent(graph, 0.01f); + case "momentum" -> new Momentum(graph, 0.01f, 0.9f, false); + case "rmsprop" -> new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false); + default -> throw new IllegalArgumentException("Unknown optimizer " + optimizerName); + }; + logger.info("Optimizer = " + optimizer); Op minimize = optimizer.minimize(loss, TRAIN); - tf.init(); - return graph; } public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) { - // Initialises the parameters. - session.runner().addTarget(INIT).run(); - logger.info("Initialised the model parameters"); - int interval = 0; // Train the model for (int i = 0; i < epochs; i++) { for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { - try (Tensor batchImages = TUint8.tensorOf(trainingBatch.images()); - Tensor batchLabels = TUint8.tensorOf(trainingBatch.labels()); - Tensor loss = session.runner() + try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images()); + TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels()); + var result = session.runner() .feed(TARGET, batchLabels) .feed(INPUT_NAME, batchImages) .addTarget(TRAIN) .fetch(TRAINING_LOSS) - .run().get(0).expect(TFloat32.DTYPE)) { + .run()) { + TFloat32 loss = (TFloat32) result.get(0); if (interval % 100 == 0) { logger.log(Level.INFO, - "Iteration = " + interval + ", training loss = " + loss.data().getFloat()); + "Iteration = " + interval + ", training loss = " + loss.getFloat()); } } interval++; @@ -231,17 +213,15 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset int[][] confusionMatrix = new int[10][10]; for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { - try (Tensor transformedInput = TUint8.tensorOf(trainingBatch.images()); - Tensor outputTensor = session.runner() - .feed(INPUT_NAME, transformedInput) - .fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) { - + try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images()); + var result = session.runner().feed(INPUT_NAME, transformedInput).fetch(OUTPUT_NAME).run()) { + TFloat32 outputTensor = (TFloat32) result.get(0); ByteNdArray labelBatch = trainingBatch.labels(); - for (int k = 0; k < labelBatch.shape().size(0); k++) { + for (int k = 0; k < labelBatch.shape().get(0); k++) { byte trueLabel = labelBatch.getByte(k); int predLabel; - predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all())); + predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all())); if (predLabel == trueLabel) { correctCount++; } @@ -268,7 +248,7 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset sb.append("\n"); } - System.out.println(sb.toString()); + System.out.println(sb); } /** @@ -280,7 +260,7 @@ public static void test(Session session, int minibatchSize, MnistDataset dataset public static int argmax(FloatNdArray probabilities) { float maxVal = Float.NEGATIVE_INFINITY; int idx = 0; - for (int i = 0; i < probabilities.shape().size(0); i++) { + for (int i = 0; i < probabilities.shape().get(0); i++) { float curVal = probabilities.getFloat(i); if (curVal > maxVal) { maxVal = curVal; @@ -291,17 +271,30 @@ public static int argmax(FloatNdArray probabilities) { } public static void main(String[] args) { - logger.info( - "Usage: MNISTTest "); + int epochs; + int minibatchSize; + String optimizerName; + + if (args.length == 0) { + epochs = 1; + minibatchSize = 64; + optimizerName = "adam"; + + } else if (args.length == 3) { + epochs = Integer.parseInt(args[0]); + minibatchSize = Integer.parseInt(args[1]); + optimizerName = args[2]; + + } else { + throw new IllegalArgumentException("Usage: MNISTTest "); + } - MnistDataset dataset = MnistDataset.create(0); + MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, + TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); logger.info("Loaded data."); - int epochs = Integer.parseInt(args[0]); - int minibatchSize = Integer.parseInt(args[1]); - - try (Graph graph = build(args[2]); + try (Graph graph = build(optimizerName); Session session = new Session(graph)) { train(session, epochs, minibatchSize, dataset); diff --git a/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMnist.java b/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMnist.java new file mode 100644 index 0000000..9396864 --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMnist.java @@ -0,0 +1,50 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.cnn.vgg; + +import java.util.logging.Logger; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; + +/** + * Trains and evaluates VGG'11 model on FashionMNIST dataset. + */ +public class VGG11OnFashionMnist { + // Hyper-parameters + public static final int EPOCHS = 1; + public static final int BATCH_SIZE = 500; + + // Fashion MNIST dataset paths + public static final String TRAINING_IMAGES_ARCHIVE = "fashionmnist/train-images-idx3-ubyte.gz"; + public static final String TRAINING_LABELS_ARCHIVE = "fashionmnist/train-labels-idx1-ubyte.gz"; + public static final String TEST_IMAGES_ARCHIVE = "fashionmnist/t10k-images-idx3-ubyte.gz"; + public static final String TEST_LABELS_ARCHIVE = "fashionmnist/t10k-labels-idx1-ubyte.gz"; + + private static final Logger logger = Logger.getLogger(VGG11OnFashionMnist.class.getName()); + + public static void main(String[] args) { + logger.info("Data loading."); + MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); + + try (VGGModel vggModel = new VGGModel()) { + logger.info("Model training."); + vggModel.train(dataset, EPOCHS, BATCH_SIZE); + + logger.info("Model evaluation."); + vggModel.test(dataset, BATCH_SIZE); + } + } +} diff --git a/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java b/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java new file mode 100644 index 0000000..f516c87 --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java @@ -0,0 +1,283 @@ +/* + * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.cnn.vgg; + +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.framework.optimizers.Adam; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.Conv2d; +import org.tensorflow.op.nn.MaxPool; +import org.tensorflow.op.nn.Relu; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.random.TruncatedNormal; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TUint8; + +/** + * Describes the VGGModel. + */ +public class VGGModel implements AutoCloseable { + private static final int PIXEL_DEPTH = 255; + private static final int NUM_CHANNELS = 1; + private static final int IMAGE_SIZE = 28; + private static final int NUM_LABELS = MnistDataset.NUM_CLASSES; + private static final long SEED = 123456789L; + + private static final String PADDING_TYPE = "SAME"; + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String TARGET = "target"; + public static final String TRAIN = "train"; + public static final String TRAINING_LOSS = "training_loss"; + + private static final Logger logger = Logger.getLogger(VGGModel.class.getName()); + + private final Graph graph; + + private final Session session; + + public VGGModel() { + graph = compile(); + session = new Session(graph); + } + + public static Graph compile() { + Graph graph = new Graph(); + + Ops tf = Ops.create(graph); + + // Inputs + Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.class, + Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); + Reshape input_reshaped = tf + .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); + Placeholder labels = tf.withName(TARGET).placeholder(TUint8.class); + + // Scaling the features + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math + .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor), + scalingFactor); + + Relu relu1 = vggConv2DLayer("1", tf, scaledInput, new int[]{3, 3, NUM_CHANNELS, 32}, 32); + + MaxPool pool1 = vggMaxPool(tf, relu1); + + Relu relu2 = vggConv2DLayer("2", tf, pool1, new int[]{3, 3, 32, 64}, 64); + + MaxPool pool2 = vggMaxPool(tf, relu2); + + Relu relu3 = vggConv2DLayer("3", tf, pool2, new int[]{3, 3, 64, 128}, 128); + Relu relu4 = vggConv2DLayer("4", tf, relu3, new int[]{3, 3, 128, 128}, 128); + + MaxPool pool3 = vggMaxPool(tf, relu4); + + Relu relu5 = vggConv2DLayer("5", tf, pool3, new int[]{3, 3, 128, 256}, 256); + Relu relu6 = vggConv2DLayer("6", tf, relu5, new int[]{3, 3, 256, 256}, 256); + + MaxPool pool4 = vggMaxPool(tf, relu6); + + Relu relu7 = vggConv2DLayer("7", tf, pool4, new int[]{3, 3, 256, 256}, 256); + Relu relu8 = vggConv2DLayer("8", tf, relu7, new int[]{3, 3, 256, 256}, 256); + + MaxPool pool5 = vggMaxPool(tf, relu8); + + Reshape flatten = vggFlatten(tf, pool5); + + Add loss = buildFCLayersAndRegularization(tf, labels, flatten); + + Optimizer optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); + + optimizer.minimize(loss, TRAIN); + + return graph; + } + + public static Add buildFCLayersAndRegularization(Ops tf, Placeholder labels, Reshape flatten) { + int fcBiasShape = 100; + int[] fcWeightShape = {256, fcBiasShape}; + + Variable fc1Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(fcWeightShape), TFloat32.class, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Biases = tf + .variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f))); + Relu fcRelu = tf.nn + .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + + // Softmax layer + Variable fc2Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.class, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Biases = tf + .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); + + Add logits = tf.math.add(tf.linalg.matMul(fcRelu, fc2Weights), fc2Biases); + + // Predicted outputs + tf.withName(OUTPUT_NAME).nn.softmax(logits); + + // Loss function & regularization + OneHot oneHot = tf + .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math + .add(tf.nn.l2Loss(fc1Biases), + tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + return tf.withName(TRAINING_LOSS).math + .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + } + + public static Reshape vggFlatten(Ops tf, MaxPool pool2) { + return tf.reshape(pool2, tf.concat(Arrays + .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), + tf.array(new int[]{-1})), tf.constant(0))); + } + + public static MaxPool vggMaxPool(Ops tf, Relu relu1) { + return tf.nn + .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), + PADDING_TYPE); + } + + public static Relu vggConv2DLayer(String layerName, Ops tf, Operand scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) { + Variable conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.class, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Conv2d conv = tf.nn + .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable convBias = tf + .withName("bias2d_" + layerName).variable(tf.fill(tf.array(new int[]{convBiasL1Shape}), tf.constant(0.0f))); + return tf.nn.relu(tf.withName("biasAdd_" + layerName).nn.biasAdd(conv, convBias)); + } + + public void train(MnistDataset dataset, int epochs, int minibatchSize) { + int interval = 0; + // Train the model + for (int i = 0; i < epochs; i++) { + for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { + try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images()); + TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels()); + var result = session.runner() + .feed(TARGET, batchLabels) + .feed(INPUT_NAME, batchImages) + .addTarget(TRAIN) + .fetch(TRAINING_LOSS) + .run()) { + TFloat32 loss = (TFloat32) result.get(0); + + logger.log(Level.INFO, + "Iteration = " + interval + ", training loss = " + loss.getFloat()); + + } + interval++; + } + } + } + + public void test(MnistDataset dataset, int minibatchSize) { + int correctCount = 0; + int[][] confusionMatrix = new int[10][10]; + + for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { + try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images()); + var result = session.runner() + .feed(INPUT_NAME, transformedInput) + .fetch(OUTPUT_NAME).run()) { + TFloat32 outputTensor = (TFloat32) result.get(0); + + ByteNdArray labelBatch = trainingBatch.labels(); + for (int k = 0; k < labelBatch.shape().get(0); k++) { + byte trueLabel = labelBatch.getByte(k); + int predLabel; + + predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all())); + if (predLabel == trueLabel) { + correctCount++; + } + + confusionMatrix[trueLabel][predLabel]++; + } + } + } + + logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); + + StringBuilder sb = new StringBuilder(); + sb.append("Label"); + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + } + sb.append("\n"); + + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + for (int j = 0; j < confusionMatrix[i].length; j++) { + sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); + } + sb.append("\n"); + } + + System.out.println(sb); + } + + /** + * Find the maximum probability and return it's index. + * + * @param probabilities The probabilites. + * @return The index of the max. + */ + public static int argmax(FloatNdArray probabilities) { + float maxVal = Float.NEGATIVE_INFINITY; + int idx = 0; + for (int i = 0; i < probabilities.shape().get(0); i++) { + float curVal = probabilities.getFloat(i); + if (curVal > maxVal) { + maxVal = curVal; + idx = i; + } + } + return idx; + } + + @Override + public void close() { + session.close(); + graph.close(); + } +} diff --git a/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java b/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java new file mode 100644 index 0000000..ea1ca63 --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java @@ -0,0 +1,24 @@ +/* + * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.datasets; + +import org.tensorflow.ndarray.ByteNdArray; + +/** + * Batch of images for batch training. + */ +public record ImageBatch(ByteNdArray images, ByteNdArray labels) { } diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java b/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java similarity index 76% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java rename to src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java index 8543501..6651cbd 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatchIterator.java +++ b/src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java @@ -14,16 +14,18 @@ * limitations under the License. * ======================================================================= */ +package org.tensorflow.model.examples.datasets; -package org.tensorflow.model.examples.mnist.data; - -import static org.tensorflow.tools.ndarray.index.Indices.range; +import static org.tensorflow.ndarray.index.Indices.range; import java.util.Iterator; -import org.tensorflow.tools.ndarray.ByteNdArray; -import org.tensorflow.tools.ndarray.index.Index; -class ImageBatchIterator implements Iterator { +import org.tensorflow.ndarray.index.Index; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.index.Index; + +/** Basic batch iterator across images presented in datset. */ +public class ImageBatchIterator implements Iterator { @Override public boolean hasNext() { @@ -38,7 +40,7 @@ public ImageBatch next() { return new ImageBatch(images.slice(range), labels.slice(range)); } - ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) { + public ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) { this.batchSize = batchSize; this.images = images; this.labels = labels; diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java b/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java similarity index 57% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java rename to src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java index c3f3ddc..9d7d479 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/MnistDataset.java +++ b/src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java @@ -14,37 +14,40 @@ * limitations under the License. * ======================================================================= */ +package org.tensorflow.model.examples.datasets.mnist; -package org.tensorflow.model.examples.mnist.data; - -import static org.tensorflow.tools.ndarray.index.Indices.from; -import static org.tensorflow.tools.ndarray.index.Indices.to; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.ImageBatchIterator; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.NdArrays; import java.io.DataInputStream; import java.io.IOException; import java.util.zip.GZIPInputStream; -import org.tensorflow.tools.Shape; -import org.tensorflow.tools.buffer.DataBuffers; -import org.tensorflow.tools.ndarray.ByteNdArray; -import org.tensorflow.tools.ndarray.NdArrays; -public class MnistDataset { +import static org.tensorflow.ndarray.index.Indices.sliceFrom; +import static org.tensorflow.ndarray.index.Indices.sliceTo; +/** Common loader and data preprocessor for MNIST and FashionMNIST datasets. */ +public class MnistDataset { public static final int NUM_CLASSES = 10; - public static MnistDataset create(int validationSize) { + public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive, + String testImagesArchive, String testLabelsArchive) { try { - ByteNdArray trainingImages = readArchive(TRAINING_IMAGES_ARCHIVE); - ByteNdArray trainingLabels = readArchive(TRAINING_LABELS_ARCHIVE); - ByteNdArray testImages = readArchive(TEST_IMAGES_ARCHIVE); - ByteNdArray testLabels = readArchive(TEST_LABELS_ARCHIVE); + ByteNdArray trainingImages = readArchive(trainingImagesArchive); + ByteNdArray trainingLabels = readArchive(trainingLabelsArchive); + ByteNdArray testImages = readArchive(testImagesArchive); + ByteNdArray testLabels = readArchive(testLabelsArchive); if (validationSize > 0) { return new MnistDataset( - trainingImages.slice(from(validationSize)), - trainingLabels.slice(from(validationSize)), - trainingImages.slice(to(validationSize)), - trainingLabels.slice(to(validationSize)), + trainingImages.slice(sliceFrom(validationSize)), + trainingLabels.slice(sliceFrom(validationSize)), + trainingImages.slice(sliceTo(validationSize)), + trainingLabels.slice(sliceTo(validationSize)), testImages, testLabels ); @@ -77,21 +80,17 @@ public long imageSize() { } public long numTrainingExamples() { - return trainingLabels.shape().size(0); + return trainingLabels.shape().get(0); } public long numTestingExamples() { - return testLabels.shape().size(0); + return testLabels.shape().get(0); } public long numValidationExamples() { - return validationLabels.shape().size(0); + return validationLabels.shape().get(0); } - private static final String TRAINING_IMAGES_ARCHIVE = "train-images-idx3-ubyte.gz"; - private static final String TRAINING_LABELS_ARCHIVE = "train-labels-idx1-ubyte.gz"; - private static final String TEST_IMAGES_ARCHIVE = "t10k-images-idx3-ubyte.gz"; - private static final String TEST_LABELS_ARCHIVE = "t10k-labels-idx1-ubyte.gz"; private static final int TYPE_UBYTE = 0x08; private final ByteNdArray trainingImages; @@ -120,24 +119,24 @@ private MnistDataset( } private static ByteNdArray readArchive(String archiveName) throws IOException { - DataInputStream archiveStream = new DataInputStream( - //new GZIPInputStream(new java.io.FileInputStream("src/main/resources/"+archiveName)) + try (DataInputStream archiveStream = new DataInputStream( new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName)) - ); - archiveStream.readShort(); // first two bytes are always 0 - byte magic = archiveStream.readByte(); - if (magic != TYPE_UBYTE) { - throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive"); - } - int numDims = archiveStream.readByte(); - long[] dimSizes = new long[numDims]; - int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE - for (int i = 0; i < dimSizes.length; ++i) { - dimSizes[i] = archiveStream.readInt(); - size *= dimSizes[i]; + )) { + archiveStream.readShort(); // first two bytes are always 0 + byte magic = archiveStream.readByte(); + if (magic != TYPE_UBYTE) { + throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive"); + } + int numDims = archiveStream.readByte(); + long[] dimSizes = new long[numDims]; + int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE + for (int i = 0; i < dimSizes.length; ++i) { + dimSizes[i] = archiveStream.readInt(); + size *= dimSizes[i]; + } + byte[] bytes = new byte[size]; + archiveStream.readFully(bytes); + return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, true, false)); } - byte[] bytes = new byte[size]; - archiveStream.readFully(bytes); - return NdArrays.wrap(DataBuffers.of(bytes, true, false), Shape.of(dimSizes)); } } diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java b/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java similarity index 55% rename from tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java rename to src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java index 749a0b3..29ec654 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/SimpleMnist.java +++ b/src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java @@ -1,29 +1,49 @@ -package org.tensorflow.model.examples.mnist; +/* + * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.dense; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.model.examples.mnist.data.ImageBatch; -import org.tensorflow.model.examples.mnist.data.MnistDataset; +import org.tensorflow.framework.optimizers.GradientDescent; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.model.examples.datasets.ImageBatch; +import org.tensorflow.model.examples.datasets.mnist.MnistDataset; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.RawOp; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.Softmax; -import org.tensorflow.tools.Shape; -import org.tensorflow.tools.ndarray.ByteNdArray; -import org.tensorflow.training.optimizers.GradientDescent; -import org.tensorflow.training.optimizers.Optimizer; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; public class SimpleMnist implements Runnable { + private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; + private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; + private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; + private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; public static void main(String[] args) { - MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE); + MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, + TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); + try (Graph graph = new Graph()) { SimpleMnist mnist = new SimpleMnist(graph, dataset); mnist.run(); @@ -35,21 +55,16 @@ public void run() { Ops tf = Ops.create(graph); // Create placeholders and variables, which should fit batches of an unknown number of images - Placeholder images = tf.placeholder(TFloat32.DTYPE); - Placeholder labels = tf.placeholder(TFloat32.DTYPE); + Placeholder images = tf.placeholder(TFloat32.class); + Placeholder labels = tf.placeholder(TFloat32.class); // Create weights with an initial value of 0 Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES); - Variable weights = tf.variable(weightShape, TFloat32.DTYPE); - tf.initAdd(tf.assign(weights, tf.zerosLike(weights))); + Variable weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class)); // Create biases with an initial value of 0 Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES); - Variable biases = tf.variable(biasShape, TFloat32.DTYPE); - tf.initAdd(tf.assign(biases, tf.zerosLike(biases))); - - // Register all variable initializers for single execution - tf.init(); + Variable biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class)); // Predict the class of each image in the batch and compute the loss Softmax softmax = @@ -77,18 +92,15 @@ public void run() { // Compute the accuracy of the model Operand predicted = tf.math.argMax(softmax, tf.constant(1)); Operand expected = tf.math.argMax(labels, tf.constant(1)); - Operand accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.DTYPE), tf.array(0)); + Operand accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0)); // Run the graph try (Session session = new Session(graph)) { - // Initialize variables - session.run(tf.init()); - // Train the model for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) { - try (Tensor batchImages = preprocessImages(trainingBatch.images()); - Tensor batchLabels = preprocessLabels(trainingBatch.labels())) { + try (TFloat32 batchImages = preprocessImages(trainingBatch.images()); + TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) { session.runner() .addTarget(minimize) .feed(images.asOutput(), batchImages) @@ -99,16 +111,15 @@ public void run() { // Test the model ImageBatch testBatch = dataset.testBatch(); - try (Tensor testImages = preprocessImages(testBatch.images()); - Tensor testLabels = preprocessLabels(testBatch.labels()); - Tensor accuracyValue = session.runner() + try (TFloat32 testImages = preprocessImages(testBatch.images()); + TFloat32 testLabels = preprocessLabels(testBatch.labels()); + var result = session.runner() .fetch(accuracy) .feed(images.asOutput(), testImages) .feed(labels.asOutput(), testLabels) - .run() - .get(0) - .expect(TFloat32.DTYPE)) { - System.out.println("Accuracy: " + accuracyValue.data().getFloat()); + .run()) { + TFloat32 accuracyValue = (TFloat32) result.get(0); + System.out.println("Accuracy: " + accuracyValue.getFloat()); } } } @@ -117,21 +128,21 @@ public void run() { private static final int TRAINING_BATCH_SIZE = 100; private static final float LEARNING_RATE = 0.2f; - private static Tensor preprocessImages(ByteNdArray rawImages) { + private static TFloat32 preprocessImages(ByteNdArray rawImages) { Ops tf = Ops.create(); // Flatten images in a single dimension and normalize their pixels as floats. long imageSize = rawImages.get(0).shape().size(); return tf.math.div( tf.reshape( - tf.dtypes.cast(tf.constant(rawImages), TFloat32.DTYPE), + tf.dtypes.cast(tf.constant(rawImages), TFloat32.class), tf.array(-1L, imageSize) ), tf.constant(255.0f) ).asTensor(); } - private static Tensor preprocessLabels(ByteNdArray rawLabels) { + private static TFloat32 preprocessLabels(ByteNdArray rawLabels) { Ops tf = Ops.create(); // Map labels to one hot vectors where only the expected predictions as a value of 1.0 @@ -143,8 +154,8 @@ private static Tensor preprocessLabels(ByteNdArray rawLabels) { ).asTensor(); } - private Graph graph; - private MnistDataset dataset; + private final Graph graph; + private final MnistDataset dataset; private SimpleMnist(Graph graph, MnistDataset dataset) { this.graph = graph; diff --git a/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java b/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java new file mode 100644 index 0000000..b67deff --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java @@ -0,0 +1,138 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.regression.linear; + +import java.util.List; +import java.util.Random; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.framework.optimizers.GradientDescent; +import org.tensorflow.framework.optimizers.Optimizer; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Div; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Pow; +import org.tensorflow.types.TFloat32; + +/** + * In this example TensorFlow finds the weight and bias of the linear regression during 1 epoch, + * training on observations one by one. + *

+ * Also, the weight and bias are extracted and printed. + */ +public class LinearRegressionExample { + /** + * Amount of data points. + */ + private static final int N = 10; + + /** + * This value is used to fill the Y placeholder in prediction. + */ + public static final float LEARNING_RATE = 0.1f; + public static final String WEIGHT_VARIABLE_NAME = "weight"; + public static final String BIAS_VARIABLE_NAME = "bias"; + + public static void main(String[] args) { + // Prepare the data + float[] xValues = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f}; + float[] yValues = new float[N]; + + Random rnd = new Random(42); + + for (int i = 0; i < yValues.length; i++) { + yValues[i] = (float) (10 * xValues[i] + 2 + 0.1 * (rnd.nextDouble() - 0.5)); + } + + try (Graph graph = new Graph()) { + Ops tf = Ops.create(graph); + + // Define placeholders + Placeholder xData = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar())); + Placeholder yData = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar())); + + // Define variables + Variable weight = tf.withName(WEIGHT_VARIABLE_NAME).variable(tf.constant(1f)); + Variable bias = tf.withName(BIAS_VARIABLE_NAME).variable(tf.constant(1f)); + + // Define the model function weight*x + bias + Mul mul = tf.math.mul(xData, weight); + Add yPredicted = tf.math.add(mul, bias); + + // Define loss function MSE + Pow sum = tf.math.pow(tf.math.sub(yPredicted, yData), tf.constant(2f)); + Div mse = tf.math.div(sum, tf.constant(2f * N)); + + // Back-propagate gradients to variables for training + Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE); + Op minimize = optimizer.minimize(mse); + + try (Session session = new Session(graph)) { + + // Train the model on data + for (int i = 0; i < xValues.length; i++) { + float y = yValues[i]; + float x = xValues[i]; + + try (TFloat32 xTensor = TFloat32.scalarOf(x); + TFloat32 yTensor = TFloat32.scalarOf(y)) { + + session.runner() + .addTarget(minimize) + .feed(xData.asOutput(), xTensor) + .feed(yData.asOutput(), yTensor) + .run(); + + System.out.println("Training phase"); + System.out.println("x is " + x + " y is " + y); + } + } + + // Extract linear regression model weight and bias values + try (var result = session.runner() + .fetch(WEIGHT_VARIABLE_NAME) + .fetch(BIAS_VARIABLE_NAME) + .run()) { + System.out.println("Weight is " + result.get(WEIGHT_VARIABLE_NAME)); + System.out.println("Bias is " + result.get(BIAS_VARIABLE_NAME)); + } + + // Let's predict y for x = 10f + float x = 10f; + float predictedY = 0f; + + try (TFloat32 xTensor = TFloat32.scalarOf(x); + TFloat32 yTensor = TFloat32.scalarOf(predictedY); + TFloat32 yPredictedTensor = (TFloat32)session.runner() + .feed(xData.asOutput(), xTensor) + .feed(yData.asOutput(), yTensor) + .fetch(yPredicted) + .run().get(0)) { + + predictedY = yPredictedTensor.getFloat(); + + System.out.println("Predicted value: " + predictedY); + } + } + } + } +} diff --git a/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java b/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java new file mode 100644 index 0000000..35955d7 --- /dev/null +++ b/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java @@ -0,0 +1,100 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.model.examples.tensors; + +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.types.TInt32; + +import java.util.Arrays; + +/** + * Creates a few tensors of ranks: 0, 1, 2, 3. + */ +public class TensorCreation { + + public static void main(String[] args) { + // Rank 0 Tensor + TInt32 rank0Tensor = TInt32.scalarOf(42); + + System.out.println("---- Scalar tensor ---------"); + + System.out.println("DataType: " + rank0Tensor.dataType().name()); + + System.out.println("Rank: " + rank0Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank0Tensor.shape().asArray())); + + rank0Tensor.scalars().forEach(value -> System.out.println("Value: " + value.getObject())); + + // Rank 1 Tensor + TInt32 rank1Tensor = TInt32.vectorOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + System.out.println("---- Vector tensor ---------"); + + System.out.println("DataType: " + rank1Tensor.dataType().name()); + + System.out.println("Rank: " + rank1Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank1Tensor.shape().asArray())); + + System.out.println("6th element: " + rank1Tensor.getInt(5)); + + // Rank 2 Tensor + // 3x2 matrix of ints. + IntNdArray matrix2d = NdArrays.ofInts(Shape.of(3, 2)); + + matrix2d.set(NdArrays.vectorOf(1, 2), 0) + .set(NdArrays.vectorOf(3, 4), 1) + .set(NdArrays.vectorOf(5, 6), 2); + + TInt32 rank2Tensor = TInt32.tensorOf(matrix2d); + + System.out.println("---- Matrix tensor ---------"); + + System.out.println("DataType: " + rank2Tensor.dataType().name()); + + System.out.println("Rank: " + rank2Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank2Tensor.shape().asArray())); + + System.out.println("6th element: " + rank2Tensor.getInt(2, 1)); + + // Rank 3 Tensor + // 3*2*4 matrix of ints. + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(3, 2, 4)); + + matrix3d.elements(0).forEach(matrix -> { + matrix + .set(NdArrays.vectorOf(1, 2, 3, 4), 0) + .set(NdArrays.vectorOf(5, 6, 7, 8), 1); + }); + + TInt32 rank3Tensor = TInt32.tensorOf(matrix3d); + + System.out.println("---- Matrix tensor ---------"); + + System.out.println("DataType: " + rank3Tensor.dataType().name()); + + System.out.println("Rank: " + rank3Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank3Tensor.shape().asArray())); + + System.out.println("n-th element: " + rank3Tensor.getInt(2, 1, 3)); + } +} diff --git a/tensorflow-examples/src/main/resources/META-INF/MANIFEST.MF b/src/main/resources/META-INF/MANIFEST.MF similarity index 100% rename from tensorflow-examples/src/main/resources/META-INF/MANIFEST.MF rename to src/main/resources/META-INF/MANIFEST.MF diff --git a/src/main/resources/fashionmnist/Readme.md b/src/main/resources/fashionmnist/Readme.md new file mode 100644 index 0000000..95b6f38 --- /dev/null +++ b/src/main/resources/fashionmnist/Readme.md @@ -0,0 +1,6 @@ +This dataset is distributed under MIT License and presented in next paper. +Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747 + +The data was downloaded from the FashionMnist Repository +https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion + diff --git a/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz b/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..667844f Binary files /dev/null and b/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz differ diff --git a/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz b/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..abdddb8 Binary files /dev/null and b/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz differ diff --git a/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz b/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..e6ee0e3 Binary files /dev/null and b/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz differ diff --git a/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz b/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..9c4aae2 Binary files /dev/null and b/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz differ diff --git a/src/main/resources/fasterrcnninception/image2.jpg b/src/main/resources/fasterrcnninception/image2.jpg new file mode 100644 index 0000000..9eb325a Binary files /dev/null and b/src/main/resources/fasterrcnninception/image2.jpg differ diff --git a/src/main/resources/fasterrcnninception/image2rcnn.jpg b/src/main/resources/fasterrcnninception/image2rcnn.jpg new file mode 100644 index 0000000..699c949 Binary files /dev/null and b/src/main/resources/fasterrcnninception/image2rcnn.jpg differ diff --git a/tensorflow-examples/src/main/resources/t10k-images-idx3-ubyte.gz b/src/main/resources/mnist/t10k-images-idx3-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/t10k-images-idx3-ubyte.gz rename to src/main/resources/mnist/t10k-images-idx3-ubyte.gz diff --git a/tensorflow-examples/src/main/resources/t10k-labels-idx1-ubyte.gz b/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/t10k-labels-idx1-ubyte.gz rename to src/main/resources/mnist/t10k-labels-idx1-ubyte.gz diff --git a/tensorflow-examples/src/main/resources/train-images-idx3-ubyte.gz b/src/main/resources/mnist/train-images-idx3-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/train-images-idx3-ubyte.gz rename to src/main/resources/mnist/train-images-idx3-ubyte.gz diff --git a/tensorflow-examples/src/main/resources/train-labels-idx1-ubyte.gz b/src/main/resources/mnist/train-labels-idx1-ubyte.gz similarity index 100% rename from tensorflow-examples/src/main/resources/train-labels-idx1-ubyte.gz rename to src/main/resources/mnist/train-labels-idx1-ubyte.gz diff --git a/tensorflow-examples-legacy/README.md b/tensorflow-examples-legacy/README.md deleted file mode 100644 index 6aa1a30..0000000 --- a/tensorflow-examples-legacy/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# TensorFlow for Java: Examples - -These examples include using pre-trained models for [image -classification](label_image) and [object detection](object_detection), -and driving the [training](training) of a pre-defined model - all using the -TensorFlow Java API. - -The TensorFlow Java API does not have feature parity with the Python API. -The Java API is most suitable for inference using pre-trained models -and for training pre-defined models from a single Java process. - -Python will be the most convenient language for defining the -numerical computation of a model. - -- [Slides](https://docs.google.com/presentation/d/e/2PACX-1vQ6DzxNTBrJo7K5P8t5_rBRGnyJoPUPBVOJR4ooHCwi4TlBFnIriFmI719rDNpcQzojqsV58aUqmBBx/pub?start=false&loop=false&delayms=3000) from January 2018. -- See README.md in each subdirectory for details. - -For a recent real-world example, see the use of this API to [assess microscope -image quality](https://research.googleblog.com/2018/03/using-deep-learning-to-facilitate.html) -in the image processing package [Fiji (ImageJ)](https://fiji.sc/). diff --git a/tensorflow-examples-legacy/docker/Dockerfile b/tensorflow-examples-legacy/docker/Dockerfile deleted file mode 100644 index 7f71f83..0000000 --- a/tensorflow-examples-legacy/docker/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM tensorflow/tensorflow:1.4.0 -WORKDIR / -RUN apt-get update -RUN apt-get -y install maven openjdk-8-jdk -RUN mvn dependency:get -Dartifact=org.tensorflow:tensorflow:1.4.0 -RUN mvn dependency:get -Dartifact=org.tensorflow:proto:1.4.0 -CMD ["/bin/bash", "-l"] diff --git a/tensorflow-examples-legacy/docker/README.md b/tensorflow-examples-legacy/docker/README.md deleted file mode 100644 index eaa3ca3..0000000 --- a/tensorflow-examples-legacy/docker/README.md +++ /dev/null @@ -1,15 +0,0 @@ -Dockerfile for building an image suitable for running the Java examples. - -Typical usage: - -``` -docker build -t java-tensorflow . -docker run -it --rm -v ${PWD}/..:/examples -w /examples java-tensorflow -``` - -That second command will pop you into a shell which has all -the dependencies required to execute the scripts and Java -examples. - -The script `sanity_test.sh` builds this container and runs a compilation -check on all the maven projects. diff --git a/tensorflow-examples-legacy/docker/sanity_test.sh b/tensorflow-examples-legacy/docker/sanity_test.sh deleted file mode 100755 index a4343f2..0000000 --- a/tensorflow-examples-legacy/docker/sanity_test.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -# -# Silly sanity test -DIR="$(cd "$(dirname "$0")" && pwd -P)" - -docker build -t java-tensorflow . -docker run -it --rm -v ${PWD}/..:/examples java-tensorflow bash /examples/docker/test_inside_container.sh diff --git a/tensorflow-examples-legacy/docker/test_inside_container.sh b/tensorflow-examples-legacy/docker/test_inside_container.sh deleted file mode 100644 index 221a023..0000000 --- a/tensorflow-examples-legacy/docker/test_inside_container.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -set -ex - -cd /examples/label_image -mvn compile - -cd /examples/object_detection -mvn compile - -cd /examples/training -mvn compile diff --git a/tensorflow-examples-legacy/label_image/.gitignore b/tensorflow-examples-legacy/label_image/.gitignore deleted file mode 100644 index 9aeb6ae..0000000 --- a/tensorflow-examples-legacy/label_image/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -images -src/main/resources -target diff --git a/tensorflow-examples-legacy/label_image/README.md b/tensorflow-examples-legacy/label_image/README.md deleted file mode 100644 index b2dbef4..0000000 --- a/tensorflow-examples-legacy/label_image/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Image Classification Example - -1. Download the model: - - If you have [TensorFlow 1.4+ for Python installed](https://www.tensorflow.org/install/), - run `python ./download.py` - - If not, but you have [docker](https://www.docker.com/get-docker) installed, - run `download.sh`. - -2. Compile [`LabelImage.java`](src/main/java/LabelImage.java): - - ``` - mvn compile - ``` - -3. Download some sample images: - If you already have some images, great. Otherwise `download_sample_images.sh` - gets a few. - -3. Classify! - - ``` - mvn -q exec:java -Dexec.args="" - ``` diff --git a/tensorflow-examples-legacy/label_image/download.py b/tensorflow-examples-legacy/label_image/download.py deleted file mode 100644 index c9082c2..0000000 --- a/tensorflow-examples-legacy/label_image/download.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Create an image classification graph. - -Script to download a pre-trained image classifier and tweak it so that -the model accepts raw bytes of an encoded image. - -Doing so involves some model-specific normalization of an image. -Ideally, this would have been part of the image classifier model, -but the particular model being used didn't include this normalization, -so this script does the necessary tweaking. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from six.moves import urllib -import os -import zipfile -import tensorflow as tf - -URL = 'https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip' -LABELS_FILE = 'imagenet_comp_graph_label_strings.txt' -GRAPH_FILE = 'tensorflow_inception_graph.pb' - -GRAPH_INPUT_TENSOR = 'input:0' -GRAPH_PROBABILITIES_TENSOR = 'output:0' - -IMAGE_HEIGHT = 224 -IMAGE_WIDTH = 224 -MEAN = 117 -SCALE = 1 - -LOCAL_DIR = 'src/main/resources' - - -def download(): - print('Downloading %s' % URL) - zip_filename, _ = urllib.request.urlretrieve(URL) - with zipfile.ZipFile(zip_filename) as zip: - zip.extract(LABELS_FILE) - zip.extract(GRAPH_FILE) - os.rename(LABELS_FILE, os.path.join(LOCAL_DIR, 'labels.txt')) - os.rename(GRAPH_FILE, os.path.join(LOCAL_DIR, 'graph.pb')) - - -def create_graph_to_decode_and_normalize_image(): - """See file docstring. - - Returns: - input: The placeholder to feed the raw bytes of an encoded image. - y: A Tensor (the decoded, normalized image) to be fed to the graph. - """ - image = tf.placeholder(tf.string, shape=(), name='encoded_image_bytes') - with tf.name_scope("preprocess"): - y = tf.image.decode_image(image, channels=3) - y = tf.cast(y, tf.float32) - y = tf.expand_dims(y, axis=0) - y = tf.image.resize_bilinear(y, (IMAGE_HEIGHT, IMAGE_WIDTH)) - y = (y - MEAN) / SCALE - return (image, y) - - -def patch_graph(): - """Create graph.pb that applies the model in URL to raw image bytes.""" - with tf.Graph().as_default() as g: - input_image, image_normalized = create_graph_to_decode_and_normalize_image() - original_graph_def = tf.GraphDef() - with open(os.path.join(LOCAL_DIR, 'graph.pb')) as f: - original_graph_def.ParseFromString(f.read()) - softmax = tf.import_graph_def( - original_graph_def, - name='inception', - input_map={GRAPH_INPUT_TENSOR: image_normalized}, - return_elements=[GRAPH_PROBABILITIES_TENSOR]) - # We're constructing a graph that accepts a single image (as opposed to a - # batch of images), so might as well make the output be a vector of - # probabilities, instead of a batch of vectors with batch size 1. - output_probabilities = tf.squeeze(softmax, name='probabilities') - # Overwrite the graph. - with open(os.path.join(LOCAL_DIR, 'graph.pb'), 'w') as f: - f.write(g.as_graph_def().SerializeToString()) - print('------------------------------------------------------------') - print('MODEL GRAPH : graph.pb') - print('LABELS : labels.txt') - print('INPUT TENSOR : %s' % input_image.op.name) - print('OUTPUT TENSOR: %s' % output_probabilities.op.name) - - -if __name__ == '__main__': - if not os.path.exists(LOCAL_DIR): - os.makedirs(LOCAL_DIR) - download() - patch_graph() diff --git a/tensorflow-examples-legacy/label_image/download.sh b/tensorflow-examples-legacy/label_image/download.sh deleted file mode 100755 index 22ca88b..0000000 --- a/tensorflow-examples-legacy/label_image/download.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -DIR="$(cd "$(dirname "$0")" && pwd -P)" -docker run -it -v ${DIR}:/x -w /x --rm tensorflow/tensorflow:1.4.0 python download.py diff --git a/tensorflow-examples-legacy/label_image/download_sample_images.sh b/tensorflow-examples-legacy/label_image/download_sample_images.sh deleted file mode 100755 index 17deb84..0000000 --- a/tensorflow-examples-legacy/label_image/download_sample_images.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -DIR=$(dirname $0) -mkdir -p ${DIR}/images -cd ${DIR}/images - -# Some random images -curl -o "porcupine.jpg" -L "https://cdn.pixabay.com/photo/2014/11/06/12/46/porcupines-519145_960_720.jpg" -curl -o "whale.jpg" -L "https://static.pexels.com/photos/417196/pexels-photo-417196.jpeg" -curl -o "terrier1u.jpg" -L "https://upload.wikimedia.org/wikipedia/commons/3/34/Australian_Terrier_Melly_%282%29.JPG" -curl -o "terrier2.jpg" -L "https://cdn.pixabay.com/photo/2014/05/13/07/44/yorkshire-terrier-343198_960_720.jpg" diff --git a/tensorflow-examples-legacy/label_image/pom.xml b/tensorflow-examples-legacy/label_image/pom.xml deleted file mode 100644 index 96ace38..0000000 --- a/tensorflow-examples-legacy/label_image/pom.xml +++ /dev/null @@ -1,26 +0,0 @@ - - 4.0.0 - org.myorg - label-image - 1.0-SNAPSHOT - - LabelImage - - - 1.7 - 1.7 - - - - org.tensorflow - tensorflow - 1.4.0 - - - - com.google.guava - guava - 23.6-jre - - - diff --git a/tensorflow-examples-legacy/label_image/src/main/java/LabelImage.java b/tensorflow-examples-legacy/label_image/src/main/java/LabelImage.java deleted file mode 100644 index 1bcd906..0000000 --- a/tensorflow-examples-legacy/label_image/src/main/java/LabelImage.java +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import com.google.common.io.ByteStreams; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.List; -import org.tensorflow.Graph; -import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.Tensors; - -/** - * Simplified version of - * https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java - */ -public class LabelImage { - public static void main(String[] args) throws Exception { - if (args.length < 1) { - System.err.println("USAGE: Provide a list of image filenames"); - System.exit(1); - } - final List labels = loadLabels(); - try (Graph graph = new Graph(); - Session session = new Session(graph)) { - graph.importGraphDef(loadGraphDef()); - - float[] probabilities = null; - for (String filename : args) { - byte[] bytes = Files.readAllBytes(Paths.get(filename)); - try (Tensor input = Tensors.create(bytes); - Tensor output = - session - .runner() - .feed("encoded_image_bytes", input) - .fetch("probabilities") - .run() - .get(0) - .expect(Float.class)) { - if (probabilities == null) { - probabilities = new float[(int) output.shape()[0]]; - } - output.copyTo(probabilities); - int label = argmax(probabilities); - System.out.printf( - "%-30s --> %-15s (%.2f%% likely)\n", - filename, labels.get(label), probabilities[label] * 100.0); - } - } - } - } - - private static byte[] loadGraphDef() throws IOException { - try (InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("graph.pb")) { - return ByteStreams.toByteArray(is); - } - } - - private static ArrayList loadLabels() throws IOException { - ArrayList labels = new ArrayList(); - String line; - final InputStream is = LabelImage.class.getClassLoader().getResourceAsStream("labels.txt"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(is))) { - while ((line = reader.readLine()) != null) { - labels.add(line); - } - } - return labels; - } - - private static int argmax(float[] probabilities) { - int best = 0; - for (int i = 1; i < probabilities.length; ++i) { - if (probabilities[i] > probabilities[best]) { - best = i; - } - } - return best; - } -} diff --git a/tensorflow-examples-legacy/object_detection/.gitignore b/tensorflow-examples-legacy/object_detection/.gitignore deleted file mode 100644 index 8149788..0000000 --- a/tensorflow-examples-legacy/object_detection/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -images -labels -models -src/main/protobuf -target diff --git a/tensorflow-examples-legacy/object_detection/README.md b/tensorflow-examples-legacy/object_detection/README.md deleted file mode 100644 index 3bb554a..0000000 --- a/tensorflow-examples-legacy/object_detection/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Object Detection in Java - -Example of using pre-trained models of the [TensorFlow Object Detection -API](https://github.com/tensorflow/models/tree/master/research/object_detection) -in Java. - -## Quickstart - -1. Download some metadata files: - ``` - ./download.sh - ``` - -2. Download a model from the [object detection API model - zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md). - For example: - ``` - mkdir -p models - curl -L \ - http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz \ - | tar -xz -C models/ - ``` - -3. Have some test images handy. For example: - ``` - mkdir -p images - curl -L -o images/test.jpg \ - https://pixnio.com/free-images/people/mother-father-and-children-washing-dog-labrador-retriever-outside-in-the-fresh-air-725x483.jpg - ``` - -4. Compile and run! - ``` - mvn -q compile exec:java \ - -Dexec.args="models/ssd_inception_v2_coco_2017_11_17/saved_model labels/mscoco_label_map.pbtxt images/test.jpg" - ``` - -## Notes - -- This example demonstrates the use of the TensorFlow [SavedModel - format](https://www.tensorflow.org/guide/saved_model). If you have - TensorFlow for Python installed, you could explore the model to get the names - of the tensors using `saved_model_cli` command. For example: - ``` - saved_model_cli show --dir models/ssd_inception_v2_coco_2017_11_17/saved_model/ --all - ``` - -- The file in `src/main/object_detection/protos/` was generated using: - - ``` - ./download.sh - protoc -Isrc/main/protobuf --java_out=src/main/java src/main/protobuf/string_int_label_map.proto - ``` - - Where `protoc` was downloaded from - https://github.com/google/protobuf/releases/tag/v3.5.1 diff --git a/tensorflow-examples-legacy/object_detection/download.sh b/tensorflow-examples-legacy/object_detection/download.sh deleted file mode 100755 index f301af2..0000000 --- a/tensorflow-examples-legacy/object_detection/download.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -set -ex - -DIR="$(cd "$(dirname "$0")" && pwd -P)" -cd "${DIR}" - -# The protobuf file needed for mapping labels to human readable names. -# From: -# https://github.com/tensorflow/models/blob/f87a58c/research/object_detection/protos/string_int_label_map.proto -mkdir -p src/main/protobuf -curl -L -o src/main/protobuf/string_int_label_map.proto "https://raw.githubusercontent.com/tensorflow/models/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/object_detection/protos/string_int_label_map.proto" - -# Labels from: -# https://github.com/tensorflow/models/tree/865c14c/research/object_detection/data -mkdir -p labels -curl -L -o labels/mscoco_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/mscoco_label_map.pbtxt" -curl -L -o labels/oid_bbox_trainable_label_map.pbtxt "https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/oid_bbox_trainable_label_map.pbtxt" diff --git a/tensorflow-examples-legacy/object_detection/pom.xml b/tensorflow-examples-legacy/object_detection/pom.xml deleted file mode 100644 index c9123da..0000000 --- a/tensorflow-examples-legacy/object_detection/pom.xml +++ /dev/null @@ -1,25 +0,0 @@ - - 4.0.0 - org.myorg - detect-objects - 1.0-SNAPSHOT - - DetectObjects - - - 1.7 - 1.7 - - - - org.tensorflow - tensorflow - 1.4.0 - - - org.tensorflow - proto - 1.4.0 - - - diff --git a/tensorflow-examples-legacy/object_detection/src/main/java/DetectObjects.java b/tensorflow-examples-legacy/object_detection/src/main/java/DetectObjects.java deleted file mode 100644 index 6f74240..0000000 --- a/tensorflow-examples-legacy/object_detection/src/main/java/DetectObjects.java +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap; -import static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem; - -import com.google.protobuf.TextFormat; -import java.awt.Graphics2D; -import java.awt.image.BufferedImage; -import java.awt.image.DataBufferByte; -import java.io.File; -import java.io.IOException; -import java.io.PrintStream; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.List; -import java.util.Map; -import javax.imageio.ImageIO; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Tensor; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; -import org.tensorflow.types.UInt8; - -/** - * Java inference for the Object Detection API at: - * https://github.com/tensorflow/models/blob/master/research/object_detection/ - */ -public class DetectObjects { - public static void main(String[] args) throws Exception { - if (args.length < 3) { - printUsage(System.err); - System.exit(1); - } - final String[] labels = loadLabels(args[1]); - try (SavedModelBundle model = SavedModelBundle.load(args[0], "serve")) { - printSignature(model); - for (int arg = 2; arg < args.length; arg++) { - final String filename = args[arg]; - List> outputs = null; - try (Tensor input = makeImageTensor(filename)) { - outputs = - model - .session() - .runner() - .feed("image_tensor", input) - .fetch("detection_scores") - .fetch("detection_classes") - .fetch("detection_boxes") - .run(); - } - try (Tensor scoresT = outputs.get(0).expect(Float.class); - Tensor classesT = outputs.get(1).expect(Float.class); - Tensor boxesT = outputs.get(2).expect(Float.class)) { - // All these tensors have: - // - 1 as the first dimension - // - maxObjects as the second dimension - // While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates). - // This can be verified by looking at scoresT.shape() etc. - int maxObjects = (int) scoresT.shape()[1]; - float[] scores = scoresT.copyTo(new float[1][maxObjects])[0]; - float[] classes = classesT.copyTo(new float[1][maxObjects])[0]; - float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0]; - // Print all objects whose score is at least 0.5. - System.out.printf("* %s\n", filename); - boolean foundSomething = false; - for (int i = 0; i < scores.length; ++i) { - if (scores[i] < 0.5) { - continue; - } - foundSomething = true; - System.out.printf("\tFound %-20s (score: %.4f)\n", labels[(int) classes[i]], scores[i]); - } - if (!foundSomething) { - System.out.println("No objects detected with a high enough score."); - } - } - } - } - } - - private static void printSignature(SavedModelBundle model) throws Exception { - MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef()); - SignatureDef sig = m.getSignatureDefOrThrow("serving_default"); - int numInputs = sig.getInputsCount(); - int i = 1; - System.out.println("MODEL SIGNATURE"); - System.out.println("Inputs:"); - for (Map.Entry entry : sig.getInputsMap().entrySet()) { - TensorInfo t = entry.getValue(); - System.out.printf( - "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n", - i++, numInputs, entry.getKey(), t.getName(), t.getDtype()); - } - int numOutputs = sig.getOutputsCount(); - i = 1; - System.out.println("Outputs:"); - for (Map.Entry entry : sig.getOutputsMap().entrySet()) { - TensorInfo t = entry.getValue(); - System.out.printf( - "%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n", - i++, numOutputs, entry.getKey(), t.getName(), t.getDtype()); - } - System.out.println("-----------------------------------------------"); - } - - private static String[] loadLabels(String filename) throws Exception { - String text = new String(Files.readAllBytes(Paths.get(filename)), StandardCharsets.UTF_8); - StringIntLabelMap.Builder builder = StringIntLabelMap.newBuilder(); - TextFormat.merge(text, builder); - StringIntLabelMap proto = builder.build(); - int maxId = 0; - for (StringIntLabelMapItem item : proto.getItemList()) { - if (item.getId() > maxId) { - maxId = item.getId(); - } - } - String[] ret = new String[maxId + 1]; - for (StringIntLabelMapItem item : proto.getItemList()) { - ret[item.getId()] = item.getDisplayName(); - } - return ret; - } - - private static void bgr2rgb(byte[] data) { - for (int i = 0; i < data.length; i += 3) { - byte tmp = data[i]; - data[i] = data[i + 2]; - data[i + 2] = tmp; - } - } - - private static Tensor makeImageTensor(String filename) throws IOException { - BufferedImage img = ImageIO.read(new File(filename)); - if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) { - BufferedImage newImage = new BufferedImage( - img.getWidth(), img.getHeight(), BufferedImage.TYPE_3BYTE_BGR); - Graphics2D g = newImage.createGraphics(); - g.drawImage(img, 0, 0, img.getWidth(), img.getHeight(), null); - g.dispose(); - img = newImage; - } - - byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData(); - // ImageIO.read seems to produce BGR-encoded images, but the model expects RGB. - bgr2rgb(data); - final long BATCH_SIZE = 1; - final long CHANNELS = 3; - long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS}; - return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data)); - } - - private static void printUsage(PrintStream s) { - s.println("USAGE: [] []"); - s.println(""); - s.println("Where"); - s.println(" is the path to the SavedModel directory of the model to use."); - s.println(" For example, the saved_model directory in tarballs from "); - s.println( - " https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)"); - s.println(""); - s.println( - " is the path to a file containing information about the labels detected by the model."); - s.println(" For example, one of the .pbtxt files from "); - s.println( - " https://github.com/tensorflow/models/tree/master/research/object_detection/data"); - s.println(""); - s.println(" is the path to an image file."); - s.println(" Sample images can be found from the COCO, Kitti, or Open Images dataset."); - s.println( - " See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md"); - } -} diff --git a/tensorflow-examples-legacy/object_detection/src/main/java/object_detection/protos/StringIntLabelMapOuterClass.java b/tensorflow-examples-legacy/object_detection/src/main/java/object_detection/protos/StringIntLabelMapOuterClass.java deleted file mode 100644 index a6808ce..0000000 --- a/tensorflow-examples-legacy/object_detection/src/main/java/object_detection/protos/StringIntLabelMapOuterClass.java +++ /dev/null @@ -1,1785 +0,0 @@ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: string_int_label_map.proto - -package object_detection.protos; - -public final class StringIntLabelMapOuterClass { - private StringIntLabelMapOuterClass() {} - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistryLite registry) { - } - - public static void registerAllExtensions( - com.google.protobuf.ExtensionRegistry registry) { - registerAllExtensions( - (com.google.protobuf.ExtensionRegistryLite) registry); - } - public interface StringIntLabelMapItemOrBuilder extends - // @@protoc_insertion_point(interface_extends:object_detection.protos.StringIntLabelMapItem) - com.google.protobuf.MessageOrBuilder { - - /** - *

-     * String name. The most common practice is to set this to a MID or synsets
-     * id.
-     * 
- * - * optional string name = 1; - */ - boolean hasName(); - /** - *
-     * String name. The most common practice is to set this to a MID or synsets
-     * id.
-     * 
- * - * optional string name = 1; - */ - java.lang.String getName(); - /** - *
-     * String name. The most common practice is to set this to a MID or synsets
-     * id.
-     * 
- * - * optional string name = 1; - */ - com.google.protobuf.ByteString - getNameBytes(); - - /** - *
-     * Integer id that maps to the string name above. Label ids should start from
-     * 1.
-     * 
- * - * optional int32 id = 2; - */ - boolean hasId(); - /** - *
-     * Integer id that maps to the string name above. Label ids should start from
-     * 1.
-     * 
- * - * optional int32 id = 2; - */ - int getId(); - - /** - *
-     * Human readable string label.
-     * 
- * - * optional string display_name = 3; - */ - boolean hasDisplayName(); - /** - *
-     * Human readable string label.
-     * 
- * - * optional string display_name = 3; - */ - java.lang.String getDisplayName(); - /** - *
-     * Human readable string label.
-     * 
- * - * optional string display_name = 3; - */ - com.google.protobuf.ByteString - getDisplayNameBytes(); - } - /** - * Protobuf type {@code object_detection.protos.StringIntLabelMapItem} - */ - public static final class StringIntLabelMapItem extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:object_detection.protos.StringIntLabelMapItem) - StringIntLabelMapItemOrBuilder { - private static final long serialVersionUID = 0L; - // Use StringIntLabelMapItem.newBuilder() to construct. - private StringIntLabelMapItem(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - private StringIntLabelMapItem() { - name_ = ""; - id_ = 0; - displayName_ = ""; - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - private StringIntLabelMapItem( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - case 10: { - com.google.protobuf.ByteString bs = input.readBytes(); - bitField0_ |= 0x00000001; - name_ = bs; - break; - } - case 16: { - bitField0_ |= 0x00000002; - id_ = input.readInt32(); - break; - } - case 26: { - com.google.protobuf.ByteString bs = input.readBytes(); - bitField0_ |= 0x00000004; - displayName_ = bs; - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMapItem_descriptor; - } - - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMapItem_fieldAccessorTable - .ensureFieldAccessorsInitialized( - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.class, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder.class); - } - - private int bitField0_; - public static final int NAME_FIELD_NUMBER = 1; - private volatile java.lang.Object name_; - /** - *
-     * String name. The most common practice is to set this to a MID or synsets
-     * id.
-     * 
- * - * optional string name = 1; - */ - public boolean hasName() { - return ((bitField0_ & 0x00000001) == 0x00000001); - } - /** - *
-     * String name. The most common practice is to set this to a MID or synsets
-     * id.
-     * 
- * - * optional string name = 1; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - if (bs.isValidUtf8()) { - name_ = s; - } - return s; - } - } - /** - *
-     * String name. The most common practice is to set this to a MID or synsets
-     * id.
-     * 
- * - * optional string name = 1; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - - public static final int ID_FIELD_NUMBER = 2; - private int id_; - /** - *
-     * Integer id that maps to the string name above. Label ids should start from
-     * 1.
-     * 
- * - * optional int32 id = 2; - */ - public boolean hasId() { - return ((bitField0_ & 0x00000002) == 0x00000002); - } - /** - *
-     * Integer id that maps to the string name above. Label ids should start from
-     * 1.
-     * 
- * - * optional int32 id = 2; - */ - public int getId() { - return id_; - } - - public static final int DISPLAY_NAME_FIELD_NUMBER = 3; - private volatile java.lang.Object displayName_; - /** - *
-     * Human readable string label.
-     * 
- * - * optional string display_name = 3; - */ - public boolean hasDisplayName() { - return ((bitField0_ & 0x00000004) == 0x00000004); - } - /** - *
-     * Human readable string label.
-     * 
- * - * optional string display_name = 3; - */ - public java.lang.String getDisplayName() { - java.lang.Object ref = displayName_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - if (bs.isValidUtf8()) { - displayName_ = s; - } - return s; - } - } - /** - *
-     * Human readable string label.
-     * 
- * - * optional string display_name = 3; - */ - public com.google.protobuf.ByteString - getDisplayNameBytes() { - java.lang.Object ref = displayName_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - displayName_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - - private byte memoizedIsInitialized = -1; - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - if (((bitField0_ & 0x00000001) == 0x00000001)) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 1, name_); - } - if (((bitField0_ & 0x00000002) == 0x00000002)) { - output.writeInt32(2, id_); - } - if (((bitField0_ & 0x00000004) == 0x00000004)) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 3, displayName_); - } - unknownFields.writeTo(output); - } - - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - if (((bitField0_ & 0x00000001) == 0x00000001)) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, name_); - } - if (((bitField0_ & 0x00000002) == 0x00000002)) { - size += com.google.protobuf.CodedOutputStream - .computeInt32Size(2, id_); - } - if (((bitField0_ & 0x00000004) == 0x00000004)) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, displayName_); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem)) { - return super.equals(obj); - } - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem other = (object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem) obj; - - boolean result = true; - result = result && (hasName() == other.hasName()); - if (hasName()) { - result = result && getName() - .equals(other.getName()); - } - result = result && (hasId() == other.hasId()); - if (hasId()) { - result = result && (getId() - == other.getId()); - } - result = result && (hasDisplayName() == other.hasDisplayName()); - if (hasDisplayName()) { - result = result && getDisplayName() - .equals(other.getDisplayName()); - } - result = result && unknownFields.equals(other.unknownFields); - return result; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (hasName()) { - hash = (37 * hash) + NAME_FIELD_NUMBER; - hash = (53 * hash) + getName().hashCode(); - } - if (hasId()) { - hash = (37 * hash) + ID_FIELD_NUMBER; - hash = (53 * hash) + getId(); - } - if (hasDisplayName()) { - hash = (37 * hash) + DISPLAY_NAME_FIELD_NUMBER; - hash = (53 * hash) + getDisplayName().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - public Builder newBuilderForType() { return newBuilder(); } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - public static Builder newBuilder(object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - /** - * Protobuf type {@code object_detection.protos.StringIntLabelMapItem} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:object_detection.protos.StringIntLabelMapItem) - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMapItem_descriptor; - } - - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMapItem_fieldAccessorTable - .ensureFieldAccessorsInitialized( - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.class, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder.class); - } - - // Construct using object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - } - } - public Builder clear() { - super.clear(); - name_ = ""; - bitField0_ = (bitField0_ & ~0x00000001); - id_ = 0; - bitField0_ = (bitField0_ & ~0x00000002); - displayName_ = ""; - bitField0_ = (bitField0_ & ~0x00000004); - return this; - } - - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMapItem_descriptor; - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem getDefaultInstanceForType() { - return object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.getDefaultInstance(); - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem build() { - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem buildPartial() { - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem result = new object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem(this); - int from_bitField0_ = bitField0_; - int to_bitField0_ = 0; - if (((from_bitField0_ & 0x00000001) == 0x00000001)) { - to_bitField0_ |= 0x00000001; - } - result.name_ = name_; - if (((from_bitField0_ & 0x00000002) == 0x00000002)) { - to_bitField0_ |= 0x00000002; - } - result.id_ = id_; - if (((from_bitField0_ & 0x00000004) == 0x00000004)) { - to_bitField0_ |= 0x00000004; - } - result.displayName_ = displayName_; - result.bitField0_ = to_bitField0_; - onBuilt(); - return result; - } - - public Builder clone() { - return (Builder) super.clone(); - } - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return (Builder) super.setField(field, value); - } - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return (Builder) super.clearField(field); - } - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return (Builder) super.clearOneof(oneof); - } - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return (Builder) super.setRepeatedField(field, index, value); - } - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return (Builder) super.addRepeatedField(field, value); - } - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem) { - return mergeFrom((object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem)other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem other) { - if (other == object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.getDefaultInstance()) return this; - if (other.hasName()) { - bitField0_ |= 0x00000001; - name_ = other.name_; - onChanged(); - } - if (other.hasId()) { - setId(other.getId()); - } - if (other.hasDisplayName()) { - bitField0_ |= 0x00000004; - displayName_ = other.displayName_; - onChanged(); - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - public final boolean isInitialized() { - return true; - } - - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - private int bitField0_; - - private java.lang.Object name_ = ""; - /** - *
-       * String name. The most common practice is to set this to a MID or synsets
-       * id.
-       * 
- * - * optional string name = 1; - */ - public boolean hasName() { - return ((bitField0_ & 0x00000001) == 0x00000001); - } - /** - *
-       * String name. The most common practice is to set this to a MID or synsets
-       * id.
-       * 
- * - * optional string name = 1; - */ - public java.lang.String getName() { - java.lang.Object ref = name_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - if (bs.isValidUtf8()) { - name_ = s; - } - return s; - } else { - return (java.lang.String) ref; - } - } - /** - *
-       * String name. The most common practice is to set this to a MID or synsets
-       * id.
-       * 
- * - * optional string name = 1; - */ - public com.google.protobuf.ByteString - getNameBytes() { - java.lang.Object ref = name_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - name_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - /** - *
-       * String name. The most common practice is to set this to a MID or synsets
-       * id.
-       * 
- * - * optional string name = 1; - */ - public Builder setName( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - bitField0_ |= 0x00000001; - name_ = value; - onChanged(); - return this; - } - /** - *
-       * String name. The most common practice is to set this to a MID or synsets
-       * id.
-       * 
- * - * optional string name = 1; - */ - public Builder clearName() { - bitField0_ = (bitField0_ & ~0x00000001); - name_ = getDefaultInstance().getName(); - onChanged(); - return this; - } - /** - *
-       * String name. The most common practice is to set this to a MID or synsets
-       * id.
-       * 
- * - * optional string name = 1; - */ - public Builder setNameBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - bitField0_ |= 0x00000001; - name_ = value; - onChanged(); - return this; - } - - private int id_ ; - /** - *
-       * Integer id that maps to the string name above. Label ids should start from
-       * 1.
-       * 
- * - * optional int32 id = 2; - */ - public boolean hasId() { - return ((bitField0_ & 0x00000002) == 0x00000002); - } - /** - *
-       * Integer id that maps to the string name above. Label ids should start from
-       * 1.
-       * 
- * - * optional int32 id = 2; - */ - public int getId() { - return id_; - } - /** - *
-       * Integer id that maps to the string name above. Label ids should start from
-       * 1.
-       * 
- * - * optional int32 id = 2; - */ - public Builder setId(int value) { - bitField0_ |= 0x00000002; - id_ = value; - onChanged(); - return this; - } - /** - *
-       * Integer id that maps to the string name above. Label ids should start from
-       * 1.
-       * 
- * - * optional int32 id = 2; - */ - public Builder clearId() { - bitField0_ = (bitField0_ & ~0x00000002); - id_ = 0; - onChanged(); - return this; - } - - private java.lang.Object displayName_ = ""; - /** - *
-       * Human readable string label.
-       * 
- * - * optional string display_name = 3; - */ - public boolean hasDisplayName() { - return ((bitField0_ & 0x00000004) == 0x00000004); - } - /** - *
-       * Human readable string label.
-       * 
- * - * optional string display_name = 3; - */ - public java.lang.String getDisplayName() { - java.lang.Object ref = displayName_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - if (bs.isValidUtf8()) { - displayName_ = s; - } - return s; - } else { - return (java.lang.String) ref; - } - } - /** - *
-       * Human readable string label.
-       * 
- * - * optional string display_name = 3; - */ - public com.google.protobuf.ByteString - getDisplayNameBytes() { - java.lang.Object ref = displayName_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - displayName_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } - } - /** - *
-       * Human readable string label.
-       * 
- * - * optional string display_name = 3; - */ - public Builder setDisplayName( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - bitField0_ |= 0x00000004; - displayName_ = value; - onChanged(); - return this; - } - /** - *
-       * Human readable string label.
-       * 
- * - * optional string display_name = 3; - */ - public Builder clearDisplayName() { - bitField0_ = (bitField0_ & ~0x00000004); - displayName_ = getDefaultInstance().getDisplayName(); - onChanged(); - return this; - } - /** - *
-       * Human readable string label.
-       * 
- * - * optional string display_name = 3; - */ - public Builder setDisplayNameBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - bitField0_ |= 0x00000004; - displayName_ = value; - onChanged(); - return this; - } - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:object_detection.protos.StringIntLabelMapItem) - } - - // @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) - private static final object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem(); - } - - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - @java.lang.Deprecated public static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - public StringIntLabelMapItem parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new StringIntLabelMapItem(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - - } - - public interface StringIntLabelMapOrBuilder extends - // @@protoc_insertion_point(interface_extends:object_detection.protos.StringIntLabelMap) - com.google.protobuf.MessageOrBuilder { - - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - java.util.List - getItemList(); - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem getItem(int index); - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - int getItemCount(); - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - java.util.List - getItemOrBuilderList(); - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder getItemOrBuilder( - int index); - } - /** - * Protobuf type {@code object_detection.protos.StringIntLabelMap} - */ - public static final class StringIntLabelMap extends - com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:object_detection.protos.StringIntLabelMap) - StringIntLabelMapOrBuilder { - private static final long serialVersionUID = 0L; - // Use StringIntLabelMap.newBuilder() to construct. - private StringIntLabelMap(com.google.protobuf.GeneratedMessageV3.Builder builder) { - super(builder); - } - private StringIntLabelMap() { - item_ = java.util.Collections.emptyList(); - } - - @java.lang.Override - public final com.google.protobuf.UnknownFieldSet - getUnknownFields() { - return this.unknownFields; - } - private StringIntLabelMap( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - this(); - if (extensionRegistry == null) { - throw new java.lang.NullPointerException(); - } - int mutable_bitField0_ = 0; - com.google.protobuf.UnknownFieldSet.Builder unknownFields = - com.google.protobuf.UnknownFieldSet.newBuilder(); - try { - boolean done = false; - while (!done) { - int tag = input.readTag(); - switch (tag) { - case 0: - done = true; - break; - default: { - if (!parseUnknownField( - input, unknownFields, extensionRegistry, tag)) { - done = true; - } - break; - } - case 10: { - if (!((mutable_bitField0_ & 0x00000001) == 0x00000001)) { - item_ = new java.util.ArrayList(); - mutable_bitField0_ |= 0x00000001; - } - item_.add( - input.readMessage(object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.PARSER, extensionRegistry)); - break; - } - } - } - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw e.setUnfinishedMessage(this); - } catch (java.io.IOException e) { - throw new com.google.protobuf.InvalidProtocolBufferException( - e).setUnfinishedMessage(this); - } finally { - if (((mutable_bitField0_ & 0x00000001) == 0x00000001)) { - item_ = java.util.Collections.unmodifiableList(item_); - } - this.unknownFields = unknownFields.build(); - makeExtensionsImmutable(); - } - } - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMap_descriptor; - } - - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMap_fieldAccessorTable - .ensureFieldAccessorsInitialized( - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.class, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.Builder.class); - } - - public static final int ITEM_FIELD_NUMBER = 1; - private java.util.List item_; - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public java.util.List getItemList() { - return item_; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public java.util.List - getItemOrBuilderList() { - return item_; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public int getItemCount() { - return item_.size(); - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem getItem(int index) { - return item_.get(index); - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder getItemOrBuilder( - int index) { - return item_.get(index); - } - - private byte memoizedIsInitialized = -1; - public final boolean isInitialized() { - byte isInitialized = memoizedIsInitialized; - if (isInitialized == 1) return true; - if (isInitialized == 0) return false; - - memoizedIsInitialized = 1; - return true; - } - - public void writeTo(com.google.protobuf.CodedOutputStream output) - throws java.io.IOException { - for (int i = 0; i < item_.size(); i++) { - output.writeMessage(1, item_.get(i)); - } - unknownFields.writeTo(output); - } - - public int getSerializedSize() { - int size = memoizedSize; - if (size != -1) return size; - - size = 0; - for (int i = 0; i < item_.size(); i++) { - size += com.google.protobuf.CodedOutputStream - .computeMessageSize(1, item_.get(i)); - } - size += unknownFields.getSerializedSize(); - memoizedSize = size; - return size; - } - - @java.lang.Override - public boolean equals(final java.lang.Object obj) { - if (obj == this) { - return true; - } - if (!(obj instanceof object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap)) { - return super.equals(obj); - } - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap other = (object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap) obj; - - boolean result = true; - result = result && getItemList() - .equals(other.getItemList()); - result = result && unknownFields.equals(other.unknownFields); - return result; - } - - @java.lang.Override - public int hashCode() { - if (memoizedHashCode != 0) { - return memoizedHashCode; - } - int hash = 41; - hash = (19 * hash) + getDescriptor().hashCode(); - if (getItemCount() > 0) { - hash = (37 * hash) + ITEM_FIELD_NUMBER; - hash = (53 * hash) + getItemList().hashCode(); - } - hash = (29 * hash) + unknownFields.hashCode(); - memoizedHashCode = hash; - return hash; - } - - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - java.nio.ByteBuffer data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - java.nio.ByteBuffer data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - com.google.protobuf.ByteString data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - com.google.protobuf.ByteString data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom(byte[] data) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - byte[] data, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return PARSER.parseFrom(data, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseDelimitedFrom(java.io.InputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseDelimitedFrom( - java.io.InputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseDelimitedWithIOException(PARSER, input, extensionRegistry); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - com.google.protobuf.CodedInputStream input) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input); - } - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parseFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - return com.google.protobuf.GeneratedMessageV3 - .parseWithIOException(PARSER, input, extensionRegistry); - } - - public Builder newBuilderForType() { return newBuilder(); } - public static Builder newBuilder() { - return DEFAULT_INSTANCE.toBuilder(); - } - public static Builder newBuilder(object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap prototype) { - return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); - } - public Builder toBuilder() { - return this == DEFAULT_INSTANCE - ? new Builder() : new Builder().mergeFrom(this); - } - - @java.lang.Override - protected Builder newBuilderForType( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - Builder builder = new Builder(parent); - return builder; - } - /** - * Protobuf type {@code object_detection.protos.StringIntLabelMap} - */ - public static final class Builder extends - com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:object_detection.protos.StringIntLabelMap) - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapOrBuilder { - public static final com.google.protobuf.Descriptors.Descriptor - getDescriptor() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMap_descriptor; - } - - protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internalGetFieldAccessorTable() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMap_fieldAccessorTable - .ensureFieldAccessorsInitialized( - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.class, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.Builder.class); - } - - // Construct using object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.newBuilder() - private Builder() { - maybeForceBuilderInitialization(); - } - - private Builder( - com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { - super(parent); - maybeForceBuilderInitialization(); - } - private void maybeForceBuilderInitialization() { - if (com.google.protobuf.GeneratedMessageV3 - .alwaysUseFieldBuilders) { - getItemFieldBuilder(); - } - } - public Builder clear() { - super.clear(); - if (itemBuilder_ == null) { - item_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - } else { - itemBuilder_.clear(); - } - return this; - } - - public com.google.protobuf.Descriptors.Descriptor - getDescriptorForType() { - return object_detection.protos.StringIntLabelMapOuterClass.internal_static_object_detection_protos_StringIntLabelMap_descriptor; - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap getDefaultInstanceForType() { - return object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.getDefaultInstance(); - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap build() { - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap result = buildPartial(); - if (!result.isInitialized()) { - throw newUninitializedMessageException(result); - } - return result; - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap buildPartial() { - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap result = new object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap(this); - int from_bitField0_ = bitField0_; - if (itemBuilder_ == null) { - if (((bitField0_ & 0x00000001) == 0x00000001)) { - item_ = java.util.Collections.unmodifiableList(item_); - bitField0_ = (bitField0_ & ~0x00000001); - } - result.item_ = item_; - } else { - result.item_ = itemBuilder_.build(); - } - onBuilt(); - return result; - } - - public Builder clone() { - return (Builder) super.clone(); - } - public Builder setField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return (Builder) super.setField(field, value); - } - public Builder clearField( - com.google.protobuf.Descriptors.FieldDescriptor field) { - return (Builder) super.clearField(field); - } - public Builder clearOneof( - com.google.protobuf.Descriptors.OneofDescriptor oneof) { - return (Builder) super.clearOneof(oneof); - } - public Builder setRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - int index, java.lang.Object value) { - return (Builder) super.setRepeatedField(field, index, value); - } - public Builder addRepeatedField( - com.google.protobuf.Descriptors.FieldDescriptor field, - java.lang.Object value) { - return (Builder) super.addRepeatedField(field, value); - } - public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap) { - return mergeFrom((object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap)other); - } else { - super.mergeFrom(other); - return this; - } - } - - public Builder mergeFrom(object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap other) { - if (other == object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap.getDefaultInstance()) return this; - if (itemBuilder_ == null) { - if (!other.item_.isEmpty()) { - if (item_.isEmpty()) { - item_ = other.item_; - bitField0_ = (bitField0_ & ~0x00000001); - } else { - ensureItemIsMutable(); - item_.addAll(other.item_); - } - onChanged(); - } - } else { - if (!other.item_.isEmpty()) { - if (itemBuilder_.isEmpty()) { - itemBuilder_.dispose(); - itemBuilder_ = null; - item_ = other.item_; - bitField0_ = (bitField0_ & ~0x00000001); - itemBuilder_ = - com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? - getItemFieldBuilder() : null; - } else { - itemBuilder_.addAllMessages(other.item_); - } - } - } - this.mergeUnknownFields(other.unknownFields); - onChanged(); - return this; - } - - public final boolean isInitialized() { - return true; - } - - public Builder mergeFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws java.io.IOException { - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap parsedMessage = null; - try { - parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); - } catch (com.google.protobuf.InvalidProtocolBufferException e) { - parsedMessage = (object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap) e.getUnfinishedMessage(); - throw e.unwrapIOException(); - } finally { - if (parsedMessage != null) { - mergeFrom(parsedMessage); - } - } - return this; - } - private int bitField0_; - - private java.util.List item_ = - java.util.Collections.emptyList(); - private void ensureItemIsMutable() { - if (!((bitField0_ & 0x00000001) == 0x00000001)) { - item_ = new java.util.ArrayList(item_); - bitField0_ |= 0x00000001; - } - } - - private com.google.protobuf.RepeatedFieldBuilderV3< - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder> itemBuilder_; - - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public java.util.List getItemList() { - if (itemBuilder_ == null) { - return java.util.Collections.unmodifiableList(item_); - } else { - return itemBuilder_.getMessageList(); - } - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public int getItemCount() { - if (itemBuilder_ == null) { - return item_.size(); - } else { - return itemBuilder_.getCount(); - } - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem getItem(int index) { - if (itemBuilder_ == null) { - return item_.get(index); - } else { - return itemBuilder_.getMessage(index); - } - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder setItem( - int index, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem value) { - if (itemBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureItemIsMutable(); - item_.set(index, value); - onChanged(); - } else { - itemBuilder_.setMessage(index, value); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder setItem( - int index, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder builderForValue) { - if (itemBuilder_ == null) { - ensureItemIsMutable(); - item_.set(index, builderForValue.build()); - onChanged(); - } else { - itemBuilder_.setMessage(index, builderForValue.build()); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder addItem(object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem value) { - if (itemBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureItemIsMutable(); - item_.add(value); - onChanged(); - } else { - itemBuilder_.addMessage(value); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder addItem( - int index, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem value) { - if (itemBuilder_ == null) { - if (value == null) { - throw new NullPointerException(); - } - ensureItemIsMutable(); - item_.add(index, value); - onChanged(); - } else { - itemBuilder_.addMessage(index, value); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder addItem( - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder builderForValue) { - if (itemBuilder_ == null) { - ensureItemIsMutable(); - item_.add(builderForValue.build()); - onChanged(); - } else { - itemBuilder_.addMessage(builderForValue.build()); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder addItem( - int index, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder builderForValue) { - if (itemBuilder_ == null) { - ensureItemIsMutable(); - item_.add(index, builderForValue.build()); - onChanged(); - } else { - itemBuilder_.addMessage(index, builderForValue.build()); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder addAllItem( - java.lang.Iterable values) { - if (itemBuilder_ == null) { - ensureItemIsMutable(); - com.google.protobuf.AbstractMessageLite.Builder.addAll( - values, item_); - onChanged(); - } else { - itemBuilder_.addAllMessages(values); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder clearItem() { - if (itemBuilder_ == null) { - item_ = java.util.Collections.emptyList(); - bitField0_ = (bitField0_ & ~0x00000001); - onChanged(); - } else { - itemBuilder_.clear(); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public Builder removeItem(int index) { - if (itemBuilder_ == null) { - ensureItemIsMutable(); - item_.remove(index); - onChanged(); - } else { - itemBuilder_.remove(index); - } - return this; - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder getItemBuilder( - int index) { - return getItemFieldBuilder().getBuilder(index); - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder getItemOrBuilder( - int index) { - if (itemBuilder_ == null) { - return item_.get(index); } else { - return itemBuilder_.getMessageOrBuilder(index); - } - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public java.util.List - getItemOrBuilderList() { - if (itemBuilder_ != null) { - return itemBuilder_.getMessageOrBuilderList(); - } else { - return java.util.Collections.unmodifiableList(item_); - } - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder addItemBuilder() { - return getItemFieldBuilder().addBuilder( - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.getDefaultInstance()); - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder addItemBuilder( - int index) { - return getItemFieldBuilder().addBuilder( - index, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.getDefaultInstance()); - } - /** - * repeated .object_detection.protos.StringIntLabelMapItem item = 1; - */ - public java.util.List - getItemBuilderList() { - return getItemFieldBuilder().getBuilderList(); - } - private com.google.protobuf.RepeatedFieldBuilderV3< - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder> - getItemFieldBuilder() { - if (itemBuilder_ == null) { - itemBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< - object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItem.Builder, object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMapItemOrBuilder>( - item_, - ((bitField0_ & 0x00000001) == 0x00000001), - getParentForChildren(), - isClean()); - item_ = null; - } - return itemBuilder_; - } - public final Builder setUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.setUnknownFields(unknownFields); - } - - public final Builder mergeUnknownFields( - final com.google.protobuf.UnknownFieldSet unknownFields) { - return super.mergeUnknownFields(unknownFields); - } - - - // @@protoc_insertion_point(builder_scope:object_detection.protos.StringIntLabelMap) - } - - // @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) - private static final object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap DEFAULT_INSTANCE; - static { - DEFAULT_INSTANCE = new object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap(); - } - - public static object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap getDefaultInstance() { - return DEFAULT_INSTANCE; - } - - @java.lang.Deprecated public static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { - public StringIntLabelMap parsePartialFrom( - com.google.protobuf.CodedInputStream input, - com.google.protobuf.ExtensionRegistryLite extensionRegistry) - throws com.google.protobuf.InvalidProtocolBufferException { - return new StringIntLabelMap(input, extensionRegistry); - } - }; - - public static com.google.protobuf.Parser parser() { - return PARSER; - } - - @java.lang.Override - public com.google.protobuf.Parser getParserForType() { - return PARSER; - } - - public object_detection.protos.StringIntLabelMapOuterClass.StringIntLabelMap getDefaultInstanceForType() { - return DEFAULT_INSTANCE; - } - - } - - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_object_detection_protos_StringIntLabelMapItem_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_object_detection_protos_StringIntLabelMapItem_fieldAccessorTable; - private static final com.google.protobuf.Descriptors.Descriptor - internal_static_object_detection_protos_StringIntLabelMap_descriptor; - private static final - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable - internal_static_object_detection_protos_StringIntLabelMap_fieldAccessorTable; - - public static com.google.protobuf.Descriptors.FileDescriptor - getDescriptor() { - return descriptor; - } - private static com.google.protobuf.Descriptors.FileDescriptor - descriptor; - static { - java.lang.String[] descriptorData = { - "\n\032string_int_label_map.proto\022\027object_det" + - "ection.protos\"G\n\025StringIntLabelMapItem\022\014" + - "\n\004name\030\001 \001(\t\022\n\n\002id\030\002 \001(\005\022\024\n\014display_name" + - "\030\003 \001(\t\"Q\n\021StringIntLabelMap\022<\n\004item\030\001 \003(" + - "\0132..object_detection.protos.StringIntLab" + - "elMapItem" - }; - com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = - new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { - public com.google.protobuf.ExtensionRegistry assignDescriptors( - com.google.protobuf.Descriptors.FileDescriptor root) { - descriptor = root; - return null; - } - }; - com.google.protobuf.Descriptors.FileDescriptor - .internalBuildGeneratedFileFrom(descriptorData, - new com.google.protobuf.Descriptors.FileDescriptor[] { - }, assigner); - internal_static_object_detection_protos_StringIntLabelMapItem_descriptor = - getDescriptor().getMessageTypes().get(0); - internal_static_object_detection_protos_StringIntLabelMapItem_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_object_detection_protos_StringIntLabelMapItem_descriptor, - new java.lang.String[] { "Name", "Id", "DisplayName", }); - internal_static_object_detection_protos_StringIntLabelMap_descriptor = - getDescriptor().getMessageTypes().get(1); - internal_static_object_detection_protos_StringIntLabelMap_fieldAccessorTable = new - com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( - internal_static_object_detection_protos_StringIntLabelMap_descriptor, - new java.lang.String[] { "Item", }); - } - - // @@protoc_insertion_point(outer_class_scope) -} diff --git a/tensorflow-examples-legacy/training/.gitignore b/tensorflow-examples-legacy/training/.gitignore deleted file mode 100644 index e8448ec..0000000 --- a/tensorflow-examples-legacy/training/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -target -checkpoint diff --git a/tensorflow-examples-legacy/training/README.md b/tensorflow-examples-legacy/training/README.md deleted file mode 100644 index 29d77af..0000000 --- a/tensorflow-examples-legacy/training/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# Training models in Java - -Example of training a model (and saving and restoring checkpoints) using the -TensorFlow Java API. - -## Quickstart - -1. Train for a few steps: - ``` - mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint" - ``` - -2. Resume training from previous checkpoint and train some more: - ``` - mvn -q exec:java -Dexec.args="model/graph.pb checkpoint" - ``` - -3. Delete checkpoint: - ``` - rm -rf checkpoint - ``` - - -## Details - -The model in `model/graph.pb` represents a very simple linear model: - -``` -y = x * W + b -``` - -The `graph.pb` file is generated by executing `create_graph.py` in Python. - -The training is orchestrated by `src/main/java/Train.java`, which generates -training data of the form `y = 3.0 * x + 2.0` and over time, using gradient -descent, the model should "learn" and the value of `W` should converge to 3.0, -and `b` to 2.0. diff --git a/tensorflow-examples-legacy/training/model/create_graph.py b/tensorflow-examples-legacy/training/model/create_graph.py deleted file mode 100644 index 7e043a9..0000000 --- a/tensorflow-examples-legacy/training/model/create_graph.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import print_function - -import tensorflow as tf - -x = tf.placeholder(tf.float32, name='input') -y_ = tf.placeholder(tf.float32, name='target') - -W = tf.Variable(5., name='W') -b = tf.Variable(3., name='b') - -y = x * W + b -y = tf.identity(y, name='output') - -loss = tf.reduce_mean(tf.square(y - y_)) -optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) -train_op = optimizer.minimize(loss, name='train') - -init = tf.global_variables_initializer() - -# Creating a tf.train.Saver adds operations to the graph to save and -# restore variables from checkpoints. -saver_def = tf.train.Saver().as_saver_def() - -print('Operation to initialize variables: ', init.name) -print('Tensor to feed as input data: ', x.name) -print('Tensor to feed as training targets: ', y_.name) -print('Tensor to fetch as prediction: ', y.name) -print('Operation to train one step: ', train_op.name) -print('Tensor to be fed for checkpoint filename:', saver_def.filename_tensor_name) -print('Operation to save a checkpoint: ', saver_def.save_tensor_name) -print('Operation to restore a checkpoint: ', saver_def.restore_op_name) -print('Tensor to read value of W ', W.value().name) -print('Tensor to read value of b ', b.value().name) - -with open('graph.pb', 'w') as f: - f.write(tf.get_default_graph().as_graph_def().SerializeToString()) diff --git a/tensorflow-examples-legacy/training/model/graph.pb b/tensorflow-examples-legacy/training/model/graph.pb deleted file mode 100644 index 51d946d..0000000 Binary files a/tensorflow-examples-legacy/training/model/graph.pb and /dev/null differ diff --git a/tensorflow-examples-legacy/training/pom.xml b/tensorflow-examples-legacy/training/pom.xml deleted file mode 100644 index 39dda07..0000000 --- a/tensorflow-examples-legacy/training/pom.xml +++ /dev/null @@ -1,20 +0,0 @@ - - 4.0.0 - org.myorg - training - 1.0-SNAPSHOT - - Train - - - 1.7 - 1.7 - - - - org.tensorflow - tensorflow - 1.4.0 - - - diff --git a/tensorflow-examples-legacy/training/src/main/java/Train.java b/tensorflow-examples-legacy/training/src/main/java/Train.java deleted file mode 100644 index 57176a4..0000000 --- a/tensorflow-examples-legacy/training/src/main/java/Train.java +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.List; -import java.util.Random; -import org.tensorflow.Graph; -import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.Tensors; - -/** - * Training a trivial linear model. - */ -public class Train { - public static void main(String[] args) throws Exception { - if (args.length != 2) { - System.err.println("Require two arguments: The GraphDef file and checkpoint directory"); - System.exit(1); - } - - final byte[] graphDef = Files.readAllBytes(Paths.get(args[0])); - final String checkpointDir = args[1]; - final boolean checkpointExists = Files.exists(Paths.get(checkpointDir)); - - try (Graph graph = new Graph(); - Session sess = new Session(graph); - Tensor checkpointPrefix = - Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) { - graph.importGraphDef(graphDef); - - // Initialize or restore. - // The names of the tensors in the graph are printed out by the program - // that created the graph: - // https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py - if (checkpointExists) { - sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run(); - } else { - sess.runner().addTarget("init").run(); - } - System.out.print("Starting from : "); - printVariables(sess); - - // Train a bunch of times. - // (Will be much more efficient if we sent batches instead of individual values). - final Random r = new Random(); - final int NUM_EXAMPLES = 500; - for (int i = 1; i <= 5; i++) { - for (int n = 0; n < NUM_EXAMPLES; n++) { - float in = r.nextFloat(); - try (Tensor input = Tensors.create(in); - Tensor target = Tensors.create(3 * in + 2)) { - // Again the tensor names are from the program that created the graph. - // https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py - sess.runner().feed("input", input).feed("target", target).addTarget("train").run(); - } - } - System.out.printf("After %5d examples: ", i*NUM_EXAMPLES); - printVariables(sess); - } - - // Checkpoint. - // The feed and target name are from the program that created the graph. - // https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py. - sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run(); - - // Example of "inference" in the same graph: - try (Tensor input = Tensors.create(1.0f); - Tensor output = - sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) { - System.out.printf( - "For input %f, produced %f (ideally would produce 3*%f + 2)\n", - input.floatValue(), output.floatValue(), input.floatValue()); - } - } - } - - private static void printVariables(Session sess) { - List> values = sess.runner().fetch("W/read").fetch("b/read").run(); - System.out.printf("W = %f\tb = %f\n", values.get(0).floatValue(), values.get(1).floatValue()); - for (Tensor t : values) { - t.close(); - } - } -} diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java deleted file mode 100644 index ab168d0..0000000 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/mnist/data/ImageBatch.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.tensorflow.model.examples.mnist.data; - -import org.tensorflow.tools.ndarray.ByteNdArray; - -public class ImageBatch { - - public ByteNdArray images() { - return images; - } - - public ByteNdArray labels() { - return labels; - } - - ImageBatch(ByteNdArray images, ByteNdArray labels) { - this.images = images; - this.labels = labels; - } - - private final ByteNdArray images; - private final ByteNdArray labels; -}