Source code for zoo.orca.data.tf.data

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tensorflow as tf

from zoo.tfpark.tf_dataset import TensorMeta
from zoo.util import nest
from zoo import getOrCreateSparkContext, get_node_and_core_number
from zoo.common import callZooFunc
from zoo.feature.common import FeatureSet
from zoo.orca.data import SparkXShards
from zoo.tfpark import TFDataset


[docs]class TFDataDataset2(TFDataset): def __init__(self, dataset, batch_size, batch_per_thread, validation_dataset=None, intra_threads=None, inter_threads=None): node_num, core_num = get_node_and_core_number() self.intra_threads = intra_threads self.inter_threads = inter_threads if intra_threads is None: self.intra_threads = core_num if inter_threads is None: self.inter_threads = 1 if batch_size > 0: num_parts = dataset.xshards.num_partitions() if num_parts != node_num: dataset.xshards = dataset.xshards.repartition(node_num) assert batch_size % node_num == 0, \ "batch_size should be a multiple of num_shards, got" \ " batch_size {}, node_num {}".format(batch_size, node_num) batch_per_shard = batch_size // node_num self.drop_remainder = True elif batch_per_thread > 0: batch_per_shard = batch_per_thread self.drop_remainder = False else: raise ValueError("one of batch_size or batch_per_thread must be larger than 0") self.rdd = dataset.as_graph_rdd(batch_per_shard, drop_remainder=self.drop_remainder).cache() meta_info = self.rdd.map(lambda x: x[1]).first() tensor_structure = meta_info["tensor_structure"] self.init_op_name = meta_info["init_op_name"] self.output_names = meta_info["output_names"] self.output_types = meta_info["output_types"] self.table_init_op = meta_info["table_init_op"] if validation_dataset is not None: self.val_rdd = validation_dataset.as_graph_rdd(batch_per_shard, False).cache() meta_info = self.val_rdd.map(lambda x: x[1]).first() self.val_init_op_name = meta_info["init_op_name"] self.val_output_names = meta_info["output_names"] self.val_output_types = meta_info["output_types"] else: self.val_rdd = None self.val_init_op_name = None self.val_output_names = None self.val_output_types = None super().__init__(tensor_structure, batch_size=batch_size, batch_per_thread=batch_per_thread, hard_code_batch_size=False) self.shard_index_op_name = None self.validation_dataset = validation_dataset def _get_prediction_data(self): assert not self.drop_remainder, \ "sanity check: drop_remainder should be false in this case," \ " otherwise please report a bug" jvalue = callZooFunc("float", "createMiniBatchRDDFromTFDataset", self.rdd.map(lambda x: x[0]), self.init_op_name, self.table_init_op, self.output_names, self.output_types, self.shard_index_op_name) rdd = jvalue.value().toJavaRDD() return rdd def _get_evaluation_data(self): jvalue = callZooFunc("float", "createMiniBatchRDDFromTFDatasetEval", self.rdd.map(lambda x: x[0]), self.init_op_name, self.table_init_op, self.output_names, self.output_types, self.shard_index_op_name) rdd = jvalue.value().toJavaRDD() return rdd def _get_training_data(self): jvalue = callZooFunc("float", "createTFDataFeatureSet", self.rdd.map(lambda x: x[0]), self.init_op_name, self.table_init_op, self.output_names, self.output_types, self.shard_index_op_name, self.inter_threads, self.intra_threads) return FeatureSet(jvalue=jvalue) def _get_validation_data(self): if self.validation_dataset is not None: jvalue = callZooFunc("float", "createTFDataFeatureSet", self.val_rdd.map(lambda x: x[0]), self.init_op_name, self.table_init_op, self.output_names, self.output_types, self.shard_index_op_name, self.inter_threads, self.intra_threads) return FeatureSet(jvalue=jvalue) return None
[docs] def get_num_partitions(self): return self.rdd.getNumPartitions()
[docs]class Dataset(object): """ Represents a distributed set of elements backed by an RDD, which is created by applying tensorflow dataset transformations on each partitions. """ def __init__(self, xshards, create_dataset_fn): self.xshards = xshards self.create_dataset_fn = create_dataset_fn
[docs] def as_graph_rdd(self, batch_per_shard, drop_remainder=True): create_dataset_fn = self.create_dataset_fn def to_dataset(iter): data_list = list(iter) import tensorflow as tf if not data_list: return [] datasets = [create_dataset_fn(data) for data in data_list] from functools import reduce dataset = reduce(lambda x, y: x.concatenate(y), datasets) dataset = dataset.batch(batch_per_shard, drop_remainder) iterator = dataset.make_initializable_iterator() train_next_ops = nest.flatten(iterator.get_next()) output_types = [t.as_datatype_enum for t in nest.flatten(dataset.output_types)] init_op_name = iterator.initializer.name table_init_op = tf.tables_initializer().name output_names = [op.name for op in train_next_ops] graph = train_next_ops[0].graph flatten_shapes = nest.flatten(dataset.output_shapes) flatten_shapes = [shape[1:] for shape in flatten_shapes] flatten_tensor_structure = [TensorMeta(dtype=output_types[i], shape=list(flatten_shapes[i]), name="zoo_input_{}".format(i)) for i in range(len(flatten_shapes))] structure = dataset.output_types if isinstance(structure, tf.DType): structure = (structure,) tensor_structure = nest.pack_sequence_as(structure, flatten_tensor_structure) meta_info = { "init_op_name": init_op_name, "table_init_op": table_init_op, "output_names": output_names, "output_types": output_types, "tensor_structure": tensor_structure } return [(bytearray(graph.as_graph_def().SerializeToString()), meta_info)] graph_rdd_and_meta = self.xshards.rdd.mapPartitions(to_dataset) return graph_rdd_and_meta
[docs] @staticmethod def from_tensor_slices(xshards): return TensorSliceDataset(xshards)
[docs] def map(self, map_func): return MapDataset(self, map_func)
[docs]class TensorSliceDataset(Dataset): def __init__(self, xshards): assert isinstance(xshards, SparkXShards), \ "only datasets backed by a SparkXShards are supported" self.xshards = xshards def create_dataset_fn(data): return tf.data.Dataset.from_tensor_slices(data) super().__init__(xshards, create_dataset_fn)
[docs]class MapDataset(Dataset): def __init__(self, input_dataset, map_func): create_pre_dataset_fn = input_dataset.create_dataset_fn def create_dataset_fn(data): dataset = create_pre_dataset_fn(data) return dataset.map(map_func) super().__init__(xshards=input_dataset.xshards, create_dataset_fn=create_dataset_fn)