コード例 #1
0
 def __init__(self,
              model,
              loss,
              optimizer,
              metrics=None,
              model_dir=None,
              bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from zoo.orca.learn.optimizers.schedule import Default
         optimizer = SGD(learningrate_schedule=Default())
     if isinstance(optimizer, TorchOptimizer):
         optimizer = TorchOptim.from_pytorch(optimizer)
     elif isinstance(optimizer, OrcaOptimizer):
         optimizer = optimizer.get_optimizer()
     else:
         raise ValueError(
             "Only PyTorch optimizer and orca optimizer are supported")
     from zoo.orca.learn.metrics import Metrics
     self.metrics = Metrics.convert_metrics_list(metrics)
     self.log_dir = None
     self.app_name = None
     self.model_dir = model_dir
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model,
                                     optimizer,
                                     model_dir,
                                     bigdl_type=bigdl_type)
コード例 #2
0
ファイル: estimator.py プロジェクト: gjlee0802/analytics-zoo
 def __init__(self,
              *,
              model,
              loss,
              optimizer=None,
              metrics=None,
              feature_preprocessing=None,
              label_preprocessing=None,
              model_dir=None):
     self.loss = loss
     self.optimizer = optimizer
     self.metrics = Metrics.convert_metrics_list(metrics)
     self.feature_preprocessing = feature_preprocessing
     self.label_preprocessing = label_preprocessing
     self.model_dir = model_dir
     self.model = model
     self.nn_model = NNModel(
         self.model, feature_preprocessing=self.feature_preprocessing)
     self.nn_estimator = NNEstimator(self.model, self.loss,
                                     self.feature_preprocessing,
                                     self.label_preprocessing)
     if self.optimizer is None:
         from bigdl.optim.optimizer import SGD
         self.optimizer = SGD()
     self.nn_estimator.setOptimMethod(self.optimizer)
     self.estimator = SparkEstimator(self.model, self.optimizer,
                                     self.model_dir)
     self.log_dir = None
     self.app_name = None
     self.is_nnframe_fit = False
コード例 #3
0
ファイル: estimator.py プロジェクト: hboshnak/analytics-zoo
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 labels_cols=None,
                 validation_metrics=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics

        assert data is not None, "validation data shouldn't be None"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)

        if isinstance(data, SparkXShards):
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(to_sample))
            return self.estimator.evaluate(val_feature_set, validation_metrics,
                                           batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            return self.estimator.evaluate_minibatch(val_feature_set,
                                                     validation_metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
コード例 #4
0
ファイル: estimator.py プロジェクト: hboshnak/analytics-zoo
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            labels_cols=None,
            validation_data=None,
            validation_metrics=None,
            checkpoint_trigger=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if isinstance(data, SparkXShards):
            train_rdd = data.rdd.flatMap(to_sample)
            train_feature_set = FeatureSet.sample_rdd(train_rdd)
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                                  "SparkXShards"
                val_feature_set = FeatureSet.sample_rdd(
                    validation_data.rdd.flatMap(to_sample))

            self.estimator.train(train_feature_set, self.loss, end_trigger,
                                 checkpoint_trigger, val_feature_set,
                                 validation_metrics, batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, DataLoader) or callable(data), \
                    "validation_data should be a pytorch DataLoader or a callable data_creator"
                val_feature_set = FeatureSet.pytorch_dataloader(
                    validation_data)

            self.estimator.train_minibatch(train_feature_set, self.loss,
                                           end_trigger, checkpoint_trigger,
                                           val_feature_set, validation_metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)
        return self
コード例 #5
0
ファイル: estimator.py プロジェクト: DingHe/analytics-zoo
    def evaluate(self, data, validation_methods=None, batch_size=32):
        assert data is not None, "validation data shouldn't be None"

        if isinstance(data, DataFrame):
            raise NotImplementedError
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import to_sample
            from zoo.orca.learn.metrics import Metrics

            validation_methods = Metrics.convert_metrics_list(
                validation_methods)
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(to_sample))
            return self.estimator.evaluate(val_feature_set, validation_methods,
                                           batch_size)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
コード例 #6
0
ファイル: estimator.py プロジェクト: DingHe/analytics-zoo
    def fit(self,
            data,
            epochs,
            feature_cols="features",
            labels_cols="label",
            batch_size=32,
            caching_sample=True,
            val_data=None,
            val_trigger=None,
            val_methods=None,
            checkpoint_trigger=None):
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        assert batch_size > 0, "batch_size should be greater than 0"

        if isinstance(data, DataFrame):
            if isinstance(feature_cols, list):
                data, val_data, feature_cols = \
                    BigDLEstimatorWrapper._combine_cols(data, feature_cols, col_name="features",
                                                        val_data=val_data)

            if isinstance(labels_cols, list):
                data, val_data, labels_cols = \
                    BigDLEstimatorWrapper._combine_cols(data, labels_cols, col_name="label",
                                                        val_data=val_data)

            self.nn_estimator.setBatchSize(batch_size).setMaxEpoch(epochs)\
                .setCachingSample(caching_sample).setFeaturesCol(feature_cols)\
                .setLabelCol(labels_cols)

            if val_data is not None:
                assert isinstance(
                    val_data,
                    DataFrame), "val_data should be a spark DataFrame."
                assert val_trigger is not None and val_methods is not None, \
                    "You should provide val_trigger and val_methods if you provide val_data."
                val_trigger = Trigger.convert_trigger(val_trigger)
                val_methods = Metrics.convert_metrics_list(val_methods)
                self.nn_estimator.setValidation(val_trigger, val_data,
                                                val_methods, batch_size)
            if self.log_dir is not None and self.app_name is not None:
                from bigdl.optim.optimizer import TrainSummary
                from bigdl.optim.optimizer import ValidationSummary
                train_summary = TrainSummary(log_dir=self.log_dir,
                                             app_name=self.app_name)
                self.nn_estimator.setTrainSummary(train_summary)
                val_summary = ValidationSummary(log_dir=self.log_dir,
                                                app_name=self.log_dir)
                self.nn_estimator.setValidationSummary(val_summary)
            if self.model_dir is not None and checkpoint_trigger is not None:
                checkpoint_trigger = Trigger.convert_trigger(
                    checkpoint_trigger)
                self.nn_estimator.setCheckpoint(self.model_dir,
                                                checkpoint_trigger)

            self.nn_model = self.nn_estimator.fit(data)
            self.is_nnframe_fit = True
        elif isinstance(data, SparkXShards):
            from zoo.orca.data.utils import to_sample

            end_trigger = MaxEpoch(epochs)
            val_methods = Metrics.convert_metrics_list(val_methods)
            checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

            if isinstance(data, SparkXShards):
                train_rdd = data.rdd.flatMap(to_sample)
                train_feature_set = FeatureSet.sample_rdd(train_rdd)
                if val_data is None:
                    val_feature_set = None
                else:
                    assert isinstance(
                        val_data, SparkXShards), "val_data should be a XShards"
                    val_feature_set = FeatureSet.sample_rdd(
                        val_data.rdd.flatMap(to_sample))
                if self.log_dir is not None and self.app_name is not None:
                    self.estimator.set_tensorboard(self.log_dir, self.app_name)
                self.estimator.train(train_feature_set, self.loss, end_trigger,
                                     checkpoint_trigger, val_feature_set,
                                     val_methods, batch_size)
                self.is_nnframe_fit = False
            else:
                raise ValueError(
                    "Data and validation data should be XShards, but get " +
                    data.__class__.__name__)
        else:
            raise ValueError(
                "Data should be XShards or Spark DataFrame, but get " +
                data.__class__.__name__)
        return self