Exemple #1
0
    def test_train_model_with_bn(self):
        class SimpleTorchModel(nn.Module):
            def __init__(self):
                super(SimpleTorchModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.bn1 = torch.nn.BatchNorm1d(4)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = self.bn1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        torch_model = SimpleTorchModel()
        loss_fn = torch.nn.BCELoss()
        az_model = TorchModel.from_pytorch(torch_model)
        zoo_loss = TorchLoss.from_pytorch(loss_fn)
        inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9], [1, 9]])
        targets = torch.Tensor([[0], [0], [0], [1], [1], [1]])
        train_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2)
        train_featureset = FeatureSet.pytorch_dataloader(train_loader)
        val_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2)
        val_featureset = FeatureSet.pytorch_dataloader(val_loader)

        zooOptimizer = Adam()
        estimator = Estimator(az_model, optim_methods=zooOptimizer)
        estimator.train_minibatch(train_featureset,
                                  zoo_loss,
                                  end_trigger=MaxEpoch(4),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=val_featureset,
                                  validation_method=[Accuracy()])

        trained_model = az_model.to_pytorch()
    def _handle_data_loader(self, data, validation_data):
        train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
        if validation_data is None:
            val_feature_set = None
        else:
            assert isinstance(validation_data, DataLoader) or callable(data), \
                "validation_data should be a pytorch DataLoader or a callable data_creator"
            val_feature_set = FeatureSet.pytorch_dataloader(validation_data)

        return train_feature_set, val_feature_set
 def _handle_xshards(self, data, validation_data):
     train_rdd = data.rdd.flatMap(xshard_to_sample)
     train_feature_set = FeatureSet.sample_rdd(train_rdd)
     if validation_data is None:
         val_feature_set = None
     else:
         assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                           "SparkXShards"
         val_feature_set = FeatureSet.sample_rdd(
             validation_data.rdd.flatMap(xshard_to_sample))
     return train_feature_set, val_feature_set
    def _handle_dataframe(self, data, validation_data, feature_cols,
                          label_cols):
        schema = data.schema
        train_rdd = data.rdd.map(
            lambda row: row_to_sample(row, schema, feature_cols, label_cols))
        train_feature_set = FeatureSet.sample_rdd(train_rdd)
        if validation_data is None:
            val_feature_set = None
        else:
            assert isinstance(validation_data, DataFrame), "validation_data should also be a " \
                                                           "DataFrame"
            val_feature_set = FeatureSet.sample_rdd(
                validation_data.rdd.map(lambda row: row_to_sample(
                    row, schema, feature_cols, label_cols)))

        return train_feature_set, val_feature_set
Exemple #5
0
    def test_estimator_train(self):
        batch_size = 8
        epoch_num = 5

        images, labels = TestEstimator._generate_image_data(data_num=8,
                                                            img_shape=(3, 224,
                                                                       224))

        image_rdd = self.sc.parallelize(images)
        labels = self.sc.parallelize(labels)

        sample_rdd = image_rdd.zip(labels).map(
            lambda img_label: Sample.from_ndarray(img_label[0], img_label[1]))

        data_set = FeatureSet.sample_rdd(sample_rdd)

        model = TestEstimator._create_cnn_model()

        optim_method = SGD(learningrate=0.01)

        estimator = Estimator(model, optim_method, "")
        estimator.set_constant_gradient_clipping(0.1, 1.2)
        estimator.train(train_set=data_set,
                        criterion=ClassNLLCriterion(),
                        end_trigger=MaxEpoch(epoch_num),
                        checkpoint_trigger=EveryEpoch(),
                        validation_set=data_set,
                        validation_method=[Top1Accuracy()],
                        batch_size=batch_size)
        eval_result = estimator.evaluate(validation_set=data_set,
                                         validation_method=[Top1Accuracy()])
        assert isinstance(eval_result[0], EvaluatedResult)
        assert len(eval_result) == 1
        predict_result = model.predict(sample_rdd)
        assert (predict_result.count(), 8)
Exemple #6
0
 def input_fn(mode):
     if mode == tf.estimator.ModeKeys.TRAIN:
         image_set = self.get_raw_image_set(with_label=True)
         feature_set = FeatureSet.image_frame(
             image_set.to_image_frame())
         train_transformer = ChainedPreprocessing([
             ImageBytesToMat(),
             ImageResize(256, 256),
             ImageRandomCrop(224, 224),
             ImageRandomPreprocessing(ImageHFlip(), 0.5),
             ImageChannelNormalize(0.485, 0.456, 0.406, 0.229, 0.224,
                                   0.225),
             ImageMatToTensor(to_RGB=True, format="NHWC"),
             ImageSetToSample(input_keys=["imageTensor"],
                              target_keys=["label"])
         ])
         feature_set = feature_set.transform(train_transformer)
         feature_set = feature_set.transform(ImageFeatureToSample())
         training_dataset = TFDataset.from_feature_set(
             feature_set,
             features=(tf.float32, [224, 224, 3]),
             labels=(tf.int32, [1]),
             batch_size=8)
         return training_dataset
     else:
         raise NotImplementedError
Exemple #7
0
 def _get_training_data(self):
     jvalue = callZooFunc("float", "createTFDataFeatureSet",
                          self.rdd.map(lambda x: x[0]), self.init_op_name,
                          self.table_init_op, self.output_names,
                          self.output_types, self.shard_index_op_name,
                          self.inter_threads, self.intra_threads)
     return FeatureSet(jvalue=jvalue)
Exemple #8
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols="features",
                 label_cols="label"):
        """
        Evaluate model.

        :param data: validation data. It can be XShardsor or Spark DataFrame, each partition is
               a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array
               or a list of numpy arrays.
        :param batch_size: Batch size used for validation. Default: 32.
        :param feature_cols: (Not supported yet) Feature column name(s) of data. Only used when
               data is a Spark  DataFrame. Default: None.
        :param label_cols: (Not supported yet) Label column name(s) of data. Only used when data
               is a Spark DataFrame. Default: None.
        :return:
        """
        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, _, feature_cols = \
                    BigDLEstimator._combine_cols(data, [feature_cols], col_name="features")

            if isinstance(label_cols, list):
                data, _, label_cols = \
                    BigDLEstimator._combine_cols(data, label_cols, col_name="label")

            self.nn_estimator._setNNBatchSize(batch_size)._setNNFeaturesCol(feature_cols) \
                ._setNNLabelCol(label_cols)

            self.nn_estimator.setValidation(None, None, self.metrics,
                                            batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.dllib.optim.optimizer import TrainSummary
                from bigdl.dllib.optim.optimizer import ValidationSummary
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.app_name)
                self.nn_estimator.setValidationSummary(val_summary)

            result = self.nn_estimator._eval(data)

        elif isinstance(data, SparkXShards):
            from bigdl.orca.data.utils import xshard_to_sample
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)

        return bigdl_metric_results_to_dict(result)
Exemple #9
0
 def _get_validation_data(self):
     if self.validation_dataset is not None:
         jvalue = callZooFunc("float", "createTFDataFeatureSet",
                              self.val_rdd.map(lambda x: x[0]),
                              self.init_op_name, self.table_init_op,
                              self.output_names, self.output_types,
                              self.shard_index_op_name, self.inter_threads,
                              self.intra_threads)
         return FeatureSet(jvalue=jvalue)
     return None
Exemple #10
0
 def create_train_features_Set(self):
     image_set = self.get_raw_image_set(with_label=True)
     feature_set = FeatureSet.image_frame(image_set.to_image_frame())
     train_transformer = ChainedPreprocessing([
         ImageBytesToMat(),
         ImageResize(256, 256),
         ImageRandomCrop(224, 224),
         ImageRandomPreprocessing(ImageHFlip(), 0.5),
         ImageChannelNormalize(0.485, 0.456, 0.406, 0.229, 0.224, 0.225),
         ImageMatToTensor(to_RGB=True, format="NHWC"),
         ImageSetToSample(input_keys=["imageTensor"], target_keys=["label"])
     ])
     feature_set = feature_set.transform(train_transformer)
     feature_set = feature_set.transform(ImageFeatureToSample())
     return feature_set
Exemple #11
0
    def test_estimator_train_imagefeature(self):
        batch_size = 8
        epoch_num = 5
        images, labels = TestEstimator._generate_image_data(data_num=8,
                                                            img_shape=(200,
                                                                       200, 3))

        image_frame = DistributedImageFrame(self.sc.parallelize(images),
                                            self.sc.parallelize(labels))

        transformer = Pipeline([
            BytesToMat(),
            Resize(256, 256),
            CenterCrop(224, 224),
            ChannelNormalize(0.485, 0.456, 0.406, 0.229, 0.224, 0.225),
            MatToTensor(),
            ImageFrameToSample(target_keys=['label'])
        ])
        data_set = FeatureSet.image_frame(image_frame).transform(transformer)

        model = TestEstimator._create_cnn_model()

        optim_method = SGD(learningrate=0.01)

        estimator = Estimator(model, optim_method, "")
        estimator.set_constant_gradient_clipping(0.1, 1.2)
        estimator.train_imagefeature(train_set=data_set,
                                     criterion=ClassNLLCriterion(),
                                     end_trigger=MaxEpoch(epoch_num),
                                     checkpoint_trigger=EveryEpoch(),
                                     validation_set=data_set,
                                     validation_method=[Top1Accuracy()],
                                     batch_size=batch_size)
        eval_result = estimator.evaluate_imagefeature(
            validation_set=data_set, validation_method=[Top1Accuracy()])
        assert isinstance(eval_result[0], EvaluatedResult)
        assert len(eval_result) == 1
        predict_result = model.predict_image(
            image_frame.transform(transformer))
        assert (predict_result.get_predict().count(), 8)
    def evaluate(self,
                 data,
                 batch_size=None,
                 feature_cols=None,
                 label_cols=None,
                 validation_metrics=None):
        """
        Evaluate model.

        :param data: data: evaluation data. It can be an XShards, Spark Dataframe,
               PyTorch DataLoader and PyTorch DataLoader creator function.
               If data is an XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of
               numpy arrays.
        :param batch_size: Batch size used for evaluation. Only used when data is a SparkXShard.
        :param feature_cols: Feature column name(s) of data. Only used when data
               is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is
               a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param validation_metrics: Orca validation metrics to be computed on validation_data.
        :return: validation results.
        """
        from bigdl.orca.data.utils import xshard_to_sample

        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."
        if isinstance(data, DataLoader):
            assert batch_size is None and data.batch_size > 0, "When using PyTorch Dataloader as " \
                                                               "input, you need to specify the " \
                                                               "batch size in DataLoader and " \
                                                               "don't specify batch_size " \
                                                               "in the fit method."
        else:
            assert batch_size is not None and batch_size > 0, "batch_size should be greater than 0"

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataFrame):
            schema = data.schema
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.map(lambda row: row_to_sample(
                    row, schema, feature_cols, label_cols)))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataLoader) or callable(data) or isinstance(
                data, types.FunctionType):
            if isinstance(data, types.FunctionType):
                data = data(self.config, batch_size)
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            result = self.estimator.evaluate_minibatch(val_feature_set,
                                                       self.metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
        return bigdl_metric_results_to_dict(result)
Exemple #13
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dir',
                        default='/tmp/data',
                        metavar='N',
                        help='the folder store mnist data')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=256,
        metavar='N',
        help='input batch size for training per executor(default: 256)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=1000,
        metavar='N',
        help='input batch size for testing per executor(default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 2)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument(
        '--deploy-mode',
        default="local",
        help='supported deploy mode is local, yarn-client, yarn-cluster')

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False)

    # init on yarn when HADOOP_CONF_DIR and ZOO_CONDA_NAME is provided.
    if args.deploy_mode == "local":
        sc = init_orca_context()
    else:
        sc = init_orca_context(cluster_mode=args.deploy_mode,
                               cores=2,
                               memory="2g",
                               num_nodes=4)

    model = Net()
    model.train()
    criterion = nn.NLLLoss()

    adam = torch.optim.Adam(model.parameters(), lr=args.lr)
    zoo_model = TorchModel.from_pytorch(model)
    zoo_criterion = TorchLoss.from_pytorch(criterion)
    zoo_optim = TorchOptim.from_pytorch(adam)
    zoo_estimator = Estimator(zoo_model, optim_methods=zoo_optim)
    train_featureset = FeatureSet.pytorch_dataloader(train_loader)
    test_featureset = FeatureSet.pytorch_dataloader(test_loader)
    from bigdl.dllib.optim.optimizer import MaxEpoch, EveryEpoch
    zoo_estimator.train_minibatch(train_featureset,
                                  zoo_criterion,
                                  end_trigger=MaxEpoch(args.epochs),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=test_featureset,
                                  validation_method=[Accuracy()])
Exemple #14
0
    def fit(self,
            data,
            epochs,
            batch_size=32,
            feature_cols="features",
            label_cols="label",
            caching_sample=True,
            validation_data=None,
            validation_trigger=None,
            checkpoint_trigger=None):
        """
        Train this BigDL model with train data.

        :param data: train data. It can be XShards or Spark DataFrame.
               If data is XShards, each partition is a dictionary of  {'x': feature,
               'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param epochs: Number of epochs to train the model.
        :param batch_size: Batch size used for training. Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame. Default: "features".
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame.
               Default: "label".
        :param caching_sample: whether to cache the Samples after preprocessing. Default: True
        :param validation_data: Validation data. XShards and Spark DataFrame are supported.
               If data is XShards, each partition is a dictionary of  {'x': feature,
               'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param validation_trigger: Orca Trigger to trigger validation computation.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return:
        """
        from bigdl.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if validation_data is not None:
            assert self.metrics is not None, \
                "You should provide metrics when creating this estimator if you provide " \
                "validation_data."

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, validation_data, feature_cols = \
                    BigDLEstimator._combine_cols(data, feature_cols, col_name="features",
                                                 val_data=validation_data)

            if isinstance(label_cols, list):
                data, validation_data, label_cols = \
                    BigDLEstimator._combine_cols(data, label_cols, col_name="label",
                                                 val_data=validation_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs) \
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols) \
                .setLabelCol(label_cols)

            if validation_data is not None:
                assert isinstance(validation_data, DataFrame), \
                    "validation_data should be a spark DataFrame."
                assert validation_trigger is not None, \
                    "You should provide validation_trigger if you provide validation_data."
                validation_trigger = Trigger.convert_trigger(
                    validation_trigger)
                self.nn_estimator.setValidation(validation_trigger,
                                                validation_data, self.metrics,
                                                batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.dllib.optim.optimizer import TrainSummary
                from bigdl.dllib.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.app_name)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from bigdl.orca.data.utils import xshard_to_sample

            end_trigger = MaxEpoch(epochs)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(xshard_to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if validation_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(validation_data, SparkXShards), \
                        "validation_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        validation_data.rdd.flatMap(xshard_to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     self.metrics, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self
Exemple #15
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR',
                        help='path to dataset')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                        choices=model_names,
                        help='model architecture: ' +
                             ' | '.join(model_names) +
                             ' (default: resnet18)')
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--max_epochs', default=90, type=int, metavar='N',
                        help='number of max epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print-freq', default=10, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--world-size', default=-1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--seed', default=None, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--cores', default=4, type=int,
                        help='num of CPUs to use.')
    parser.add_argument('--nodes', default=1, type=int,
                        help='num of nodes to use.')
    parser.add_argument('--executor_memory', default='20g', type=str,
                        help='size of executor memory.')
    parser.add_argument('--driver_memory', default='20g', type=str,
                        help='size of driver memory.')
    parser.add_argument('--driver_cores', default=1, type=int,
                        help='num of driver cores to use.')
    args = parser.parse_args()
    if os.environ.get('HADOOP_CONF_DIR') is None:
        sc = init_spark_on_local(cores=args.cores, conf={"spark.driver.memory": "20g"})
    else:
        hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')
        num_executors = args.nodes
        executor_memory = args.executor_memory
        driver_memory = args.driver_memory
        driver_cores = args.driver_cores
        num_cores_per_executor = args.cores
        os.environ['ZOO_MKL_NUMTHREADS'] = str(num_cores_per_executor)
        os.environ['OMP_NUM_THREADS'] = str(num_cores_per_executor)
        sc = init_spark_on_yarn(
            hadoop_conf=hadoop_conf_dir,
            conda_name=detect_conda_env_name(),  # auto detect current conda env name
            num_executors=num_executors,
            executor_cores=num_cores_per_executor,
            executor_memory=executor_memory,
            driver_memory=driver_memory,
            driver_cores=driver_cores,
            conf={"spark.rpc.message.maxSize": "1024",
                  "spark.task.maxFailures": "1",
                  "spark.driver.extraJavaOptions": "-Dbigdl.failure.retryTimes=1"})

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True)

    model = torchvision.models.resnet50()
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False)

    iterationPerEpoch = int(math.ceil(float(1281167) / args.batch_size))
    step = Step(iterationPerEpoch * 30, 0.1)
    zooOptimizer = SGD(args.lr, momentum=args.momentum, dampening=0.0,
                       leaningrate_schedule=step, weightdecay=args.weight_decay)
    zooModel = TorchModel.from_pytorch(model)
    criterion = torch.nn.CrossEntropyLoss()
    zooCriterion = TorchLoss.from_pytorch(criterion)
    estimator = Estimator(zooModel, optim_methods=zooOptimizer)
    train_featureSet = FeatureSet.pytorch_dataloader(train_loader)
    test_featureSet = FeatureSet.pytorch_dataloader(val_loader)
    estimator.train_minibatch(train_featureSet, zooCriterion, end_trigger=MaxEpoch(args.max_epochs),
                              checkpoint_trigger=EveryEpoch(), validation_set=test_featureSet,
                              validation_method=[Accuracy(), Top5Accuracy()])