コード例 #1
0
 def __init__(self,
              model,
              loss,
              optimizer,
              metrics=None,
              model_dir=None,
              bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from zoo.orca.learn.optimizers.schedule import Default
         optimizer = SGD(learningrate_schedule=Default())
     if isinstance(optimizer, TorchOptimizer):
         optimizer = TorchOptim.from_pytorch(optimizer)
     elif isinstance(optimizer, OrcaOptimizer):
         optimizer = optimizer.get_optimizer()
     else:
         raise ValueError(
             "Only PyTorch optimizer and orca optimizer are supported")
     from zoo.orca.learn.metrics import Metric
     self.metrics = Metric.convert_metrics_list(metrics)
     self.log_dir = None
     self.app_name = None
     self.model_dir = model_dir
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model,
                                     optimizer,
                                     model_dir,
                                     bigdl_type=bigdl_type)
コード例 #2
0
 def __init__(self,
              *,
              model,
              loss,
              optimizer=None,
              metrics=None,
              feature_preprocessing=None,
              label_preprocessing=None,
              model_dir=None):
     self.loss = loss
     self.optimizer = optimizer
     self.metrics = Metric.convert_metrics_list(metrics)
     self.feature_preprocessing = feature_preprocessing
     self.label_preprocessing = label_preprocessing
     self.model_dir = model_dir
     self.model = model
     self.nn_model = NNModel(
         self.model, feature_preprocessing=self.feature_preprocessing)
     self.nn_estimator = NNEstimator(self.model, self.loss,
                                     self.feature_preprocessing,
                                     self.label_preprocessing)
     if self.optimizer is None:
         from bigdl.optim.optimizer import SGD
         self.optimizer = SGD()
     self.nn_estimator.setOptimMethod(self.optimizer)
     self.estimator = SparkEstimator(self.model, self.optimizer,
                                     self.model_dir)
     self.log_dir = None
     self.app_name = None
     self.is_nnframe_fit = False
コード例 #3
0
    def validate(self, val_iterator, info, metrics):
        """Runs one standard validation pass over the val_iterator.

        This will call ``model.eval()`` and ``torch.no_grad`` when iterating
        over the validation dataloader.

        If overriding this method, you can access model, criterion via
        ``self.model`` and ``self.criterion``. You also do not need to call
        ``validate_batch`` if overriding this method.

        Args:
            val_iterator (iter): Iterable constructed from the
                validation dataloader.
            info: (dict): Dictionary for information to be used for custom
                validation operations.

        Returns:
            A dict of metrics from the evaluation.
                By default, returns "val_accuracy" and "val_loss"
                which is computed by aggregating "loss" and "correct" values
                from ``validate_batch`` and dividing it by the sum of
                ``num_samples`` from all calls to ``self.validate_batch``.
        """
        # switch to evaluate mode
        self.model.eval()
        metrics = Metric.convert_metrics_dict(metrics, backend="pytorch")
        losses = []
        total_samples = 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_iterator):
                batch_info = {"batch_idx": batch_idx}
                batch_info.update(info)
                output, target, loss = self.forward_batch(batch, batch_info)
                num_samples = target.size(0)
                total_samples += num_samples
                losses.append(loss.item() * num_samples)
                for metric in metrics.values():
                    metric(output, target)

        result = {name: metric.compute() for name, metric in metrics.items()}

        result["val_loss"] = sum(losses) / total_samples

        result["num_samples"] = total_samples

        return result
コード例 #4
0
def validate(config, model, valid_ld, metrics, validate_batches):
    import torch
    from zoo.orca.learn.metrics import Metric

    model.eval()
    metrics = Metric.convert_metrics_dict(metrics, backend="pytorch")
    valid_iter = iter(valid_ld)
    with torch.no_grad():
        for j in range(validate_batches):
            # Iterate again from the beginning if running out of batches.
            if j > 0 and j % len(valid_ld) == 0:
                valid_iter = iter(valid_ld)
            x, y = next(valid_iter)
            o = model(x, y)
            for metric in metrics.values():
                metric(o, y)
    result = {name: metric.compute() for name, metric in metrics.items()}
    output = "Validation results: "
    for metric, value in result.items():
        output += "{}:{} ".format(metric, value)
    print(output)
    return result