Beispiel #1
0
class ExampleIntentBertModel(torch.nn.Module):
    def __init__(self,
                 model_name_or_path: str,
                 dropout: float,
                 num_intent_labels: int,
                 use_observers: bool = False):
        super(ExampleIntentBertModel, self).__init__()
        #self.bert_model = BertModel.from_pretrained(model_name_or_path)
        self.bert_model = BertModel(
            BertConfig.from_pretrained(model_name_or_path,
                                       output_attentions=True))

        self.dropout = Dropout(dropout)
        self.num_intent_labels = num_intent_labels
        self.use_observers = use_observers
        self.all_outputs = []

    def encode(self, input_ids: torch.tensor, attention_mask: torch.tensor,
               token_type_ids: torch.tensor):
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(
            2).repeat(1, 1, input_ids.size(1), 1)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.bert_model.parameters()).dtype)

        # Combine attention maps
        padding = (input_ids.unsqueeze(1) == 0).unsqueeze(-1)
        padding = padding.repeat(1, 1, 1, padding.size(-2))

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.bert_model.embeddings(
            input_ids, position_ids=None, token_type_ids=token_type_ids)
        encoder_outputs = self.bert_model.encoder(
            embedding_output,
            extended_attention_mask,
            head_mask=[None] * self.bert_model.config.num_hidden_layers)

        if encoder_outputs[0].size(0) == 1:
            pass
            #self.all_outputs.append(torch.cat(encoder_outputs[1], dim=0).cpu())
            #self.all_outputs.append(encoder_outputs[0][:, -20:].cpu())
        sequence_output = encoder_outputs[0]

        if self.use_observers:
            pooled_output = sequence_output[:, -20:].mean(dim=1)
        else:
            pooled_output = self.bert_model.pooler(sequence_output)

        return pooled_output

    def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor,
                token_type_ids: torch.tensor, intent_label: torch.tensor,
                example_input: torch.tensor, example_mask: torch.tensor,
                example_token_types: torch.tensor,
                example_intents: torch.tensor):
        example_pooled_output = self.encode(input_ids=example_input,
                                            attention_mask=example_mask,
                                            token_type_ids=example_token_types)

        pooled_output = self.encode(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    token_type_ids=token_type_ids)

        pooled_output = self.dropout(pooled_output)
        probs = torch.softmax(pooled_output.mm(example_pooled_output.t()),
                              dim=-1)

        intent_probs = 1e-6 + torch.zeros(
            probs.size(0), self.num_intent_labels).cuda().scatter_add(
                -1,
                example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs)

        # Compute losses if labels provided
        if intent_label is not None:
            loss_fct = NLLLoss()
            intent_lp = torch.log(intent_probs)
            intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels),
                                   intent_label.type(torch.long))
        else:
            intent_loss = torch.tensor(0)

        return intent_probs, intent_loss