zoo.orca.learn.mxnet package¶
Submodules¶
zoo.orca.learn.mxnet.mxnet_runner module¶
zoo.orca.learn.mxnet.mxnet_trainer module¶
-
class
zoo.orca.learn.mxnet.mxnet_trainer.Estimator(config, model_creator, loss_creator=None, eval_metrics_creator=None, validation_metrics_creator=None, num_workers=None, num_servers=None, runner_cores=None)[source]¶ Bases:
objectMXNet Estimator provides an automatic setup for synchronous distributed MXNet training.
Parameters: config – A dictionary for training configurations. Keys must include the following: optimizer, optimizer_params, log_interval. optimizer should be an MXNet optimizer or its string representation. optimizer_params should be a dict in companion with the optimizer. It can contain learning_rate and other optimization configurations. log_interval should be an integer, specifying the interval for logging throughput and metrics information (if any) during the training process. You can call create_config to directly create it. You can specify “seed” in config to set random seed for weight initialization. You can specify “init” in extra_config to set model initializer for gluon models.
Parameters: model_creator – A function that takes config as argument and returns an MXNet model. The model can be defined either using MXNet symbolic API or imperative(gluon) API.
Parameters: loss_creator – A function that takes config as argument and returns an MXNet loss. This is not needed for symbolic API where loss is already defined as model output.
Parameters: eval_metrics_creator – A function that takes config as argument and returns one or a list of MXNet metrics or corresponding string representations of metrics, for example, ‘accuracy’. This is not needed if you don’t need evaluation on the training data set.
Parameters: validation_metrics_creator – A function that takes config as argument and returns one or a list of MXNet metrics or corresponding string representations of metrics, for example, ‘accuracy’. This is not needed if you don’t have validation data throughout the training.
Parameters: num_workers – The number of workers for distributed training. Default to be the number of nodes in the cluster.
Parameters: num_servers – The number of servers for distributed training. Default is None and in this case it would be equal to the number of workers.
Parameters: runner_cores – The number of CPU cores allocated for each MXNet worker and server. Default is None. You may need to specify this for better performance when you run in cluster.
-
fit(data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None)[source]¶ Trains an MXNet model given train_data (with val_data) for several epochs.
Parameters: data – An instance of SparkXShards or a function that takes config and kv as arguments and returns an MXNet DataIter/DataLoader for training. You can specify data related configurations for this function in the config argument above. kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank can be used in this function to split data for different workers if necessary.
Parameters: - epochs – The number of epochs to train the MXNet model. Default is 1.
- batch_size – The number of samples per batch for each worker. Default is 32.
- validation_data – An instance of SparkXShards or a function that takes config and
kv as arguments and returns an MXNet DataIter/DataLoader for validation. You can specify data related configurations for this function in the config argument above. kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank can be used in this function to split data for different workers if necessary.
Parameters: train_resize_batch_num – The number of batches per epoch to resize to. Default is None. You might need to specify this if the size of train_data for each worker varies. MXNet distributed training would crash when the first worker finishes the training if the workers have unbalanced training data. See this issue for more details: https://github.com/apache/incubator-mxnet/issues/17651
-