コード例 #1
0
ファイル: test_torch.py プロジェクト: EmiCareOfCell44/BigDL
    def test_train_model_with_bn(self):
        class SimpleTorchModel(nn.Module):
            def __init__(self):
                super(SimpleTorchModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.bn1 = torch.nn.BatchNorm1d(4)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = self.bn1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        torch_model = SimpleTorchModel()
        loss_fn = torch.nn.BCELoss()
        az_model = TorchModel.from_pytorch(torch_model)
        zoo_loss = TorchLoss.from_pytorch(loss_fn)
        inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9], [1, 9]])
        targets = torch.Tensor([[0], [0], [0], [1], [1], [1]])
        train_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2)
        train_featureset = FeatureSet.pytorch_dataloader(train_loader)
        val_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2)
        val_featureset = FeatureSet.pytorch_dataloader(val_loader)

        zooOptimizer = Adam()
        estimator = Estimator(az_model, optim_methods=zooOptimizer)
        estimator.train_minibatch(train_featureset,
                                  zoo_loss,
                                  end_trigger=MaxEpoch(4),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=val_featureset,
                                  validation_method=[Accuracy()])

        trained_model = az_model.to_pytorch()
コード例 #2
0
    def test_control_inputs(self):

        features = np.random.randn(20, 10)
        labels = np.random.randint(0, 10, size=[20])
        with tf.Graph().as_default():
            dataset = TFDataset.from_ndarrays((features, labels),
                                              batch_size=4,
                                              val_tensors=(features, labels))
            is_training = tf.placeholder(dtype=tf.bool, shape=())
            feature_tensor, label_tensor = dataset.tensors
            features = tf.layers.dense(feature_tensor, 8)
            features = tf.layers.dropout(features, training=is_training)
            output = tf.layers.dense(features, 10)
            loss = tf.reduce_mean(
                tf.losses.sparse_softmax_cross_entropy(logits=output,
                                                       labels=label_tensor))
            optimizer = TFOptimizer.from_loss(
                loss,
                Adam(),
                val_outputs=[output],
                val_labels=[label_tensor],
                val_method=Accuracy(),
                tensor_with_value={is_training: (True, False)},
                metrics={"loss": loss})
            optimizer.optimize(end_trigger=MaxEpoch(1))
            optimizer.sess.close()
コード例 #3
0
    def _fit_distributed(self, dataset, epochs, **kwargs):
        self.tf_optimizer = TFOptimizer.from_keras(self.model,
                                                   dataset,
                                                   model_dir=self.model_dir,
                                                   metrics=self.metric_tensors,
                                                   optimizer=self.optimizer,
                                                   **kwargs)

        self.tf_optimizer.optimize(MaxEpoch(epochs))
コード例 #4
0
    def test_tf_optimizer_with_sparse_gradient(self):
        ids = np.random.randint(0, 10, size=[40])
        labels = np.random.randint(0, 5, size=[40])
        id_rdd = self.sc.parallelize(ids)
        label_rdd = self.sc.parallelize(labels)
        training_rdd = id_rdd.zip(label_rdd).map(lambda x: [x[0], x[1]])
        with tf.Graph().as_default():
            dataset = TFDataset.from_rdd(training_rdd,
                                         names=["ids", "labels"],
                                         shapes=[[], []],
                                         types=[tf.int32, tf.int32],
                                         batch_size=8)
            id_tensor, label_tensor = dataset.tensors
            embedding_table = tf.get_variable(name="word_embedding",
                                              shape=[10, 5])

            embedding = tf.nn.embedding_lookup(embedding_table, id_tensor)
            loss = tf.reduce_mean(
                tf.losses.sparse_softmax_cross_entropy(logits=embedding,
                                                       labels=label_tensor))
            optimizer = TFOptimizer.from_loss(loss, Adam(1e-3))
            optimizer.optimize(end_trigger=MaxEpoch(1))
            optimizer.sess.close()
コード例 #5
0
ファイル: tf_optimizer.py プロジェクト: EmiCareOfCell44/BigDL
    def optimize(self, end_trigger=None, checkpoint_trigger=None):
        """
        Run the training loop of the this optimizer
        :param end_trigger: BigDL's Trigger to indicate when to stop the training.
        :param checkpoint_trigger: When to save a checkpoint and evaluate model.
        """
        if end_trigger is None:
            end_trigger = MaxEpoch(1)

        if checkpoint_trigger is None:
            checkpoint_trigger = EveryEpoch()

        if isinstance(self.train_data, FeatureSet):
            if self.train_data.value.getNumOfSlice() != 1:
                if isinstance(checkpoint_trigger, EveryEpoch):
                    checkpoint_trigger = ZEveryEpoch()
                elif not isinstance(checkpoint_trigger, ZooTrigger):
                    raise Exception(
                        "Please use a trigger defined in bigdl.dllib.utils.triggers"
                    )

        if self.tf_model.val_methods and self.val_data is not None:
            self.estimator.train_minibatch(
                train_set=self.train_data,
                criterion=self.tf_model.criterion,
                end_trigger=end_trigger,
                checkpoint_trigger=checkpoint_trigger,
                validation_set=self.val_data,
                validation_method=self.tf_model.val_methods)
        else:
            self.estimator.train_minibatch(
                train_set=self.train_data,
                criterion=self.tf_model.criterion,
                end_trigger=end_trigger,
                checkpoint_trigger=checkpoint_trigger)

        self.tf_model.training_helper_layer.get_weights_to_python()
コード例 #6
0
    def test_tf_optimizer_metrics(self):

        features = np.random.randn(20, 10)
        labels = np.random.randint(0, 10, size=[20])
        with tf.Graph().as_default():
            dataset = TFDataset.from_ndarrays((features, labels),
                                              batch_size=4,
                                              val_tensors=(features, labels))
            feature_tensor, label_tensor = dataset.tensors
            features = tf.layers.dense(feature_tensor, 8)
            output = tf.layers.dense(features, 10)
            loss = tf.reduce_mean(
                tf.losses.sparse_softmax_cross_entropy(logits=output,
                                                       labels=label_tensor))
            optimizer = TFOptimizer.from_loss(loss, {
                "dense/": Adam(1e-3),
                "dense_1/": SGD(0.0)
            },
                                              val_outputs=[output],
                                              val_labels=[label_tensor],
                                              val_method=Accuracy(),
                                              metrics={"loss": loss})
            initial_weights = optimizer.tf_model.training_helper_layer.get_weights(
            )
            optimizer.optimize(end_trigger=MaxEpoch(1))
            updated_weights = optimizer.tf_model.training_helper_layer.get_weights(
            )
            for i in [
                    0, 1
            ]:  # weights and bias combined with "dense/" should be updated
                assert not np.allclose(initial_weights[i], updated_weights[i])
            for i in [
                    2, 3
            ]:  # weights and bias combined with "dense_1" should be unchanged
                assert np.allclose(initial_weights[i], updated_weights[i])
            optimizer.sess.close()
コード例 #7
0
    def fit(self,
            data,
            epochs=1,
            batch_size=None,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            checkpoint_trigger=None):
        """
        Train this torch model with train data.

        :param data: train data. It can be a XShards, Spark Dataframe, PyTorch DataLoader and
               PyTorch DataLoader creator function that takes config and batch_size as argument and
               returns a PyTorch DataLoader for training.
               If data is an XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or
               a list of numpy arrays.
        :param epochs: Number of epochs to train the model. Default: 1.
        :param batch_size: Batch size used for training. Only used when data is an XShards.
               Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data
               is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is
               a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param validation_data: Validation data. XShards, PyTorch DataLoader and PyTorch DataLoader
               creator function are supported.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of
               numpy arrays.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return: The trained estimator object.
        """
        from bigdl.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        if isinstance(data, DataLoader):
            assert batch_size is None and data.batch_size > 0, "When using PyTorch Dataloader as " \
                                                               "input, you need to specify the " \
                                                               "batch size in DataLoader and " \
                                                               "don't specify batch_size " \
                                                               "in the fit method."
        else:
            assert batch_size is not None and batch_size > 0, "batch_size should be greater than 0"
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if validation_data:
            assert self.metrics is not None, "You should provide metrics when creating this " \
                                             "estimator if you provide validation_data."

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data,
                    feature_cols,
                    label_cols,
                    validation_data,
                    mode="fit")
            train_fset, val_fset = self._handle_xshards(data, validation_data)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataFrame):
            train_fset, val_fset = self._handle_dataframe(
                data, validation_data, feature_cols, label_cols)
            self.estimator.train(train_fset, self.loss, end_trigger,
                                 checkpoint_trigger, val_fset, self.metrics,
                                 batch_size)
        elif isinstance(data, DataLoader) or callable(data) or isinstance(
                data, types.FunctionType):
            if isinstance(data, types.FunctionType):
                data, validation_data = data(self.config,
                                             batch_size), validation_data(
                                                 self.config, batch_size)
            train_fset, val_fset = self._handle_data_loader(
                data, validation_data)
            self.estimator.train_minibatch(train_fset, self.loss, end_trigger,
                                           checkpoint_trigger, val_fset,
                                           self.metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)

        return self
コード例 #8
0
    def test_checkpoint(self):

        features = np.random.randn(20, 10)
        labels = np.random.randint(0, 10, size=[20])
        with tf.Graph().as_default():
            dataset = TFDataset.from_ndarrays((features, labels),
                                              batch_size=4,
                                              val_tensors=(features, labels))
            feature_tensor, label_tensor = dataset.tensors
            features = tf.layers.dense(feature_tensor, 8)
            output = tf.layers.dense(features, 10)
            loss = tf.reduce_mean(
                tf.losses.sparse_softmax_cross_entropy(logits=output,
                                                       labels=label_tensor))
            model_dir = tempfile.mkdtemp()
            try:
                optimizer = TFOptimizer.from_loss(loss,
                                                  Adam(),
                                                  val_outputs=[output],
                                                  val_labels=[label_tensor],
                                                  val_method=Accuracy(),
                                                  metrics={"loss": loss},
                                                  model_dir=model_dir)
                optimizer.optimize(end_trigger=MaxEpoch(1))

                first_weights = optimizer.sess.run(tf.trainable_variables()[0])
                import re
                ckpt_path = None
                versions = []
                for (root, dirs, files) in os.walk(model_dir, topdown=True):
                    temp_versions = []
                    for file_name in files:
                        if re.match("^optimMethod-TFParkTraining\.[0-9]+$",
                                    file_name) is not None:
                            version = int(file_name.split(".")[1])
                            temp_versions.append(version)
                    if temp_versions:
                        ckpt_path = root
                        versions = temp_versions
                        break

                assert ckpt_path is not None, "Cannot fine checkpoint file"
                optimizer.sess.run(
                    tf.global_variables_initializer())  # reset variable
                optimizer_load = TFOptimizer.from_loss(
                    loss,
                    Adam(),
                    session=optimizer.sess,
                    val_outputs=[output],
                    val_labels=[label_tensor],
                    val_method=Accuracy(),
                    metrics={"loss": loss},
                    model_dir=model_dir)
                optimizer_load.load_checkpoint(ckpt_path, max(versions))
                loaded_first_weights_before_train = optimizer.sess.run(
                    tf.trainable_variables()[0])
                assert np.allclose(first_weights,
                                   loaded_first_weights_before_train)
                # max epoch still 1, should not train
                optimizer_load.optimize(end_trigger=MaxEpoch(1))
                loaded_first_weights = optimizer.sess.run(
                    tf.trainable_variables()[0])
                assert np.allclose(first_weights, loaded_first_weights)

                # max epoch increase 1, should train 1 epoch
                optimizer_load.optimize(end_trigger=MaxEpoch(2))
                loaded_first_weights_2 = optimizer.sess.run(
                    tf.trainable_variables()[0])
                assert not np.allclose(first_weights, loaded_first_weights_2)
                optimizer_load.sess.close()
            finally:
                import shutil
                shutil.rmtree(model_dir)
コード例 #9
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dir',
                        default='/tmp/data',
                        metavar='N',
                        help='the folder store mnist data')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=256,
        metavar='N',
        help='input batch size for training per executor(default: 256)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=1000,
        metavar='N',
        help='input batch size for testing per executor(default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 2)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument(
        '--deploy-mode',
        default="local",
        help='supported deploy mode is local, yarn-client, yarn-cluster')

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False)

    # init on yarn when HADOOP_CONF_DIR and ZOO_CONDA_NAME is provided.
    if args.deploy_mode == "local":
        sc = init_orca_context()
    else:
        sc = init_orca_context(cluster_mode=args.deploy_mode,
                               cores=2,
                               memory="2g",
                               num_nodes=4)

    model = Net()
    model.train()
    criterion = nn.NLLLoss()

    adam = torch.optim.Adam(model.parameters(), lr=args.lr)
    zoo_model = TorchModel.from_pytorch(model)
    zoo_criterion = TorchLoss.from_pytorch(criterion)
    zoo_optim = TorchOptim.from_pytorch(adam)
    zoo_estimator = Estimator(zoo_model, optim_methods=zoo_optim)
    train_featureset = FeatureSet.pytorch_dataloader(train_loader)
    test_featureset = FeatureSet.pytorch_dataloader(test_loader)
    from bigdl.dllib.optim.optimizer import MaxEpoch, EveryEpoch
    zoo_estimator.train_minibatch(train_featureset,
                                  zoo_criterion,
                                  end_trigger=MaxEpoch(args.epochs),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=test_featureset,
                                  validation_method=[Accuracy()])
コード例 #10
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            session_config=None,
            checkpoint_trigger=None,
            auto_shard_files=True):
        """
        Train this keras model with train data.

        :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
               If data is tf.data.Dataset, each element is [feature tensor tuple, label tensor
               tuple]
        :param epochs: number of epochs to train.
        :param batch_size: total batch size for each iteration.
        :param feature_cols: feature column names if train data is Spark DataFrame or XShards
               of Pandas DataFrame.
        :param label_cols: label column names if train data is Spark DataFrame or XShards of
               Pandas DataFrame.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param session_config: tensorflow session configuration for training.
               Should be object of tf.ConfigProto
        :param checkpoint_trigger: when to trigger checkpoint during training.
               Should be a bigdl.orca.learn.trigger, like EveryEpoch(), SeveralIteration(
               num_iterations),etc.
        :param auto_shard_files: whether to automatically detect if the dataset is file-based and
               and apply sharding on files, otherwise sharding on records. Default is False.
        """

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in training"
            assert label_cols is not None, \
                "label columns is None; it should not be None in training"

        if isinstance(data, tf.data.Dataset):
            assert isinstance(data.element_spec, tuple), \
                "If data is tf.data.Dataset, each element should be " \
                "(feature tensors, label tensor), where each feature/label tensor can be " \
                "either a single tensor or a tuple of tensors"
            if validation_data is not None:
                assert isinstance(validation_data, tf.data.Dataset), \
                    "train data and validation data should be both tf.data.Dataset"
                assert isinstance(validation_data.element_spec, tuple), \
                    "If validation_data is tf.data.Dataset, each element should be " \
                    "(feature tensors, label tensor), where each feature/label tensor can be " \
                    "either a single tensor or a tuple of tensors"

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                assert feature_cols is not None, \
                    "feature columns is None; it should not be None in training"
                assert label_cols is not None, \
                    "label columns is None; it should not be None in training"
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols, validation_data, "fit")

        if checkpoint_trigger is not None:
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if is_tf_data_dataset(data):
            data = data.map(_standardize_keras_target_data)
            validation_data = validation_data.map(
                _standardize_keras_target_data)

        memory_type = OrcaContext.train_data_store
        dataset = to_dataset(data,
                             batch_size=batch_size,
                             batch_per_thread=-1,
                             validation_data=validation_data,
                             feature_cols=feature_cols,
                             label_cols=label_cols,
                             hard_code_batch_size=False,
                             sequential_order=False,
                             shuffle=True,
                             auto_shard_files=auto_shard_files,
                             memory_type=memory_type)

        self.tf_optimizer = TFOptimizer.from_keras(
            self.model.model,
            dataset,
            model_dir=self.model.model_dir,
            session_config=session_config,
            metrics=self.metrics,
            optimizer=self.optimizer)

        if self.clip_norm:
            self.tf_optimizer.set_gradient_clipping_by_l2_norm(
                clip_norm=self.clip_norm)
        if self.clip_min and self.clip_max:
            self.tf_optimizer.set_constant_gradient_clipping(
                self.clip_min, self.clip_max)

        if self.load_checkpoint:
            self.tf_optimizer.load_checkpoint(self.checkpoint_path,
                                              self.checkpoint_version)

        if self.log_dir and self.app_name:
            self.tf_optimizer.estimator.set_tensorboard(
                self.log_dir, self.app_name)

        self.tf_optimizer.optimize(MaxEpoch(epochs),
                                   checkpoint_trigger=checkpoint_trigger)

        return self
コード例 #11
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            label_cols=None,
            validation_data=None,
            session_config=None,
            checkpoint_trigger=None,
            auto_shard_files=False,
            feed_dict=None):
        """
        Train this graph model with train data.

        :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset.
               If data is XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of
               numpy arrays.
               If data is tf.data.Dataset, each element is a tuple of input tensors.
        :param epochs: number of epochs to train.
        :param batch_size: total batch size for each iteration.
        :param feature_cols: feature column names if train data is Spark DataFrame or XShards
               of Pandas DataFrame.
        :param label_cols: label column names if train data is Spark DataFrame or XShards of
               Pandas DataFrame.
        :param validation_data: validation data. Validation data type should be the same
               as train data.
        :param auto_shard_files: whether to automatically detect if the dataset is file-based and
               and apply sharding on files, otherwise sharding on records. Default is False.
        :param session_config: tensorflow session configuration for training.
               Should be object of tf.ConfigProto
        :param feed_dict: a dictionary. The key is TensorFlow tensor, usually a
               placeholder, the value of the dictionary is a tuple of two elements. The first one of
               the tuple is the value to feed to the tensor in training phase and the second one
               is the value to feed to the tensor in validation phase.
        :param checkpoint_trigger: when to trigger checkpoint during training.
               Should be a bigdl.orca.learn.trigger, like EveryEpoch(), SeveralIteration(
               num_iterations),etc.
        """

        assert self.labels is not None, \
            "labels is None; it should not be None in training"
        assert self.loss is not None, \
            "loss is None; it should not be None in training"
        assert self.optimizer is not None, \
            "optimizer is None; it should not be None in training"

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in training"
            assert label_cols is not None, \
                "label columns is None; it should not be None in training"

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                assert feature_cols is not None, \
                    "feature columns is None; it should not be None in training"
                assert label_cols is not None, \
                    "label columns is None; it should not be None in training"
                data, validation_data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols, validation_data, "fit")

        if checkpoint_trigger is not None:
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        memory_type = OrcaContext.train_data_store
        dataset = to_dataset(data,
                             batch_size=batch_size,
                             batch_per_thread=-1,
                             validation_data=validation_data,
                             feature_cols=feature_cols,
                             label_cols=label_cols,
                             hard_code_batch_size=False,
                             sequential_order=False,
                             shuffle=True,
                             auto_shard_files=auto_shard_files,
                             memory_type=memory_type)

        if feed_dict is not None:
            tensor_with_value = {
                key: (value[0], value[1])
                for key, value in feed_dict.items()
            }
        else:
            tensor_with_value = None

        if self.use_bigdl_optim:
            self.tf_optimizer = TFOptimizer.from_loss(
                self.loss,
                self.optimizer,
                session=self.sess,
                inputs=(self.inputs, self.labels),
                dataset=dataset,
                clip_norm=self.clip_norm,
                clip_value=self.clip_value,
                metrics=self.metrics,
                tensor_with_value=tensor_with_value,
                session_config=session_config,
                model_dir=self.model_dir,
                updates=self.updates)
        else:

            self.tf_optimizer = TFOptimizer.from_train_op(
                train_op=self.train_op,
                loss=self.loss,
                inputs=self.inputs,
                labels=self.labels,
                dataset=dataset,
                metrics=self.metrics,
                updates=self.updates,
                sess=self.sess,
                tensor_with_value=tensor_with_value,
                session_config=session_config,
                model_dir=self.model_dir)

        if self.load_checkpoint:
            self.tf_optimizer.load_checkpoint(self.checkpoint_path,
                                              self.checkpoint_version)

        if self.log_dir and self.app_name:
            self.tf_optimizer.estimator.set_tensorboard(
                self.log_dir, self.app_name)

        self.tf_optimizer.optimize(end_trigger=MaxEpoch(epochs),
                                   checkpoint_trigger=checkpoint_trigger)
        return self
コード例 #12
0
    def fit(self,
            data,
            epochs,
            batch_size=32,
            feature_cols="features",
            label_cols="label",
            caching_sample=True,
            validation_data=None,
            validation_trigger=None,
            checkpoint_trigger=None):
        """
        Train this BigDL model with train data.

        :param data: train data. It can be XShards or Spark DataFrame.
               If data is XShards, each partition is a dictionary of  {'x': feature,
               'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param epochs: Number of epochs to train the model.
        :param batch_size: Batch size used for training. Default: 32.
        :param feature_cols: Feature column name(s) of data. Only used when data is a Spark
               DataFrame. Default: "features".
        :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame.
               Default: "label".
        :param caching_sample: whether to cache the Samples after preprocessing. Default: True
        :param validation_data: Validation data. XShards and Spark DataFrame are supported.
               If data is XShards, each partition is a dictionary of  {'x': feature,
               'y': label}, where feature(label) is a numpy array or a list of numpy arrays.
        :param validation_trigger: Orca Trigger to trigger validation computation.
        :param checkpoint_trigger: Orca Trigger to set a checkpoint.
        :return:
        """
        from bigdl.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if validation_data is not None:
            assert self.metrics is not None, \
                "You should provide metrics when creating this estimator if you provide " \
                "validation_data."

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, validation_data, feature_cols = \
                    BigDLEstimator._combine_cols(data, feature_cols, col_name="features",
                                                 val_data=validation_data)

            if isinstance(label_cols, list):
                data, validation_data, label_cols = \
                    BigDLEstimator._combine_cols(data, label_cols, col_name="label",
                                                 val_data=validation_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs) \
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols) \
                .setLabelCol(label_cols)

            if validation_data is not None:
                assert isinstance(validation_data, DataFrame), \
                    "validation_data should be a spark DataFrame."
                assert validation_trigger is not None, \
                    "You should provide validation_trigger if you provide validation_data."
                validation_trigger = Trigger.convert_trigger(
                    validation_trigger)
                self.nn_estimator.setValidation(validation_trigger,
                                                validation_data, self.metrics,
                                                batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.dllib.optim.optimizer import TrainSummary
                from bigdl.dllib.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.app_name)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from bigdl.orca.data.utils import xshard_to_sample

            end_trigger = MaxEpoch(epochs)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(xshard_to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if validation_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(validation_data, SparkXShards), \
                        "validation_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        validation_data.rdd.flatMap(xshard_to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     self.metrics, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self