def test_train_model_with_bn(self): class SimpleTorchModel(nn.Module): def __init__(self): super(SimpleTorchModel, self).__init__() self.dense1 = nn.Linear(2, 4) self.bn1 = torch.nn.BatchNorm1d(4) self.dense2 = nn.Linear(4, 1) def forward(self, x): x = self.dense1(x) x = self.bn1(x) x = torch.sigmoid(self.dense2(x)) return x torch_model = SimpleTorchModel() loss_fn = torch.nn.BCELoss() az_model = TorchModel.from_pytorch(torch_model) zoo_loss = TorchLoss.from_pytorch(loss_fn) inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9], [1, 9]]) targets = torch.Tensor([[0], [0], [0], [1], [1], [1]]) train_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2) train_featureset = FeatureSet.pytorch_dataloader(train_loader) val_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2) val_featureset = FeatureSet.pytorch_dataloader(val_loader) zooOptimizer = Adam() estimator = Estimator(az_model, optim_methods=zooOptimizer) estimator.train_minibatch(train_featureset, zoo_loss, end_trigger=MaxEpoch(4), checkpoint_trigger=EveryEpoch(), validation_set=val_featureset, validation_method=[Accuracy()]) trained_model = az_model.to_pytorch()
def test_control_inputs(self): features = np.random.randn(20, 10) labels = np.random.randint(0, 10, size=[20]) with tf.Graph().as_default(): dataset = TFDataset.from_ndarrays((features, labels), batch_size=4, val_tensors=(features, labels)) is_training = tf.placeholder(dtype=tf.bool, shape=()) feature_tensor, label_tensor = dataset.tensors features = tf.layers.dense(feature_tensor, 8) features = tf.layers.dropout(features, training=is_training) output = tf.layers.dense(features, 10) loss = tf.reduce_mean( tf.losses.sparse_softmax_cross_entropy(logits=output, labels=label_tensor)) optimizer = TFOptimizer.from_loss( loss, Adam(), val_outputs=[output], val_labels=[label_tensor], val_method=Accuracy(), tensor_with_value={is_training: (True, False)}, metrics={"loss": loss}) optimizer.optimize(end_trigger=MaxEpoch(1)) optimizer.sess.close()
def _fit_distributed(self, dataset, epochs, **kwargs): self.tf_optimizer = TFOptimizer.from_keras(self.model, dataset, model_dir=self.model_dir, metrics=self.metric_tensors, optimizer=self.optimizer, **kwargs) self.tf_optimizer.optimize(MaxEpoch(epochs))
def test_tf_optimizer_with_sparse_gradient(self): ids = np.random.randint(0, 10, size=[40]) labels = np.random.randint(0, 5, size=[40]) id_rdd = self.sc.parallelize(ids) label_rdd = self.sc.parallelize(labels) training_rdd = id_rdd.zip(label_rdd).map(lambda x: [x[0], x[1]]) with tf.Graph().as_default(): dataset = TFDataset.from_rdd(training_rdd, names=["ids", "labels"], shapes=[[], []], types=[tf.int32, tf.int32], batch_size=8) id_tensor, label_tensor = dataset.tensors embedding_table = tf.get_variable(name="word_embedding", shape=[10, 5]) embedding = tf.nn.embedding_lookup(embedding_table, id_tensor) loss = tf.reduce_mean( tf.losses.sparse_softmax_cross_entropy(logits=embedding, labels=label_tensor)) optimizer = TFOptimizer.from_loss(loss, Adam(1e-3)) optimizer.optimize(end_trigger=MaxEpoch(1)) optimizer.sess.close()
def optimize(self, end_trigger=None, checkpoint_trigger=None): """ Run the training loop of the this optimizer :param end_trigger: BigDL's Trigger to indicate when to stop the training. :param checkpoint_trigger: When to save a checkpoint and evaluate model. """ if end_trigger is None: end_trigger = MaxEpoch(1) if checkpoint_trigger is None: checkpoint_trigger = EveryEpoch() if isinstance(self.train_data, FeatureSet): if self.train_data.value.getNumOfSlice() != 1: if isinstance(checkpoint_trigger, EveryEpoch): checkpoint_trigger = ZEveryEpoch() elif not isinstance(checkpoint_trigger, ZooTrigger): raise Exception( "Please use a trigger defined in bigdl.dllib.utils.triggers" ) if self.tf_model.val_methods and self.val_data is not None: self.estimator.train_minibatch( train_set=self.train_data, criterion=self.tf_model.criterion, end_trigger=end_trigger, checkpoint_trigger=checkpoint_trigger, validation_set=self.val_data, validation_method=self.tf_model.val_methods) else: self.estimator.train_minibatch( train_set=self.train_data, criterion=self.tf_model.criterion, end_trigger=end_trigger, checkpoint_trigger=checkpoint_trigger) self.tf_model.training_helper_layer.get_weights_to_python()
def test_tf_optimizer_metrics(self): features = np.random.randn(20, 10) labels = np.random.randint(0, 10, size=[20]) with tf.Graph().as_default(): dataset = TFDataset.from_ndarrays((features, labels), batch_size=4, val_tensors=(features, labels)) feature_tensor, label_tensor = dataset.tensors features = tf.layers.dense(feature_tensor, 8) output = tf.layers.dense(features, 10) loss = tf.reduce_mean( tf.losses.sparse_softmax_cross_entropy(logits=output, labels=label_tensor)) optimizer = TFOptimizer.from_loss(loss, { "dense/": Adam(1e-3), "dense_1/": SGD(0.0) }, val_outputs=[output], val_labels=[label_tensor], val_method=Accuracy(), metrics={"loss": loss}) initial_weights = optimizer.tf_model.training_helper_layer.get_weights( ) optimizer.optimize(end_trigger=MaxEpoch(1)) updated_weights = optimizer.tf_model.training_helper_layer.get_weights( ) for i in [ 0, 1 ]: # weights and bias combined with "dense/" should be updated assert not np.allclose(initial_weights[i], updated_weights[i]) for i in [ 2, 3 ]: # weights and bias combined with "dense_1" should be unchanged assert np.allclose(initial_weights[i], updated_weights[i]) optimizer.sess.close()
def fit(self, data, epochs=1, batch_size=None, feature_cols=None, label_cols=None, validation_data=None, checkpoint_trigger=None): """ Train this torch model with train data. :param data: train data. It can be a XShards, Spark Dataframe, PyTorch DataLoader and PyTorch DataLoader creator function that takes config and batch_size as argument and returns a PyTorch DataLoader for training. If data is an XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of numpy arrays. :param epochs: Number of epochs to train the model. Default: 1. :param batch_size: Batch size used for training. Only used when data is an XShards. Default: 32. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :param validation_data: Validation data. XShards, PyTorch DataLoader and PyTorch DataLoader creator function are supported. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of numpy arrays. :param checkpoint_trigger: Orca Trigger to set a checkpoint. :return: The trained estimator object. """ from bigdl.orca.learn.trigger import Trigger end_trigger = MaxEpoch(epochs) if isinstance(data, DataLoader): assert batch_size is None and data.batch_size > 0, "When using PyTorch Dataloader as " \ "input, you need to specify the " \ "batch size in DataLoader and " \ "don't specify batch_size " \ "in the fit method." else: assert batch_size is not None and batch_size > 0, "batch_size should be greater than 0" checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) if self.log_dir is not None and self.app_name is not None: self.estimator.set_tensorboard(self.log_dir, self.app_name) if validation_data: assert self.metrics is not None, "You should provide metrics when creating this " \ "estimator if you provide validation_data." if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data, validation_data = process_xshards_of_pandas_dataframe( data, feature_cols, label_cols, validation_data, mode="fit") train_fset, val_fset = self._handle_xshards(data, validation_data) self.estimator.train(train_fset, self.loss, end_trigger, checkpoint_trigger, val_fset, self.metrics, batch_size) elif isinstance(data, DataFrame): train_fset, val_fset = self._handle_dataframe( data, validation_data, feature_cols, label_cols) self.estimator.train(train_fset, self.loss, end_trigger, checkpoint_trigger, val_fset, self.metrics, batch_size) elif isinstance(data, DataLoader) or callable(data) or isinstance( data, types.FunctionType): if isinstance(data, types.FunctionType): data, validation_data = data(self.config, batch_size), validation_data( self.config, batch_size) train_fset, val_fset = self._handle_data_loader( data, validation_data) self.estimator.train_minibatch(train_fset, self.loss, end_trigger, checkpoint_trigger, val_fset, self.metrics) else: raise ValueError( "Data and validation data should be SparkXShards, DataLoaders or " "callable data_creators but get " + data.__class__.__name__) return self
def test_checkpoint(self): features = np.random.randn(20, 10) labels = np.random.randint(0, 10, size=[20]) with tf.Graph().as_default(): dataset = TFDataset.from_ndarrays((features, labels), batch_size=4, val_tensors=(features, labels)) feature_tensor, label_tensor = dataset.tensors features = tf.layers.dense(feature_tensor, 8) output = tf.layers.dense(features, 10) loss = tf.reduce_mean( tf.losses.sparse_softmax_cross_entropy(logits=output, labels=label_tensor)) model_dir = tempfile.mkdtemp() try: optimizer = TFOptimizer.from_loss(loss, Adam(), val_outputs=[output], val_labels=[label_tensor], val_method=Accuracy(), metrics={"loss": loss}, model_dir=model_dir) optimizer.optimize(end_trigger=MaxEpoch(1)) first_weights = optimizer.sess.run(tf.trainable_variables()[0]) import re ckpt_path = None versions = [] for (root, dirs, files) in os.walk(model_dir, topdown=True): temp_versions = [] for file_name in files: if re.match("^optimMethod-TFParkTraining\.[0-9]+$", file_name) is not None: version = int(file_name.split(".")[1]) temp_versions.append(version) if temp_versions: ckpt_path = root versions = temp_versions break assert ckpt_path is not None, "Cannot fine checkpoint file" optimizer.sess.run( tf.global_variables_initializer()) # reset variable optimizer_load = TFOptimizer.from_loss( loss, Adam(), session=optimizer.sess, val_outputs=[output], val_labels=[label_tensor], val_method=Accuracy(), metrics={"loss": loss}, model_dir=model_dir) optimizer_load.load_checkpoint(ckpt_path, max(versions)) loaded_first_weights_before_train = optimizer.sess.run( tf.trainable_variables()[0]) assert np.allclose(first_weights, loaded_first_weights_before_train) # max epoch still 1, should not train optimizer_load.optimize(end_trigger=MaxEpoch(1)) loaded_first_weights = optimizer.sess.run( tf.trainable_variables()[0]) assert np.allclose(first_weights, loaded_first_weights) # max epoch increase 1, should train 1 epoch optimizer_load.optimize(end_trigger=MaxEpoch(2)) loaded_first_weights_2 = optimizer.sess.run( tf.trainable_variables()[0]) assert not np.allclose(first_weights, loaded_first_weights_2) optimizer_load.sess.close() finally: import shutil shutil.rmtree(model_dir)
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--dir', default='/tmp/data', metavar='N', help='the folder store mnist data') parser.add_argument( '--batch-size', type=int, default=256, metavar='N', help='input batch size for training per executor(default: 256)') parser.add_argument( '--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing per executor(default: 1000)') parser.add_argument('--epochs', type=int, default=2, metavar='N', help='number of epochs to train (default: 2)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.001)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument( '--deploy-mode', default="local", help='supported deploy mode is local, yarn-client, yarn-cluster') args = parser.parse_args() torch.manual_seed(args.seed) train_loader = torch.utils.data.DataLoader(datasets.MNIST( args.dir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.MNIST( args.dir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=args.test_batch_size, shuffle=False) # init on yarn when HADOOP_CONF_DIR and ZOO_CONDA_NAME is provided. if args.deploy_mode == "local": sc = init_orca_context() else: sc = init_orca_context(cluster_mode=args.deploy_mode, cores=2, memory="2g", num_nodes=4) model = Net() model.train() criterion = nn.NLLLoss() adam = torch.optim.Adam(model.parameters(), lr=args.lr) zoo_model = TorchModel.from_pytorch(model) zoo_criterion = TorchLoss.from_pytorch(criterion) zoo_optim = TorchOptim.from_pytorch(adam) zoo_estimator = Estimator(zoo_model, optim_methods=zoo_optim) train_featureset = FeatureSet.pytorch_dataloader(train_loader) test_featureset = FeatureSet.pytorch_dataloader(test_loader) from bigdl.dllib.optim.optimizer import MaxEpoch, EveryEpoch zoo_estimator.train_minibatch(train_featureset, zoo_criterion, end_trigger=MaxEpoch(args.epochs), checkpoint_trigger=EveryEpoch(), validation_set=test_featureset, validation_method=[Accuracy()])
def fit(self, data, epochs=1, batch_size=32, feature_cols=None, label_cols=None, validation_data=None, session_config=None, checkpoint_trigger=None, auto_shard_files=True): """ Train this keras model with train data. :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. If data is tf.data.Dataset, each element is [feature tensor tuple, label tensor tuple] :param epochs: number of epochs to train. :param batch_size: total batch size for each iteration. :param feature_cols: feature column names if train data is Spark DataFrame or XShards of Pandas DataFrame. :param label_cols: label column names if train data is Spark DataFrame or XShards of Pandas DataFrame. :param validation_data: validation data. Validation data type should be the same as train data. :param session_config: tensorflow session configuration for training. Should be object of tf.ConfigProto :param checkpoint_trigger: when to trigger checkpoint during training. Should be a bigdl.orca.learn.trigger, like EveryEpoch(), SeveralIteration( num_iterations),etc. :param auto_shard_files: whether to automatically detect if the dataset is file-based and and apply sharding on files, otherwise sharding on records. Default is False. """ if isinstance(data, DataFrame): assert feature_cols is not None, \ "feature columns is None; it should not be None in training" assert label_cols is not None, \ "label columns is None; it should not be None in training" if isinstance(data, tf.data.Dataset): assert isinstance(data.element_spec, tuple), \ "If data is tf.data.Dataset, each element should be " \ "(feature tensors, label tensor), where each feature/label tensor can be " \ "either a single tensor or a tuple of tensors" if validation_data is not None: assert isinstance(validation_data, tf.data.Dataset), \ "train data and validation data should be both tf.data.Dataset" assert isinstance(validation_data.element_spec, tuple), \ "If validation_data is tf.data.Dataset, each element should be " \ "(feature tensors, label tensor), where each feature/label tensor can be " \ "either a single tensor or a tuple of tensors" if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': assert feature_cols is not None, \ "feature columns is None; it should not be None in training" assert label_cols is not None, \ "label columns is None; it should not be None in training" data, validation_data = process_xshards_of_pandas_dataframe( data, feature_cols, label_cols, validation_data, "fit") if checkpoint_trigger is not None: checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) if is_tf_data_dataset(data): data = data.map(_standardize_keras_target_data) validation_data = validation_data.map( _standardize_keras_target_data) memory_type = OrcaContext.train_data_store dataset = to_dataset(data, batch_size=batch_size, batch_per_thread=-1, validation_data=validation_data, feature_cols=feature_cols, label_cols=label_cols, hard_code_batch_size=False, sequential_order=False, shuffle=True, auto_shard_files=auto_shard_files, memory_type=memory_type) self.tf_optimizer = TFOptimizer.from_keras( self.model.model, dataset, model_dir=self.model.model_dir, session_config=session_config, metrics=self.metrics, optimizer=self.optimizer) if self.clip_norm: self.tf_optimizer.set_gradient_clipping_by_l2_norm( clip_norm=self.clip_norm) if self.clip_min and self.clip_max: self.tf_optimizer.set_constant_gradient_clipping( self.clip_min, self.clip_max) if self.load_checkpoint: self.tf_optimizer.load_checkpoint(self.checkpoint_path, self.checkpoint_version) if self.log_dir and self.app_name: self.tf_optimizer.estimator.set_tensorboard( self.log_dir, self.app_name) self.tf_optimizer.optimize(MaxEpoch(epochs), checkpoint_trigger=checkpoint_trigger) return self
def fit(self, data, epochs=1, batch_size=32, feature_cols=None, label_cols=None, validation_data=None, session_config=None, checkpoint_trigger=None, auto_shard_files=False, feed_dict=None): """ Train this graph model with train data. :param data: train data. It can be XShards, Spark DataFrame, tf.data.Dataset. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. If data is tf.data.Dataset, each element is a tuple of input tensors. :param epochs: number of epochs to train. :param batch_size: total batch size for each iteration. :param feature_cols: feature column names if train data is Spark DataFrame or XShards of Pandas DataFrame. :param label_cols: label column names if train data is Spark DataFrame or XShards of Pandas DataFrame. :param validation_data: validation data. Validation data type should be the same as train data. :param auto_shard_files: whether to automatically detect if the dataset is file-based and and apply sharding on files, otherwise sharding on records. Default is False. :param session_config: tensorflow session configuration for training. Should be object of tf.ConfigProto :param feed_dict: a dictionary. The key is TensorFlow tensor, usually a placeholder, the value of the dictionary is a tuple of two elements. The first one of the tuple is the value to feed to the tensor in training phase and the second one is the value to feed to the tensor in validation phase. :param checkpoint_trigger: when to trigger checkpoint during training. Should be a bigdl.orca.learn.trigger, like EveryEpoch(), SeveralIteration( num_iterations),etc. """ assert self.labels is not None, \ "labels is None; it should not be None in training" assert self.loss is not None, \ "loss is None; it should not be None in training" assert self.optimizer is not None, \ "optimizer is None; it should not be None in training" if isinstance(data, DataFrame): assert feature_cols is not None, \ "feature columns is None; it should not be None in training" assert label_cols is not None, \ "label columns is None; it should not be None in training" if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': assert feature_cols is not None, \ "feature columns is None; it should not be None in training" assert label_cols is not None, \ "label columns is None; it should not be None in training" data, validation_data = process_xshards_of_pandas_dataframe( data, feature_cols, label_cols, validation_data, "fit") if checkpoint_trigger is not None: checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) memory_type = OrcaContext.train_data_store dataset = to_dataset(data, batch_size=batch_size, batch_per_thread=-1, validation_data=validation_data, feature_cols=feature_cols, label_cols=label_cols, hard_code_batch_size=False, sequential_order=False, shuffle=True, auto_shard_files=auto_shard_files, memory_type=memory_type) if feed_dict is not None: tensor_with_value = { key: (value[0], value[1]) for key, value in feed_dict.items() } else: tensor_with_value = None if self.use_bigdl_optim: self.tf_optimizer = TFOptimizer.from_loss( self.loss, self.optimizer, session=self.sess, inputs=(self.inputs, self.labels), dataset=dataset, clip_norm=self.clip_norm, clip_value=self.clip_value, metrics=self.metrics, tensor_with_value=tensor_with_value, session_config=session_config, model_dir=self.model_dir, updates=self.updates) else: self.tf_optimizer = TFOptimizer.from_train_op( train_op=self.train_op, loss=self.loss, inputs=self.inputs, labels=self.labels, dataset=dataset, metrics=self.metrics, updates=self.updates, sess=self.sess, tensor_with_value=tensor_with_value, session_config=session_config, model_dir=self.model_dir) if self.load_checkpoint: self.tf_optimizer.load_checkpoint(self.checkpoint_path, self.checkpoint_version) if self.log_dir and self.app_name: self.tf_optimizer.estimator.set_tensorboard( self.log_dir, self.app_name) self.tf_optimizer.optimize(end_trigger=MaxEpoch(epochs), checkpoint_trigger=checkpoint_trigger) return self
def fit(self, data, epochs, batch_size=32, feature_cols="features", label_cols="label", caching_sample=True, validation_data=None, validation_trigger=None, checkpoint_trigger=None): """ Train this BigDL model with train data. :param data: train data. It can be XShards or Spark DataFrame. If data is XShards, each partition is a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of numpy arrays. :param epochs: Number of epochs to train the model. :param batch_size: Batch size used for training. Default: 32. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame. Default: "features". :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame. Default: "label". :param caching_sample: whether to cache the Samples after preprocessing. Default: True :param validation_data: Validation data. XShards and Spark DataFrame are supported. If data is XShards, each partition is a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of numpy arrays. :param validation_trigger: Orca Trigger to trigger validation computation. :param checkpoint_trigger: Orca Trigger to set a checkpoint. :return: """ from bigdl.orca.learn.trigger import Trigger assert batch_size > 0, "batch_size should be greater than 0" if validation_data is not None: assert self.metrics is not None, \ "You should provide metrics when creating this estimator if you provide " \ "validation_data." if isinstance(data, DataFrame): if isinstance(feature_cols, list): data, validation_data, feature_cols = \ BigDLEstimator._combine_cols(data, feature_cols, col_name="features", val_data=validation_data) if isinstance(label_cols, list): data, validation_data, label_cols = \ BigDLEstimator._combine_cols(data, label_cols, col_name="label", val_data=validation_data) self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs) \ .setCachingSample(caching_sample).setFeaturesCol(feature_cols) \ .setLabelCol(label_cols) if validation_data is not None: assert isinstance(validation_data, DataFrame), \ "validation_data should be a spark DataFrame." assert validation_trigger is not None, \ "You should provide validation_trigger if you provide validation_data." validation_trigger = Trigger.convert_trigger( validation_trigger) self.nn_estimator.setValidation(validation_trigger, validation_data, self.metrics, batch_size) if self.log_dir is not None and self.app_name is not None: from bigdl.dllib.optim.optimizer import TrainSummary from bigdl.dllib.optim.optimizer import ValidationSummary train_summary = TrainSummary(log_dir=self.log_dir, app_name=self.app_name) self.nn_estimator.setTrainSummary(train_summary) val_summary = ValidationSummary(log_dir=self.log_dir, app_name=self.app_name) self.nn_estimator.setValidationSummary(val_summary) if self.model_dir is not None and checkpoint_trigger is not None: checkpoint_trigger = Trigger.convert_trigger( checkpoint_trigger) self.nn_estimator.setCheckpoint(self.model_dir, checkpoint_trigger) self.nn_model = self.nn_estimator.fit(data) self.is_nnframe_fit = True elif isinstance(data, SparkXShards): from bigdl.orca.data.utils import xshard_to_sample end_trigger = MaxEpoch(epochs) checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) if isinstance(data, SparkXShards): train_rdd = data.rdd.flatMap(xshard_to_sample) train_feature_set = FeatureSet.sample_rdd(train_rdd) if validation_data is None: val_feature_set = None else: assert isinstance(validation_data, SparkXShards), \ "validation_data should be a XShards" val_feature_set = FeatureSet.sample_rdd( validation_data.rdd.flatMap(xshard_to_sample)) if self.log_dir is not None and self.app_name is not None: self.estimator.set_tensorboard(self.log_dir, self.app_name) self.estimator.train(train_feature_set, self.loss, end_trigger, checkpoint_trigger, val_feature_set, self.metrics, batch_size) self.is_nnframe_fit = False else: raise ValueError( "Data and validation data should be XShards, but get " + data.__class__.__name__) else: raise ValueError( "Data should be XShards or Spark DataFrame, but get " + data.__class__.__name__) return self