Source code for zoo.orca.learn.tf2.tf_ray_estimator

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import pickle

import numpy as np
import ray

from zoo.orca.learn.tf2.tf_runner import TFRunner
from zoo.ray import RayContext

logger = logging.getLogger(__name__)


[docs]class Estimator: def __init__(self, model_creator, compile_args_creator=None, config=None, verbose=False, backend="tf", workers_per_node=1): """Sets up the TensorFlow trainer. Args: model_creator (dict -> Model): This function takes in the `config` dict and returns a compiled TF model. data_creator (dict -> tf.Dataset, tf.Dataset): Creates the training and validation data sets using the config. `config` dict is passed into the function. config (dict): configuration passed to 'model_creator', 'data_creator'. Also contains `fit_config`, which is passed into `model.fit(data, **fit_config)` and `evaluate_config` which is passed into `model.evaluate`. num_replicas (int): Sets number of workers used in distributed training. Workers will be placed arbitrarily across the cluster. use_gpu (bool): Enables all workers to use GPU. verbose (bool): Prints output of one model if true. """ self.model_creator = model_creator self.compile_args_creator = compile_args_creator self.config = {} if config is None else config self.verbose = verbose ray_ctx = RayContext.get() if "inter_op_parallelism" not in self.config: self.config["inter_op_parallelism"] = 1 if "intra_op_parallelism" not in config: self.config["intra_op_parallelism"] = ray_ctx.ray_node_cpu_cores // workers_per_node if backend == "horovod": assert compile_args_creator is not None, "compile_args_creator should not be None," \ " when backend is set to horovod" params = { "model_creator": model_creator, "compile_args_creator": compile_args_creator, "config": self.config, "verbose": self.verbose, } if backend == "tf": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node worker_class = ray.remote(num_cpus=cores_per_node)(TFRunner) self.remote_workers = [worker_class.remote(**params) for i in range(0, num_nodes)] ips = ray.get( [worker.get_node_ip.remote() for worker in self.remote_workers]) ports = ray.get( [worker.find_free_port.remote() for worker in self.remote_workers]) urls = ["{ip}:{port}".format(ip=ips[i], port=ports[i]) for i in range(len(self.remote_workers))] # Get setup tasks in order to throw errors on failure ray.get([ worker.setup_distributed.remote(urls, i, len(self.remote_workers)) for i, worker in enumerate(self.remote_workers)]) elif backend == "horovod": # it is necessary to call self.run first to set horovod environment from zoo.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner horovod_runner = HorovodRayRunner(ray_ctx, worker_cls=TFRunner, worker_param=params, workers_per_node=workers_per_node) horovod_runner.run(lambda: print("worker initialized")) self.remote_workers = horovod_runner.remote_workers ray.get([ worker.setup_horovod.remote() for i, worker in enumerate(self.remote_workers)]) else: raise Exception("Only \"tf\" and \"horovod\" are legal " "value of backend, but got {}".format(backend))
[docs] @classmethod def from_keras(cls, model_creator, config=None, verbose=False, workers_per_node=1, compile_args_creator=None, backend="tf"): return cls(model_creator, config=config, verbose=verbose, workers_per_node=workers_per_node, backend=backend, compile_args_creator=compile_args_creator)
[docs] def fit(self, data_creator, epochs=1, verbose=1, callbacks=None, validation_data_creator=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1): """Runs a training epoch.""" params = dict( data_creator=data_creator, epochs=epochs, verbose=verbose, callbacks=callbacks, validation_data_creator=validation_data_creator, class_weight=class_weight, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq, ) worker_stats = ray.get([w.step.remote(**params) for w in self.remote_workers]) stats = worker_stats[0].copy() return stats
[docs] def evaluate(self, data_creator, verbose=1, sample_weight=None, steps=None, callbacks=None): """Evaluates the model on the validation data set.""" logger.info("Starting validation step.") params = dict( data_creator=data_creator, verbose=verbose, sample_weight=sample_weight, steps=steps, callbacks=callbacks ) # see ./tf_runner.py:setup_distributed # for an explanation of only taking the first worker's data stats = ray.get([w.validate.remote(**params) for w in self.remote_workers]) stats = stats[0].copy() return stats
[docs] def get_model(self): """Returns the learned model.""" state = ray.get(self.remote_workers[0].get_state.remote()) return self._get_model_from_state(state)
[docs] def save(self, checkpoint): """Saves the model at the provided checkpoint. Args: checkpoint (str): Path to target checkpoint file. """ state = ray.get(self.remote_workers[0].get_state.remote()) with open(checkpoint, "wb") as f: pickle.dump(state, f) return checkpoint
[docs] def restore(self, checkpoint): """Restores the model from the provided checkpoint. Args: checkpoint (str): Path to target checkpoint file. """ with open(checkpoint, "rb") as f: state = pickle.load(f) state_id = ray.put(state) ray.get([worker.set_state.remote(state_id) for worker in self.remote_workers])
[docs] def shutdown(self): """Shuts down workers and releases resources.""" for worker in self.remote_workers: worker.shutdown.remote() worker.__ray_terminate__.remote()
def _get_model_from_state(self, state): """Creates model and load weights from state""" model = self.model_creator(self.config) model.set_weights(state["weights"]) # This part is due to ray.get() changing scalar np.int64 object to int state["optimizer_weights"][0] = np.array( state["optimizer_weights"][0], dtype=np.int64) if model.optimizer.weights == []: model._make_train_function() model.optimizer.set_weights(state["optimizer_weights"]) return model