def load_checkpoint(path: str, device: torch.device = None, logger: logging.Logger = None) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param device: Device where the model will be moved. :param logger: A logger for recording output. :return: The loaded :class:`~chemprop.models.model.MoleculeModel`. """ if logger is not None: debug, info = logger.debug, logger.info else: debug = info = print # Load model and args state = torch.load(path, map_location=lambda storage, loc: storage) args = TrainArgs() args.from_dict(vars(state['args']), skip_unsettable=True) loaded_state_dict = state['state_dict'] if device is not None: args.device = device # Build model model = MoleculeModel(args) model_state_dict = model.state_dict() # Skip missing parameters and parameters of mismatched size pretrained_state_dict = {} for loaded_param_name in loaded_state_dict.keys(): # Backward compatibility for parameter names if re.match(r'(encoder\.encoder\.)([Wc])', loaded_param_name): param_name = loaded_param_name.replace('encoder.encoder', 'encoder.encoder.0') else: param_name = loaded_param_name # Load pretrained parameter, skipping unmatched parameters if param_name not in model_state_dict: info(f'Warning: Pretrained parameter "{loaded_param_name}" cannot be found in model parameters.') elif model_state_dict[param_name].shape != loaded_state_dict[loaded_param_name].shape: info(f'Warning: Pretrained parameter "{loaded_param_name}" ' f'of shape {loaded_state_dict[loaded_param_name].shape} does not match corresponding ' f'model parameter of shape {model_state_dict[param_name].shape}.') else: debug(f'Loading pretrained parameter "{loaded_param_name}".') pretrained_state_dict[param_name] = loaded_state_dict[loaded_param_name] # Load pretrained weights model_state_dict.update(pretrained_state_dict) model.load_state_dict(model_state_dict) if args.cuda: debug('Moving model to cuda') model = model.to(args.device) return model
def load_args(path: str) -> TrainArgs: """ Loads the arguments a model was trained with. :param path: Path where model checkpoint is saved. :return: The :class:`~chemprop.args.TrainArgs` object that the model was trained with. """ args = TrainArgs() args.from_dict(vars(torch.load(path, map_location=lambda storage, loc: storage)['args']), skip_unsettable=True) return args