Пример #1
0
    def __init__(self,
                 device: Union[None, str, torch.device] = None,
                 **kwargs):
        super().__init__(device, **kwargs)

        alphabet = esm.Alphabet.from_dict(proteinseq_toks)
        if torch.cuda.is_available():
            model_data = torch.load(self._options["model_file"])
        else:
            model_data = torch.load(self._options["model_file"],
                                    map_location=torch.device('cpu'))

        # upgrade state dict
        pra = lambda s: ''.join(
            s.split('decoder_')[1:] if 'decoder' in s else s)
        prs = lambda s: ''.join(
            s.split('decoder.')[1:] if 'decoder' in s else s)
        model_args = {
            pra(arg[0]): arg[1]
            for arg in vars(model_data["args"]).items()
        }
        model_state = {
            prs(arg[0]): arg[1]
            for arg in model_data["model"].items()
        }
        model = esm.ProteinBertModel(Namespace(**model_args),
                                     len(alphabet),
                                     padding_idx=alphabet.padding_idx)
        model.load_state_dict(model_state)

        self._model = model.to(self._device)
        self._batch_converter = alphabet.get_batch_converter()
Пример #2
0
def load_model_and_alphabet_hub(model_name):
    alphabet = esm.Alphabet.from_dict(proteinseq_toks)

    url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
    if torch.cuda.is_available():
        model_data = torch.hub.load_state_dict_from_url(url, progress=False)
    else:
        model_data = torch.hub.load_state_dict_from_url(
            url, progress=False, map_location=torch.device('cpu'))

    # upgrade state dict
    pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
    prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
    model_args = {
        pra(arg[0]): arg[1]
        for arg in vars(model_data["args"]).items()
    }
    model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

    model = esm.ProteinBertModel(Namespace(**model_args),
                                 len(alphabet),
                                 padding_idx=alphabet.padding_idx)
    model.load_state_dict(model_state)

    return model, alphabet
Пример #3
0
def load_model_and_alphabet_local(model_location):
    alphabet = esm.Alphabet.from_dict(proteinseq_toks)

    model_data = torch.load(model_location)

    # upgrade state dict
    pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s)
    prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s)
    model_args = {
        pra(arg[0]): arg[1]
        for arg in vars(model_data["args"]).items()
    }
    model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()}

    model = esm.ProteinBertModel(Namespace(**model_args),
                                 len(alphabet),
                                 padding_idx=alphabet.padding_idx)
    model.load_state_dict(model_state)
    return model, alphabet
Пример #4
0
def load_model_and_alphabet_core(model_data, regression_data=None):
    if regression_data is not None:
        model_data["model"].update(regression_data["model"])
    if model_data["args"].arch == 'roberta_large':
        alphabet = esm.RobertaAlphabet.from_dict(proteinseq_toks)
        # upgrade state dict
        pra = lambda s: ''.join(
            s.split('encoder_')[1:] if 'encoder' in s else s)
        prs1 = lambda s: ''.join(
            s.split('encoder.')[1:] if 'encoder' in s else s)
        prs2 = lambda s: ''.join(
            s.split('sentence_encoder.')[1:] if 'sentence_encoder' in s else s)
        model_args = {
            pra(arg[0]): arg[1]
            for arg in vars(model_data["args"]).items()
        }
        model_state = {
            prs1(prs2(arg[0])): arg[1]
            for arg in model_data["model"].items()
        }
        model_state["embed_tokens.weight"][
            alphabet.mask_idx].zero_()  # For token drop
    elif model_data["args"].arch == 'protein_bert_base':
        alphabet = esm.Alphabet.from_dict(proteinseq_toks)

        # upgrade state dict
        pra = lambda s: ''.join(
            s.split('decoder_')[1:] if 'decoder' in s else s)
        prs = lambda s: ''.join(
            s.split('decoder.')[1:] if 'decoder' in s else s)
        model_args = {
            pra(arg[0]): arg[1]
            for arg in vars(model_data["args"]).items()
        }
        model_state = {
            prs(arg[0]): arg[1]
            for arg in model_data["model"].items()
        }
    else:
        raise ValueError("Unkown architecture selected")
    model = esm.ProteinBertModel(
        Namespace(**model_args),
        alphabet,
    )

    expected_keys = set(model.state_dict().keys())
    found_keys = set(model_state.keys())

    if regression_data is None:
        expected_missing = {
            "contact_head.regression.weight", "contact_head.regression.bias"
        }
        error_msgs = []
        missing = (expected_keys - found_keys) - expected_missing
        if missing:
            error_msgs.append(f"Missing key(s) in state_dict: {missing}.")
        unexpected = found_keys - expected_keys
        if unexpected:
            error_msgs.append(
                f"Unexpected key(s) in state_dict: {unexpected}.")

        if error_msgs:
            raise RuntimeError(
                "Error(s) in loading state_dict for {}:\n\t{}".format(
                    model.__class__.__name__, "\n\t".join(error_msgs)))
        if expected_missing - found_keys:
            warnings.warn(
                "Regression weights not found, predicting contacts will not produce correct results."
            )

    model.load_state_dict(model_state, strict=regression_data is not None)

    return model, alphabet