def __init__(self, model, loss, optimizer, metrics=None, model_dir=None, bigdl_type="float"): from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim self.loss = loss if self.loss is None: self.loss = TorchLoss() else: self.loss = TorchLoss.from_pytorch(loss) if optimizer is None: from zoo.orca.learn.optimizers.schedule import Default optimizer = SGD(learningrate_schedule=Default()) if isinstance(optimizer, TorchOptimizer): optimizer = TorchOptim.from_pytorch(optimizer) elif isinstance(optimizer, OrcaOptimizer): optimizer = optimizer.get_optimizer() else: raise ValueError( "Only PyTorch optimizer and orca optimizer are supported") from zoo.orca.learn.metrics import Metrics self.metrics = Metrics.convert_metrics_list(metrics) self.log_dir = None self.app_name = None self.model_dir = model_dir self.model = TorchModel.from_pytorch(model) self.estimator = SparkEstimator(self.model, optimizer, model_dir, bigdl_type=bigdl_type)
def __init__(self, *, model, loss, optimizer=None, metrics=None, feature_preprocessing=None, label_preprocessing=None, model_dir=None): self.loss = loss self.optimizer = optimizer self.metrics = Metrics.convert_metrics_list(metrics) self.feature_preprocessing = feature_preprocessing self.label_preprocessing = label_preprocessing self.model_dir = model_dir self.model = model self.nn_model = NNModel( self.model, feature_preprocessing=self.feature_preprocessing) self.nn_estimator = NNEstimator(self.model, self.loss, self.feature_preprocessing, self.label_preprocessing) if self.optimizer is None: from bigdl.optim.optimizer import SGD self.optimizer = SGD() self.nn_estimator.setOptimMethod(self.optimizer) self.estimator = SparkEstimator(self.model, self.optimizer, self.model_dir) self.log_dir = None self.app_name = None self.is_nnframe_fit = False
def evaluate(self, data, batch_size=32, feature_cols=None, labels_cols=None, validation_metrics=None): from zoo.orca.data.utils import to_sample from zoo.orca.learn.metrics import Metrics assert data is not None, "validation data shouldn't be None" validation_metrics = Metrics.convert_metrics_list(validation_metrics) if isinstance(data, SparkXShards): val_feature_set = FeatureSet.sample_rdd( data.rdd.flatMap(to_sample)) return self.estimator.evaluate(val_feature_set, validation_metrics, batch_size) elif isinstance(data, DataLoader) or callable(data): val_feature_set = FeatureSet.pytorch_dataloader(data) return self.estimator.evaluate_minibatch(val_feature_set, validation_metrics) else: raise ValueError( "Data should be a SparkXShards, a DataLoader or a callable " "data_creator, but get " + data.__class__.__name__)
def fit(self, data, epochs=1, batch_size=32, feature_cols=None, labels_cols=None, validation_data=None, validation_metrics=None, checkpoint_trigger=None): from zoo.orca.data.utils import to_sample from zoo.orca.learn.metrics import Metrics from zoo.orca.learn.trigger import Trigger end_trigger = MaxEpoch(epochs) assert batch_size > 0, "batch_size should be greater than 0" validation_metrics = Metrics.convert_metrics_list(validation_metrics) checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) if self.log_dir is not None and self.app_name is not None: self.estimator.set_tensorboard(self.log_dir, self.app_name) if isinstance(data, SparkXShards): train_rdd = data.rdd.flatMap(to_sample) train_feature_set = FeatureSet.sample_rdd(train_rdd) if validation_data is None: val_feature_set = None else: assert isinstance(validation_data, SparkXShards), "validation_data should be a " \ "SparkXShards" val_feature_set = FeatureSet.sample_rdd( validation_data.rdd.flatMap(to_sample)) self.estimator.train(train_feature_set, self.loss, end_trigger, checkpoint_trigger, val_feature_set, validation_metrics, batch_size) elif isinstance(data, DataLoader) or callable(data): train_feature_set = FeatureSet.pytorch_dataloader(data, "", "") if validation_data is None: val_feature_set = None else: assert isinstance(validation_data, DataLoader) or callable(data), \ "validation_data should be a pytorch DataLoader or a callable data_creator" val_feature_set = FeatureSet.pytorch_dataloader( validation_data) self.estimator.train_minibatch(train_feature_set, self.loss, end_trigger, checkpoint_trigger, val_feature_set, validation_metrics) else: raise ValueError( "Data and validation data should be SparkXShards, DataLoaders or " "callable data_creators but get " + data.__class__.__name__) return self
def evaluate(self, data, validation_methods=None, batch_size=32): assert data is not None, "validation data shouldn't be None" if isinstance(data, DataFrame): raise NotImplementedError elif isinstance(data, SparkXShards): from zoo.orca.data.utils import to_sample from zoo.orca.learn.metrics import Metrics validation_methods = Metrics.convert_metrics_list( validation_methods) val_feature_set = FeatureSet.sample_rdd( data.rdd.flatMap(to_sample)) return self.estimator.evaluate(val_feature_set, validation_methods, batch_size) else: raise ValueError( "Data should be XShards or Spark DataFrame, but get " + data.__class__.__name__)
def fit(self, data, epochs, feature_cols="features", labels_cols="label", batch_size=32, caching_sample=True, val_data=None, val_trigger=None, val_methods=None, checkpoint_trigger=None): from zoo.orca.learn.metrics import Metrics from zoo.orca.learn.trigger import Trigger assert batch_size > 0, "batch_size should be greater than 0" if isinstance(data, DataFrame): if isinstance(feature_cols, list): data, val_data, feature_cols = \ BigDLEstimatorWrapper._combine_cols(data, feature_cols, col_name="features", val_data=val_data) if isinstance(labels_cols, list): data, val_data, labels_cols = \ BigDLEstimatorWrapper._combine_cols(data, labels_cols, col_name="label", val_data=val_data) self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs)\ .setCachingSample(caching_sample).setFeaturesCol(feature_cols)\ .setLabelCol(labels_cols) if val_data is not None: assert isinstance( val_data, DataFrame), "val_data should be a spark DataFrame." assert val_trigger is not None and val_methods is not None, \ "You should provide val_trigger and val_methods if you provide val_data." val_trigger = Trigger.convert_trigger(val_trigger) val_methods = Metrics.convert_metrics_list(val_methods) self.nn_estimator.setValidation(val_trigger, val_data, val_methods, batch_size) if self.log_dir is not None and self.app_name is not None: from bigdl.optim.optimizer import TrainSummary from bigdl.optim.optimizer import ValidationSummary train_summary = TrainSummary(log_dir=self.log_dir, app_name=self.app_name) self.nn_estimator.setTrainSummary(train_summary) val_summary = ValidationSummary(log_dir=self.log_dir, app_name=self.log_dir) self.nn_estimator.setValidationSummary(val_summary) if self.model_dir is not None and checkpoint_trigger is not None: checkpoint_trigger = Trigger.convert_trigger( checkpoint_trigger) self.nn_estimator.setCheckpoint(self.model_dir, checkpoint_trigger) self.nn_model = self.nn_estimator.fit(data) self.is_nnframe_fit = True elif isinstance(data, SparkXShards): from zoo.orca.data.utils import to_sample end_trigger = MaxEpoch(epochs) val_methods = Metrics.convert_metrics_list(val_methods) checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) if isinstance(data, SparkXShards): train_rdd = data.rdd.flatMap(to_sample) train_feature_set = FeatureSet.sample_rdd(train_rdd) if val_data is None: val_feature_set = None else: assert isinstance( val_data, SparkXShards), "val_data should be a XShards" val_feature_set = FeatureSet.sample_rdd( val_data.rdd.flatMap(to_sample)) if self.log_dir is not None and self.app_name is not None: self.estimator.set_tensorboard(self.log_dir, self.app_name) self.estimator.train(train_feature_set, self.loss, end_trigger, checkpoint_trigger, val_feature_set, val_methods, batch_size) self.is_nnframe_fit = False else: raise ValueError( "Data and validation data should be XShards, but get " + data.__class__.__name__) else: raise ValueError( "Data should be XShards or Spark DataFrame, but get " + data.__class__.__name__) return self