Distributed Training and Inference¶
Orca Estimator provides sklearn-style APIs for transparently distributed model training and inference
1. Estimator¶
To perform distributed training and inference, the user can first create an Orca Estimator from any standard (single-node) TensorFlow, Kera or PyTorch model, and then call Estimator.fit or Estimator.predict methods (using the data-parallel processing pipeline as input).
Under the hood, the Orca Estimator will replicate the model on each node in the cluster, feed the data partition (generated by the data-parallel processing pipeline) on each node to the local model replica, and synchronize model parameters using various backend technologies (such as Horovod, tf.distribute.MirroredStrategy, torch.distributed, or the parameter sync layer in BigDL).
2. TensorFlow/Keras Estimator¶
2.1 TensorFlow 1.15 and Keras 2.3¶
2.2 TensorFlow 2.x and Keras 2.4+¶
3. PyTorch Estimator¶
Using BigDL backend
The user may create a PyTorch Estimator using the BigDL backend (currently default for PyTorch) as follows: <TODO: add a simple example>
Then the user can perform distributed model training and inference as follows: <TODO: add a simple example>
The input to fit and predict methods can be torch.utils.data.DataLoader, XShards, or a Data Creator Function (which returns torch.utils.data.DataLoader). See the data-parallel processing pipeline page for more details. <TODO: we need to add Spark Dataframe support too>
View the related Python API doc for more details.
Using torch.distributed or Horovod backend
<TODO: add description for torch.distributed or Horovod support>
For more details, view the distributed PyTorch training/inference page.