Пример #1
0
    def load_orca_checkpoint(self, path, version=None, prefix=None):
        """
        Load existing checkpoint. To load a specific checkpoint, please provide both `version`
        and `perfix`. If `version` is None, then the latest checkpoint under the specified
        directory will be loaded.

        :param path: Path to the existing checkpoint (or directory containing Orca checkpoint
               files).
        :param version: checkpoint version, which is the suffix of model.* file, i.e., for
               modle.4 file, the version is 4. If it is None, then load the latest checkpoint.
        :param prefix: optimMethod prefix, for example 'optimMethod-Sequentialf53bddcc'
        :return:
        """
        from bigdl.dllib.nn.layer import Model, Container
        from bigdl.dllib.optim.optimizer import OptimMethod
        from bigdl.orca.learn.utils import find_latest_checkpoint
        import os

        if version is None:
            path, prefix, version = find_latest_checkpoint(path,
                                                           model_type="bigdl")
            if path is None:
                raise ValueError(
                    "Cannot find BigDL checkpoint, please check your checkpoint"
                    " path.")
        else:
            assert prefix is not None, "You should provide optimMethod prefix, " \
                                       "for example 'optimMethod-TorchModelf53bddcc'"

        try:
            self.model = Model.load(
                os.path.join(path, "model.{}".format(version)))
            assert isinstance(self.model, Container), \
                "The loaded model should be a Container, please check your checkpoint type."
            self.optimizer = OptimMethod.load(
                os.path.join(path, "{}.{}".format(prefix, version)))
        except Exception:
            raise ValueError(
                "Cannot load BigDL checkpoint, please check your checkpoint path "
                "and checkpoint type.")
        self.estimator = SparkEstimator(self.model, self.optimizer,
                                        self.model_dir)
        self.nn_estimator = NNEstimator(self.model, self.loss,
                                        self.feature_preprocessing,
                                        self.label_preprocessing)
        if self.optimizer is not None:
            self.nn_estimator.setOptimMethod(self.optimizer)
        self.nn_model = NNModel(
            self.model, feature_preprocessing=self.feature_preprocessing)
    def load_orca_checkpoint(self, path, version=None, prefix=None):
        """
        Load existing checkpoint. To load a specific checkpoint, please provide both `version` and
        `perfix`. If `version` is None, then the latest checkpoint will be loaded.

        :param path: Path to the existing checkpoint (or directory containing Orca checkpoint
               files).
        :param version: checkpoint version, which is the suffix of model.* file, i.e., for
               modle.4 file, the version is 4. If it is None, then load the latest checkpoint.
        :param prefix: optimMethod prefix, for example 'optimMethod-TorchModelf53bddcc'.
        :return:
        """
        import os
        from bigdl.dllib.nn.layer import Model
        from bigdl.dllib.optim.optimizer import OptimMethod
        from bigdl.orca.learn.utils import find_latest_checkpoint
        from bigdl.orca.torch import TorchModel

        if version is None:
            path, prefix, version = find_latest_checkpoint(
                path, model_type="pytorch")
            if path is None:
                raise ValueError(
                    "Cannot find PyTorch checkpoint, please check your checkpoint"
                    " path.")
        else:
            assert prefix is not None, "You should provide optimMethod prefix, " \
                                       "for example 'optimMethod-TorchModelf53bddcc'"

        try:
            loaded_model = Model.load(
                os.path.join(path, "model.{}".format(version)))
            self.model = TorchModel.from_value(loaded_model.value)
            self.optimizer = OptimMethod.load(
                os.path.join(path, "{}.{}".format(prefix, version)))
        except Exception as e:
            raise ValueError(
                "Cannot load PyTorch checkpoint, please check your checkpoint path "
                "and checkpoint type." + str(e))
        self.estimator = SparkEstimator(self.model, self.optimizer,
                                        self.model_dir)
Пример #3
0
    def test_model_save_and_load(self):
        class SimpleTorchModel(nn.Module):
            def __init__(self):
                super(SimpleTorchModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        torch_model = SimpleTorchModel()
        az_model = TorchModel.from_pytorch(torch_model)

        with tempfile.TemporaryDirectory() as tmp_dir_name:
            path = tmp_dir_name + "/model.obj"
            az_model.save(path, True)
            loaded_model = Model.load(path)
            loaded_torchModel = TorchModel.from_value(loaded_model.value)
            dummy_input = torch.ones(16, 2)
            loaded_torchModel.forward(dummy_input.numpy())
            loaded_torchModel.to_pytorch()