Source code for zoo.examples.imageclassification.predict

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from optparse import OptionParser
from zoo.common.nncontext import init_nncontext
from zoo.models.image.imageclassification import *


[docs]def predict(model_path, img_path, topN, partition_num): print("ImageClassification prediction") print("Model Path %s" % model_path) print("Image Path %s" % img_path) print("Top N : %d" % topN) imc = ImageClassifier.load_model(model_path) image_set = ImageSet.read(img_path, sc, partition_num) output = imc.predict_image_set(image_set) labelMap = imc.get_config().label_map() predicts = output.get_predict().collect() for predict in predicts: (uri, probs) = predict sortedProbs = [(prob, index) for index, prob in enumerate(probs[0])] sortedProbs.sort() print("Image : %s, top %d prediction result" % (uri, topN)) for i in range(topN): print("\t%s, %f" % (labelMap[sortedProbs[999 - i][1]], sortedProbs[999 - i][0]))
if __name__ == "__main__": parser = OptionParser() parser.add_option("-f", "--folder", type=str, dest="img_path", default=".", help="Path where the images are stored") parser.add_option("--model", type=str, dest="model_path", default="", help="Path where the model is stored") parser.add_option("--topN", type=int, dest="topN", default=1, help="top N number") parser.add_option("--partition_num", type=int, dest="partition_num", default=4, help="The number of partitions") (options, args) = parser.parse_args(sys.argv) sc = init_nncontext("Image Classification Example") predict(options.model_path, options.img_path, options.topN, options.partition_num) print("finished...") sc.stop()