Example #1
0
 def __init__(self,
              model,
              loss,
              optimizer,
              metrics=None,
              model_dir=None,
              bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from zoo.orca.learn.optimizers.schedule import Default
         optimizer = SGD(learningrate_schedule=Default())
     if isinstance(optimizer, TorchOptimizer):
         optimizer = TorchOptim.from_pytorch(optimizer)
     elif isinstance(optimizer, OrcaOptimizer):
         optimizer = optimizer.get_optimizer()
     else:
         raise ValueError(
             "Only PyTorch optimizer and orca optimizer are supported")
     from zoo.orca.learn.metrics import Metric
     self.metrics = Metric.convert_metrics_list(metrics)
     self.log_dir = None
     self.app_name = None
     self.model_dir = model_dir
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model,
                                     optimizer,
                                     model_dir,
                                     bigdl_type=bigdl_type)
Example #2
0
 def __init__(self,
              *,
              model,
              loss,
              optimizer=None,
              metrics=None,
              feature_preprocessing=None,
              label_preprocessing=None,
              model_dir=None):
     self.loss = loss
     self.optimizer = optimizer
     self.metrics = Metric.convert_metrics_list(metrics)
     self.feature_preprocessing = feature_preprocessing
     self.label_preprocessing = label_preprocessing
     self.model_dir = model_dir
     self.model = model
     self.nn_model = NNModel(
         self.model, feature_preprocessing=self.feature_preprocessing)
     self.nn_estimator = NNEstimator(self.model, self.loss,
                                     self.feature_preprocessing,
                                     self.label_preprocessing)
     if self.optimizer is None:
         from bigdl.optim.optimizer import SGD
         self.optimizer = SGD()
     self.nn_estimator.setOptimMethod(self.optimizer)
     self.estimator = SparkEstimator(self.model, self.optimizer,
                                     self.model_dir)
     self.log_dir = None
     self.app_name = None
     self.is_nnframe_fit = False