zoo.orca.learn.tf2 package

Submodules

zoo.orca.learn.tf2.tf_ray_estimator module

class zoo.orca.learn.tf2.tf_ray_estimator.Estimator(model_creator, compile_args_creator=None, config=None, verbose=False, backend='tf', workers_per_node=1)[source]

Bases: object

evaluate(data_creator, verbose=1, sample_weight=None, steps=None, callbacks=None)[source]

Evaluates the model on the validation data set.

fit(data_creator, epochs=1, verbose=1, callbacks=None, validation_data_creator=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1)[source]

Runs a training epoch.

classmethod from_keras(model_creator, config=None, verbose=False, workers_per_node=1, compile_args_creator=None, backend='tf')[source]
get_model()[source]

Returns the learned model.

restore(checkpoint)[source]

Restores the model from the provided checkpoint.

Args: checkpoint (str): Path to target checkpoint file.

save(checkpoint)[source]

Saves the model at the provided checkpoint.

Args: checkpoint (str): Path to target checkpoint file.

shutdown()[source]

Shuts down workers and releases resources.

zoo.orca.learn.tf2.tf_runner module

class zoo.orca.learn.tf2.tf_runner.TFRunner(model_creator, compile_args_creator, config=None, verbose=False)[source]

Bases: object

Manages a TensorFlow model for training.

find_free_port()[source]

Finds a free port on the current node.

get_node_ip()[source]

Returns the IP address of the current node.

get_state()[source]

Returns the state of the runner.

set_state(state)[source]

Sets the state of the model.

setup()[source]

Initializes the model.

setup_distributed(urls, world_rank, world_size)[source]

Sets up TensorFLow distributed environment and initializes the model. Args: urls (str): the URLs that each node uses to connect. world_rank (int): the index of the runner. world_size (int): the total number of runners.

setup_horovod()[source]
shutdown()[source]

Attempts to shut down the worker.

step(data_creator, epochs=1, verbose=1, callbacks=None, validation_data_creator=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1)[source]

Runs a training epoch and updates the model parameters.

validate(data_creator, verbose=1, sample_weight=None, steps=None, callbacks=None)[source]

Evaluates the model on the validation data set.

zoo.orca.learn.tf2.tf_runner.find_free_port()[source]

Module contents