def load_checkpoint(path: str, device: torch.device = None, logger: logging.Logger = None) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param device: Device where the model will be moved. :param logger: A logger for recording output. :return: The loaded :class:`~chemprop.models.model.MoleculeModel`. """ if logger is not None: debug, info = logger.debug, else: debug = info = print # Load model and args state = torch.load(path, map_location=lambda storage, loc: storage) args = TrainArgs() args.from_dict(vars(state['args']), skip_unsettable=True) loaded_state_dict = state['state_dict'] if device is not None: args.device = device # Build model model = MoleculeModel(args) model_state_dict = model.state_dict() # Skip missing parameters and parameters of mismatched size pretrained_state_dict = {} for loaded_param_name in loaded_state_dict.keys(): # Backward compatibility for parameter names if re.match(r'(encoder\.encoder\.)([Wc])', loaded_param_name): param_name = loaded_param_name.replace('encoder.encoder', 'encoder.encoder.0') else: param_name = loaded_param_name # Load pretrained parameter, skipping unmatched parameters if param_name not in model_state_dict: info(f'Warning: Pretrained parameter "{loaded_param_name}" cannot be found in model parameters.') elif model_state_dict[param_name].shape != loaded_state_dict[loaded_param_name].shape: info(f'Warning: Pretrained parameter "{loaded_param_name}" ' f'of shape {loaded_state_dict[loaded_param_name].shape} does not match corresponding ' f'model parameter of shape {model_state_dict[param_name].shape}.') else: debug(f'Loading pretrained parameter "{loaded_param_name}".') pretrained_state_dict[param_name] = loaded_state_dict[loaded_param_name] # Load pretrained weights model_state_dict.update(pretrained_state_dict) model.load_state_dict(model_state_dict) if args.cuda: debug('Moving model to cuda') model = return model
def load_args(path: str) -> TrainArgs: """ Loads the arguments a model was trained with. :param path: Path where model checkpoint is saved. :return: The :class:`~chemprop.args.TrainArgs` object that the model was trained with. """ args = TrainArgs() args.from_dict(vars(torch.load(path, map_location=lambda storage, loc: storage)['args']), skip_unsettable=True) return args
def chemprop_train() -> None: """Parses Chemprop training arguments and trains (cross-validates) a Chemprop model. This is the entry point for the command line command :code:`chemprop_train`. """ cross_validate(args=TrainArgs().parse_args(), train_func=run_training)
def update_prediction_args(predict_args: PredictArgs, train_args: TrainArgs, missing_to_defaults: bool = True, validate_feature_sources: bool = True) -> None: """ Updates prediction arguments with training arguments loaded from a checkpoint file. If an argument is present in both, the prediction argument will be used. Also raises errors for situations where the prediction arguments and training arguments are different but must match for proper function. :param predict_args: The :class:`~chemprop.args.PredictArgs` object containing the arguments to use for making predictions. :param train_args: The :class:`~chemprop.args.TrainArgs` object containing the arguments used to train the model previously. :param missing_to_defaults: Whether to replace missing training arguments with the current defaults for :class: `~chemprop.args.TrainArgs`. This is used for backwards compatibility. :param validate_feature_sources: Indicates whether the feature sources (from path or generator) are checked for consistency between the training and prediction arguments. This is not necessary for fingerprint generation, where molecule features are not used. """ for key, value in vars(train_args).items(): if not hasattr(predict_args, key): setattr(predict_args, key, value) if missing_to_defaults: # If a default argument would cause different behavior than occurred in legacy checkpoints before the argument existed, # then that argument must be included in the `override_defaults` dictionary to force the legacy behavior. override_defaults = { 'bond_features_scaling':False, 'no_bond_features_scaling':True, 'atom_descriptors_scaling':False, 'no_atom_descriptors_scaling':True, } default_train_args=TrainArgs().parse_args(['--data_path', None, '--dataset_type', str(train_args.dataset_type)]) for key, value in vars(default_train_args).items(): if not hasattr(predict_args,key): setattr(predict_args,key,override_defaults.get(key,value)) # Same number of molecules must be used in training as in making predictions if train_args.number_of_molecules != predict_args.number_of_molecules: raise ValueError('A different number of molecules was used in training ' f'model than is specified for prediction, {train_args.number_of_molecules} ' 'smiles fields must be provided') # If atom-descriptors were used during training, they must be used when predicting and vice-versa if train_args.atom_descriptors != predict_args.atom_descriptors: raise ValueError('The use of atom descriptors is inconsistent between training and prediction. If atom descriptors ' ' were used during training, they must be specified again during prediction using the same type of ' ' descriptors as before. If they were not used during training, they cannot be specified during prediction.') # If bond features were used during training, they must be used when predicting and vice-versa if (train_args.bond_features_path is None) != (predict_args.bond_features_path is None): raise ValueError('The use of bond descriptors is different between training and prediction. If you used bond ' 'descriptors for training, please specify a path to new bond descriptors for prediction.') # if atom or bond features were scaled, the same must be done during prediction if train_args.features_scaling != predict_args.features_scaling: raise ValueError('If scaling of the additional features was done during training, the ' 'same must be done during prediction.') # If atom descriptors were used during training, they must be used when predicting and vice-versa if train_args.atom_descriptors != predict_args.atom_descriptors: raise ValueError('The use of atom descriptors is inconsistent between training and prediction. ' 'If atom descriptors were used during training, they must be specified again ' 'during prediction using the same type of descriptors as before. ' 'If they were not used during training, they cannot be specified during prediction.') # If bond features were used during training, they must be used when predicting and vice-versa if (train_args.bond_features_path is None) != (predict_args.bond_features_path is None): raise ValueError('The use of bond descriptors is different between training and prediction. If you used bond' 'descriptors for training, please specify a path to new bond descriptors for prediction.') if validate_feature_sources: # If features were used during training, they must be used when predicting if ((train_args.features_path is None) != (predict_args.features_path is None)): # or ((train_args.features_generator is None) != (predict_args.features_generator is None))): raise ValueError('Features were used during training so they must be specified again during prediction ' 'using the same type of features as before (with either --features_generator or ' '--features_path and using --no_features_scaling if applicable).')