def load_orca_checkpoint(self, path, version=None, prefix=None): """ Load existing checkpoint. To load a specific checkpoint, please provide both `version` and `perfix`. If `version` is None, then the latest checkpoint under the specified directory will be loaded. :param path: Path to the existing checkpoint (or directory containing Orca checkpoint files). :param version: checkpoint version, which is the suffix of model.* file, i.e., for modle.4 file, the version is 4. If it is None, then load the latest checkpoint. :param prefix: optimMethod prefix, for example 'optimMethod-Sequentialf53bddcc' :return: """ from bigdl.dllib.nn.layer import Model, Container from bigdl.dllib.optim.optimizer import OptimMethod from bigdl.orca.learn.utils import find_latest_checkpoint import os if version is None: path, prefix, version = find_latest_checkpoint(path, model_type="bigdl") if path is None: raise ValueError( "Cannot find BigDL checkpoint, please check your checkpoint" " path.") else: assert prefix is not None, "You should provide optimMethod prefix, " \ "for example 'optimMethod-TorchModelf53bddcc'" try: self.model = Model.load( os.path.join(path, "model.{}".format(version))) assert isinstance(self.model, Container), \ "The loaded model should be a Container, please check your checkpoint type." self.optimizer = OptimMethod.load( os.path.join(path, "{}.{}".format(prefix, version))) except Exception: raise ValueError( "Cannot load BigDL checkpoint, please check your checkpoint path " "and checkpoint type.") self.estimator = SparkEstimator(self.model, self.optimizer, self.model_dir) self.nn_estimator = NNEstimator(self.model, self.loss, self.feature_preprocessing, self.label_preprocessing) if self.optimizer is not None: self.nn_estimator.setOptimMethod(self.optimizer) self.nn_model = NNModel( self.model, feature_preprocessing=self.feature_preprocessing)
def load_orca_checkpoint(self, path, version=None, prefix=None): """ Load existing checkpoint. To load a specific checkpoint, please provide both `version` and `perfix`. If `version` is None, then the latest checkpoint will be loaded. :param path: Path to the existing checkpoint (or directory containing Orca checkpoint files). :param version: checkpoint version, which is the suffix of model.* file, i.e., for modle.4 file, the version is 4. If it is None, then load the latest checkpoint. :param prefix: optimMethod prefix, for example 'optimMethod-TorchModelf53bddcc'. :return: """ import os from bigdl.dllib.nn.layer import Model from bigdl.dllib.optim.optimizer import OptimMethod from bigdl.orca.learn.utils import find_latest_checkpoint from bigdl.orca.torch import TorchModel if version is None: path, prefix, version = find_latest_checkpoint( path, model_type="pytorch") if path is None: raise ValueError( "Cannot find PyTorch checkpoint, please check your checkpoint" " path.") else: assert prefix is not None, "You should provide optimMethod prefix, " \ "for example 'optimMethod-TorchModelf53bddcc'" try: loaded_model = Model.load( os.path.join(path, "model.{}".format(version))) self.model = TorchModel.from_value(loaded_model.value) self.optimizer = OptimMethod.load( os.path.join(path, "{}.{}".format(prefix, version))) except Exception as e: raise ValueError( "Cannot load PyTorch checkpoint, please check your checkpoint path " "and checkpoint type." + str(e)) self.estimator = SparkEstimator(self.model, self.optimizer, self.model_dir)
def test_model_save_and_load(self): class SimpleTorchModel(nn.Module): def __init__(self): super(SimpleTorchModel, self).__init__() self.dense1 = nn.Linear(2, 4) self.dense2 = nn.Linear(4, 1) def forward(self, x): x = self.dense1(x) x = torch.sigmoid(self.dense2(x)) return x torch_model = SimpleTorchModel() az_model = TorchModel.from_pytorch(torch_model) with tempfile.TemporaryDirectory() as tmp_dir_name: path = tmp_dir_name + "/model.obj" az_model.save(path, True) loaded_model = Model.load(path) loaded_torchModel = TorchModel.from_value(loaded_model.value) dummy_input = torch.ones(16, 2) loaded_torchModel.forward(dummy_input.numpy()) loaded_torchModel.to_pytorch()