#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import time
import logging
import subprocess
import ray.services
import mxnet as mx
from mxnet import gluon
from zoo.ray.utils import to_list
[docs]class MXNetRunner(object):
"""Manages a MXNet model for training."""
[docs] def setup_distributed(self, env, config, model_creator, loss_creator=None,
validation_metrics_creator=None, eval_metrics_creator=None):
logging.basicConfig(level=logging.INFO) # This can print log messages to console.
self.logger = logging.getLogger()
assert isinstance(config, dict), "config must be a dict"
for param in ["optimizer", "optimizer_params", "log_interval"]:
assert param in config, param + " must be specified in config"
self.config = config
self.model_creator = model_creator
self.loss_creator = loss_creator
self.validation_metrics_creator = validation_metrics_creator
self.eval_metrics_creator = eval_metrics_creator
self.is_worker = False
env["DMLC_NODE_HOST"] = self.get_node_ip()
if env["DMLC_ROLE"] == "worker":
self.is_worker = True
if self.is_worker:
os.environ.update(env)
self.kv = mx.kv.create("dist_sync")
# Set seed so that the model on each worker is initialized with the same weights.
if "seed" in self.config:
mx.random.seed(self.config["seed"])
self.model = self.model_creator(self.config)
self.loss = self.loss_creator(self.config) if self.loss_creator else None
self.eval_metrics = self.eval_metrics_creator(self.config) \
if self.eval_metrics_creator else None
from mxnet.metric import CompositeEvalMetric
if isinstance(self.eval_metrics, list):
self.eval_metrics = CompositeEvalMetric(self.eval_metrics)
self.val_metrics = self.validation_metrics_creator(self.config) \
if self.validation_metrics_creator else None
if isinstance(self.val_metrics, list):
self.val_metrics = CompositeEvalMetric(self.val_metrics)
# For BaseModule, use symbolic API. Otherwise, use imperative API.
# TODO: change Gluon Trainer to Estimator API?
if not isinstance(self.model, mx.module.BaseModule):
assert self.loss, "Loss not defined for gluon model, please specify loss_creator"
self.trainer = gluon.Trainer(self.model.collect_params(), self.config["optimizer"],
optimizer_params=self.config["optimizer_params"],
kvstore=self.kv)
else: # Trainer is not needed for symbolic API.
self.trainer = None
else: # server
# Need to use the environment on each raylet process for the correct python environment.
# TODO: Need to kill this process manually?
modified_env = os.environ.copy()
modified_env.update(env)
# For servers, just import mxnet and no need to do anything else.
subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env)
[docs] def train(self, train_data, epochs=1, batch_size=32,
validation_data=None, train_resize_batch_num=None):
"""Train the model and update the model parameters."""
stats = dict()
if self.is_worker:
from zoo.orca.data.shard import RayPartition
if isinstance(train_data, RayPartition):
from zoo.orca.data.utils import ray_partition_get_data_label
data, label = ray_partition_get_data_label(train_data.get_data(),
allow_tuple=False,
allow_list=False)
train_data_iter = mx.io.NDArrayIter(data=data, label=label,
batch_size=batch_size, shuffle=True)
if train_resize_batch_num is not None:
train_data_iter = mx.io.ResizeIter(train_data_iter, train_resize_batch_num)
if validation_data:
data_val, label_val = ray_partition_get_data_label(validation_data.get_data(),
allow_tuple=False,
allow_list=False)
val_data_iter = mx.io.NDArrayIter(data=data_val, label=label_val,
batch_size=batch_size, shuffle=True)
else:
val_data_iter = None
else: # data_creator functions; should return Iter or DataLoader
config = self.config
if "batch_size" not in config:
config["batch_size"] = batch_size
train_data_iter = train_data(config, self.kv)
val_data_iter = validation_data(config, self.kv) if validation_data else None
start_time = time.time()
if self.trainer: # Imperative API
def cpu_context(target_data):
if isinstance(target_data, list):
return [cpu_context(d) for d in target_data]
else:
return target_data.as_in_context(mx.cpu())
for epoch in range(epochs):
# DataLoader doesn't need to be reset.
if isinstance(train_data_iter, mx.io.DataIter):
train_data_iter.reset()
if self.eval_metrics:
self.eval_metrics.reset() # metrics will accumulate for one batch.
batch_start_time = time.time()
epoch_start_time = time.time()
for i, batch in enumerate(train_data_iter):
data = cpu_context(batch.data)
label = cpu_context(batch.label)
if not isinstance(data, list):
data = [data]
if not isinstance(label, list):
label = [label]
from mxnet import autograd as ag
with ag.record():
output = self.model(*data) # forward
if not isinstance(output, list):
output = [output]
Ls = self.loss(*output, *label)
ag.backward(Ls)
self.trainer.step(batch_size)
if self.eval_metrics:
self.eval_metrics.update(label, output)
if not (i + 1) % self.config["log_interval"]:
# This would be logged on driver for each worker process.
iteration_log = \
"Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \
% (epoch, i,
batch_size / (time.time() - batch_start_time),
"loss", Ls.asnumpy().mean())
if self.eval_metrics:
names, accs = self.eval_metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
iteration_log += " %s=%f" % (name, acc)
self.logger.info(iteration_log)
batch_start_time = time.time()
# Epoch time log.
self.logger.info("[Epoch %d] time cost: %f" %
(epoch, time.time() - epoch_start_time))
# Epoch metrics log on train data.
if self.eval_metrics:
epoch_train_log = "[Epoch %d] training: " % epoch
names, accs = self.eval_metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_train_log += "%s=%f " % (name, acc)
self.logger.info(epoch_train_log)
# Epoch metrics log on validation data if any.
if val_data_iter:
if isinstance(val_data_iter, mx.io.DataIter):
val_data_iter.reset()
self.val_metrics.reset()
for batch in val_data_iter:
data = cpu_context(batch.data)
label = cpu_context(batch.label)
if not isinstance(data, list):
data = [data]
if not isinstance(label, list):
label = [label]
output = self.model(*data)
if not isinstance(output, list):
output = [output]
self.val_metrics.update(label, output)
epoch_val_log = "[Epoch %d] validation: " % epoch
names, accs = self.val_metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_val_log += "%s=%f " % (name, acc)
self.logger.info(epoch_val_log)
# TODO: save checkpoints
if self.eval_metrics:
names, accs = self.eval_metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
stats[name] = acc
else: # Symbolic API
# TODO: seems no history (i.e. validation accuracy) returned by fit?
if "init" not in self.config:
from mxnet.initializer import Uniform
self.config["init"] = Uniform(0.01) # This is the default value for MXNet.
if self.eval_metrics is None:
self.eval_metrics = 'acc' # This is the default value for MXNet.
self.model.fit(train_data=train_data_iter,
num_epoch=epochs,
initializer=self.config["init"],
kvstore=self.kv,
optimizer=self.config["optimizer"],
optimizer_params=self.config["optimizer_params"],
eval_data=val_data_iter,
eval_metric=self.eval_metrics,
validation_metric=self.val_metrics,
batch_end_callback=mx.callback.Speedometer(
batch_size, self.config["log_interval"]),
epoch_end_callback=None if "model" not in self.config
else mx.callback.do_checkpoint(self.config["model"]))
epoch_time = time.time() - start_time
stats["epoch_time"] = epoch_time
if isinstance(train_data, RayPartition):
del train_data
if validation_data and isinstance(validation_data, RayPartition):
del validation_data
return stats
[docs] def shutdown(self):
"""Attempts to shut down the runner."""
del self.logger
if self.is_worker:
del self.kv
del self.model
del self.trainer
del self.loss
del self.eval_metrics
del self.val_metrics
[docs] def get_node_ip(self):
"""Returns the IP address of the current node."""
if "node_ip" not in self.__dict__:
self.node_ip = ray.services.get_node_ip_address()
return self.node_ip
[docs] def find_free_port(self):
"""Finds a free port on the current node."""
if "port" not in self.__dict__:
from zoo.orca.learn.mxnet.utils import find_free_port
self.port = find_free_port()
return self.port