Пример #1
0
 def __init__(self,
              model,
              loss,
              optimizer,
              metrics=None,
              model_dir=None,
              bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from zoo.orca.learn.optimizers.schedule import Default
         optimizer = SGD(learningrate_schedule=Default())
     if isinstance(optimizer, TorchOptimizer):
         optimizer = TorchOptim.from_pytorch(optimizer)
     elif isinstance(optimizer, OrcaOptimizer):
         optimizer = optimizer.get_optimizer()
     else:
         raise ValueError(
             "Only PyTorch optimizer and orca optimizer are supported")
     from zoo.orca.learn.metrics import Metric
     self.metrics = Metric.convert_metrics_list(metrics)
     self.log_dir = None
     self.app_name = None
     self.model_dir = model_dir
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model,
                                     optimizer,
                                     model_dir,
                                     bigdl_type=bigdl_type)
Пример #2
0
    def load_orca_checkpoint(self, path, version, prefix=None):
        """
        Load existing checkpoint

        :param path: Path to the existing checkpoint.
        :param version: checkpoint version, which is the suffix of model.* file,
               i.e., for model.4 file, the version is 4.
        :param prefix: optimMethod prefix, for example 'optimMethod-TorchModelf53bddcc'
        :return:
        """
        import os
        from bigdl.nn.layer import Model
        from bigdl.optim.optimizer import OptimMethod
        assert prefix is not None, "You should provide optimMethod prefix, " \
                                   "for example 'optimMethod-TorchModelf53bddcc'"
        try:
            self.model = Model.load(
                os.path.join(path, "model.{}".format(version)))
            optimizer = OptimMethod.load(
                os.path.join(path, "{}.{}".format(prefix, version)))
        except Exception:
            raise ValueError(
                "Cannot load PyTorch checkpoint, please check your checkpoint path "
                "and checkpoint type.")
        self.estimator = SparkEstimator(self.model, optimizer, self.model_dir)
Пример #3
0
    def load(self, model_path):
        """
        Load the Estimator state (model and possibly with optimizer) from provided model_path.
        The model file should be generated by the save method of this estimator, or by
        ``torch.save(state_dict, model_path)``, where `state_dict` can be obtained by
        the ``state_dict()`` method of a pytorch model.

        :param model_path: path to the saved model.
        :return:
        """

        from zoo.pipeline.api.torch import TorchModel
        import os

        try:
            pytorch_model = self.get_model()
            pytorch_model.load_state_dict(torch.load(model_path))
            self.model = TorchModel.from_pytorch(pytorch_model)
        except Exception:
            raise ValueError(
                "Cannot load the PyTorch model. Please check your model path.")

        optim_path = self._get_optimizer_path(model_path)
        if os.path.isfile(optim_path):
            try:
                self.optimizer = OptimMethod.load(optim_path)
            except Exception:
                raise ValueError(
                    "Cannot load the optimizer. Only `bigdl.optim.optimizer."
                    "OptimMethod` is supported for loading.")
        else:
            self.optimizer = None

        self.estimator = SparkEstimator(self.model, self.optimizer,
                                        self.model_dir)
Пример #4
0
    def load_orca_checkpoint(self, path, version=None, prefix=None):
        """
        Load existing checkpoint. To load a specific checkpoint, please provide both `version` and
        `perfix`. If `version` is None, then the latest checkpoint will be loaded.

        :param path: Path to the existing checkpoint (or directory containing Orca checkpoint
               files).
        :param version: checkpoint version, which is the suffix of model.* file, i.e., for
               modle.4 file, the version is 4. If it is None, then load the latest checkpoint.
        :param prefix: optimMethod prefix, for example 'optimMethod-TorchModelf53bddcc'.
        :return:
        """
        import os
        from bigdl.nn.layer import Model
        from bigdl.optim.optimizer import OptimMethod
        from zoo.orca.learn.utils import find_latest_checkpoint
        from zoo.pipeline.api.torch import TorchModel

        if version is None:
            path, prefix, version = find_latest_checkpoint(path, model_type="pytorch")
            if path is None:
                raise ValueError("Cannot find PyTorch checkpoint, please check your checkpoint"
                                 " path.")
        else:
            assert prefix is not None, "You should provide optimMethod prefix, " \
                                       "for example 'optimMethod-TorchModelf53bddcc'"

        try:
            loaded_model = Model.load(os.path.join(path, "model.{}".format(version)))
            self.model = TorchModel.from_value(loaded_model.value)
            self.optimizer = OptimMethod.load(os.path.join(path, "{}.{}".format(prefix, version)))
        except Exception:
            raise ValueError("Cannot load PyTorch checkpoint, please check your checkpoint path "
                             "and checkpoint type.")
        self.estimator = SparkEstimator(self.model, self.optimizer, self.model_dir)
Пример #5
0
 def __init__(self, model, loss, optimizer, model_dir=None, bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from bigdl.optim.optimizer import SGD
         optimizer = SGD()
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model, optimizer, model_dir, bigdl_type=bigdl_type)
Пример #6
0
 def load_orca_checkpoint(self, path, version, prefix=None):
     import os
     from bigdl.nn.layer import Model
     from bigdl.optim.optimizer import OptimMethod
     assert prefix is not None, "You should provide optimMethod prefix, " \
                                "for example 'optimMethod-TorchModelf53bddcc'"
     try:
         self.model = Model.load(os.path.join(path, "model.{}".format(version)))
         optimizer = OptimMethod.load(os.path.join(path, "{}.{}".format(prefix, version)))
     except Exception:
         raise ValueError("Cannot load PyTorch checkpoint, please check your checkpoint path "
                          "and checkpoint type.")
     self.estimator = SparkEstimator(self.model, optimizer, self.model_dir)
Пример #7
0
 def load(self, checkpoint, loss=None):
     from zoo.orca.learn.utils import find_latest_checkpoint
     from bigdl.nn.layer import Model
     from bigdl.optim.optimizer import OptimMethod
     import os
     if loss is not None:
         from zoo.pipeline.api.torch import TorchLoss
         self.loss = TorchLoss.from_pytorch(loss)
     path, prefix, version = find_latest_checkpoint(checkpoint,
                                                    model_type="pytorch")
     if path is None:
         raise ValueError(
             "Cannot find PyTorch checkpoint, please check your checkpoint path."
         )
     try:
         self.model = Model.load(
             os.path.join(path, "model.{}".format(version)))
         optimizer = OptimMethod.load(
             os.path.join(path, "{}.{}".format(prefix, version)))
     except Exception:
         raise ValueError(
             "Cannot load PyTorch checkpoint, please check your checkpoint path "
             "and checkpoint type.")
     self.estimator = SparkEstimator(self.model, optimizer, self.model_dir)