示例#1
0
    def fit(self, data, epochs=1, batch_size=32, validation_data=None, validation_methods=None,
            checkpoint_trigger=None):
        from zoo.orca.data.utils import to_sample

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"

        if isinstance(data, SparkXShards):
            train_rdd = data.rdd.flatMap(to_sample)
            train_feature_set = FeatureSet.sample_rdd(train_rdd)
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                                  "SparkXShards"
                val_feature_set = FeatureSet.sample_rdd(validation_data.rdd.flatMap(to_sample))

            self.estimator.train(train_feature_set, self.loss, end_trigger, checkpoint_trigger,
                                 val_feature_set, validation_methods, batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, DataLoader) or callable(data), \
                    "validation_data should be a pytorch DataLoader or a callable data_creator"
                val_feature_set = FeatureSet.pytorch_dataloader(validation_data)

            self.estimator.train_minibatch(train_feature_set, self.loss, end_trigger,
                                           checkpoint_trigger, val_feature_set, validation_methods)
        else:
            raise ValueError("Data and validation data should be SparkXShards, DataLoaders or "
                             "callable data_creators but get " + data.__class__.__name__)
        return self
示例#2
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 label_cols=None):
        from zoo.orca.data.utils import xshard_to_sample

        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, SparkXShards):
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataFrame):
            schema = data.schema
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.map(lambda row: row_to_sample(
                    row, schema, feature_cols, label_cols)))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            result = self.estimator.evaluate_minibatch(val_feature_set,
                                                       self.metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
        return bigdl_metric_results_to_dict(result)
示例#3
0
 def _handle_xshards(self, data, validation_data):
     train_rdd = data.rdd.flatMap(xshard_to_sample)
     train_feature_set = FeatureSet.sample_rdd(train_rdd)
     if validation_data is None:
         val_feature_set = None
     else:
         assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                           "SparkXShards"
         val_feature_set = FeatureSet.sample_rdd(validation_data.rdd.flatMap(xshard_to_sample))
     return train_feature_set, val_feature_set
示例#4
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            labels_cols=None,
            validation_data=None,
            validation_metrics=None,
            checkpoint_trigger=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if isinstance(data, SparkXShards):
            train_rdd = data.rdd.flatMap(to_sample)
            train_feature_set = FeatureSet.sample_rdd(train_rdd)
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                                  "SparkXShards"
                val_feature_set = FeatureSet.sample_rdd(
                    validation_data.rdd.flatMap(to_sample))

            self.estimator.train(train_feature_set, self.loss, end_trigger,
                                 checkpoint_trigger, val_feature_set,
                                 validation_metrics, batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, DataLoader) or callable(data), \
                    "validation_data should be a pytorch DataLoader or a callable data_creator"
                val_feature_set = FeatureSet.pytorch_dataloader(
                    validation_data)

            self.estimator.train_minibatch(train_feature_set, self.loss,
                                           end_trigger, checkpoint_trigger,
                                           val_feature_set, validation_metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)
        return self
示例#5
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 label_cols=None,
                 validation_metrics=None):
        """
        Evaluate model.

        :param data: data: evaluation data. It can be an XShards, Spark Dataframe,
               PyTorch DataLoader and PyTorch DataLoader creator function.
               If data is an XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of
               numpy arrays.
        :param batch_size: Batch size used for evaluation. Only used when data is a SparkXShard.
        :param feature_cols: Feature column name(s) of data. Only used when data
               is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is
               a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param validation_metrics: Orca validation metrics to be computed on validation_data.
        :return: validation results.
        """
        from zoo.orca.data.utils import xshard_to_sample

        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataFrame):
            schema = data.schema
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.map(lambda row: row_to_sample(
                    row, schema, feature_cols, label_cols)))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            result = self.estimator.evaluate_minibatch(val_feature_set,
                                                       self.metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
        return bigdl_metric_results_to_dict(result)
示例#6
0
    def _handle_dataframe(self, data, validation_data, feature_cols, label_cols):
        schema = data.schema
        train_rdd = data.rdd.map(lambda row: row_to_sample(row, schema, feature_cols, label_cols))
        train_feature_set = FeatureSet.sample_rdd(train_rdd)
        if validation_data is None:
            val_feature_set = None
        else:
            assert isinstance(validation_data, DataFrame), "validation_data should also be a " \
                                                           "DataFrame"
            val_feature_set = FeatureSet.sample_rdd(validation_data.rdd.map(
                lambda row: row_to_sample(row, schema, feature_cols, label_cols)))

        return train_feature_set, val_feature_set
示例#7
0
 def get_training_data(self):
     sample_rdd = self.rdd.map(
         lambda t: Sample.from_ndarray(nest.flatten(t), np.array([0.0])))
     fs = FeatureSet.sample_rdd(sample_rdd,
                                sequential_order=self.sequential_order,
                                shuffle=self.shuffle)
     return fs
示例#8
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 label_cols=None):
        """
        Evaluate model.

        :param data: validation data. It can be XShards, each partition is a dictionary of
        {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param batch_size: Batch size used for validation. Default: 32.
        :param feature_cols: (Not supported yet) Feature column name(s) of data. Only used when
        data is a Spark  DataFrame. Default: None.
        :param label_cols: (Not supported yet) Label column name(s) of data. Only used when data
        is a Spark DataFrame. Default: None.
        :return:
        """
        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, DataFrame):
            raise NotImplementedError
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import xshard_to_sample
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)

        return bigdl_metric_results_to_dict(result)
示例#9
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 labels_cols=None,
                 validation_metrics=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics

        assert data is not None, "validation data shouldn't be None"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)

        if isinstance(data, SparkXShards):
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(to_sample))
            return self.estimator.evaluate(val_feature_set, validation_metrics,
                                           batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            return self.estimator.evaluate_minibatch(val_feature_set,
                                                     validation_metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
示例#10
0
    def test_estimator_train(self):
        batch_size = 8
        epoch_num = 5

        images, labels = TestEstimator._generate_image_data(data_num=8,
                                                            img_shape=(3, 224,
                                                                       224))

        image_rdd = self.sc.parallelize(images)
        labels = self.sc.parallelize(labels)

        sample_rdd = image_rdd.zip(labels).map(
            lambda img_label: zoo.common.Sample.from_ndarray(
                img_label[0], img_label[1]))

        data_set = FeatureSet.sample_rdd(sample_rdd)

        model = TestEstimator._create_cnn_model()

        optim_method = SGD(learningrate=0.01)

        estimator = Estimator(model, optim_method, "")
        estimator.set_constant_gradient_clipping(0.1, 1.2)
        estimator.train(train_set=data_set,
                        criterion=ClassNLLCriterion(),
                        end_trigger=MaxEpoch(epoch_num),
                        checkpoint_trigger=EveryEpoch(),
                        validation_set=data_set,
                        validation_method=[Top1Accuracy()],
                        batch_size=batch_size)
        predict_result = model.predict(sample_rdd)
        assert (predict_result.count(), 8)
示例#11
0
 def get_training_data(self):
     sample_rdd = self.text_set.get_samples().map(
         lambda sample: Sample.from_jtensor(
             features=sample.features + sample.labels,
             labels=JTensor.from_ndarray(np.array([0.0]))))
     return FeatureSet.sample_rdd(sample_rdd,
                                  sequential_order=self.sequential_order,
                                  shuffle=self.shuffle)
示例#12
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols="features",
                 label_cols="label"):
        """
        Evaluate model.

        :param data: validation data. It can be XShardsor or Spark DataFrame, each partition is
               a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array
               or a list of numpy arrays.
        :param batch_size: Batch size used for validation. Default: 32.
        :param feature_cols: (Not supported yet) Feature column name(s) of data. Only used when
               data is a Spark  DataFrame. Default: None.
        :param label_cols: (Not supported yet) Label column name(s) of data. Only used when data
               is a Spark DataFrame. Default: None.
        :return:
        """
        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, _, feature_cols = \
                    BigDLEstimator._combine_cols(data, [feature_cols], col_name="features")

            if isinstance(label_cols, list):
                data, _, label_cols = \
                    BigDLEstimator._combine_cols(data, label_cols, col_name="label")

            self.nn_estimator._setNNBatchSize(batch_size)._setNNFeaturesCol(feature_cols) \
                ._setNNLabelCol(label_cols)

            self.nn_estimator.setValidation(None, None, self.metrics,
                                            batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.optim.optimizer import TrainSummary
                from bigdl.optim.optimizer import ValidationSummary
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.app_name)
                self.nn_estimator.setValidationSummary(val_summary)

            result = self.nn_estimator._eval(data)

        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import xshard_to_sample
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)

        return bigdl_metric_results_to_dict(result)
示例#13
0
 def get_validation_data(self):
     if self.val_rdd is not None:
         sample_rdd = self.val_rdd.map(lambda t: Sample.from_ndarray(
             nest.flatten(t), np.array([0.0])))
         return FeatureSet.sample_rdd(
             sample_rdd,
             sequential_order=self.sequential_order,
             shuffle=self.shuffle)
     return None
示例#14
0
 def get_validation_data(self):
     if self.val_rdd is not None:
         sample_rdd = self.val_rdd.map(lambda t: Sample.from_ndarray(
             nest.flatten(t), np.array([0.0])))
         fs = FeatureSet.sample_rdd(sample_rdd,
                                    sequential_order=self.sequential_order,
                                    shuffle=self.shuffle)
         fs = fs.transform(SampleToMiniBatch(self.batch_size))
         return fs
     return None
示例#15
0
def get_featureset(x, y, shuffle=True):
    x = np.split(x.data.numpy(), x.shape[0])
    y = np.split(y.data.numpy(), y.shape[0])
    print(x[0].shape)
    print(y[0].shape)
    samples = [
        Sample.from_ndarray(np.squeeze(x[i]), np.squeeze(y[i]))
        for i in range(len(x))
    ]
    sample_rdd = sc.parallelize(samples)
    return FeatureSet.sample_rdd(sample_rdd, shuffle=shuffle)
示例#16
0
 def get_validation_data(self):
     if self.validation_text_set is not None:
         sample_rdd = self.validation_text_set.get_samples().map(
             lambda sample: Sample.from_jtensor(
                 features=sample.features + sample.labels,
                 labels=JTensor.from_ndarray(np.array([0.0]))))
         return FeatureSet.sample_rdd(
             sample_rdd,
             sequential_order=self.sequential_order,
             shuffle=self.shuffle)
     return None
示例#17
0
 def get_validation_data(self):
     if self.validation_text_set is not None:
         sample_rdd = self.validation_text_set.get_samples().map(
             lambda sample: Sample.from_jtensor(
                 features=sample.features + sample.labels,
                 labels=JTensor.from_ndarray(np.array([0.0]))))
         fs = FeatureSet.sample_rdd(sample_rdd,
                                    sequential_order=self.sequential_order,
                                    shuffle=self.shuffle)
         fs = fs.transform(SampleToMiniBatch(self.batch_size))
         return fs
     return None
示例#18
0
    def evaluate(self, data, batch_size=32, feature_cols=None, label_cols=None):
        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, DataFrame):
            raise NotImplementedError
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import xshard_to_sample
            from zoo.orca.learn.metrics import Metrics

            val_feature_set = FeatureSet.sample_rdd(data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics, batch_size)
        else:
            raise ValueError("Data should be XShards or Spark DataFrame, but get " +
                             data.__class__.__name__)

        return bigdl_metric_results_to_dict(result)
示例#19
0
    def evaluate(self, data, validation_methods=None, batch_size=32):
        assert data is not None, "validation data shouldn't be None"

        if isinstance(data, DataFrame):
            raise NotImplementedError
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import to_sample
            from zoo.orca.learn.metrics import Metrics

            validation_methods = Metrics.convert_metrics_list(
                validation_methods)
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(to_sample))
            return self.estimator.evaluate(val_feature_set, validation_methods,
                                           batch_size)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
示例#20
0
    def fit(self,
            data,
            epochs,
            batch_size=32,
            feature_cols="features",
            label_cols="label",
            caching_sample=True,
            validation_data=None,
            validation_trigger=None,
            checkpoint_trigger=None):
        """
        Train this BigDL model with train data.

        :param data: train data. It can be XShards or Spark DataFrame.
        If data is XShards, each partition is a dictionary of  {'x': feature,
        'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param epochs: Number of epochs to train the model.
        :param batch_size: Batch size used for training. Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
        DataFrame. Default: "features".
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame.
        Default: "label".
        :param caching_sample: whether to cache the Samples after preprocessing. Default: True
        :param validation_data: Validation data. XShards and Spark DataFrame are supported.
        If data is XShards, each partition is a dictionary of  {'x': feature,
        'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param validation_trigger: Orca Trigger to trigger validation computation.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return:
        """
        from zoo.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if validation_data is not None:
            assert self.metrics is not None, \
                "You should provide metrics when creating this estimator if you provide " \
                "validation_data."

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, validation_data, feature_cols = \
                    BigDLEstimator._combine_cols(data, feature_cols, col_name="features",
                                                 val_data=validation_data)

            if isinstance(label_cols, list):
                data, validation_data, label_cols = \
                    BigDLEstimator._combine_cols(data, label_cols, col_name="label",
                                                 val_data=validation_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs)\
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols)\
                .setLabelCol(label_cols)

            if validation_data is not None:
                assert isinstance(validation_data, DataFrame), \
                    "validation_data should be a spark DataFrame."
                assert validation_trigger is not None, \
                    "You should provide validation_trigger if you provide validation_data."
                validation_trigger = Trigger.convert_trigger(
                    validation_trigger)
                self.nn_estimator.setValidation(validation_trigger,
                                                validation_data, self.metrics,
                                                batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.optim.optimizer import TrainSummary
                from bigdl.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.app_name)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import xshard_to_sample

            end_trigger = MaxEpoch(epochs)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(xshard_to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if validation_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(validation_data, SparkXShards), \
                        "validation_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        validation_data.rdd.flatMap(xshard_to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     self.metrics, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self
示例#21
0
 def get_validation_data(self):
     return FeatureSet.sample_rdd(self.validation_rdd,
                                  sequential_order=self.sequential_order,
                                  shuffle=self.shuffle)
示例#22
0
 def get_training_data(self):
     return FeatureSet.sample_rdd(self.train_rdd,
                                  sequential_order=self.sequential_order,
                                  shuffle=self.shuffle)
示例#23
0
    def fit(self,
            data,
            epochs,
            feature_cols="features",
            labels_cols="label",
            batch_size=32,
            caching_sample=True,
            val_data=None,
            val_trigger=None,
            val_methods=None,
            checkpoint_trigger=None):
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, val_data, feature_cols = \
                    BigDLEstimatorWrapper._combine_cols(data, feature_cols, col_name="features",
                                                        val_data=val_data)

            if isinstance(labels_cols, list):
                data, val_data, labels_cols = \
                    BigDLEstimatorWrapper._combine_cols(data, labels_cols, col_name="label",
                                                        val_data=val_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs)\
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols)\
                .setLabelCol(labels_cols)

            if val_data is not None:
                assert isinstance(
                    val_data,
                    DataFrame), "val_data should be a spark DataFrame."
                assert val_trigger is not None and val_methods is not None, \
                    "You should provide val_trigger and val_methods if you provide val_data."
                val_trigger = Trigger.convert_trigger(val_trigger)
                val_methods = Metrics.convert_metrics_list(val_methods)
                self.nn_estimator.setValidation(val_trigger, val_data,
                                                val_methods, batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.optim.optimizer import TrainSummary
                from bigdl.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.log_dir)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import to_sample

            end_trigger = MaxEpoch(epochs)
            val_methods = Metrics.convert_metrics_list(val_methods)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if val_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(
                        val_data, SparkXShards), "val_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        val_data.rdd.flatMap(to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     val_methods, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self