def __init__(self, model, loss, optimizer, metrics=None, model_dir=None, bigdl_type="float"): from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim self.loss = loss if self.loss is None: self.loss = TorchLoss() else: self.loss = TorchLoss.from_pytorch(loss) if optimizer is None: from zoo.orca.learn.optimizers.schedule import Default optimizer = SGD(learningrate_schedule=Default()) if isinstance(optimizer, TorchOptimizer): optimizer = TorchOptim.from_pytorch(optimizer) elif isinstance(optimizer, OrcaOptimizer): optimizer = optimizer.get_optimizer() else: raise ValueError( "Only PyTorch optimizer and orca optimizer are supported") from zoo.orca.learn.metrics import Metric self.metrics = Metric.convert_metrics_list(metrics) self.log_dir = None self.app_name = None self.model_dir = model_dir self.model = TorchModel.from_pytorch(model) self.estimator = SparkEstimator(self.model, optimizer, model_dir, bigdl_type=bigdl_type)
def __init__(self, *, model, loss, optimizer=None, metrics=None, feature_preprocessing=None, label_preprocessing=None, model_dir=None): self.loss = loss self.optimizer = optimizer self.metrics = Metric.convert_metrics_list(metrics) self.feature_preprocessing = feature_preprocessing self.label_preprocessing = label_preprocessing self.model_dir = model_dir self.model = model self.nn_model = NNModel( self.model, feature_preprocessing=self.feature_preprocessing) self.nn_estimator = NNEstimator(self.model, self.loss, self.feature_preprocessing, self.label_preprocessing) if self.optimizer is None: from bigdl.optim.optimizer import SGD self.optimizer = SGD() self.nn_estimator.setOptimMethod(self.optimizer) self.estimator = SparkEstimator(self.model, self.optimizer, self.model_dir) self.log_dir = None self.app_name = None self.is_nnframe_fit = False
def validate(self, val_iterator, info, metrics): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating over the validation dataloader. If overriding this method, you can access model, criterion via ``self.model`` and ``self.criterion``. You also do not need to call ``validate_batch`` if overriding this method. Args: val_iterator (iter): Iterable constructed from the validation dataloader. info: (dict): Dictionary for information to be used for custom validation operations. Returns: A dict of metrics from the evaluation. By default, returns "val_accuracy" and "val_loss" which is computed by aggregating "loss" and "correct" values from ``validate_batch`` and dividing it by the sum of ``num_samples`` from all calls to ``self.validate_batch``. """ # switch to evaluate mode self.model.eval() metrics = Metric.convert_metrics_dict(metrics, backend="pytorch") losses = [] total_samples = 0 with torch.no_grad(): for batch_idx, batch in enumerate(val_iterator): batch_info = {"batch_idx": batch_idx} batch_info.update(info) output, target, loss = self.forward_batch(batch, batch_info) num_samples = target.size(0) total_samples += num_samples losses.append(loss.item() * num_samples) for metric in metrics.values(): metric(output, target) result = {name: metric.compute() for name, metric in metrics.items()} result["val_loss"] = sum(losses) / total_samples result["num_samples"] = total_samples return result
def validate(config, model, valid_ld, metrics, validate_batches): import torch from zoo.orca.learn.metrics import Metric model.eval() metrics = Metric.convert_metrics_dict(metrics, backend="pytorch") valid_iter = iter(valid_ld) with torch.no_grad(): for j in range(validate_batches): # Iterate again from the beginning if running out of batches. if j > 0 and j % len(valid_ld) == 0: valid_iter = iter(valid_ld) x, y = next(valid_iter) o = model(x, y) for metric in metrics.values(): metric(o, y) result = {name: metric.compute() for name, metric in metrics.items()} output = "Validation results: " for metric, value in result.items(): output += "{}:{} ".format(metric, value) print(output) return result