Ejemplo n.º 1
0
    def forward_qa(self, input_ids, token_type_ids, attention_mask,
                   start_positions, end_positions, global_step):
        sequence_output, pooled_output = self.bert(
            input_ids,
            token_type_ids,
            attention_mask,
            output_all_encoded_layers=False)

        log_prob = None
        if self.emb_disc == 'qa':
            q, a = self.get_avg_qa_emb(input_ids, sequence_output,
                                       start_positions, end_positions,
                                       token_type_ids)
            log_prob = self.discriminator(torch.cat((a, q), 1))
        elif self.emb_disc == 'pool':
            log_prob = self.discriminator(pooled_output)
        elif self.emb_disc == 'cls':
            cls_embedding = sequence_output[:, 0]
            if self.concat:
                sep_embedding = self.get_sep_embedding(input_ids,
                                                       sequence_output)
                hidden = torch.cat([cls_embedding, sep_embedding], dim=1)
            else:
                hidden = sequence_output[:, 0]  # [b, d] : [CLS] representation
            log_prob = self.discriminator(hidden.detach())

        criterion = None
        if self.disc_loss == 'wass':
            criterion = wasserstein_dist
        elif self.disc_loss == 'ce':
            criterion = nn.KLDivLoss(reduction="batchmean")

        targets = torch.ones_like(log_prob) * (1 / self.num_classes)
        if self.anneal:
            self.dis_lambda = self.dis_lambda * kl_coef(global_step)
        kld = self.dis_lambda * criterion(log_prob, targets)

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        # If we are on multi-GPU, split add a dimension
        if len(start_positions.size()) > 1:
            start_positions = start_positions.squeeze(-1)
        if len(end_positions.size()) > 1:
            end_positions = end_positions.squeeze(-1)
        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        ignored_index = start_logits.size(1)
        start_positions.clamp_(0, ignored_index)
        end_positions.clamp_(0, ignored_index)

        loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        qa_loss = (start_loss + end_loss) / 2
        total_loss = qa_loss + kld
        return total_loss
Ejemplo n.º 2
0
 def _compute_loss(self, batch, total_step=0):
     logits, mu, log_var = self.model(batch.orig, batch.para)
     B, L, _ = logits.size()
     target, _ = batch.para
     recon_loss = self.criterion(logits.view(B * L, -1), target.view(-1))
     kl_loss = torch.sum((log_var - log_var.exp() - mu.pow(2) + 1) * -0.5,
                         dim=1).mean()
     coef = kl_coef(total_step)  # kl annlealing
     return recon_loss, kl_loss, coef
Ejemplo n.º 3
0
 def compute_loss(self, batch, total_step=0):
     variational_params, prior_params, recon_logits = self.model(batch)
     B, L, _ = recon_logits.size()
     target, _ = batch.summ
     recon_loss = self.criterion(recon_logits.view(B*L, -1), target.view(-1))
     kl_loss = 0
     for var_par, prior_par in zip(variational_params.values(), prior_params.values()):
         kl_loss += self.kl_div_two_normal(var_par, prior_par) / len(prior_params)
     coef = kl_coef(total_step) # kl annlealing
     return recon_loss, kl_loss, coef
Ejemplo n.º 4
0
    def _compute_loss(self, batch, total_step=0):  # overriding
        logits, mu, log_var, bow_logits = self.model(batch.orig, batch.para)
        B, L, _ = logits.size()
        target, _ = batch.para  # (B, L)
        num_tokens = (target != PAD_IDX).sum().float()
        mask = torch.ones_like(bow_logits)
        mask[:, PAD_IDX] = 0

        recon_loss = self.criterion(logits.view(B * L, -1), target.view(-1))
        kl_loss = torch.sum((log_var - log_var.exp() - mu.pow(2) + 1) * -0.5,
                            dim=1).mean()
        coef = kl_coef(total_step)  # kl annlealing
        bow_loss = (bow_logits * mask).log_softmax(dim=-1).gather(
            1, target).sum() * -1 / num_tokens
        return recon_loss, kl_loss, coef, bow_loss
Ejemplo n.º 5
0
    def forward_qa(self, input_ids, token_type_ids, attention_mask,
                   start_positions, end_positions, global_step):
        sequence_output, _ = self.bert(input_ids,
                                       token_type_ids,
                                       attention_mask,
                                       output_all_encoded_layers=False)
        cls_embedding = sequence_output[:, 0]
        if self.concat:
            sep_embedding = self.get_sep_embedding(input_ids, sequence_output)
            hidden = torch.cat([cls_embedding, sep_embedding], dim=1)
        else:
            hidden = sequence_output[:, 0]  # [b, d] : [CLS] representation
        log_prob = self.discriminator(hidden)
        targets = torch.ones_like(log_prob) * (1 / self.num_classes)
        # As with NLLLoss, the input given is expected to contain log-probabilities
        # and is not restricted to a 2D Tensor. The targets are given as probabilities
        kl_criterion = nn.KLDivLoss(reduction="batchmean")
        if self.anneal:
            self.dis_lambda = self.dis_lambda * kl_coef(global_step)
        kld = self.dis_lambda * kl_criterion(log_prob, targets)

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        # If we are on multi-GPU, split add a dimension
        if len(start_positions.size()) > 1:
            start_positions = start_positions.squeeze(-1)
        if len(end_positions.size()) > 1:
            end_positions = end_positions.squeeze(-1)
        # sometimes the start/end positions are outside our model inputs, we ignore these terms
        ignored_index = start_logits.size(1)
        start_positions.clamp_(0, ignored_index)
        end_positions.clamp_(0, ignored_index)

        loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
        start_loss = loss_fct(start_logits, start_positions)
        end_loss = loss_fct(end_logits, end_positions)
        qa_loss = (start_loss + end_loss) / 2
        total_loss = qa_loss + kld
        return total_loss