def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, d_model=512, d_ff=2048, h=8, dropout=0.1): "Helper: Construct a model from hyperparameters." enc_config = BertConfig(vocab_size=1, hidden_size=d_model, num_hidden_layers=N_enc, num_attention_heads=h, intermediate_size=d_ff, hidden_dropout_prob=dropout, attention_probs_dropout_prob=dropout, max_position_embeddings=1, type_vocab_size=1) dec_config = BertConfig(vocab_size=tgt_vocab, hidden_size=d_model, num_hidden_layers=N_dec, num_attention_heads=h, intermediate_size=d_ff, hidden_dropout_prob=dropout, attention_probs_dropout_prob=dropout, max_position_embeddings=17, # max_position_embeddings=51, type_vocab_size=1, is_decoder=True) encoder = BertModel(enc_config) def return_embeds(*args, **kwargs): return kwargs['inputs_embeds'] del encoder.embeddings; encoder.embeddings = return_embeds decoder = BertModel(dec_config) model = EncoderDecoder( encoder, decoder, Generator(d_model, tgt_vocab)) return model
class SlotAttention(BertPreTrainedModel): def __init__(self, config, args): super(SlotAttention, self).__init__(config) self.num_labels = config.num_labels self.cls_lambda = args.cls_lambda self.ans_lambda = args.ans_lambda self.bert = BertModel(config) self.start_layer = nn.Linear(config.hidden_size, config.hidden_size) self.end_layer = nn.Linear(config.hidden_size, config.hidden_size) self.type_attention = Att_Layer(config) self.cls_layer = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, value_types=None, slot_input_ids=None, start_positions=None, end_positions=None): sequence_output, pool_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False) slot_hidden = self.bert.embeddings(slot_input_ids[0][0]).mean(-2) type_att_output = self.type_attention(sequence_output, slot_hidden, attention_mask) type_logits = self.cls_layer(type_att_output) start_hidden = self.start_layer(slot_hidden) end_hidden = self.end_layer(slot_hidden) start_logits = torch.matmul(start_hidden, sequence_output.permute(0, 2, 1)) end_logits = torch.matmul(end_hidden, sequence_output.permute(0, 2, 1)) outputs = (type_logits, start_logits, end_logits) if value_types is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) cls_loss = loss_fct(type_logits.view(-1, self.num_labels), value_types.view(-1)) start_loss = loss_fct( start_logits.reshape(-1, start_logits.size(-1)), start_positions.view(-1)) end_loss = loss_fct(end_logits.reshape(-1, end_logits.size(-1)), end_positions.view(-1)) total_loss = self.ans_lambda * ( start_loss + end_loss) + self.cls_lambda * cls_loss outputs = (total_loss, cls_loss, start_loss, end_loss) + outputs return outputs
class BertEncoder(BertPreTrainedModel): def __init__(self, config): super(BertEncoder, self).__init__(config) self.bert = BertModel(config) def forward(self, input_ids, token_type_ids, attention_mask, label_id=None): # output_all_encoded_layers=False): bert_encode, _ = self.bert( input_ids, token_type_ids, attention_mask, ) # output_all_encoded_layers=output_all_encoded_layers) bert_embeddings = self.bert.embeddings(input_ids, token_type_ids) return bert_encode, bert_embeddings
class RandomBert(nn.Module): def __init__(self, config, num_class): super(RandomBert, self).__init__() self.bert = BertModel(BertConfig()) self.drop_out = nn.Dropout(p=config['hidden_dropout_prob']) self.classifier = nn.Linear(self.bert.config.hidden_size, num_class) def forward(self, input_ids, attention_mask, token_type, **kwargs): try: out = self.bert(token_type_ids=token_type, attention_mask=attention_mask, \ inputs_embeds=kwargs.get('input_embeds'))[1] except: if kwargs.get('return_embed') == None: out = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type)[1] else: return self.bert.embeddings(input_ids=input_ids, token_type_ids=token_type) out = self.drop_out(out) out = self.classifier(out) return out
class BertFold(nn.Module): def __init__( self, pretrained: bool = True, gradient_checkpointing: bool = False, ): super().__init__() if pretrained: self.bert = BertModel.from_pretrained( 'Rostlab/prot_bert_bfd', gradient_checkpointing=gradient_checkpointing, ) else: conf = BertConfig.from_pretrained('Rostlab/prot_bert_bfd') self.bert = BertModel(conf) # noinspection PyUnresolvedReferences dim = self.bert.config.hidden_size self.evo_linear = nn.Linear(21, dim) self.decoder_dist = PairwiseDistanceDecoder(dim) # self.decoder_phi = ElementwiseAngleDecoder(dim, 2) # self.decoder_psi = ElementwiseAngleDecoder(dim, 2) self.evo_linear.apply(init_weights) self.decoder_dist.apply(init_weights) # self.decoder_phi.apply(init_weights) # self.decoder_psi.apply(init_weights) del self.bert.pooler def forward( self, inputs: ProteinNetBatch, targets: Optional[BertFoldTargets] = None, ) -> BertFoldOutput: x_emb = self.bert.embeddings(inputs['input_ids']) x_evo = self.evo_linear(inputs['evo'].type_as(x_emb)) x = x_emb + x_evo extended_attention_mask = self.bert.get_extended_attention_mask( inputs['attention_mask'], inputs['input_ids'].shape, inputs['input_ids'].device, ) x = self.bert.encoder.forward( x, attention_mask=extended_attention_mask)[0] # x = self.bert.forward( # inputs['input_ids'], # attention_mask=inputs['attention_mask'], # )[0] # x = torch.cat(( # x, # inputs['evo'].type_as(x), # ), dim=-1) targets_dist = None if targets is None else targets.dist # targets_phi = None if targets is None else targets.phi # targets_psi = None if targets is None else targets.psi outs = [ self.decoder_dist.forward(x, targets_dist), # self.decoder_phi.forward(x, targets_phi), # self.decoder_psi.forward(x, targets_psi), ] y_hat = tuple(x.y_hat for x in outs) if targets is None: return BertFoldOutput(y_hat=y_hat, ) loss = torch.stack([x.loss for x in outs]).sum() # Collect metrics with torch.no_grad(): # Long range MAE metrics mae_l8_fn = MAEForSeq(contact_thre=8.) results = mae_l8_fn( inputs=y_hat[0][targets.dist.indices], targets=targets.dist.values, indices=targets.dist.indices, ) if len(results) > 0: mae_l_8 = (results.mean().detach().item(), len(results)) else: mae_l_8 = (0, 0) # Top L/5 precision metrics # top_l5_precision_fn = TopLNPrecision(n=5, contact_thre=8.) # results = top_l5_precision_fn( # inputs=out_dist.y_hat[targets.dist.indices], # targets=targets.dist.values, # indices=targets.dist.indices, # seq_lens=attention_mask.sum(-1) - 2, # ) # if len(results) > 0: # top_l5_precision = (results.mean().detach().item(), len(results)) # else: # top_l5_precision = (0, 0) return BertFoldOutput( y_hat=y_hat, loss=loss, loss_dist=outs[0].loss_and_cnt, # loss_phi=outs[1].loss_and_cnt, # loss_psi=outs[2].loss_and_cnt, mae_l_8=mae_l_8, )
class ExampleIntentBertModel(torch.nn.Module): def __init__(self, model_name_or_path: str, dropout: float, num_intent_labels: int, use_observers: bool = False): super(ExampleIntentBertModel, self).__init__() #self.bert_model = BertModel.from_pretrained(model_name_or_path) self.bert_model = BertModel( BertConfig.from_pretrained(model_name_or_path, output_attentions=True)) self.dropout = Dropout(dropout) self.num_intent_labels = num_intent_labels self.use_observers = use_observers self.all_outputs = [] def encode(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor): extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze( 2).repeat(1, 1, input_ids.size(1), 1) extended_attention_mask = extended_attention_mask.to( dtype=next(self.bert_model.parameters()).dtype) # Combine attention maps padding = (input_ids.unsqueeze(1) == 0).unsqueeze(-1) padding = padding.repeat(1, 1, 1, padding.size(-2)) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 embedding_output = self.bert_model.embeddings( input_ids, position_ids=None, token_type_ids=token_type_ids) encoder_outputs = self.bert_model.encoder( embedding_output, extended_attention_mask, head_mask=[None] * self.bert_model.config.num_hidden_layers) if encoder_outputs[0].size(0) == 1: pass #self.all_outputs.append(torch.cat(encoder_outputs[1], dim=0).cpu()) #self.all_outputs.append(encoder_outputs[0][:, -20:].cpu()) sequence_output = encoder_outputs[0] if self.use_observers: pooled_output = sequence_output[:, -20:].mean(dim=1) else: pooled_output = self.bert_model.pooler(sequence_output) return pooled_output def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor, intent_label: torch.tensor, example_input: torch.tensor, example_mask: torch.tensor, example_token_types: torch.tensor, example_intents: torch.tensor): example_pooled_output = self.encode(input_ids=example_input, attention_mask=example_mask, token_type_ids=example_token_types) pooled_output = self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = self.dropout(pooled_output) probs = torch.softmax(pooled_output.mm(example_pooled_output.t()), dim=-1) intent_probs = 1e-6 + torch.zeros( probs.size(0), self.num_intent_labels).cuda().scatter_add( -1, example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs) # Compute losses if labels provided if intent_label is not None: loss_fct = NLLLoss() intent_lp = torch.log(intent_probs) intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels), intent_label.type(torch.long)) else: intent_loss = torch.tensor(0) return intent_probs, intent_loss