Example #1
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            labels_cols=None,
            validation_data=None,
            validation_metrics=None,
            checkpoint_trigger=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if isinstance(data, SparkXShards):
            train_rdd = data.rdd.flatMap(to_sample)
            train_feature_set = FeatureSet.sample_rdd(train_rdd)
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                                  "SparkXShards"
                val_feature_set = FeatureSet.sample_rdd(
                    validation_data.rdd.flatMap(to_sample))

            self.estimator.train(train_feature_set, self.loss, end_trigger,
                                 checkpoint_trigger, val_feature_set,
                                 validation_metrics, batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, DataLoader) or callable(data), \
                    "validation_data should be a pytorch DataLoader or a callable data_creator"
                val_feature_set = FeatureSet.pytorch_dataloader(
                    validation_data)

            self.estimator.train_minibatch(train_feature_set, self.loss,
                                           end_trigger, checkpoint_trigger,
                                           val_feature_set, validation_metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)
        return self
Example #2
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            checkpoint_trigger=None):
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if validation_data:
            assert self.metrics is not None, "You should provide metrics when creating this " \
                                             "estimator if you provide validation_data."

        if isinstance(data, SparkXShards):
            train_fset, val_fset = self._handle_xshards(data, validation_data)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataFrame):
            train_fset, val_fset = self._handle_dataframe(
                data, validation_data, feature_cols, label_cols)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_fset, val_fset = self._hanle_data_loader(
                data, validation_data)
            self.estimator.train_minibatch(train_fset, self.loss, end_trigger,
                                           checkpoint_trigger, val_fset,
                                           self.metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)

        return self
Example #3
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            checkpoint_trigger=None):
        """
        Train this torch model with train data.

        :param data: train data. It can be a XShards, Spark Dataframe, PyTorch DataLoader and
               PyTorch DataLoader creator function.
               If data is an XShards, each partition is a dictionary of  {'x': feature,
               'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param epochs: Number of epochs to train the model. Default: 1.
        :param batch_size: Batch size used for training. Only used when data is an XShards.
               Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data
               is a Spark DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is
               a Spark DataFrame. Default: None.
        :param validation_data: Validation data. XShards, PyTorch DataLoader and PyTorch DataLoader
               creator function are supported.
               If data is XShards, each partition is a dictionary of  {'x': feature,
               'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return: The trained estimator object.
        """
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if validation_data:
            assert self.metrics is not None, "You should provide metrics when creating this " \
                                             "estimator if you provide validation_data."

        if isinstance(data, SparkXShards):
            train_fset, val_fset = self._handle_xshards(data, validation_data)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataFrame):
            train_fset, val_fset = self._handle_dataframe(
                data, validation_data, feature_cols, label_cols)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_fset, val_fset = self._hanle_data_loader(
                data, validation_data)
            self.estimator.train_minibatch(train_fset, self.loss, end_trigger,
                                           checkpoint_trigger, val_fset,
                                           self.metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)

        return self
Example #4
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            session_config=None,
            checkpoint_trigger=None,
            auto_shard_files=True):
        """
        Train this keras model with train data.

        :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset.
               If data is XShards, each partition can be Pandas Dataframe or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
               If data is tf.data.Dataset, each element is [feature tensor tuple, label tensor
               tuple]
        :param epochs: number of epochs to train.
        :param batch_size: total batch size for each iteration.
        :param feature_cols: feature column names if train data is Spark DataFrame or XShards
         of Pandas DataFrame.
        :param label_cols: label column names if train data is Spark DataFrame or XShards of
        Pandas DataFrame.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param session_config: tensorflow session configuration for training.
               Should be object of tf.ConfigProto
        :param checkpoint_trigger: when to trigger checkpoint during training.
               Should be a zoo.orca.learn.trigger, like EveryEpoch(), SeveralIteration(
               num_iterations),etc.
        :param auto_shard_files: whether to automatically detect if the dataset is file-based and
               and apply sharding on files, otherwise sharding on records. Default is False.
        """

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in training"
            assert label_cols is not None, \
                "label columns is None; it should not be None in training"

        if isinstance(data, tf.data.Dataset):
            assert isinstance(data.element_spec, tuple), \
                "If data is tf.data.Dataset, each element should be " \
                "(feature tensors, label tensor), where each feature/label tensor can be " \
                "either a single tensor or a tuple of tensors"
            if validation_data is not None:
                assert isinstance(validation_data, tf.data.Dataset), \
                    "train data and validation data should be both tf.data.Dataset"
                assert isinstance(validation_data.element_spec, tuple), \
                    "If validation_data is tf.data.Dataset, each element should be " \
                    "(feature tensors, label tensor), where each feature/label tensor can be " \
                    "either a single tensor or a tuple of tensors"

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                assert feature_cols is not None, \
                    "feature columns is None; it should not be None in training"
                assert label_cols is not None, \
                    "label columns is None; it should not be None in training"
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols, validation_data, "fit")

        if checkpoint_trigger is not None:
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if is_tf_data_dataset(data):
            data = data.map(_standardize_keras_target_data)
            validation_data = validation_data.map(
                _standardize_keras_target_data)

        memory_type = OrcaContext.train_data_store
        dataset = to_dataset(data,
                             batch_size=batch_size,
                             batch_per_thread=-1,
                             validation_data=validation_data,
                             feature_cols=feature_cols,
                             label_cols=label_cols,
                             hard_code_batch_size=False,
                             sequential_order=False,
                             shuffle=True,
                             auto_shard_files=auto_shard_files,
                             memory_type=memory_type)

        self.tf_optimizer = TFOptimizer.from_keras(
            self.model.model,
            dataset,
            model_dir=self.model.model_dir,
            session_config=session_config,
            metrics=self.metrics,
            optimizer=self.optimizer)

        if self.clip_norm:
            self.tf_optimizer.set_gradient_clipping_by_l2_norm(
                clip_norm=self.clip_norm)
        if self.clip_min and self.clip_max:
            self.tf_optimizer.set_constant_gradient_clipping(
                self.clip_min, self.clip_max)

        if self.load_checkpoint:
            self.tf_optimizer.load_checkpoint(self.checkpoint_path,
                                              self.checkpoint_version)

        if self.log_dir and self.app_name:
            self.tf_optimizer.estimator.set_tensorboard(
                self.log_dir, self.app_name)

        self.tf_optimizer.optimize(MaxEpoch(epochs),
                                   checkpoint_trigger=checkpoint_trigger)

        return self
Example #5
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            session_config=None,
            checkpoint_trigger=None,
            auto_shard_files=False,
            feed_dict=None):
        """
        Train this graph model with train data.

        :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset.
               If data is XShards, each partition can be Pandas Dataframe or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
               If data is tf.data.Dataset, each element is a tuple of input tensors.
        :param epochs: number of epochs to train.
        :param batch_size: total batch size for each iteration.
        :param feature_cols: feature column names if train data is Spark DataFrame or XShards
         of Pandas Dataframe.
        :param label_cols: label column names if train data is Spark DataFrame or XShards of
        Pandas Dataframe.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param auto_shard_files: whether to automatically detect if the dataset is file-based and
               and apply sharding on files, otherwise sharding on records. Default is False.
        :param session_config: tensorflow session configuration for training.
               Should be object of tf.ConfigProto
        :param feed_dict: a dictionary. The key is TensorFlow tensor, usually a
               placeholder, the value of the dictionary is a tuple of two elements. The first one of
               the tuple is the value to feed to the tensor in training phase and the second one
               is the value to feed to the tensor in validation phase.
        :param checkpoint_trigger: when to trigger checkpoint during training.
               Should be a zoo.orca.learn.trigger, like EveryEpoch(), SeveralIteration(
               num_iterations),etc.
        """

        assert self.labels is not None, \
            "labels is None; it should not be None in training"
        assert self.loss is not None, \
            "loss is None; it should not be None in training"
        assert self.optimizer is not None, \
            "optimizer is None; it should not be None in training"

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in training"
            assert label_cols is not None, \
                "label columns is None; it should not be None in training"

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                assert feature_cols is not None, \
                    "feature columns is None; it should not be None in training"
                assert label_cols is not None, \
                    "label columns is None; it should not be None in training"
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols, validation_data, "fit")

        if checkpoint_trigger is not None:
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        memory_type = OrcaContext.train_data_store
        dataset = to_dataset(data,
                             batch_size=batch_size,
                             batch_per_thread=-1,
                             validation_data=validation_data,
                             feature_cols=feature_cols,
                             label_cols=label_cols,
                             hard_code_batch_size=False,
                             sequential_order=False,
                             shuffle=True,
                             auto_shard_files=auto_shard_files,
                             memory_type=memory_type)

        if feed_dict is not None:
            tensor_with_value = {
                key: (value[0], value[1])
                for key, value in feed_dict.items()
            }
        else:
            tensor_with_value = None

        if self.use_bigdl_optim:
            self.tf_optimizer = TFOptimizer.from_loss(
                self.loss,
                self.optimizer,
                session=self.sess,
                inputs=(self.inputs, self.labels),
                dataset=dataset,
                clip_norm=self.clip_norm,
                clip_value=self.clip_value,
                metrics=self.metrics,
                tensor_with_value=tensor_with_value,
                session_config=session_config,
                model_dir=self.model_dir,
                updates=self.updates)
        else:

            self.tf_optimizer = TFOptimizer.from_train_op(
                train_op=self.train_op,
                loss=self.loss,
                inputs=self.inputs,
                labels=self.labels,
                dataset=dataset,
                metrics=self.metrics,
                updates=self.updates,
                sess=self.sess,
                tensor_with_value=tensor_with_value,
                session_config=session_config,
                model_dir=self.model_dir)

        if self.load_checkpoint:
            self.tf_optimizer.load_checkpoint(self.checkpoint_path,
                                              self.checkpoint_version)

        if self.log_dir and self.app_name:
            self.tf_optimizer.estimator.set_tensorboard(
                self.log_dir, self.app_name)

        self.tf_optimizer.optimize(end_trigger=MaxEpoch(epochs),
                                   checkpoint_trigger=checkpoint_trigger)
        return self
Example #6
0
    def fit(self,
            data,
            epochs,
            batch_size=32,
            feature_cols="features",
            label_cols="label",
            caching_sample=True,
            validation_data=None,
            validation_trigger=None,
            checkpoint_trigger=None):
        """
        Train this BigDL model with train data.

        :param data: train data. It can be XShards or Spark DataFrame.
        If data is XShards, each partition is a dictionary of  {'x': feature,
        'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param epochs: Number of epochs to train the model.
        :param batch_size: Batch size used for training. Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
        DataFrame. Default: "features".
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame.
        Default: "label".
        :param caching_sample: whether to cache the Samples after preprocessing. Default: True
        :param validation_data: Validation data. XShards and Spark DataFrame are supported.
        If data is XShards, each partition is a dictionary of  {'x': feature,
        'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param validation_trigger: Orca Trigger to trigger validation computation.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return:
        """
        from zoo.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if validation_data is not None:
            assert self.metrics is not None, \
                "You should provide metrics when creating this estimator if you provide " \
                "validation_data."

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, validation_data, feature_cols = \
                    BigDLEstimator._combine_cols(data, feature_cols, col_name="features",
                                                 val_data=validation_data)

            if isinstance(label_cols, list):
                data, validation_data, label_cols = \
                    BigDLEstimator._combine_cols(data, label_cols, col_name="label",
                                                 val_data=validation_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs)\
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols)\
                .setLabelCol(label_cols)

            if validation_data is not None:
                assert isinstance(validation_data, DataFrame), \
                    "validation_data should be a spark DataFrame."
                assert validation_trigger is not None, \
                    "You should provide validation_trigger if you provide validation_data."
                validation_trigger = Trigger.convert_trigger(
                    validation_trigger)
                self.nn_estimator.setValidation(validation_trigger,
                                                validation_data, self.metrics,
                                                batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.optim.optimizer import TrainSummary
                from bigdl.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.app_name)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import xshard_to_sample

            end_trigger = MaxEpoch(epochs)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(xshard_to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if validation_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(validation_data, SparkXShards), \
                        "validation_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        validation_data.rdd.flatMap(xshard_to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     self.metrics, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self
Example #7
0
    def fit(self,
            data,
            epochs,
            feature_cols="features",
            labels_cols="label",
            batch_size=32,
            caching_sample=True,
            val_data=None,
            val_trigger=None,
            val_methods=None,
            checkpoint_trigger=None):
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, val_data, feature_cols = \
                    BigDLEstimatorWrapper._combine_cols(data, feature_cols, col_name="features",
                                                        val_data=val_data)

            if isinstance(labels_cols, list):
                data, val_data, labels_cols = \
                    BigDLEstimatorWrapper._combine_cols(data, labels_cols, col_name="label",
                                                        val_data=val_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs)\
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols)\
                .setLabelCol(labels_cols)

            if val_data is not None:
                assert isinstance(
                    val_data,
                    DataFrame), "val_data should be a spark DataFrame."
                assert val_trigger is not None and val_methods is not None, \
                    "You should provide val_trigger and val_methods if you provide val_data."
                val_trigger = Trigger.convert_trigger(val_trigger)
                val_methods = Metrics.convert_metrics_list(val_methods)
                self.nn_estimator.setValidation(val_trigger, val_data,
                                                val_methods, batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.optim.optimizer import TrainSummary
                from bigdl.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.log_dir)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import to_sample

            end_trigger = MaxEpoch(epochs)
            val_methods = Metrics.convert_metrics_list(val_methods)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if val_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(
                        val_data, SparkXShards), "val_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        val_data.rdd.flatMap(to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     val_methods, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self
Example #8
0
    def fit(self,
            data,
            epochs=1,
            batch_size=None,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            checkpoint_trigger=None):
        """
        Train this torch model with train data.

        :param data: train data. It can be a XShards, Spark Dataframe, PyTorch DataLoader and
               PyTorch DataLoader creator function that takes config and batch_size as argument and
               returns a PyTorch DataLoader for training.
               If data is an XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or
               a list of numpy arrays.
        :param epochs: Number of epochs to train the model. Default: 1.
        :param batch_size: Batch size used for training. Only used when data is an XShards.
               Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data
               is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is
               a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param validation_data: Validation data. XShards, PyTorch DataLoader and PyTorch DataLoader
               creator function are supported.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of
               numpy arrays.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return: The trained estimator object.
        """
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        if isinstance(data, DataLoader):
            assert batch_size is None and data.batch_size > 0, "When using PyTorch Dataloader as " \
                                                               "input, you need to specify the " \
                                                               "batch size in DataLoader and " \
                                                               "don't specify batch_size " \
                                                               "in the fit method."
        else:
            assert batch_size is not None and batch_size > 0, "batch_size should be greater than 0"
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if validation_data:
            assert self.metrics is not None, "You should provide metrics when creating this " \
                                             "estimator if you provide validation_data."

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data,
                    feature_cols,
                    label_cols,
                    validation_data,
                    mode="fit")
            train_fset, val_fset = self._handle_xshards(data, validation_data)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataFrame):
            train_fset, val_fset = self._handle_dataframe(
                data, validation_data, feature_cols, label_cols)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataLoader) or callable(data) or isinstance(
                data, types.FunctionType):
            if isinstance(data, types.FunctionType):
                data, validation_data = data(self.config,
                                             batch_size), validation_data(
                                                 self.config, batch_size)
            train_fset, val_fset = self._handle_data_loader(
                data, validation_data)
            self.estimator.train_minibatch(train_fset, self.loss, end_trigger,
                                           checkpoint_trigger, val_fset,
                                           self.metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)

        return self
Example #9
0
    def fit(
        self,
        data,
        epochs=1,
        batch_size=32,
        feature_cols=None,
        labels_cols=None,
        validation_data=None,
        hard_code_batch_size=False,
        session_config=None,
        checkpoint_trigger=None,
        auto_shard_files=True,
    ):
        """
        Train this keras model with train data.
        :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset.
        If data is XShards, each element needs to be {'x': a feature numpy array
         or a tuple of feature numpy arrays, 'y': a label numpy array or a tuple of
         label numpy arrays}
        If data is tf.data.Dataset, each element is [feature tensor tuple, label tensor tuple]
        :param epochs: number of epochs to train.
        :param batch_size: total batch size for each iteration.
        :param feature_cols: feature column names if train data is Spark DataFrame.
        :param labels_cols: label column names if train data is Spark DataFrame.
        :param validation_data: validation data. Validation data type should be the same
        as train data.
        :param hard_code_batch_size: whether hard code batch size for training. Default is False.
        :param session_config: tensorflow session configuration for training.
        Should be object of tf.ConfigProto
        :param checkpoint_trigger: when to trigger checkpoint during training.
        Should be a zoo.orca.learn.trigger, like EveryEpoch(), SeveralIteration(num_iterations),etc.
        """

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in training"
            assert labels_cols is not None, \
                "label columns is None; it should not be None in training"

        if isinstance(data, tf.data.Dataset):
            assert isinstance(data.element_spec, tuple), \
                "If data is tf.data.Dataset, each element should be " \
                "(feature tensors, label tensor), where each feature/label tensor can be " \
                "either a single tensor or a tuple of tensors"
            if validation_data is not None:
                assert isinstance(validation_data, tf.data.Dataset), \
                    "train data and validation data should be both tf.data.Dataset"
                assert isinstance(validation_data.element_spec, tuple), \
                    "If validation_data is tf.data.Dataset, each element should be " \
                    "(feature tensors, label tensor), where each feature/label tensor can be " \
                    "either a single tensor or a tuple of tensors"

        if checkpoint_trigger is not None:
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if is_tf_data_dataset(data):
            data = data.map(_standardize_keras_target_data)
            validation_data = validation_data.map(
                _standardize_keras_target_data)

        dataset = to_dataset(data,
                             batch_size=batch_size,
                             batch_per_thread=-1,
                             validation_data=validation_data,
                             feature_cols=feature_cols,
                             labels_cols=labels_cols,
                             hard_code_batch_size=hard_code_batch_size,
                             sequential_order=False,
                             shuffle=True,
                             auto_shard_files=auto_shard_files)

        if isinstance(dataset, TFNdarrayDataset):
            dataset = _standarize_feature_label_dataset(
                dataset, self.model.model)

        self.tf_optimizer = TFOptimizer.from_keras(
            self.model.model,
            dataset,
            model_dir=self.model.model_dir,
            session_config=session_config,
            metrics=self.metrics,
            optimizer=self.optimizer)

        if self.clip_norm:
            self.tf_optimizer.set_gradient_clipping_by_l2_norm(
                clip_norm=self.clip_norm)
        if self.clip_min and self.clip_max:
            self.tf_optimizer.set_constant_gradient_clipping(
                self.clip_min, self.clip_max)

        if self.load_checkpoint:
            self.tf_optimizer.load_checkpoint(self.checkpoint_path,
                                              self.checkpoint_version)

        if self.log_dir and self.app_name:
            self.tf_optimizer.estimator.set_tensorboard(
                self.log_dir, self.app_name)

        self.tf_optimizer.optimize(MaxEpoch(epochs),
                                   checkpoint_trigger=checkpoint_trigger)

        return self