Source code for zoo.examples.tensorflow.tfnet.predict

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from optparse import OptionParser

from zoo.tfpark import TFNet
from zoo.common.nncontext import init_nncontext
from zoo.feature.common import *
from zoo.models.image.objectdetection import *


[docs]def predict(model_path, img_path, partition_num=4): inputs = "image_tensor:0" outputs = ["num_detections:0", "detection_boxes:0", "detection_scores:0", "detection_classes:0"] model = TFNet(model_path, inputs, outputs) image_set = ImageSet.read(img_path, sc, partition_num) transformer = ChainedPreprocessing([ImageResize(256, 256), ImageMatToTensor(format="NHWC"), ImageSetToSample()]) transformed_image_set = image_set.transform(transformer) output = model.predict_image(transformed_image_set.to_image_frame(), batch_per_partition=1) # Print the detection result of the first image. result = ImageSet.from_image_frame(output).get_predict().first() print(result)
if __name__ == "__main__": parser = OptionParser() parser.add_option("--image", type=str, dest="img_path", help="The path where the images are stored, " "can be either a folder or an image path") parser.add_option("--model", type=str, dest="model_path", help="The path of the TensorFlow object detection model") parser.add_option("--partition_num", type=int, dest="partition_num", default=4, help="The number of partitions") (options, args) = parser.parse_args(sys.argv) sc = init_nncontext("TFNet Object Detection Example") predict(options.model_path, options.img_path, options.partition_num) print("finished...") sc.stop()