def save_checkpoint(path: str, model: MoleculeModel, scaler: StandardScaler = None, features_scaler: StandardScaler = None, args: Namespace = None): """ Saves a model checkpoint. :param model: A MoleculeModel. :param scaler: A StandardScaler fitted on the data. :param features_scaler: A StandardScaler fitted on the features. :param args: Arguments namespace. :param path: Path where checkpoint will be saved. """ state = { 'args': args, 'state_dict': model.state_dict(), 'data_scaler': { 'means': scaler.means, 'stds': scaler.stds } if scaler is not None else None, 'features_scaler': { 'means': features_scaler.means, 'stds': features_scaler.stds } if features_scaler is not None else None } torch.save(state, path)
def save_checkpoint(path: str, model: MoleculeModel, scaler: StandardScaler = None, features_scaler: StandardScaler = None, args: TrainArgs = None) -> None: """ Saves a model checkpoint. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param scaler: A :class:`~chemprop.data.scaler.StandardScaler` fitted on the data. :param features_scaler: A :class:`~chemprop.data.scaler.StandardScaler` fitted on the features. :param args: The :class:`~chemprop.args.TrainArgs` object containing the arguments the model was trained with. :param path: Path where checkpoint will be saved. """ # Convert args to namespace for backwards compatibility if args is not None: args = Namespace(**args.as_dict()) state = { 'args': args, 'state_dict': model.state_dict(), 'data_scaler': { 'means': scaler.means, 'stds': scaler.stds } if scaler is not None else None, 'features_scaler': { 'means': features_scaler.means, 'stds': features_scaler.stds } if features_scaler is not None else None } torch.save(state, path)
def save_checkpoint(path: str, model: MoleculeModel, scaler: StandardScaler = None, features_scaler: StandardScaler = None, args: TrainArgs = None): """ Saves a model checkpoint. :param model: A MoleculeModel. :param scaler: A StandardScaler fitted on the data. :param features_scaler: A StandardScaler fitted on the features. :param args: Arguments. :param path: Path where checkpoint will be saved. """ # Convert args to namespace for backwards compatibility if args is not None: args = Namespace(**args.as_dict()) state = { 'args': args, 'state_dict': model.state_dict(), 'data_scaler': { 'means': scaler.means, 'stds': scaler.stds } if scaler is not None else None, 'features_scaler': { 'means': features_scaler.means, 'stds': features_scaler.stds } if features_scaler is not None else None } torch.save(state, path)
def load_checkpoint(path: str, device: torch.device = None, logger: logging.Logger = None) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param device: Device where the model will be moved. :param logger: A logger for recording output. :return: The loaded :class:`~chemprop.models.model.MoleculeModel`. """ if logger is not None: debug, info = logger.debug, logger.info else: debug = info = print # Load model and args state = torch.load(path, map_location=lambda storage, loc: storage) args = TrainArgs() args.from_dict(vars(state['args']), skip_unsettable=True) loaded_state_dict = state['state_dict'] if device is not None: args.device = device # Build model model = MoleculeModel(args) model_state_dict = model.state_dict() # Skip missing parameters and parameters of mismatched size pretrained_state_dict = {} for loaded_param_name in loaded_state_dict.keys(): # Backward compatibility for parameter names if re.match(r'(encoder\.encoder\.)([Wc])', loaded_param_name): param_name = loaded_param_name.replace('encoder.encoder', 'encoder.encoder.0') else: param_name = loaded_param_name # Load pretrained parameter, skipping unmatched parameters if param_name not in model_state_dict: info(f'Warning: Pretrained parameter "{loaded_param_name}" cannot be found in model parameters.') elif model_state_dict[param_name].shape != loaded_state_dict[loaded_param_name].shape: info(f'Warning: Pretrained parameter "{loaded_param_name}" ' f'of shape {loaded_state_dict[loaded_param_name].shape} does not match corresponding ' f'model parameter of shape {model_state_dict[param_name].shape}.') else: debug(f'Loading pretrained parameter "{loaded_param_name}".') pretrained_state_dict[param_name] = loaded_state_dict[loaded_param_name] # Load pretrained weights model_state_dict.update(pretrained_state_dict) model.load_state_dict(model_state_dict) if args.cuda: debug('Moving model to cuda') model = model.to(args.device) return model
def save_checkpoint( path: str, model: MoleculeModel, scaler: StandardScaler = None, features_scaler: StandardScaler = None, atom_descriptor_scaler: StandardScaler = None, bond_feature_scaler: StandardScaler = None, args: TrainArgs = None, ) -> None: """ Saves a model checkpoint. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param scaler: A :class:`~chemprop.data.scaler.StandardScaler` fitted on the data. :param features_scaler: A :class:`~chemprop.data.scaler.StandardScaler` fitted on the features. :param atom_descriptor_scaler: A :class:`~chemprop.data.scaler.StandardScaler` fitted on the atom descriptors. :param bond_feature_scaler: A :class:`~chemprop.data.scaler.StandardScaler` fitted on the bond_fetaures. :param args: The :class:`~chemprop.args.TrainArgs` object containing the arguments the model was trained with. :param path: Path where checkpoint will be saved. """ # Convert args to namespace for backwards compatibility if args is not None: args = Namespace(**args.as_dict()) data_scaler = { "means": scaler.means, "stds": scaler.stds } if scaler is not None else None if features_scaler is not None: features_scaler = { "means": features_scaler.means, "stds": features_scaler.stds } if atom_descriptor_scaler is not None: atom_descriptor_scaler = { "means": atom_descriptor_scaler.means, "stds": atom_descriptor_scaler.stds, } if bond_feature_scaler is not None: bond_feature_scaler = { "means": bond_feature_scaler.means, "stds": bond_feature_scaler.stds } state = { "args": args, "state_dict": model.state_dict(), "data_scaler": data_scaler, "features_scaler": features_scaler, "atom_descriptor_scaler": atom_descriptor_scaler, "bond_feature_scaler": bond_feature_scaler, } torch.save(state, path)
def load_checkpoint(path: str, device: torch.device = None, logger: logging.Logger = None, template=None) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param device: Device where the model will be moved. :param logger: A logger. :return: The loaded MoleculeModel. """ if logger is not None: debug, info = logger.debug, logger.info else: debug = info = print # Load model and args state = torch.load(path, map_location=lambda storage, loc: storage) args = TrainArgs() args.from_dict(vars(state['args']), skip_unsettable=True) loaded_state_dict = state['state_dict'] if device is not None: args.device = device # Build model if template is not None: model = template else: model = MoleculeModel(args) model_state_dict = model.state_dict() # Skip missing parameters and parameters of mismatched size pretrained_state_dict = {} for param_name in loaded_state_dict.keys(): if param_name not in model_state_dict: info( f'Warning: Pretrained parameter "{param_name}" cannot be found in model parameters.' ) elif model_state_dict[param_name].shape != loaded_state_dict[ param_name].shape: info( f'Warning: Pretrained parameter "{param_name}" ' f'of shape {loaded_state_dict[param_name].shape} does not match corresponding ' f'model parameter of shape {model_state_dict[param_name].shape}.' ) else: #debug(f'Loading pretrained parameter "{param_name}".') pretrained_state_dict[param_name] = loaded_state_dict[param_name] # Load pretrained weights model_state_dict.update(pretrained_state_dict) model.load_state_dict(model_state_dict) if args.cuda: debug('Moving model to cuda') model = model.to(args.device) return model