def test_load_state_dict(self): state_dict = torch_module.state_dict() model_definition = torch.nn.Linear(1, 1) assert model_definition.state_dict() != state_dict assert load_torch_model(state_dict, model_definition).state_dict() == state_dict
def load_checkpoint( checkpoint: Checkpoint, model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module], tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None, *, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **pretrained_model_kwargs, ) -> Tuple[Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module], transformers.training_args.TrainingArguments, Optional[transformers.PreTrainedTokenizer], Optional["Preprocessor"], ]: """Load a Checkpoint from ``HuggingFaceTrainer``. Args: checkpoint: The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``HuggingFaceTrainer`` run. model: Either a ``transformers.PreTrainedModel`` class (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the weights to. This should be the same model used for training. tokenizer: A ``transformers.PreTrainedTokenizer`` class to load the model tokenizer to. If not specified, the tokenizer will not be loaded. Will throw an exception if specified, but no tokenizer was found in the checkpoint. tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained`` call. Ignored if ``tokenizer`` is None. **pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained`` call. Ignored if ``model`` is not a ``transformers.PreTrainedModel`` class. Returns: The model, ``TrainingArguments``, tokenizer and AIR preprocessor contained within. Those can be used to initialize a ``transformers.Trainer`` object locally. """ tokenizer_kwargs = tokenizer_kwargs or {} with checkpoint.as_directory() as checkpoint_path: preprocessor = load_preprocessor_from_dir(checkpoint_path) if isinstance(model, torch.nn.Module): state_dict = torch.load(os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu") model = load_torch_model(saved_model=state_dict, model_definition=model) else: model = model.from_pretrained(checkpoint_path, **pretrained_model_kwargs) if tokenizer: tokenizer = tokenizer.from_pretrained(checkpoint_path, **tokenizer_kwargs) training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) if os.path.exists(training_args_path): with open(training_args_path, "rb") as f: training_args = torch.load(f, map_location="cpu") else: training_args = None return model, training_args, tokenizer, preprocessor
def load_checkpoint( checkpoint: Checkpoint, model: Optional[torch.nn.Module] = None ) -> Tuple[torch.nn.Module, Optional["Preprocessor"]]: """Load a Checkpoint from ``TorchTrainer``. Args: checkpoint: The checkpoint to load the model and preprocessor from. It is expected to be from the result of a ``TorchTrainer`` run. model: If the checkpoint contains a model state dict, and not the model itself, then the state dict will be loaded to this ``model``. Returns: The model with set weights and AIR preprocessor contained within. """ saved_model, preprocessor = _load_checkpoint(checkpoint, "TorchTrainer") model = load_torch_model(saved_model=saved_model, model_definition=model) return model, preprocessor
def get_model( self, model: Union[ Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module ], **pretrained_model_kwargs, ) -> Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module]: """Retrieve the model stored in this checkpoint.""" with self.as_directory() as checkpoint_path: if isinstance(model, torch.nn.Module): state_dict = torch.load( os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" ) model = load_torch_model(saved_model=state_dict, model_definition=model) else: model = model.from_pretrained( checkpoint_path, **pretrained_model_kwargs ) return model
def test_load_state_dict_fail(self): with pytest.raises(ValueError): # model_definition is required to load state dict. load_torch_model(torch_module.state_dict())
def test_load_module(self): assert load_torch_model(torch_module) == torch_module