def _compute_loss(self, batch, output, target, copy_attn, align): """ Compute the loss. The args must match self._make_shard_state(). Args: batch: the current batch. output: the predict output from the model. target: the validate target to compare output with. copy_attn: the copy attention value. align: the align info. """ target = target.view(-1) align = align.view(-1) scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.src_map) loss = self.criterion(scores, align, target) scores_data = scores.data.clone() if self.data_type == 'text': scores_data = inputters.TextDataset.collapse_copy_scores( self._unbottle(scores_data, batch.batch_size), batch, self.tgt_vocab, batch.dataset.src_vocabs) else: #amr scores_data = inputters.AMRDataset.collapse_copy_scores( self._unbottle(scores_data, batch.batch_size), batch, self.tgt_vocab, batch.dataset.src_vocabs) scores_data = self._bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = target.data.clone() correct_mask = target_data.eq(0) * align.data.ne(0) correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long() target_data = target_data + correct_copy # Compute sum of perplexities for stats loss_data = loss.sum().data.clone() stats = self._stats(loss_data, scores_data, target_data) if self.normalize_by_length: # Compute Loss as NLL divided by seq length # Compute Sequence Lengths pad_ix = batch.dataset.fields['tgt'].vocab.stoi[inputters.PAD_WORD] tgt_lens = batch.tgt.ne(pad_ix).float().sum(0) # Compute Total Loss per sequence in batch loss = loss.view(-1, batch.batch_size).sum(0) # Divide by length of each sequence and sum loss = torch.div(loss, tgt_lens).sum() else: loss = loss.sum() return loss, stats
def monolithic_compute_loss(self, batch, output, attns): """ Compute the forward loss for the batch. Args: batch (batch): batch of labeled examples output (:obj:`FloatTensor`): output of decoder model `[tgt_len x batch x hidden]` attns (dict of :obj:`FloatTensor`) : dictionary of attention distributions `[tgt_len x batch x src_len]` Returns: :obj:`onmt.utils.Statistics`: loss statistics """ target = batch.tgt[1:].view(-1) align = batch.alignment[1:].view(-1) copy_attn = attns.get("copy") scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.map) loss = self.criterion(scores, align, target) scores_data = scores.data.clone() scores_data = inputters.TextDataset.collapse_copy_scores( self._unbottle(scores_data, batch.batch_size), batch, self.tgt_vocab, batch.dataset.src_vocabs) scores_data = self._bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = target.data.clone() correct_mask = target_data.eq(0) * align.data.ne(0) correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long() target_data = target_data + correct_copy # Compute sum of perplexities for stats loss_data = loss.sum().data.clone() stats = self._stats(loss_data, scores_data, target_data) return stats
def monolithic_compute_loss(self, batch, output, attns): """ Compute the forward loss for the batch. Args: batch (batch): batch of labeled examples output (:obj:`FloatTensor`): output of decoder model `[tgt_len x batch x hidden]` attns (dict of :obj:`FloatTensor`) : dictionary of attention distributions `[tgt_len x batch x src_len]` Returns: :obj:`onmt.utils.Statistics`: loss statistics """ target = batch.tgt[1:].view(-1) align = batch.alignment[1:].view(-1) copy_attn = attns.get("copy") scores, coref_scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.map) loss = self.criterion(scores, align, target) # loss for coreference coref_vocab_loss_data, coref_attn_loss_data = 0, 0 coref_confidence = batch.coref_score.unsqueeze(0).repeat( batch.tgt[1:].size(0), 1).view(-1) if self.coref_vocab: # calculate coref vocab loss coref_tgt = batch.coref_tgt[1:].view(-1) if self.coref_confscore: coref_vocab_loss = ( self.criterion_coref_vocab(coref_scores, coref_tgt) * coref_confidence).sum() else: coref_vocab_loss = (self.criterion_coref_vocab( coref_scores, coref_tgt)).sum() if type(coref_vocab_loss) == int: coref_vocab_loss = torch.Tensor([coref_vocab_loss ]).type_as(coref_scores) coref_vocab_loss_data = coref_vocab_loss.data.clone().item() if self.coref_attn: # calculate coref attention loss qa_attn = attns.get("qa") if self.coref_confscore: coref_attn_loss = ( self.criterion_coref_attn(qa_attn, batch.coref_attn_loss) * batch.coref_score).sum() else: coref_attn_loss = (self.criterion_coref_attn( qa_attn, batch.coref_attn_loss)).sum() if type(coref_attn_loss) == int: coref_attn_loss = torch.Tensor([coref_attn_loss ]).type_as(qa_attn) coref_attn_loss_data = coref_attn_loss.data.clone().item() # loss for flow tracking passage_attn = attns.get("passage") flow_loss = self.flow_criterion(passage_attn, batch.sentence_label) flow_loss_data = flow_loss.sum().data.clone() flow_history_loss = self.flow_history_criterion( passage_attn, batch.history_label) flow_history_loss_data = flow_history_loss.sum().data.clone() scores_data = scores.data.clone() scores_data = inputters.TextDataset.collapse_copy_scores( self._unbottle(scores_data, batch.batch_size), batch, self.tgt_vocab, batch.dataset.src_vocabs) scores_data = self._bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = target.data.clone() correct_mask = target_data.eq(0) * align.data.ne(0) correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long() target_data = target_data + correct_copy # Compute sum of perplexities for stats loss_data = loss.sum().data.clone() stats = self._stats(loss_data, coref_vocab_loss_data, coref_attn_loss_data, len(batch.coref_attn_loss), scores_data, target_data, flow_loss_data, flow_history_loss_data) return stats
def _compute_loss(self, batch, output, target, copy_attn, align, click_score=None, click_target=None): """ Compute the loss. The args must match self._make_shard_state(). Args: batch: the current batch. output: the predict output from the model. target: the validate target to compare output with. copy_attn: the copy attention value. align: the align info. """ target = target.view(-1) align = align.view(-1) scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.src_map) loss = self.criterion(scores, align, target) scores_data = scores.data.clone() scores_data = inputters.TextDataset.collapse_copy_scores( self._unbottle(scores_data, batch.batch_size), batch, self.tgt_vocab, batch.dataset.src_vocabs) scores_data = self._bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = target.data.clone() correct_mask = target_data.eq(0) * align.data.ne(0) correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long() target_data = target_data + correct_copy if click_score is not None: loss_session = self.session_criterion(click_score, click_target[0]) click_predict = np.argsort(-click_score.detach())[:,0].to(click_score.device) else: loss_session = None click_predict = None # Compute sum of perplexities for stats loss_data = loss.sum().data.clone() #stats = self._stats(loss_data, scores_data, target_data, loss_session.item(), click_predict, click_target) stats = self._stats(loss_data, scores_data, target_data) if self.normalize_by_length: # Compute Loss as NLL divided by seq length # Compute Sequence Lengths pad_ix = batch.dataset.fields['tgt'].vocab.stoi[inputters.PAD_WORD] tgt_lens = batch.tgt.ne(pad_ix).float().sum(0) # Compute Total Loss per sequence in batch loss = loss.view(-1, batch.batch_size).sum(0) # Divide by length of each sequence and sum loss = torch.div(loss, tgt_lens).sum() else: loss = loss.sum() if click_score is not None: print('session loss:{}'.format(loss_session.item())) print('sentence loss:{}'.format(loss_data)) print('all loss:{}'.format(loss_session.item()*500 + loss_data)) return self.session_weight * loss_session + self.explanation_weight * loss, stats else: print('sentence loss:{}'.format(loss_data)) return loss, stats
def compute_loss(self, batch, output, attns, normalization): """ Compute the loss. The args must match self._make_shard_state(). Args: batch: the current batch. output: the predict output from the model. target: the validate target to compare output with. copy_attn: the copy attention value. align: the align info. """ target = batch.tgt[1:].view(-1) align = batch.alignment[1:].view(-1) copy_attn = attns.get("copy") scores, coref_scores = self.generator(self._bottle(output), self._bottle(copy_attn), batch.map) loss = self.criterion(scores, align, target) # loss for coreference coref_vocab_loss_data, coref_attn_loss_data = 0, 0 coref_confidence = batch.coref_score.unsqueeze(0).repeat( batch.tgt[1:].size(0), 1).view(-1) if self.coref_vocab: # calculate coref vocab loss coref_tgt = batch.coref_tgt[1:].view(-1) if self.coref_confscore: coref_vocab_loss = ( self.criterion_coref_vocab(coref_scores, coref_tgt) * coref_confidence).sum() else: coref_vocab_loss = (self.criterion_coref_vocab( coref_scores, coref_tgt)).sum() if type(coref_vocab_loss) == int: coref_vocab_loss = torch.Tensor([coref_vocab_loss ]).type_as(coref_scores) coref_vocab_loss_data = coref_vocab_loss.data.clone().item() if self.coref_attn: # calculate coref attention loss qa_attn = attns.get("qa") if self.coref_confscore: coref_attn_loss = ( self.criterion_coref_attn(qa_attn, batch.coref_attn_loss) * batch.coref_score).sum() else: coref_attn_loss = (self.criterion_coref_attn( qa_attn, batch.coref_attn_loss)).sum() if type(coref_attn_loss) == int: coref_attn_loss = torch.Tensor([coref_attn_loss ]).type_as(qa_attn) coref_attn_loss_data = coref_attn_loss.data.clone().item() # loss for flow tracking passage_attn = attns.get("passage") flow_loss = self.flow_criterion(passage_attn, batch.sentence_label) flow_loss_data = flow_loss.sum().data.clone() flow_history_loss = self.flow_history_criterion( passage_attn, batch.history_label) flow_history_loss_data = flow_history_loss.sum().data.clone() scores_data = scores.data.clone() scores_data = inputters.TextDataset.collapse_copy_scores( self._unbottle(scores_data, batch.batch_size), batch, self.tgt_vocab, batch.dataset.src_vocabs) scores_data = self._bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 target_data = target.data.clone() correct_mask = target_data.eq(0) * align.data.ne(0) correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long() target_data = target_data + correct_copy # Compute sum of perplexities for stats loss_data = loss.sum().data.clone() stats = self._stats(loss_data, coref_vocab_loss_data, coref_attn_loss_data, len(batch.coref_attn_loss), scores_data, target_data, flow_loss_data, flow_history_loss_data) if self.normalize_by_length: # Compute Loss as NLL divided by seq length # Compute Sequence Lengths pad_ix = batch.dataset.fields['tgt'].vocab.stoi[inputters.PAD_WORD] tgt_lens = batch.tgt.ne(pad_ix).float().sum(0) # Compute Total Loss per sequence in batch loss = loss.view(-1, batch.batch_size).sum(0) # Divide by length of each sequence and sum loss = torch.div(loss, tgt_lens).sum() else: loss = loss.sum() if self.coref_vocab: loss = loss + coref_attn_loss * self.lambda_coref_attn if self.coref_attn: loss = loss + coref_vocab_loss * self.lambda_coref_vocab if self.flow: loss = loss + flow_loss.sum() * self.lambda_flow if self.flow_history: loss = loss + flow_history_loss.sum() * self.lambda_flow_history loss.div(float(normalization)).backward() return stats