zoo.orca.learn.mxnet package

Submodules

zoo.orca.learn.mxnet.mxnet_runner module

class zoo.orca.learn.mxnet.mxnet_runner.MXNetRunner[source]

Bases: object

Manages a MXNet model for training.

find_free_port()[source]

Finds a free port on the current node.

get_node_ip()[source]

Returns the IP address of the current node.

setup_distributed(env, config, model_creator, loss_creator=None, validation_metrics_creator=None, eval_metrics_creator=None)[source]
shutdown()[source]

Attempts to shut down the runner.

train(train_data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None)[source]

Train the model and update the model parameters.

zoo.orca.learn.mxnet.mxnet_trainer module

class zoo.orca.learn.mxnet.mxnet_trainer.Estimator(config, model_creator, loss_creator=None, eval_metrics_creator=None, validation_metrics_creator=None, num_workers=None, num_servers=None, runner_cores=None)[source]

Bases: object

MXNet Estimator provides an automatic setup for synchronous distributed MXNet training.

Parameters:config – A dictionary for training configurations. Keys must include the following:

optimizer, optimizer_params, log_interval. optimizer should be an MXNet optimizer or its string representation. optimizer_params should be a dict in companion with the optimizer. It can contain learning_rate and other optimization configurations. log_interval should be an integer, specifying the interval for logging throughput and metrics information (if any) during the training process. You can call create_config to directly create it. You can specify “seed” in config to set random seed for weight initialization. You can specify “init” in extra_config to set model initializer for gluon models.

Parameters:model_creator – A function that takes config as argument and returns an MXNet model.

The model can be defined either using MXNet symbolic API or imperative(gluon) API.

Parameters:loss_creator – A function that takes config as argument and returns an MXNet loss.

This is not needed for symbolic API where loss is already defined as model output.

Parameters:eval_metrics_creator – A function that takes config as argument and returns one or

a list of MXNet metrics or corresponding string representations of metrics, for example, ‘accuracy’. This is not needed if you don’t need evaluation on the training data set.

Parameters:validation_metrics_creator – A function that takes config as argument and returns one or

a list of MXNet metrics or corresponding string representations of metrics, for example, ‘accuracy’. This is not needed if you don’t have validation data throughout the training.

Parameters:num_workers – The number of workers for distributed training. Default to be the number of

nodes in the cluster.

Parameters:num_servers – The number of servers for distributed training. Default is None and in this

case it would be equal to the number of workers.

Parameters:runner_cores – The number of CPU cores allocated for each MXNet worker and server.

Default is None. You may need to specify this for better performance when you run in cluster.

fit(data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None)[source]

Trains an MXNet model given train_data (with val_data) for several epochs.

Parameters:data – An instance of SparkXShards or a function that takes config and kv as

arguments and returns an MXNet DataIter/DataLoader for training. You can specify data related configurations for this function in the config argument above. kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank can be used in this function to split data for different workers if necessary.

Parameters:
  • epochs – The number of epochs to train the MXNet model. Default is 1.
  • batch_size – The number of samples per batch for each worker. Default is 32.
  • validation_data – An instance of SparkXShards or a function that takes config and

kv as arguments and returns an MXNet DataIter/DataLoader for validation. You can specify data related configurations for this function in the config argument above. kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank can be used in this function to split data for different workers if necessary.

Parameters:train_resize_batch_num – The number of batches per epoch to resize to.

Default is None. You might need to specify this if the size of train_data for each worker varies. MXNet distributed training would crash when the first worker finishes the training if the workers have unbalanced training data. See this issue for more details: https://github.com/apache/incubator-mxnet/issues/17651

shutdown()[source]

Shuts down runners and releases resources.

zoo.orca.learn.mxnet.utils module

zoo.orca.learn.mxnet.utils.create_config(optimizer='sgd', optimizer_params=None, log_interval=10, seed=None, extra_config=None)[source]
zoo.orca.learn.mxnet.utils.find_free_port()[source]

Module contents