Esempio n. 1
0
def load_frzn_model(
    model: torch.nn,
    path: str,
    current_args: Namespace = None,
    cuda: bool = None,
    logger: logging.Logger = None,
) -> MoleculeModel:
    """
    Loads a model checkpoint.
    :param path: Path where checkpoint is saved.
    :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
    :param cuda: Whether to move model to cuda.
    :param logger: A logger.
    :return: The loaded MoleculeModel.
    """
    debug = logger.debug if logger is not None else print

    loaded_mpnn_model = torch.load(path,
                                   map_location=lambda storage, loc: storage)
    loaded_state_dict = loaded_mpnn_model["state_dict"]
    loaded_args = loaded_mpnn_model["args"]

    model_state_dict = model.state_dict()

    if loaded_args.number_of_molecules == 1 and current_args.number_of_molecules == 1:
        encoder_param_names = [
            "encoder.encoder.0.W_i.weight",
            "encoder.encoder.0.W_h.weight",
            "encoder.encoder.0.W_o.weight",
            "encoder.encoder.0.W_o.bias",
        ]
        if current_args.checkpoint_frzn is not None:
            # Freeze the MPNN
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            ffn_param_names = [[f"ffn.{i*3+1}.weight", f"ffn.{i*3+1}.bias"]
                               for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            # Freeze MPNN and FFN layers
            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.freeze_first_only:
            debug(
                "WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)"
            )

    elif loaded_args.number_of_molecules == 1 and current_args.number_of_molecules > 1:
        # TODO(degraff): these two `if`-blocks can be condensed into one
        if (current_args.checkpoint_frzn is not None
                and current_args.freeze_first_only and
                current_args.frzn_ffn_layers <= 0):  # Only freeze first MPNN
            encoder_param_names = [
                "encoder.encoder.0.W_i.weight",
                "encoder.encoder.0.W_h.weight",
                "encoder.encoder.0.W_o.weight",
                "encoder.encoder.0.W_o.bias",
            ]
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)
        if (
                current_args.checkpoint_frzn is not None
                and not current_args.freeze_first_only
                and current_args.frzn_ffn_layers <= 0
        ):  # Duplicate encoder from frozen checkpoint and overwrite all encoders
            loaded_encoder_param_names = [
                "encoder.encoder.0.W_i.weight",
                "encoder.encoder.0.W_h.weight",
                "encoder.encoder.0.W_o.weight",
                "encoder.encoder.0.W_o.bias",
            ] * current_args.number_of_molecules

            model_encoder_param_names = [[(
                f"encoder.encoder.{mol_num}.W_i.weight",
                f"encoder.encoder.{mol_num}.W_h.weight",
                f"encoder.encoder.{mol_num}.W_o.weight",
                f"encoder.encoder.{mol_num}.W_o.bias",
            )] for mol_num in range(current_args.number_of_molecules)]
            model_encoder_param_names = [
                item for sublist in model_encoder_param_names
                for item in sublist
            ]

            for loaded_param_name, model_param_name in zip(
                    loaded_encoder_param_names, model_encoder_param_names):
                model_state_dict = overwrite_state_dict(
                    loaded_param_name, model_param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            raise ValueError(
                f"Number of molecules from checkpoint_frzn ({loaded_args.number_of_molecules}) "
                f"must equal current number of molecules ({current_args.number_of_molecules})!"
            )

    elif loaded_args.number_of_molecules > 1 and current_args.number_of_molecules > 1:
        if loaded_args.number_of_molecules != current_args.number_of_molecules:
            raise ValueError(
                f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) "
                f"must either match current model ({current_args.number_of_molecules}) or equal 1."
            )

        if current_args.freeze_first_only:
            raise ValueError(
                f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) "
                "must be equal to 1 for freeze_first_only to be used!")

        if (current_args.checkpoint_frzn
                is not None) & (not (current_args.frzn_ffn_layers > 0)):
            encoder_param_names = [[(
                f"encoder.encoder.{mol_num}.W_i.weight",
                f"encoder.encoder.{mol_num}.W_h.weight",
                f"encoder.encoder.{mol_num}.W_o.weight",
                f"encoder.encoder.{mol_num}.W_o.bias",
            )] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]

            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            encoder_param_names = [[(
                f"encoder.encoder.{mol_num}.W_i.weight",
                f"encoder.encoder.{mol_num}.W_h.weight",
                f"encoder.encoder.{mol_num}.W_o.weight",
                f"encoder.encoder.{mol_num}.W_o.bias",
            )] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]
            ffn_param_names = [[f"ffn.{i+3+1}.weight", f"ffn.{i+3+1}.bias"]
                               for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers >= current_args.ffn_num_layers:
            raise ValueError(
                f"Number of frozen FFN layers ({current_args.frzn_ffn_layers}) "
                f"must be less than the number of FFN layers ({current_args.ffn_num_layers})!"
            )

    # Load pretrained weights
    model.load_state_dict(model_state_dict)

    return model
Esempio n. 2
0
def load_frzn_model(model: torch.nn,
                    path: str,
                    current_args: Namespace = None,
                    cuda: bool = None,
                    logger: logging.Logger = None) -> MoleculeModel:
    """
    Loads a model checkpoint.
    :param path: Path where checkpoint is saved.
    :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
    :param cuda: Whether to move model to cuda.
    :param logger: A logger.
    :return: The loaded MoleculeModel.
    """
    debug = logger.debug if logger is not None else print

    loaded_mpnn_model = torch.load(path,
                                   map_location=lambda storage, loc: storage)
    loaded_state_dict = loaded_mpnn_model['state_dict']
    loaded_args = loaded_mpnn_model['args']

    model_state_dict = model.state_dict()

    if loaded_args.number_of_molecules == 1 & current_args.number_of_molecules == 1:
        encoder_param_names = [
            'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight',
            'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias'
        ]
        if current_args.checkpoint_frzn is not None:
            # Freeze the MPNN
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            ffn_param_names = [[
                'ffn.' + str(i * 3 + 1) + '.weight',
                'ffn.' + str(i * 3 + 1) + '.bias'
            ] for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            # Freeze MPNN and FFN layers
            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.freeze_first_only:
            debug(
                f'WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)'
            )

    elif (loaded_args.number_of_molecules
          == 1) & (current_args.number_of_molecules > 1):

        if (current_args.checkpoint_frzn is not None) & (
                current_args.freeze_first_only
        ) & (not (current_args.frzn_ffn_layers > 0)):  # Only freeze first MPNN
            encoder_param_names = [
                'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight',
                'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias'
            ]
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if (current_args.checkpoint_frzn is not None) & (
                not current_args.freeze_first_only
        ) & (
                not (current_args.frzn_ffn_layers > 0)
        ):  # Duplicate encoder from frozen checkpoint and overwrite all encoders
            loaded_encoder_param_names = [
                'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight',
                'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias'
            ] * current_args.number_of_molecules
            model_encoder_param_names = [[
                'encoder.encoder.' + str(mol_num) + '.W_i.weight',
                'encoder.encoder.' + str(mol_num) + '.W_h.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.bias'
            ] for mol_num in range(current_args.number_of_molecules)]
            model_encoder_param_names = [
                item for sublist in model_encoder_param_names
                for item in sublist
            ]
            for loaded_param_name, model_param_name in zip(
                    loaded_encoder_param_names, model_encoder_param_names):
                model_state_dict = overwrite_state_dict(
                    loaded_param_name, model_param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:  # Duplicate encoder from frozen checkpoint and overwrite all encoders + FFN layers
            raise Exception(
                'Number of molecules in checkpoint_frzn must be equal to current model for ffn layers to be frozen'
            )

    elif (loaded_args.number_of_molecules >
          1) & (current_args.number_of_molecules > 1):
        if (loaded_args.number_of_molecules) != (
                current_args.number_of_molecules):
            raise Exception(
                'Number of molecules in checkpoint_frzn ({}) must match current model ({}) OR equal to 1.'
                .format(loaded_args.number_of_molecules,
                        current_args.number_of_molecules))

        if current_args.freeze_first_only:
            raise Exception(
                'Number of molecules in checkpoint_frzn ({}) must be equal to 1 for freeze_first_only to be used.'
                .format(loaded_args.number_of_molecules))

        if (current_args.checkpoint_frzn
                is not None) & (not (current_args.frzn_ffn_layers > 0)):
            encoder_param_names = [[
                'encoder.encoder.' + str(mol_num) + '.W_i.weight',
                'encoder.encoder.' + str(mol_num) + '.W_h.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.bias'
            ] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]

            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:

            encoder_param_names = [[
                'encoder.encoder.' + str(mol_num) + '.W_i.weight',
                'encoder.encoder.' + str(mol_num) + '.W_h.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.bias'
            ] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]
            ffn_param_names = [[
                'ffn.' + str(i * 3 + 1) + '.weight',
                'ffn.' + str(i * 3 + 1) + '.bias'
            ] for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers >= current_args.ffn_num_layers:
            raise Exception(
                'Number of frozen FFN layers must be less than the number of FFN layers'
            )

    # Load pretrained weights
    model.load_state_dict(model_state_dict)

    return model