Exemple #1
0
    def get_normalized_probs(
        self,
        net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
        log_probs: bool,
        sample: Optional[Dict[str, Tensor]] = None,
    ):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
            if sample is not None:
                assert "target" in sample
                target = sample["target"]
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            #print('Fairseq Decoder: net_output size: {}'.format(net_output.size()))

            if use_ort_backend:
                return utils.log_softmax(net_output, dim=-1, onnx_trace=self.onnx_trace)

            return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
Exemple #2
0
    def get_normalized_probs(self, ctc_logits, logits, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        if log_probs:
            ctc_res = utils.log_softmax(ctc_logits.float(), dim=-1)
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            ctc_res = utils.softmax(ctc_logits.float(), dim=-1)
            res = utils.softmax(logits.float(), dim=-1)
        ctc_res.batch_first = True
        res.batch_first = True

        return ctc_res, res
Exemple #3
0
    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits_ctc = net_output["logits_ctc"]
        logits = net_output["logits"]

        if log_probs:
            ctc_res = utils.log_softmax(logits_ctc.float(), dim=-1)
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            ctc_res = utils.softmax(logits_ctc.float(), dim=-1)
            res = utils.softmax(logits.float(), dim=-1)

        return ctc_res, res
Exemple #4
0
    def get_logits(self, net_output):
        logits = net_output["x"]
        lprob = utils.log_softmax(logits.float(), dim=-1)
        lprob = lprob.transpose(0, 1)
        lprob.batch_first = False

        return lprob
    def get_normalized_probs_with_temperature(self,
                                              net_output,
                                              log_probs,
                                              sample=None,
                                              temperature=1.):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0] / temperature
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
    def compute_mlm_loss(self,
                         enc_output,
                         target,
                         data_type=None,
                         reduce=True):
        lprobs = utils.log_softmax(enc_output, dim=-1, onnx_trace=False)
        p_lprobs = lprobs.clone()

        if data_type is not None:
            data_type = data_type.view(-1, 1).repeat(1, lprobs.size()[1])
            data_type = data_type.view(-1, 1)
            ## de is mono so en to de, source need to subtract 1
            data_type = 1 - data_type

        lprobs = lprobs.view(-1, lprobs.size(-1))
        predict_sentence = torch.argmax(lprobs, dim=-1)
        predict_sentence = predict_sentence.view(target.size())
        target = target.view(-1, 1)

        loss, nll_loss = label_smoothed_nll_loss(lprobs,
                                                 target,
                                                 self.eps,
                                                 ignore_index=self.padding_idx,
                                                 data_type=data_type)

        return loss, predict_sentence, p_lprobs
Exemple #7
0
    def get_normalized_probs(self,
                             net_output,
                             log_probs,
                             sample,
                             adaptive_softmax=True):
        """Get normalized probabilities (or log probs) from a net's output."""
        if adaptive_softmax:
            if hasattr(
                    self,
                    'adaptive_softmax') and self.adaptive_softmax is not None:
                if sample is not None:
                    assert 'target' in sample
                    target = sample['target']
                else:
                    target = None
                out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                         target=target)
                return out.exp_() if not log_probs else out

        # judge for extend the previous
        logits = net_output[0] if isinstance(net_output, list) else net_output
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
Exemple #8
0
    def get_normalized_probs(
        self,
        net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
        log_probs: bool,
        sample: Optional[Dict[str, Tensor]] = None,
    ):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   "adaptive_softmax") and self.adaptive_softmax is not None:
            if sample is not None:
                assert "source" in sample
                source = sample["source"]
            else:
                source = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=source)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
 def get_normalized_probs(self, net_output, log_probs, sample):
     """Get normalized probabilities (or log probs) from a net's output."""
     logits = net_output
     if log_probs:
         return utils.log_softmax(logits, dim=-1)
     else:
         return utils.softmax(logits, dim=-1)
    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        sample['padding_idx'] = self.padding_idx
        target = model.get_targets(sample, net_output).view(-1, 1)
        non_pad_mask = target.ne(self.padding_idx)

        # compute length prediction loss
        length_lprobs = net_output[1]['predicted_lengths']
        length_target = sample['net_input']['prev_output_tokens'].ne(
            self.padding_idx).sum(-1).unsqueeze(-1)
        length_loss = -length_lprobs.gather(dim=-1, index=length_target)

        src_lprobs = utils.log_softmax(net_output[1]['encoder_out'], dim=-1)
        src_lprobs = src_lprobs.view(-1, src_lprobs.size(-1))
        src_target = sample['src_target'].view(-1, 1)
        src_non_pad_mask = src_target.ne(self.padding_idx)
        src_nll_loss = -src_lprobs.gather(dim=-1,
                                          index=src_target)[src_non_pad_mask]

        nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
        if reduce:
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()
            length_loss = length_loss.sum()
            src_nll_loss = src_nll_loss.sum()
        eps_i = self.eps / lprobs.size(-1)
        loss = (
            1. - self.eps
        ) * nll_loss + eps_i * smooth_loss + 0.1 * length_loss + 0.01 * src_nll_loss
        return loss, nll_loss, length_loss, src_nll_loss
Exemple #11
0
    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out
        '''
        logits_list = net_output[0]
        if log_probs:
            return [utils.log_softmax(
                        logits, dim=-1, onnx_trace=self.onnx_trace)
                        for logits in logits_list][0]
        else:
            return [utils.softmax(
                        logits, dim=-1, onnx_trace=self.onnx_trace)
                        for logits in logits_list][0]
        '''
        logits = net_output[0]
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
Exemple #12
0
    def get_normalized_probs(self,
                             net_output,
                             log_probs,
                             sample,
                             gs_tau=0.5,
                             gs_hard=False):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0][0]
        orders = net_output[0][1]
        if log_probs:
            return (utils.log_softmax(logits,
                                      dim=-1,
                                      onnx_trace=self.onnx_trace),
                    self.gumbel_softmax(orders,
                                        gs_tau=gs_tau,
                                        gs_hard=gs_hard,
                                        dim=-1))
        else:
            return (utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace),
                    self.gumbel_softmax(orders,
                                        gs_tau=gs_tau,
                                        gs_hard=gs_hard,
                                        dim=-1))
Exemple #13
0
 def get_normalized_probs(self, net_output, log_probs):
     """Get normalized probabilities (or log probs) from a net's output."""
     logits = net_output["encoder_out"]
     if log_probs:
         return utils.log_softmax(logits.float(), dim=-1)
     else:
         return utils.softmax(logits.float(), dim=-1)
Exemple #14
0
    def compute_loss(self, model, net_output, sample, reduce=True):
        #get target and generated text
        target = model.get_targets(sample, net_output).view(-1, 1)
        ## semantic sim_loss
        output_tokens = net_output[0]
        sentence_tok = torch.argmax(utils.log_softmax(output_tokens, dim=-1),
                                    -1)  # maxpool
        sentence_txt = self.bpe.decode(
            self.task.target_dictionary.string(sentence_tok))
        ignore_index = self.padding_idx
        if ignore_index is not None:
            non_pad_mask = target.ne(ignore_index)
            target_ig = target[non_pad_mask]
        target_txt = self.bpe.decode(
            self.task.target_dictionary.string(target_ig))
        print("\n\n## sentence_txt: ", sentence_txt, "\n## target_txt: ",
              target_txt)

        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1, 1)
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs,
            target,
            self.eps,
            ignore_index=self.padding_idx,
            reduce=reduce,
        )
        return loss, nll_loss
    def get_normalized_probs(self, net_output, log_probs, retrun_ctc=False):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits_ctc = net_output["logits_ctc"]
        logits = net_output["logits"]
        if log_probs:
            res_ctc = utils.log_softmax(logits_ctc.float(), dim=-1)
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            res_ctc = utils.softmax(logits_ctc.float(), dim=-1)
            res = utils.softmax(logits.float(), dim=-1)
        res_ctc.batch_first = True
        res.batch_first = True

        if retrun_ctc:
            return res_ctc, res
        else:
            return res
Exemple #16
0
 def get_normalized_probs_w2v(self, net_output, log_probs):
     """Get normalized probabilities (or log probs) from a net's output."""
     print(net_output.keys())
     logits = net_output["wav2vec_logits"]
     if log_probs:
         return utils.log_softmax(logits.float(), dim=-1)
     else:
         return utils.softmax(logits.float(), dim=-1)
Exemple #17
0
    def get_normalized_probs_cif(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = self.get_logits_cif(net_output)

        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)
Exemple #18
0
    def get_knn_log_prob(self, queries, tgt, pad_idx):
        def dist_func(d, k, q, function=None):
            if not function:
                # Default behavior for L2 metric is to recompute distances.
                # Default behavior for IP metric is to return faiss distances.
                qsize = q.shape
                if self.metric_type == 'l2':
                    start = time.time()
                    knns_vecs = torch.from_numpy(self.keys[k]).cuda().view(
                        qsize[0], self.k, -1)
                    if self.half:
                        knns_vecs = knns_vecs.half()
                    query_vecs = q.view(qsize[0], 1,
                                        qsize[1]).repeat(1, self.k, 1)
                    l2 = torch.sum((query_vecs - knns_vecs.detach())**2, dim=2)
                    return -1 * l2
                return d

            if function == 'dot':
                qsize = q.shape
                return (torch.from_numpy(self.keys[k]).cuda() *
                        q.view(qsize[0], 1, qsize[1])).sum(dim=-1)

            if function == 'do_not_recomp_l2':
                return -1 * d

            raise ValueError("Invalid knn similarity function!")

        # queries  are TxBxC
        # reshape: (TxB)xC
        qshape = queries.shape
        queries = queries.view(-1, qshape[-1])
        tgt = tgt.contiguous().view(-1)
        dists, knns = self.get_knns(queries[tgt != pad_idx])
        # (T_reducedxB)xK
        dists = torch.from_numpy(dists).cuda()
        start = time.time()
        dists = dist_func(dists,
                          knns,
                          queries[tgt != pad_idx, :],
                          function=self.sim_func)
        probs = utils.log_softmax(dists, dim=-1)

        index_mask = torch.eq(
            torch.from_numpy(self.vals[knns]).long().cuda().squeeze(-1),
            tgt[tgt != pad_idx].unsqueeze(-1)).float()
        index_mask[index_mask == 0] = -10000  # for stability
        index_mask[index_mask == 1] = 0

        # (T_reducedxB)
        yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1).clone()
        full_yhat_knn_prob = torch.full([qshape[0] * qshape[1]], -10000).cuda()
        full_yhat_knn_prob[tgt != pad_idx] = yhat_knn_prob

        # TxBx1
        return full_yhat_knn_prob.view(qshape[0], qshape[1], 1)
Exemple #19
0
    def forward_train(self, prev_output_tokens, encoder_out, target, **kwargs):
        print('Target tokens:', prev_output_tokens)
        # source embeddings
        src_emb = encoder_out['encoder_out']  # B, Ts, ds 
        # target embeddings:
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=None,
        ) if self.embed_positions is not None else None

        decoder_mask = prev_output_tokens.eq(self.padding_idx)
        if not decoder_mask.any():
            decoder_mask = None

        # Build the full grid
        tgt_emb = self.embed_scale * self.embed_tokens(prev_output_tokens)
        if positions is not None:
            tgt_emb += positions
        tgt_emb = self.embedding_dropout(tgt_emb)
        batch_size = src_emb.size(0)
        src_length = src_emb.size(1)
        tgt_length = tgt_emb.size(1)

        # build 2d "image" of embeddings
        src_emb = _expand(src_emb, 1, tgt_length)  # B, Tt, Ts, ds
        tgt_emb = _expand(tgt_emb, 2, src_length)  # B, Tt, Ts, dt
        x = torch.cat((src_emb, tgt_emb), dim=3)   # B, Tt, Ts, C=ds+dt
        x = self.input_dropout(x)
        if 'embed' in self.controller_input:
            observations = x
        # pass through dense convolutional layers
        encoder_mask = encoder_out['encoder_padding_mask']
        x = self.net(
            x, 
            decoder_mask=decoder_mask,
            encoder_mask=encoder_mask,
            incremental_state=None,
        )  # B, Tt, Ts, C
        x, _ = self.aggregator(x)  # B, Tt, Ts, C
        x = self.projection(x) if self.projection is not None else x  # B, Tt, C

        if 'feat' in self.controller_input:
            if 'embed' in self.controller_input:
                observations = torch.cat((observations, x), dim=-1)
            else:
                observations = x
        # Predict
        x = self.prediction_dropout(x)
        x = self.prediction(x)  # B, Tt, Ts, V
        x = utils.log_softmax(x, dim=-1)
        x = x.view(-1, x.size(-1)).gather(
            dim=-1,
            index=target.unsqueeze(-1).expand(-1, -1, src_length).contiguous().view(-1, 1)
        ).view(batch_size, tgt_length, src_length).permute(1,0,2)  # Tt, B, Ts
        controls, gamma, read_labels, write_labels = self.hmm(observations, x)
        return x, observations, controls, gamma, read_labels, write_labels
Exemple #20
0
    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output[0]
        if log_probs:
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            res = utils.softmax(logits.float(), dim=-1)
        res.batch_first = True

        return res
Exemple #21
0
 def compute_cross_entropy(self, logits, sample, reduce=True):
     lprobs = utils.log_softmax(logits, dim=-1, onnx_trace=False)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = sample['target'].view(-1)
     loss = F.nll_loss(
         lprobs,
         target,
         ignore_index=self.padding_idx,
         reduction='sum' if reduce else 'none',
     )
     return loss
Exemple #22
0
 def gumbel_softmax(self, logits, gs_tau=0.5, gs_hard=False, dim=-1):
     logprob = utils.log_softmax(logits,
                                 dim=dim,
                                 onnx_trace=self.onnx_trace)
     logprob = torch.clamp(logprob, math.log(0.1), math.log(0.9))
     gs = F.gumbel_softmax(
         logprob,
         tau=gs_tau,
         hard=gs_hard,
     )
     return gs
Exemple #23
0
    def forward(self, sample, encoder_out, decoder_out):
        # First encode the observations
        if not self.share_embeddings:
            x = self.observation_grid(sample['src_tokens'],
                                      sample['prev_output_tokens'])
        else:
            # The writing input grid
            x = decoder_out[1].clone()

        # Cumulative ResNet:
        x = self.net(x)
        # Cell aggregation
        # The R/W decisions:
        x = self.gate_dropout(x)
        x = self.gate(x)
        s = F.logsigmoid(x)
        RWlogits = torch.cat((s, s - x), dim=-1).contiguous().float()

        with torch.no_grad():
            lprobs = decoder_out[0].clone()
            target = sample['target']
            encoder_mask = encoder_out['encoder_padding_mask']
            decoder_mask = decoder_out[2]
            # Gather the ground truth likelihoods
            B, Tt, Ts, V = lprobs.size()
            lprobs = utils.log_softmax(lprobs, dim=-1)
            scores = lprobs.view(-1, V).gather(
                dim=-1,
                index=target.unsqueeze(-1).expand(-1,
                                                  -1, Ts).contiguous().view(
                                                      -1, 1)  # BTtTs
            ).view(B, Tt, Ts)
            # Forbid padding positions:  # I'm using NLL beware
            if encoder_mask is not None:
                scores = scores.masked_fill(encoder_mask.unsqueeze(1), -1000)
            if decoder_mask is not None:
                scores = scores.masked_fill(decoder_mask.unsqueeze(-1), -1000)

            # The Oracle
            best_context = self.oracle(scores)

            AP = best_context.add(1).float().mean(dim=1) / Ts
            print('-', round(AP.mean().data.item(), 2))
            Gamma = torch.zeros_like(scores).scatter_(
                -1, best_context.unsqueeze(-1), 1.0)  # B, Tt, Ts

        # Write beyond the ideal context
        if self.write_right:
            Gamma = Gamma.cumsum(dim=-1)
            write = Gamma[:, 1:]  # B, Tt-1, Ts
        else:
            write = Gamma[:, 1:].cumsum(dim=-1)  # B, Tt-1, Ts
        read = 1 - write
        return Gamma, RWlogits[:, :-1], read, write
Exemple #24
0
    def get_logits(self, net_output):
        logits = net_output["x"]
        logits = logits.transpose(0, 2)
        logits = logits.reshape(-1, logits.size(-1))

        logtis_ctc = net_output["logtis_ctc"]
        lprob = utils.log_softmax(logtis_ctc.float(), dim=-1)
        lprob = lprob.transpose(0, 1)
        lprob.batch_first = False

        return logits, lprob
Exemple #25
0
 def compute_xet_loss(self, logits, gold, padding_idx, reduce=True):
     lprobs = utils.log_softmax(logits, dim=-1)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = gold.contiguous().view(-1)
     loss = F.nll_loss(
         lprobs,
         target,
         ignore_index=padding_idx,
         reduction='sum' if reduce else 'none',
     )
     return loss, loss
 def get_ctc_output(
         self, net_output: Tuple[Tensor,
                                 Optional[Dict[str,
                                               List[Optional[Tensor]]]]],
         sample: Optional[Dict[str, Tensor]]):
     encoder_out = net_output[1]["encoder_out"]["encoder_out"][0]
     logits = self.encoder.ctc_proj(encoder_out)  # T x B x C
     out = utils.log_softmax(logits.float(), dim=-1)
     padding_mask = net_output[1]["encoder_out"]["encoder_padding_mask"]
     lens = out.new_full((out.shape[1], ), out.shape[0]).long()
     if len(padding_mask) > 0:
         lens -= padding_mask[0].sum(dim=-1)
     return out, lens
Exemple #27
0
 def compute_loss(self, net_output, sample, reduce=True):
     lprobs = utils.log_softmax(net_output, dim=-1)
     lprobs = lprobs.view(-1, lprobs.size(-1))
     target = sample['target'].view(-1, 1)
     non_pad_mask = target.ne(self.padding_idx)
     nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
     smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
     if reduce:
         nll_loss = nll_loss.sum()
         smooth_loss = smooth_loss.sum()
     eps_i = self.eps / lprobs.size(-1)
     loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
     return loss, nll_loss
    def forward(self, sample, encoder_out, decoder_out):
        x = decoder_out[1]
        # Final LN
        if self.final_ln is not None:
            x = self.final_ln(x)
        # Aggregate
        x, _ = self.aggregator(x)
        # A stack of linear layers
        x = self.net(x)
        # The R/W decisions:
        x = self.gate(x)
        s = F.logsigmoid(x)
        RWlogits = torch.cat((s, s - x), dim=-1).float()

        lprobs = decoder_out[0]
        target = sample['target']
        encoder_mask = encoder_out['encoder_padding_mask']
        decoder_mask = decoder_out[2]

        with torch.no_grad():
            # Gather the ground truth likelihoods
            B, Tt, Ts, V = lprobs.size()
            lprobs = utils.log_softmax(lprobs, dim=-1)
            scores = lprobs.view(-1, V).gather(
                dim=-1,
                index=target.unsqueeze(-1).expand(-1,
                                                  -1, Ts).contiguous().view(
                                                      -1, 1)  # BTtTs
            ).view(B, Tt, Ts)
            # Forbid padding positions:  # I'm using NLL beware
            if encoder_mask is not None:
                scores = scores.masked_fill(encoder_mask.unsqueeze(1), -1000)
            if decoder_mask is not None:
                scores = scores.masked_fill(decoder_mask.unsqueeze(-1), -1000)

            # The Oracle
            best_context = self.oracle(scores)

            # AP = best_context.add(1).float().mean(dim=1) / Ts
            # print('AP:', ' '.join(map(lambda x: '{:.2f}'.format(x), AP.tolist())))
            Gamma = torch.zeros_like(scores).scatter_(
                -1, best_context.unsqueeze(-1), 1.0)  # B, Tt, Ts

        # Write beyond the ideal context
        if self.write_right:
            Gamma = Gamma.cumsum(dim=-1)
            write = Gamma[:, 1:]  # B, Tt-1, Ts
        else:
            write = Gamma[:, 1:].cumsum(dim=-1)  # B, Tt-1, Ts
        read = 1 - write
        return Gamma, RWlogits[:, :-1], read, write
    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""
        # print('enter normalized.')
        if 'net_input' in sample.keys():
            enc_seq_ids = sample['net_input']['src_tokens']
        else:
            enc_seq_ids = sample['src_tokens']

        # wvocab_size = net_output[0].size(2)
        # batch_size = enc_seq_ids.size(0)
        # seq_len = enc_seq_ids.size(1)
        # one_hot = torch.zeros(batch_size, seq_len, wvocab_size).cuda().scatter_(dim=2, index=enc_seq_ids.unsqueeze(-1), value=1)
        #
        # copy_probs = torch.matmul(net_output[1]['attn'], one_hot)

        # final_dist = vocab_dist.scatter_add(1, encoder_batch_extend_vocab, attn_dist)

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            generate = utils.softmax(
                logits, dim=-1,
                onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate']
            copy = net_output[1]['attn'] * (1 -
                                            net_output[1]['copy_or_generate'])
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['attn'].size(1), 1)
            final = generate.scatter_add(2, enc_seq_ids, copy)
            final = torch.log(final + 1e-15)
            return final
        else:
            generate = utils.log_softmax(
                logits, dim=-1,
                onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate']
            copy = net_output[1]['attn'] * (1 -
                                            net_output[1]['copy_or_generate'])
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['attn'].size(1), 1)
            final = generate.scatter_add(2, enc_seq_ids, copy)
            return final
def sesim_loss(lprobs, target, epsilon, task=None, bpe=None, rewarder=None, output_tokens=None, ignore_index=None, reduce=True, loss_weight=None, debug=True):
    if loss_weight is None:
        loss_weight = -15
    ## semantic sim_loss
    sentence_tok = torch.argmax(utils.log_softmax(output_tokens, dim=-1),-1) # maxpool
    sentence_txt = bpe.decode(task.target_dictionary.string(sentence_tok)) 

    if ignore_index is not None:
        non_pad_mask = target.ne(ignore_index)
        target_ig=target[non_pad_mask]

    target_txt = bpe.decode(task.target_dictionary.string(target_ig))

    semsim_score = rewarder(target_txt, sentence_txt)
    if debug:
        print("\n\n## sentence_txt: ", sentence_txt,"\n## target_txt: ",  target_txt, "\n## Reward :", semsim_score)

    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        non_pad_mask = target.ne(ignore_index)
        nll_loss = nll_loss[non_pad_mask]
        smooth_loss = smooth_loss[non_pad_mask]
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)
    if reduce:
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
    if debug:
        print("nll_loss, smooth_loss: ",  nll_loss, smooth_loss)
        print("normal_loss, reward: ",  loss, semsim_score)
    print('loss before:')
    print(loss)
    loss = loss * (1 - semsim_score)
    print('loss: ')
    print(loss)
    #loss = loss - loss_weight * semsim_score
    # LOG : loss
    # was 1:1, increased to 1: 100 | 20191212
    # original : loss + 100*semsim_score, neg : loss - 100*semsim_score | 20191212
    if debug:
        print("==="*10)
    return loss, nll_loss, semsim_score # semsim_score : semsim_score