示例#1
0
    def _forward_expanded(self, x, incremental_state):
        '''Turn the convolution filters into band matrices and do matrix multiplication.
        This is faster when the sequence is short, but less memory efficient.
        This is not used in the decoder during inference.
        '''
        T, B, C = x.size()
        K, H = self.kernel_size, self.num_heads
        R = C // H
        assert R * H == C == self.input_size
        if self.weight_linear:
            if self.sample_lc_sp:
                weight = self.weight_linear(
                    torch.mean(torch.mean(x, dim=0),
                               dim=1).unsqueeze(1)).view(B, H, K)  # B,H,K
            else:
                weight = self.weight_linear(torch.mean(x,
                                                       dim=0)).view(B, H,
                                                                    K)  # B,H,K
        else:
            weight = self.weight.view(H, K)
        if self.weight_softmax:
            if self.weight_linear:
                weight = utils.softmax(
                    weight, dim=2, onnx_trace=self.onnx_trace).type_as(weight)
            else:
                weight = utils.softmax(
                    weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
        if self.weight_linear:
            weight = weight.unsqueeze(0).expand(T, B, H,
                                                K).reshape(T * B, H,
                                                           K).contiguous()
        else:
            weight = weight.view(1, H, K).expand(T * B, H, K).contiguous()
        weight = weight.view(T, B * H, K).transpose(0, 1)  # B*h T,k

        x = x.view(T, B * H, R).transpose(0, 1)  # (B*H,T,R)
        P = self.padding_l
        if K > T and P == K - 1:
            weight = weight.narrow(2, K - T, T)
            K, P = T, T - 1
        # turn the convolution filters into band matrices
        weight_expanded = weight.new_zeros(B * H,
                                           T,
                                           T + K - 1,
                                           requires_grad=False)
        weight_expanded.as_strided((B * H, T, K),
                                   (T * (T + K - 1), T + K, 1)).copy_(weight)
        weight_expanded = weight_expanded.narrow(2, P, T)  # (B*H,T,T)
        weight_expanded = F.dropout(weight_expanded,
                                    self.weight_dropout,
                                    training=self.training)

        if bmm_fp16_support:
            output = torch.bmm(weight_expanded, x)  # (B*H,T,R)
        else:
            output = torch.bmm(weight_expanded.float(),
                               x.float()).type_as(weight)
        output = output.transpose(0, 1).contiguous().view(T, B, C)
        return output
示例#2
0
    def get_normalized_probs(self, ctc_logits, logits, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        if log_probs:
            ctc_res = utils.log_softmax(ctc_logits.float(), dim=-1)
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            ctc_res = utils.softmax(ctc_logits.float(), dim=-1)
            res = utils.softmax(logits.float(), dim=-1)
        ctc_res.batch_first = True
        res.batch_first = True

        return ctc_res, res
示例#3
0
    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits_ctc = net_output["logits_ctc"]
        logits = net_output["logits"]

        if log_probs:
            ctc_res = utils.log_softmax(logits_ctc.float(), dim=-1)
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            ctc_res = utils.softmax(logits_ctc.float(), dim=-1)
            res = utils.softmax(logits.float(), dim=-1)

        return ctc_res, res
    def cross_attentive_loss(self,
                             teacher_states,
                             student_states,
                             teacher_masking,
                             student_masking,
                             eps=1e-6):
        x = teacher_states.transpose(0, 1)  # from T X B X D to B X T X D
        y = student_states.transpose(0, 1)
        if self.cross_attentive_loss_with_norm:
            x = x / (x.norm(dim=2, keepdim=True) + eps)
            y = y / (y.norm(dim=2, keepdim=True) + eps)
        dim = x.size(-1)
        # lengths: batch X seqLen
        sim_scores_xy = torch.bmm(x, y.transpose(1,
                                                 2))  # batch X lenx X leny ]
        if y.dtype == torch.float16:
            sim_scores_xy = sim_scores_xy.float()
            y = y.float()
            x = x.float()
        if teacher_masking != []:
            assert len(teacher_masking) == 1
            sim_scores_xy = sim_scores_xy.masked_fill(
                teacher_masking[0].unsqueeze(-1), float("-inf"))
        if student_masking != []:
            sim_scores_xy = sim_scores_xy.masked_fill(
                student_masking[0].unsqueeze(1), float("-inf"))
        # do masking
        y_weights = utils.softmax(sim_scores_xy, dim=-1)
        if teacher_masking != []:
            y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1),
                                              0)
        x_reconstruct_from_y = torch.bmm(y_weights, y)

        sim_scores_xx = torch.bmm(x, x.transpose(1,
                                                 2))  # batch X lenx X lenx ]
        x_weights = utils.softmax(sim_scores_xx, dim=-1)
        if teacher_masking != []:
            x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1),
                                              0)

        # no gradient for teacher state
        x_reconstruct_from_x = torch.bmm(x_weights, x).detach()
        cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2)
        if teacher_masking != []:
            cost = cost.masked_fill(teacher_masking[0], 0)

        if not self.cross_attentive_loss_with_norm:
            cost = cost / dim
        return cost
示例#5
0
    def get_normalized_probs(self,
                             net_output,
                             log_probs,
                             sample,
                             adaptive_softmax=True):
        """Get normalized probabilities (or log probs) from a net's output."""
        if adaptive_softmax:
            if hasattr(
                    self,
                    'adaptive_softmax') and self.adaptive_softmax is not None:
                if sample is not None:
                    assert 'target' in sample
                    target = sample['target']
                else:
                    target = None
                out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                         target=target)
                return out.exp_() if not log_probs else out

        # judge for extend the previous
        logits = net_output[0] if isinstance(net_output, list) else net_output
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
示例#6
0
    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out
        '''
        logits_list = net_output[0]
        if log_probs:
            return [utils.log_softmax(
                        logits, dim=-1, onnx_trace=self.onnx_trace)
                        for logits in logits_list][0]
        else:
            return [utils.softmax(
                        logits, dim=-1, onnx_trace=self.onnx_trace)
                        for logits in logits_list][0]
        '''
        logits = net_output[0]
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
    def get_normalized_probs_with_temperature(self,
                                              net_output,
                                              log_probs,
                                              sample=None,
                                              temperature=1.):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0] / temperature
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
    def _forward_unfolded(self, x, incremental_state):
        '''The conventional implementation of convolutions.
        Unfolding the input by having a window shifting to the right.'''
        T, B, C = x.size()
        K, H = self.kernel_size, self.num_heads
        R = C // H
        assert R * H == C == self.input_size

        weight = self.weight.view(H, K)
        if incremental_state is not None:
            input_buffer = self._get_input_buffer(incremental_state)
            if input_buffer is None:
                input_buffer = x.new()
            x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3)
            if self.kernel_size > 1:
                self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:])
            x_unfold = x_unfold.view(T*B*H, R, -1)
        else:
            # unfold the input: T x B x C --> T' x B x C x K
            x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0)
            x_unfold = x_unfold.view(T*B*H, R, K)

        if self.weight_softmax:
            weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)

        if incremental_state is not None:
            weight = weight[:, -x_unfold.size(2):]
            K = weight.size(1)

        weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)

        weight = self.weight_dropout_module(weight)
        output = torch.bmm(x_unfold, weight)  # T*B*H x R x 1
        output = output.view(T, B, C)
        return output
示例#9
0
 def get_normalized_probs(self, net_output, log_probs):
     """Get normalized probabilities (or log probs) from a net's output."""
     logits = net_output["encoder_out"]
     if log_probs:
         return utils.log_softmax(logits.float(), dim=-1)
     else:
         return utils.softmax(logits.float(), dim=-1)
示例#10
0
    def get_normalized_probs(
        self,
        net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
        log_probs: bool,
        sample: Optional[Dict[str, Tensor]] = None,
    ):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
            if sample is not None:
                assert "target" in sample
                target = sample["target"]
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            #print('Fairseq Decoder: net_output size: {}'.format(net_output.size()))

            if use_ort_backend:
                return utils.log_softmax(net_output, dim=-1, onnx_trace=self.onnx_trace)

            return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
示例#11
0
    def get_normalized_probs(
        self,
        net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
        log_probs: bool,
        sample: Optional[Dict[str, Tensor]] = None,
    ):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   "adaptive_softmax") and self.adaptive_softmax is not None:
            if sample is not None:
                assert "source" in sample
                source = sample["source"]
            else:
                source = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=source)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            return utils.log_softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
示例#12
0
    def forward(self, query, value, key_padding_mask=None, state=None):
        # projected_query: 1 x bsz x embed_dim
        projected_query = self.query_proj(query).unsqueeze(0)
        key = self.value_proj(value)  # len x bsz x embed_dim
        if self.normalize:
            # normed_v = g * v / ||v||
            normed_v = self.g * self.v / torch.norm(self.v)
            attn_scores = (normed_v * torch.tanh(projected_query + key + \
                self.b)).sum(dim=2) # len x bsz
        else:
            attn_scores = v * torch.tanh(projected_query + key).sum(dim=2)

        if key_padding_mask is not None:
            attn_scores = attn_scores.float().masked_fill_(
                key_padding_mask,
                float('-inf'),
            ).type_as(attn_scores)  # FP16 support: cast to float and back

        attn_scores = utils.softmax(attn_scores,
                                    dim=0,
                                    onnx_trace=self.onnx_trace).type_as(
                                        attn_scores)  # len x bsz

        # sum weighted value. context: bsz x value_dim
        context = (attn_scores.unsqueeze(2) * value).sum(dim=0)
        next_state = attn_scores

        return context, attn_scores, next_state
 def get_normalized_probs(self, net_output, log_probs, sample):
     """Get normalized probabilities (or log probs) from a net's output."""
     logits = net_output
     if log_probs:
         return utils.log_softmax(logits, dim=-1)
     else:
         return utils.softmax(logits, dim=-1)
    def generate(self, models, sample, **unused):
        """Generate a batch of inferences."""
        model = models[0]

        # encoder_output = model.encoder(tbc=False, **sample["net_input"])
        # alphas = CIFFcModel.get_alphas(encoder_output)
        # decode_length = torch.round(alphas.sum(-1)).int()
        # _alphas, num_output = model.resize(alphas, decode_length, noise=0.0)
        #
        # padding_mask = ~utils.sequence_mask(decode_length).bool()
        # cif_outputs = model.cif(encoder_output['encoder_out'][:, :, :-1], _alphas)
        # hidden = model.proj(cif_outputs)
        # logits_ac = model.to_vocab_ac(hidden)
        #
        # infer_threash = self.infer_threshold if self.infer_threshold else model.args.infer_threash
        # for i in range(1):
        #     logits, gold_embedding, pred_mask, token_mask = model.bert_forward(
        #         hidden, logits_ac, padding_mask, None, 0.0,
        #         threash=infer_threash)
        #     logits = self.args.lambda_am * logits_ac + model.args.lambda_lm * logits
        # probs = utils.softmax(logits.float(), dim=-1)
        net_output = model(**sample["net_input"])
        logits = net_output['logits']
        probs = utils.softmax(logits.float(), dim=-1)
        decode_length = net_output['len_logits']

        res = []
        for distribution, length in zip(probs, decode_length):
            result = distribution.argmax(-1)
            score = 0.0
            res.append([{'tokens': result[:length], "score": score}])

        return res
示例#15
0
    def _forward_expanded(self, x, incremental_state):
        """Turn the convolution filters into band matrices and do matrix multiplication.
        This is faster when the sequence is short, but less memory efficient.
        This is not used in the decoder during inference.
        """
        T, B, C = x.size()
        K, H = self.kernel_size, self.num_heads
        R = C // H
        assert R * H == C == self.input_size

        weight = self.weight.view(H, K)
        if self.weight_softmax:
            weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(
                weight
            )
        weight = weight.view(1, H, K).expand(T * B, H, K).contiguous()
        weight = weight.view(T, B * H, K).transpose(0, 1)

        x = x.view(T, B * H, R).transpose(0, 1)
        P = self.padding_l
        if K > T and P == K - 1:
            weight = weight.narrow(2, K - T, T)
            K, P = T, T - 1
        # turn the convolution filters into band matrices
        weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False)
        weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_(
            weight
        )
        weight_expanded = weight_expanded.narrow(2, P, T)
        weight_expanded = self.weight_dropout_module(weight_expanded)

        output = torch.bmm(weight_expanded, x)
        output = output.transpose(0, 1).contiguous().view(T, B, C)
        return output
示例#16
0
    def bert_forward(self, hidden, logits_ac, padding_mask, input_ids=None, gold_rate=0.0, threash=0.8):
        """
        """
        device = hidden.device

        if self.training:
            token_mask = input_ids.ne(self.tgt_dict.cls()) * \
                         input_ids.ne(self.tgt_dict.sep()) * \
                         input_ids.ne(self.tgt_dict.pad())
            gold_embedding = self.bert.embeddings.word_embeddings(input_ids)
            pred_mask = (torch.rand(input_ids.size(), device=device) > gold_rate) * token_mask
        else: # infer
            token_mask = F.pad(~padding_mask, [1, 1, 0, 0], value=0)
            probs = F.pad(utils.softmax(logits_ac.float(), dim=-1), [0, 0, 1, 1, 0, 0], value=0)
            confident, preds = probs.max(-1)
            preds_ids = pred2bert_input(preds, token_mask)
            # preds = torch.where(token_mask, preds, input_ids)
            gold_embedding = self.bert.embeddings.word_embeddings(preds_ids)
            pred_mask = (confident < threash) * token_mask

        hidden_mix = torch.where(pred_mask[:, :, None].repeat(1, 1, hidden.size(-1)),
                                 F.pad(hidden, [0, 0, 1, 1, 0, 0], value=0),
                                 gold_embedding)

        attention_mask = padding2attention_mask(padding_mask)

        embeddings = self.bert.embeddings(inputs_embeds=hidden_mix)
        encoder_outputs = self.bert.encoder(
            embeddings,
            attention_mask=attention_mask[:, None, None, :])

        logits = self.to_vocab(encoder_outputs[0])
        logits = logits[:, 1:-1, :]

        return logits, gold_embedding, pred_mask, token_mask
示例#17
0
 def attn_fn(attn_weights, is_query=False):
     if attn_mask is not None:
         attn_weights += attn_mask
     if key_padding_mask is not None:
         attn_weights = attn_weights.view(bsz, self.num_heads, src_len,
                                          src_len)
         attn_weights = attn_weights.masked_fill(
             key_padding_mask.unsqueeze(1).unsqueeze(2),
             float('-inf'),
         )
         attn_weights = attn_weights.view(bsz * self.num_heads, src_len,
                                          src_len)
     if is_query is True:
         query_mask = torch.eye(
             attn_weights.size(-1)).to(attn_weights) * -1e9
         query_mask[0][0] = 0.0
         attn_weights = attn_weights + query_mask
     attn_weights = utils.softmax(
         attn_weights,
         dim=-1,
     ).type_as(attn_weights)
     attn_weights = F.dropout(attn_weights,
                              p=self.dropout,
                              training=self.training)
     return attn_weights
示例#18
0
    def get_normalized_probs(self,
                             net_output,
                             log_probs,
                             sample,
                             gs_tau=0.5,
                             gs_hard=False):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0][0]
        orders = net_output[0][1]
        if log_probs:
            return (utils.log_softmax(logits,
                                      dim=-1,
                                      onnx_trace=self.onnx_trace),
                    self.gumbel_softmax(orders,
                                        gs_tau=gs_tau,
                                        gs_hard=gs_hard,
                                        dim=-1))
        else:
            return (utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace),
                    self.gumbel_softmax(orders,
                                        gs_tau=gs_tau,
                                        gs_hard=gs_hard,
                                        dim=-1))
示例#19
0
    def get_normalized_probs(self, net_output, log_probs, retrun_ctc=False):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits_ctc = net_output["logits_ctc"]
        logits = net_output["logits"]
        if log_probs:
            res_ctc = utils.log_softmax(logits_ctc.float(), dim=-1)
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            res_ctc = utils.softmax(logits_ctc.float(), dim=-1)
            res = utils.softmax(logits.float(), dim=-1)
        res_ctc.batch_first = True
        res.batch_first = True

        if retrun_ctc:
            return res_ctc, res
        else:
            return res
示例#20
0
 def get_normalized_probs_w2v(self, net_output, log_probs):
     """Get normalized probabilities (or log probs) from a net's output."""
     print(net_output.keys())
     logits = net_output["wav2vec_logits"]
     if log_probs:
         return utils.log_softmax(logits.float(), dim=-1)
     else:
         return utils.softmax(logits.float(), dim=-1)
示例#21
0
    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]

        is_copy = 'p_copy' in net_output[1].keys(
        ) and net_output[1]['p_copy'] is not None
        # print(net_output[1]['attn'])
        if is_copy and False:
            p_copy = net_output[1]['p_copy']
            if 'net_input' in sample.keys():
                enc_seq_ids = sample['net_input']['src_tokens']
            else:
                # for decode step
                enc_seq_ids = sample['src_tokens']
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['copy_attn'].size(1), 1)
            generate_prob = utils.softmax(
                logits, dim=-1, onnx_trace=self.onnx_trace) * (1 - p_copy)
            copy_prob = net_output[1]['copy_attn'] * p_copy
            final = generate_prob.scatter_add(2, enc_seq_ids, copy_prob)
            if log_probs:
                return torch.log(final + 1e-15)
            else:
                return final
        else:
            if log_probs:
                return utils.log_softmax(logits,
                                         dim=-1,
                                         onnx_trace=self.onnx_trace)
            else:
                return utils.softmax(logits,
                                     dim=-1,
                                     onnx_trace=self.onnx_trace)
示例#22
0
 def forward(self, x, need_attention_weights=False):
     # Attention scorees:
     alpha = self.w2(self.w1(x))  # B, Tt, Ts, 1
     alpha = utils.softmax(alpha, dim=2).type_as(alpha)
     x = x.permute(0, 1, 3, 2)
     x = torch.matmul(x, alpha).squeeze(-1)
     if need_attention_weights:
         return x, alpha.squeeze(-1)
     return x, None
示例#23
0
    def get_normalized_probs_cif(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = self.get_logits_cif(net_output)

        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)
示例#24
0
 def one_step(self, x, need_attention_weights=False):
     x = x[:, -1:]  # B, 1, Ts, C
     alpha = self.w2(self.w1(x))  # B, 1, Ts, 1
     alpha = utils.softmax(alpha, dim=2)
     x = x.permute(0, 1, 3, 2)
     x = torch.matmul(x, alpha).squeeze(-1)
     if need_attention_weights:
         return x, alpha.squeeze(-1)
     return x, None
示例#25
0
 def forward(self, x, need_attention_weights=False):
     # Attention scorees:
     B, Tt, Ts, C = x.size()
     alpha = self.w2(self.w1(x))  # B, Tt, Ts, 1
     # for every (t,j) allow first j
     mask = torch.triu(utils.fill_with_neg_inf(x.new(Ts, Ts)), 1).type_as(alpha)
     alpha = alpha.permute(0,1,3,2) + mask.unsqueeze(0).unsqueeze(0)  # B,Tt,Ts,Ts
     alpha = utils.softmax(alpha, dim=-1)
     x = torch.matmul(alpha, x)
     return x, None
示例#26
0
    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output[0]
        if log_probs:
            res = utils.log_softmax(logits.float(), dim=-1)
        else:
            res = utils.softmax(logits.float(), dim=-1)
        res.batch_first = True

        return res
示例#27
0
 def forward(self, x, need_attention_weights=False):
     # Attention scorees:
     B, Tt, Ts, C = x.size()
     alpha = self.w2(self.w1(x))  # B, Tt, Ts, 1
     mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)), self.waitk)
     alpha = utils.softmax(alpha + mask.unsqueeze(0).unsqueeze(-1), dim=2).type_as(alpha)
     x = x.permute(0,1,3,2)
     x = torch.matmul(x, alpha).squeeze(-1)
     if need_attention_weights:
         return x, alpha.squeeze(-1)
     return x, None
 def gumbel_softmax(self, logits, gs_tau=0.5, gs_hard=False, dim=-1):
     if not gs_hard:
         prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
         prob_clamp = torch.clamp(
             prob, self.clamp_value,
             1. - (self.decoder_max_order - 1) * self.clamp_value)
         logprob = torch.log(prob_clamp if self.gs_clamp else prob)
         gs = F.gumbel_softmax(
             logprob,
             tau=gs_tau,
             hard=False,
         )
     else:
         prob = utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
         prob_clamp = torch.clamp(
             prob, self.clamp_value,
             1. - (self.decoder_max_order - 1) * self.clamp_value)
         max_idx = torch.argmax(logits, -1, keepdim=True)
         one_hot = logits.new_zeros(logits.size())
         gs = one_hot.scatter(-1, max_idx, 1)
     return gs, prob, prob_clamp
    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""
        # print('enter normalized.')
        if 'net_input' in sample.keys():
            enc_seq_ids = sample['net_input']['src_tokens']
        else:
            enc_seq_ids = sample['src_tokens']

        # wvocab_size = net_output[0].size(2)
        # batch_size = enc_seq_ids.size(0)
        # seq_len = enc_seq_ids.size(1)
        # one_hot = torch.zeros(batch_size, seq_len, wvocab_size).cuda().scatter_(dim=2, index=enc_seq_ids.unsqueeze(-1), value=1)
        #
        # copy_probs = torch.matmul(net_output[1]['attn'], one_hot)

        # final_dist = vocab_dist.scatter_add(1, encoder_batch_extend_vocab, attn_dist)

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        if log_probs:
            generate = utils.softmax(
                logits, dim=-1,
                onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate']
            copy = net_output[1]['attn'] * (1 -
                                            net_output[1]['copy_or_generate'])
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['attn'].size(1), 1)
            final = generate.scatter_add(2, enc_seq_ids, copy)
            final = torch.log(final + 1e-15)
            return final
        else:
            generate = utils.log_softmax(
                logits, dim=-1,
                onnx_trace=self.onnx_trace) * net_output[1]['copy_or_generate']
            copy = net_output[1]['attn'] * (1 -
                                            net_output[1]['copy_or_generate'])
            enc_seq_ids = enc_seq_ids.unsqueeze(1).repeat(
                1, net_output[1]['attn'].size(1), 1)
            final = generate.scatter_add(2, enc_seq_ids, copy)
            return final
示例#30
0
    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""

        if hasattr(self,
                   'adaptive_softmax') and self.adaptive_softmax is not None:
            if sample is not None:
                assert 'target' in sample
                target = sample['target']
            else:
                target = None
            out = self.adaptive_softmax.get_log_prob(net_output[0],
                                                     target=target)
            return out.exp_() if not log_probs else out

        logits = net_output[0]
        copy_scores = net_output[1]["copy_scores"]

        p_copy = net_output[1]["p_copy"].float()
        if log_probs:
            return torch.log((1 - p_copy) * utils.softmax(logits, dim=-1) +
                             p_copy * copy_scores.float())
        else:
            return (1 - p_copy) * utils.softmax(
                logits, dim=-1) + p_copy * copy_scores.float()