Source code for zoo.examples.orca.learn.mxnet.lenet_mnist

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Reference: https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html

import argparse

from zoo.orca import init_orca_context, stop_orca_context
from zoo.orca.learn.mxnet import Estimator, create_config


[docs]def get_train_data_iter(config, kv): from mxnet.test_utils import get_mnist_iterator from filelock import FileLock with FileLock("data.lock"): iters = get_mnist_iterator(config["batch_size"], (1, 28, 28), num_parts=kv.num_workers, part_index=kv.rank) return iters[0]
[docs]def get_test_data_iter(config, kv): from mxnet.test_utils import get_mnist_iterator from filelock import FileLock with FileLock("data.lock"): iters = get_mnist_iterator(config["batch_size"], (1, 28, 28), num_parts=kv.num_workers, part_index=kv.rank) return iters[1]
[docs]def get_model(config): import mxnet as mx from mxnet import gluon from mxnet.gluon import nn import mxnet.ndarray as F class LeNet(gluon.Block): def __init__(self, **kwargs): super(LeNet, self).__init__(**kwargs) with self.name_scope(): # layers created in name_scope will inherit name space # from parent layer. self.conv1 = nn.Conv2D(20, kernel_size=(5, 5)) self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) self.conv2 = nn.Conv2D(50, kernel_size=(5, 5)) self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) self.fc1 = nn.Dense(500) self.fc2 = nn.Dense(10) def forward(self, x): x = self.pool1(F.tanh(self.conv1(x))) x = self.pool2(F.tanh(self.conv2(x))) # 0 means copy over size from corresponding dimension. # -1 means infer size from the rest of dimensions. x = x.reshape((0, -1)) x = F.tanh(self.fc1(x)) x = F.tanh(self.fc2(x)) return x net = LeNet() net.initialize(mx.init.Xavier(magnitude=2.24), ctx=[mx.cpu()]) return net
[docs]def get_loss(config): from mxnet import gluon return gluon.loss.SoftmaxCrossEntropyLoss()
[docs]def get_metrics(config): import mxnet as mx return mx.metric.Accuracy()
if __name__ == '__main__': parser = argparse.ArgumentParser( description='Train a LeNet model for handwritten digit recognition.') parser.add_argument('--cluster_mode', type=str, default="local", help='The mode for the Spark cluster.') parser.add_argument('--cores', type=int, default=4, help='The number of cores you want to use on each node.') parser.add_argument('-n', '--num_workers', type=int, default=2, help='The number of MXNet workers to be launched.') parser.add_argument('-s', '--num_servers', type=int, help='The number of MXNet servers to be launched. If not specified, ' 'default to be equal to the number of workers.') parser.add_argument('-b', '--batch_size', type=int, default=100, help='The number of samples per gradient update for each worker.') parser.add_argument('-e', '--epochs', type=int, default=10, help='The number of epochs to train the model.') parser.add_argument('-l', '--learning_rate', type=float, default=0.02, help='Learning rate for the LeNet model.') parser.add_argument('--log_interval', type=int, default=20, help='The number of batches to wait before logging throughput and ' 'metrics information during the training process.') opt = parser.parse_args() num_nodes = 1 if opt.cluster_mode == "local" else opt.num_workers init_orca_context(cluster_mode=opt.cluster_mode, cores=opt.cores, num_nodes=num_nodes) config = create_config(optimizer="sgd", optimizer_params={'learning_rate': opt.learning_rate}, log_interval=opt.log_interval, seed=42) estimator = Estimator(config, model_creator=get_model, loss_creator=get_loss, validation_metrics_creator=get_metrics, num_workers=opt.num_workers, num_servers=opt.num_servers, eval_metrics_creator=get_metrics) estimator.fit(data=get_train_data_iter, validation_data=get_test_data_iter, epochs=opt.epochs, batch_size=opt.batch_size) estimator.shutdown() stop_orca_context()