def __init__(self, device: Union[None, str, torch.device] = None, **kwargs): super().__init__(device, **kwargs) alphabet = esm.Alphabet.from_dict(proteinseq_toks) if torch.cuda.is_available(): model_data = torch.load(self._options["model_file"]) else: model_data = torch.load(self._options["model_file"], map_location=torch.device('cpu')) # upgrade state dict pra = lambda s: ''.join( s.split('decoder_')[1:] if 'decoder' in s else s) prs = lambda s: ''.join( s.split('decoder.')[1:] if 'decoder' in s else s) model_args = { pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items() } model_state = { prs(arg[0]): arg[1] for arg in model_data["model"].items() } model = esm.ProteinBertModel(Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx) model.load_state_dict(model_state) self._model = model.to(self._device) self._batch_converter = alphabet.get_batch_converter()
def load_model_and_alphabet_hub(model_name): alphabet = esm.Alphabet.from_dict(proteinseq_toks) url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt" if torch.cuda.is_available(): model_data = torch.hub.load_state_dict_from_url(url, progress=False) else: model_data = torch.hub.load_state_dict_from_url( url, progress=False, map_location=torch.device('cpu')) # upgrade state dict pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s) prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s) model_args = { pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items() } model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()} model = esm.ProteinBertModel(Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx) model.load_state_dict(model_state) return model, alphabet
def load_model_and_alphabet_local(model_location): alphabet = esm.Alphabet.from_dict(proteinseq_toks) model_data = torch.load(model_location) # upgrade state dict pra = lambda s: ''.join(s.split('decoder_')[1:] if 'decoder' in s else s) prs = lambda s: ''.join(s.split('decoder.')[1:] if 'decoder' in s else s) model_args = { pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items() } model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()} model = esm.ProteinBertModel(Namespace(**model_args), len(alphabet), padding_idx=alphabet.padding_idx) model.load_state_dict(model_state) return model, alphabet
def load_model_and_alphabet_core(model_data, regression_data=None): if regression_data is not None: model_data["model"].update(regression_data["model"]) if model_data["args"].arch == 'roberta_large': alphabet = esm.RobertaAlphabet.from_dict(proteinseq_toks) # upgrade state dict pra = lambda s: ''.join( s.split('encoder_')[1:] if 'encoder' in s else s) prs1 = lambda s: ''.join( s.split('encoder.')[1:] if 'encoder' in s else s) prs2 = lambda s: ''.join( s.split('sentence_encoder.')[1:] if 'sentence_encoder' in s else s) model_args = { pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items() } model_state = { prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items() } model_state["embed_tokens.weight"][ alphabet.mask_idx].zero_() # For token drop elif model_data["args"].arch == 'protein_bert_base': alphabet = esm.Alphabet.from_dict(proteinseq_toks) # upgrade state dict pra = lambda s: ''.join( s.split('decoder_')[1:] if 'decoder' in s else s) prs = lambda s: ''.join( s.split('decoder.')[1:] if 'decoder' in s else s) model_args = { pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items() } model_state = { prs(arg[0]): arg[1] for arg in model_data["model"].items() } else: raise ValueError("Unkown architecture selected") model = esm.ProteinBertModel( Namespace(**model_args), alphabet, ) expected_keys = set(model.state_dict().keys()) found_keys = set(model_state.keys()) if regression_data is None: expected_missing = { "contact_head.regression.weight", "contact_head.regression.bias" } error_msgs = [] missing = (expected_keys - found_keys) - expected_missing if missing: error_msgs.append(f"Missing key(s) in state_dict: {missing}.") unexpected = found_keys - expected_keys if unexpected: error_msgs.append( f"Unexpected key(s) in state_dict: {unexpected}.") if error_msgs: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs))) if expected_missing - found_keys: warnings.warn( "Regression weights not found, predicting contacts will not produce correct results." ) model.load_state_dict(model_state, strict=regression_data is not None) return model, alphabet