def load_frzn_model( model: torch.nn, path: str, current_args: Namespace = None, cuda: bool = None, logger: logging.Logger = None, ) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided. :param cuda: Whether to move model to cuda. :param logger: A logger. :return: The loaded MoleculeModel. """ debug = logger.debug if logger is not None else print loaded_mpnn_model = torch.load(path, map_location=lambda storage, loc: storage) loaded_state_dict = loaded_mpnn_model["state_dict"] loaded_args = loaded_mpnn_model["args"] model_state_dict = model.state_dict() if loaded_args.number_of_molecules == 1 and current_args.number_of_molecules == 1: encoder_param_names = [ "encoder.encoder.0.W_i.weight", "encoder.encoder.0.W_h.weight", "encoder.encoder.0.W_o.weight", "encoder.encoder.0.W_o.bias", ] if current_args.checkpoint_frzn is not None: # Freeze the MPNN for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: ffn_param_names = [[f"ffn.{i*3+1}.weight", f"ffn.{i*3+1}.bias"] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] # Freeze MPNN and FFN layers for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.freeze_first_only: debug( "WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)" ) elif loaded_args.number_of_molecules == 1 and current_args.number_of_molecules > 1: # TODO(degraff): these two `if`-blocks can be condensed into one if (current_args.checkpoint_frzn is not None and current_args.freeze_first_only and current_args.frzn_ffn_layers <= 0): # Only freeze first MPNN encoder_param_names = [ "encoder.encoder.0.W_i.weight", "encoder.encoder.0.W_h.weight", "encoder.encoder.0.W_o.weight", "encoder.encoder.0.W_o.bias", ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if ( current_args.checkpoint_frzn is not None and not current_args.freeze_first_only and current_args.frzn_ffn_layers <= 0 ): # Duplicate encoder from frozen checkpoint and overwrite all encoders loaded_encoder_param_names = [ "encoder.encoder.0.W_i.weight", "encoder.encoder.0.W_h.weight", "encoder.encoder.0.W_o.weight", "encoder.encoder.0.W_o.bias", ] * current_args.number_of_molecules model_encoder_param_names = [[( f"encoder.encoder.{mol_num}.W_i.weight", f"encoder.encoder.{mol_num}.W_h.weight", f"encoder.encoder.{mol_num}.W_o.weight", f"encoder.encoder.{mol_num}.W_o.bias", )] for mol_num in range(current_args.number_of_molecules)] model_encoder_param_names = [ item for sublist in model_encoder_param_names for item in sublist ] for loaded_param_name, model_param_name in zip( loaded_encoder_param_names, model_encoder_param_names): model_state_dict = overwrite_state_dict( loaded_param_name, model_param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: raise ValueError( f"Number of molecules from checkpoint_frzn ({loaded_args.number_of_molecules}) " f"must equal current number of molecules ({current_args.number_of_molecules})!" ) elif loaded_args.number_of_molecules > 1 and current_args.number_of_molecules > 1: if loaded_args.number_of_molecules != current_args.number_of_molecules: raise ValueError( f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) " f"must either match current model ({current_args.number_of_molecules}) or equal 1." ) if current_args.freeze_first_only: raise ValueError( f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) " "must be equal to 1 for freeze_first_only to be used!") if (current_args.checkpoint_frzn is not None) & (not (current_args.frzn_ffn_layers > 0)): encoder_param_names = [[( f"encoder.encoder.{mol_num}.W_i.weight", f"encoder.encoder.{mol_num}.W_h.weight", f"encoder.encoder.{mol_num}.W_o.weight", f"encoder.encoder.{mol_num}.W_o.bias", )] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: encoder_param_names = [[( f"encoder.encoder.{mol_num}.W_i.weight", f"encoder.encoder.{mol_num}.W_h.weight", f"encoder.encoder.{mol_num}.W_o.weight", f"encoder.encoder.{mol_num}.W_o.bias", )] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] ffn_param_names = [[f"ffn.{i+3+1}.weight", f"ffn.{i+3+1}.bias"] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers >= current_args.ffn_num_layers: raise ValueError( f"Number of frozen FFN layers ({current_args.frzn_ffn_layers}) " f"must be less than the number of FFN layers ({current_args.ffn_num_layers})!" ) # Load pretrained weights model.load_state_dict(model_state_dict) return model
def load_frzn_model(model: torch.nn, path: str, current_args: Namespace = None, cuda: bool = None, logger: logging.Logger = None) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided. :param cuda: Whether to move model to cuda. :param logger: A logger. :return: The loaded MoleculeModel. """ debug = logger.debug if logger is not None else print loaded_mpnn_model = torch.load(path, map_location=lambda storage, loc: storage) loaded_state_dict = loaded_mpnn_model['state_dict'] loaded_args = loaded_mpnn_model['args'] model_state_dict = model.state_dict() if loaded_args.number_of_molecules == 1 & current_args.number_of_molecules == 1: encoder_param_names = [ 'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight', 'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias' ] if current_args.checkpoint_frzn is not None: # Freeze the MPNN for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: ffn_param_names = [[ 'ffn.' + str(i * 3 + 1) + '.weight', 'ffn.' + str(i * 3 + 1) + '.bias' ] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] # Freeze MPNN and FFN layers for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.freeze_first_only: debug( f'WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)' ) elif (loaded_args.number_of_molecules == 1) & (current_args.number_of_molecules > 1): if (current_args.checkpoint_frzn is not None) & ( current_args.freeze_first_only ) & (not (current_args.frzn_ffn_layers > 0)): # Only freeze first MPNN encoder_param_names = [ 'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight', 'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias' ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if (current_args.checkpoint_frzn is not None) & ( not current_args.freeze_first_only ) & ( not (current_args.frzn_ffn_layers > 0) ): # Duplicate encoder from frozen checkpoint and overwrite all encoders loaded_encoder_param_names = [ 'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight', 'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias' ] * current_args.number_of_molecules model_encoder_param_names = [[ 'encoder.encoder.' + str(mol_num) + '.W_i.weight', 'encoder.encoder.' + str(mol_num) + '.W_h.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.bias' ] for mol_num in range(current_args.number_of_molecules)] model_encoder_param_names = [ item for sublist in model_encoder_param_names for item in sublist ] for loaded_param_name, model_param_name in zip( loaded_encoder_param_names, model_encoder_param_names): model_state_dict = overwrite_state_dict( loaded_param_name, model_param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: # Duplicate encoder from frozen checkpoint and overwrite all encoders + FFN layers raise Exception( 'Number of molecules in checkpoint_frzn must be equal to current model for ffn layers to be frozen' ) elif (loaded_args.number_of_molecules > 1) & (current_args.number_of_molecules > 1): if (loaded_args.number_of_molecules) != ( current_args.number_of_molecules): raise Exception( 'Number of molecules in checkpoint_frzn ({}) must match current model ({}) OR equal to 1.' .format(loaded_args.number_of_molecules, current_args.number_of_molecules)) if current_args.freeze_first_only: raise Exception( 'Number of molecules in checkpoint_frzn ({}) must be equal to 1 for freeze_first_only to be used.' .format(loaded_args.number_of_molecules)) if (current_args.checkpoint_frzn is not None) & (not (current_args.frzn_ffn_layers > 0)): encoder_param_names = [[ 'encoder.encoder.' + str(mol_num) + '.W_i.weight', 'encoder.encoder.' + str(mol_num) + '.W_h.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.bias' ] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: encoder_param_names = [[ 'encoder.encoder.' + str(mol_num) + '.W_i.weight', 'encoder.encoder.' + str(mol_num) + '.W_h.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.bias' ] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] ffn_param_names = [[ 'ffn.' + str(i * 3 + 1) + '.weight', 'ffn.' + str(i * 3 + 1) + '.bias' ] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers >= current_args.ffn_num_layers: raise Exception( 'Number of frozen FFN layers must be less than the number of FFN layers' ) # Load pretrained weights model.load_state_dict(model_state_dict) return model