Exemplo n.º 1
0
    def _compute_loss(self, batch, output, target, copy_attn, align):
        """
        Compute the loss. The args must match self._make_shard_state().
        Args:
            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            copy_attn: the copy attention value.
            align: the align info.
        """
        target = target.view(-1)
        align = align.view(-1)
        scores = self.generator(self._bottle(output),
                                self._bottle(copy_attn),
                                batch.src_map)
        loss = self.criterion(scores, align, target)
        scores_data = scores.data.clone()
        
        if self.data_type == 'text':
            scores_data = inputters.TextDataset.collapse_copy_scores(
                self._unbottle(scores_data, batch.batch_size),
                batch, self.tgt_vocab, batch.dataset.src_vocabs)
        else: #amr
                scores_data = inputters.AMRDataset.collapse_copy_scores(
                self._unbottle(scores_data, batch.batch_size),
                batch, self.tgt_vocab, batch.dataset.src_vocabs)
        scores_data = self._bottle(scores_data)

        # Correct target copy token instead of <unk>
        # tgt[i] = align[i] + len(tgt_vocab)
        # for i such that tgt[i] == 0 and align[i] != 0
        target_data = target.data.clone()
        correct_mask = target_data.eq(0) * align.data.ne(0)
        correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long()
        target_data = target_data + correct_copy

        # Compute sum of perplexities for stats
        loss_data = loss.sum().data.clone()
        stats = self._stats(loss_data, scores_data, target_data)

        if self.normalize_by_length:
            # Compute Loss as NLL divided by seq length
            # Compute Sequence Lengths
            pad_ix = batch.dataset.fields['tgt'].vocab.stoi[inputters.PAD_WORD]
            tgt_lens = batch.tgt.ne(pad_ix).float().sum(0)
            # Compute Total Loss per sequence in batch
            loss = loss.view(-1, batch.batch_size).sum(0)
            # Divide by length of each sequence and sum
            loss = torch.div(loss, tgt_lens).sum()
        else:
            loss = loss.sum()

        return loss, stats
Exemplo n.º 2
0
    def _compute_loss(self, batch, output, target, copy_attn, align, click_score=None, click_target=None):
        """
        Compute the loss. The args must match self._make_shard_state().
        Args:
            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            copy_attn: the copy attention value.
            align: the align info.
        """
        target = target.view(-1)
        align = align.view(-1)
        scores = self.generator(self._bottle(output),
                                self._bottle(copy_attn),
                                batch.src_map)
        loss = self.criterion(scores, align, target)
        scores_data = scores.data.clone()
        scores_data = inputters.TextDataset.collapse_copy_scores(
            self._unbottle(scores_data, batch.batch_size),
            batch, self.tgt_vocab, batch.dataset.src_vocabs)
        scores_data = self._bottle(scores_data)

        # Correct target copy token instead of <unk>
        # tgt[i] = align[i] + len(tgt_vocab)
        # for i such that tgt[i] == 0 and align[i] != 0
        target_data = target.data.clone()
        correct_mask = target_data.eq(0) * align.data.ne(0)
        correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long()
        target_data = target_data + correct_copy

        if click_score is not None:
            loss_session = self.session_criterion(click_score, click_target[0])
            click_predict = np.argsort(-click_score.detach())[:,0].to(click_score.device)
        else:
            loss_session = None
            click_predict = None
        # Compute sum of perplexities for stats
        loss_data = loss.sum().data.clone()
        #stats = self._stats(loss_data, scores_data, target_data, loss_session.item(), click_predict, click_target)
        stats = self._stats(loss_data, scores_data, target_data)

        if self.normalize_by_length:
            # Compute Loss as NLL divided by seq length
            # Compute Sequence Lengths
            pad_ix = batch.dataset.fields['tgt'].vocab.stoi[inputters.PAD_WORD]
            tgt_lens = batch.tgt.ne(pad_ix).float().sum(0)
            # Compute Total Loss per sequence in batch
            loss = loss.view(-1, batch.batch_size).sum(0)
            # Divide by length of each sequence and sum
            loss = torch.div(loss, tgt_lens).sum()
        else:
            loss = loss.sum()
        if click_score is not None:
            print('session loss:{}'.format(loss_session.item()))
            print('sentence loss:{}'.format(loss_data))
            print('all loss:{}'.format(loss_session.item()*500 + loss_data))
            return self.session_weight * loss_session + self.explanation_weight * loss, stats
        else:
            print('sentence loss:{}'.format(loss_data))
            
            return loss, stats
Exemplo n.º 3
0
    def compute_loss(self, batch, output, attns, normalization):
        """
        Compute the loss. The args must match self._make_shard_state().
        Args:
            batch: the current batch.
            output: the predict output from the model.
            target: the validate target to compare output with.
            copy_attn: the copy attention value.
            align: the align info.
        """
        target = batch.tgt[1:].view(-1)
        align = batch.alignment[1:].view(-1)
        copy_attn = attns.get("copy")

        scores, coref_scores = self.generator(self._bottle(output),
                                              self._bottle(copy_attn),
                                              batch.map)
        loss = self.criterion(scores, align, target)

        # loss for coreference
        coref_vocab_loss_data, coref_attn_loss_data = 0, 0
        coref_confidence = batch.coref_score.unsqueeze(0).repeat(
            batch.tgt[1:].size(0), 1).view(-1)
        if self.coref_vocab:
            # calculate coref vocab loss
            coref_tgt = batch.coref_tgt[1:].view(-1)
            if self.coref_confscore:
                coref_vocab_loss = (
                    self.criterion_coref_vocab(coref_scores, coref_tgt) *
                    coref_confidence).sum()
            else:
                coref_vocab_loss = (self.criterion_coref_vocab(
                    coref_scores, coref_tgt)).sum()
            if type(coref_vocab_loss) == int:
                coref_vocab_loss = torch.Tensor([coref_vocab_loss
                                                 ]).type_as(coref_scores)
            coref_vocab_loss_data = coref_vocab_loss.data.clone().item()
        if self.coref_attn:
            # calculate coref attention loss
            qa_attn = attns.get("qa")
            if self.coref_confscore:
                coref_attn_loss = (
                    self.criterion_coref_attn(qa_attn, batch.coref_attn_loss) *
                    batch.coref_score).sum()
            else:
                coref_attn_loss = (self.criterion_coref_attn(
                    qa_attn, batch.coref_attn_loss)).sum()
            if type(coref_attn_loss) == int:
                coref_attn_loss = torch.Tensor([coref_attn_loss
                                                ]).type_as(qa_attn)
            coref_attn_loss_data = coref_attn_loss.data.clone().item()

        # loss for flow tracking
        passage_attn = attns.get("passage")
        flow_loss = self.flow_criterion(passage_attn, batch.sentence_label)
        flow_loss_data = flow_loss.sum().data.clone()
        flow_history_loss = self.flow_history_criterion(
            passage_attn, batch.history_label)
        flow_history_loss_data = flow_history_loss.sum().data.clone()

        scores_data = scores.data.clone()
        scores_data = inputters.TextDataset.collapse_copy_scores(
            self._unbottle(scores_data, batch.batch_size), batch,
            self.tgt_vocab, batch.dataset.src_vocabs)
        scores_data = self._bottle(scores_data)

        # Correct target copy token instead of <unk>
        # tgt[i] = align[i] + len(tgt_vocab)
        # for i such that tgt[i] == 0 and align[i] != 0
        target_data = target.data.clone()
        correct_mask = target_data.eq(0) * align.data.ne(0)
        correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long()
        target_data = target_data + correct_copy

        # Compute sum of perplexities for stats
        loss_data = loss.sum().data.clone()
        stats = self._stats(loss_data,
                            coref_vocab_loss_data, coref_attn_loss_data,
                            len(batch.coref_attn_loss), scores_data,
                            target_data, flow_loss_data,
                            flow_history_loss_data)

        if self.normalize_by_length:
            # Compute Loss as NLL divided by seq length
            # Compute Sequence Lengths
            pad_ix = batch.dataset.fields['tgt'].vocab.stoi[inputters.PAD_WORD]
            tgt_lens = batch.tgt.ne(pad_ix).float().sum(0)
            # Compute Total Loss per sequence in batch
            loss = loss.view(-1, batch.batch_size).sum(0)
            # Divide by length of each sequence and sum
            loss = torch.div(loss, tgt_lens).sum()
        else:
            loss = loss.sum()

        if self.coref_vocab:
            loss = loss + coref_attn_loss * self.lambda_coref_attn
        if self.coref_attn:
            loss = loss + coref_vocab_loss * self.lambda_coref_vocab
        if self.flow:
            loss = loss + flow_loss.sum() * self.lambda_flow
        if self.flow_history:
            loss = loss + flow_history_loss.sum() * self.lambda_flow_history

        loss.div(float(normalization)).backward()

        return stats