Beispiel #1
0
 def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, 
            d_model=512, d_ff=2048, h=8, dropout=0.1):
     "Helper: Construct a model from hyperparameters."
     enc_config = BertConfig(vocab_size=1,
                             hidden_size=d_model,
                             num_hidden_layers=N_enc,
                             num_attention_heads=h,
                             intermediate_size=d_ff,
                             hidden_dropout_prob=dropout,
                             attention_probs_dropout_prob=dropout,
                             max_position_embeddings=1,
                             type_vocab_size=1)
     dec_config = BertConfig(vocab_size=tgt_vocab,
                             hidden_size=d_model,
                             num_hidden_layers=N_dec,
                             num_attention_heads=h,
                             intermediate_size=d_ff,
                             hidden_dropout_prob=dropout,
                             attention_probs_dropout_prob=dropout,
                             max_position_embeddings=17,
                             # max_position_embeddings=51,
                             type_vocab_size=1,
                             is_decoder=True)
     encoder = BertModel(enc_config)
     def return_embeds(*args, **kwargs):
         return kwargs['inputs_embeds']
     del encoder.embeddings; encoder.embeddings = return_embeds
     decoder = BertModel(dec_config)
     model = EncoderDecoder(
         encoder,
         decoder,
         Generator(d_model, tgt_vocab))
     return model
Beispiel #2
0
class SlotAttention(BertPreTrainedModel):
    def __init__(self, config, args):
        super(SlotAttention, self).__init__(config)
        self.num_labels = config.num_labels
        self.cls_lambda = args.cls_lambda
        self.ans_lambda = args.ans_lambda
        self.bert = BertModel(config)
        self.start_layer = nn.Linear(config.hidden_size, config.hidden_size)
        self.end_layer = nn.Linear(config.hidden_size, config.hidden_size)

        self.type_attention = Att_Layer(config)
        self.cls_layer = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                value_types=None,
                slot_input_ids=None,
                start_positions=None,
                end_positions=None):
        sequence_output, pool_output = self.bert(input_ids,
                                                 attention_mask=attention_mask,
                                                 token_type_ids=token_type_ids,
                                                 return_dict=False)

        slot_hidden = self.bert.embeddings(slot_input_ids[0][0]).mean(-2)

        type_att_output = self.type_attention(sequence_output, slot_hidden,
                                              attention_mask)
        type_logits = self.cls_layer(type_att_output)

        start_hidden = self.start_layer(slot_hidden)
        end_hidden = self.end_layer(slot_hidden)

        start_logits = torch.matmul(start_hidden,
                                    sequence_output.permute(0, 2, 1))
        end_logits = torch.matmul(end_hidden, sequence_output.permute(0, 2, 1))

        outputs = (type_logits, start_logits, end_logits)
        if value_types is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            cls_loss = loss_fct(type_logits.view(-1, self.num_labels),
                                value_types.view(-1))
            start_loss = loss_fct(
                start_logits.reshape(-1, start_logits.size(-1)),
                start_positions.view(-1))
            end_loss = loss_fct(end_logits.reshape(-1, end_logits.size(-1)),
                                end_positions.view(-1))

            total_loss = self.ans_lambda * (
                start_loss + end_loss) + self.cls_lambda * cls_loss
            outputs = (total_loss, cls_loss, start_loss, end_loss) + outputs
        return outputs
Beispiel #3
0
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)
        self.bert = BertModel(config)

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                label_id=None):
        # output_all_encoded_layers=False):
        bert_encode, _ = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
        )
        #    output_all_encoded_layers=output_all_encoded_layers)
        bert_embeddings = self.bert.embeddings(input_ids, token_type_ids)

        return bert_encode, bert_embeddings
Beispiel #4
0
class RandomBert(nn.Module):
    def __init__(self, config, num_class):
        super(RandomBert, self).__init__()
        self.bert = BertModel(BertConfig())

        self.drop_out = nn.Dropout(p=config['hidden_dropout_prob'])
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_class)

    def forward(self, input_ids, attention_mask, token_type, **kwargs):
        try:
            out = self.bert(token_type_ids=token_type, attention_mask=attention_mask, \
                inputs_embeds=kwargs.get('input_embeds'))[1]
        except:
            if kwargs.get('return_embed') == None:
                out = self.bert(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type)[1]
            else:
                return self.bert.embeddings(input_ids=input_ids,
                                            token_type_ids=token_type)
        out = self.drop_out(out)
        out = self.classifier(out)

        return out
Beispiel #5
0
class BertFold(nn.Module):
    def __init__(
        self,
        pretrained: bool = True,
        gradient_checkpointing: bool = False,
    ):
        super().__init__()
        if pretrained:
            self.bert = BertModel.from_pretrained(
                'Rostlab/prot_bert_bfd',
                gradient_checkpointing=gradient_checkpointing,
            )
        else:
            conf = BertConfig.from_pretrained('Rostlab/prot_bert_bfd')
            self.bert = BertModel(conf)

        # noinspection PyUnresolvedReferences
        dim = self.bert.config.hidden_size

        self.evo_linear = nn.Linear(21, dim)

        self.decoder_dist = PairwiseDistanceDecoder(dim)
        # self.decoder_phi = ElementwiseAngleDecoder(dim, 2)
        # self.decoder_psi = ElementwiseAngleDecoder(dim, 2)

        self.evo_linear.apply(init_weights)
        self.decoder_dist.apply(init_weights)
        # self.decoder_phi.apply(init_weights)
        # self.decoder_psi.apply(init_weights)

        del self.bert.pooler

    def forward(
        self,
        inputs: ProteinNetBatch,
        targets: Optional[BertFoldTargets] = None,
    ) -> BertFoldOutput:
        x_emb = self.bert.embeddings(inputs['input_ids'])
        x_evo = self.evo_linear(inputs['evo'].type_as(x_emb))
        x = x_emb + x_evo
        extended_attention_mask = self.bert.get_extended_attention_mask(
            inputs['attention_mask'],
            inputs['input_ids'].shape,
            inputs['input_ids'].device,
        )
        x = self.bert.encoder.forward(
            x, attention_mask=extended_attention_mask)[0]

        # x = self.bert.forward(
        #     inputs['input_ids'],
        #     attention_mask=inputs['attention_mask'],
        # )[0]
        # x = torch.cat((
        #     x,
        #     inputs['evo'].type_as(x),
        # ), dim=-1)

        targets_dist = None if targets is None else targets.dist
        # targets_phi = None if targets is None else targets.phi
        # targets_psi = None if targets is None else targets.psi

        outs = [
            self.decoder_dist.forward(x, targets_dist),
            # self.decoder_phi.forward(x, targets_phi),
            # self.decoder_psi.forward(x, targets_psi),
        ]

        y_hat = tuple(x.y_hat for x in outs)

        if targets is None:
            return BertFoldOutput(y_hat=y_hat, )

        loss = torch.stack([x.loss for x in outs]).sum()

        # Collect metrics
        with torch.no_grad():
            # Long range MAE metrics
            mae_l8_fn = MAEForSeq(contact_thre=8.)
            results = mae_l8_fn(
                inputs=y_hat[0][targets.dist.indices],
                targets=targets.dist.values,
                indices=targets.dist.indices,
            )
            if len(results) > 0:
                mae_l_8 = (results.mean().detach().item(), len(results))
            else:
                mae_l_8 = (0, 0)

            # Top L/5 precision metrics
            # top_l5_precision_fn = TopLNPrecision(n=5, contact_thre=8.)
            # results = top_l5_precision_fn(
            #     inputs=out_dist.y_hat[targets.dist.indices],
            #     targets=targets.dist.values,
            #     indices=targets.dist.indices,
            #     seq_lens=attention_mask.sum(-1) - 2,
            # )
            # if len(results) > 0:
            #     top_l5_precision = (results.mean().detach().item(), len(results))
            # else:
            #     top_l5_precision = (0, 0)

        return BertFoldOutput(
            y_hat=y_hat,
            loss=loss,
            loss_dist=outs[0].loss_and_cnt,
            # loss_phi=outs[1].loss_and_cnt,
            # loss_psi=outs[2].loss_and_cnt,
            mae_l_8=mae_l_8,
        )
Beispiel #6
0
class ExampleIntentBertModel(torch.nn.Module):
    def __init__(self,
                 model_name_or_path: str,
                 dropout: float,
                 num_intent_labels: int,
                 use_observers: bool = False):
        super(ExampleIntentBertModel, self).__init__()
        #self.bert_model = BertModel.from_pretrained(model_name_or_path)
        self.bert_model = BertModel(
            BertConfig.from_pretrained(model_name_or_path,
                                       output_attentions=True))

        self.dropout = Dropout(dropout)
        self.num_intent_labels = num_intent_labels
        self.use_observers = use_observers
        self.all_outputs = []

    def encode(self, input_ids: torch.tensor, attention_mask: torch.tensor,
               token_type_ids: torch.tensor):
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(
            2).repeat(1, 1, input_ids.size(1), 1)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.bert_model.parameters()).dtype)

        # Combine attention maps
        padding = (input_ids.unsqueeze(1) == 0).unsqueeze(-1)
        padding = padding.repeat(1, 1, 1, padding.size(-2))

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.bert_model.embeddings(
            input_ids, position_ids=None, token_type_ids=token_type_ids)
        encoder_outputs = self.bert_model.encoder(
            embedding_output,
            extended_attention_mask,
            head_mask=[None] * self.bert_model.config.num_hidden_layers)

        if encoder_outputs[0].size(0) == 1:
            pass
            #self.all_outputs.append(torch.cat(encoder_outputs[1], dim=0).cpu())
            #self.all_outputs.append(encoder_outputs[0][:, -20:].cpu())
        sequence_output = encoder_outputs[0]

        if self.use_observers:
            pooled_output = sequence_output[:, -20:].mean(dim=1)
        else:
            pooled_output = self.bert_model.pooler(sequence_output)

        return pooled_output

    def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor,
                token_type_ids: torch.tensor, intent_label: torch.tensor,
                example_input: torch.tensor, example_mask: torch.tensor,
                example_token_types: torch.tensor,
                example_intents: torch.tensor):
        example_pooled_output = self.encode(input_ids=example_input,
                                            attention_mask=example_mask,
                                            token_type_ids=example_token_types)

        pooled_output = self.encode(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    token_type_ids=token_type_ids)

        pooled_output = self.dropout(pooled_output)
        probs = torch.softmax(pooled_output.mm(example_pooled_output.t()),
                              dim=-1)

        intent_probs = 1e-6 + torch.zeros(
            probs.size(0), self.num_intent_labels).cuda().scatter_add(
                -1,
                example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs)

        # Compute losses if labels provided
        if intent_label is not None:
            loss_fct = NLLLoss()
            intent_lp = torch.log(intent_probs)
            intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels),
                                   intent_label.type(torch.long))
        else:
            intent_loss = torch.tensor(0)

        return intent_probs, intent_loss