zoo.orca.learn.tf2 package¶
Submodules¶
zoo.orca.learn.tf2.tf_ray_estimator module¶
-
class
zoo.orca.learn.tf2.tf_ray_estimator.Estimator(model_creator, compile_args_creator=None, config=None, verbose=False, backend='tf', workers_per_node=1)[source]¶ Bases:
object-
evaluate(data_creator, verbose=1, sample_weight=None, steps=None, callbacks=None)[source]¶ Evaluates the model on the validation data set.
-
fit(data_creator, epochs=1, verbose=1, callbacks=None, validation_data_creator=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1)[source]¶ Runs a training epoch.
-
classmethod
from_keras(model_creator, config=None, verbose=False, workers_per_node=1, compile_args_creator=None, backend='tf')[source]¶
-
restore(checkpoint)[source]¶ Restores the model from the provided checkpoint.
Args: checkpoint (str): Path to target checkpoint file.
-
zoo.orca.learn.tf2.tf_runner module¶
-
class
zoo.orca.learn.tf2.tf_runner.TFRunner(model_creator, compile_args_creator, config=None, verbose=False)[source]¶ Bases:
objectManages a TensorFlow model for training.
-
setup_distributed(urls, world_rank, world_size)[source]¶ Sets up TensorFLow distributed environment and initializes the model. Args: urls (str): the URLs that each node uses to connect. world_rank (int): the index of the runner. world_size (int): the total number of runners.
-