Exemple #1
0
 def buffered_mask(self, tensor):
     dim = tensor.size(-1)
     if self._mask is None:
         self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
     if self._mask.size(0) < dim:
         self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
     return self._mask[:dim, :dim]
 def buffered_future_mask(self, tensor):
     dim = tensor.size(0)
     if not hasattr(
             self, '_future_mask'
     ) or self._future_mask is None or self._future_mask.device != tensor.device:
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
     if self._future_mask.size(0) < dim:
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
             1)
     return self._future_mask[:dim, :dim]
Exemple #3
0
 def buffered_future_mask(self, tensor):
     """Cached future mask."""
     dim = tensor.size(0)
     #pylint: disable=access-member-before-definition, attribute-defined-outside-init
     if not hasattr(
             self, '_future_mask'
     ) or self._future_mask is None or self._future_mask.device != tensor.device:
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
     if self._future_mask.size(0) < dim:
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
             1)
     return self._future_mask[:dim, :dim]
Exemple #4
0
 def forward(self, x, need_attention_weights=False):
     if not need_attention_weights:
         # Maxpool
         B, Tt, Ts, C = x.size()
         mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)),
                           self.waitk)
         # print('Mask (%d, %d):' % (Tt, Ts), mask)
         # for t in range(Tt):
         # ctx = min((t // 1 * 1)  + self.waitk, Ts)
         # print('z_%d = %d' % (t, ctx))
         x, _ = (x + mask.unsqueeze(0).unsqueeze(-1)).max(dim=2)  # B, Tt, C
         return x, None
     # Output attention weights:
     if need_attention_weights:
         # x in B, Tt, Ts, C
         B, Tt, Ts, C = x.size()
         x, indices = x.max(dim=2)
         # indices in B, Tt, C with each channel selecting a source position
         # Terrible but will do:
         attn = x.new_zeros(B, Tt, Ts)
         for i in range(Ts):
             attn[:, :, i] = indices.eq(i).sum(dim=-1)
         # Normalize
         attn = attn / attn.sum(dim=-1, keepdim=True)
     return x, attn
 def buffered_future_mask_short(self, tensor, line):
     dim = tensor.size(1)
     self._future_mask = torch.triu(
         utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
     for i in range(line, dim):
         self._future_mask[i] = [float('-inf')] * dim
     return self._future_mask
Exemple #6
0
 def mask(self, tensor):
     dim = tensor.size(-1)
     half_dim = dim // 2
     ones = tensor.new_ones(half_dim, dim).byte()
     mask = ones.triu(half_dim + 1) + ones.tril(-1)
     mask = utils.fill_with_neg_inf(tensor.new(mask.size())).masked_fill_(
         mask, 0)
     return mask
Exemple #7
0
 def _forward_alpha(self, emissions, M):
     Tt, B, Ts = emissions.size()
     alpha = utils.fill_with_neg_inf(
         torch.empty_like(emissions))  # Tt, B, Ts
     # initialization  t=1
     # initial = torch.empty_like(alpha[0]).fill_(-math.log(Ts))  # log(1/Ts)
     initial = utils.fill_with_neg_inf(torch.empty_like(alpha[0]))
     initial[:, 0] = 0
     alpha[0] = emissions[0] + initial
     # print('Initialize alpha:', alpha[0])
     # induction
     for i in range(1, Tt):
         alpha[i] = torch.logsumexp(alpha[i - 1].unsqueeze(-1) + M[i - 1],
                                    dim=1)
         alpha[i] = alpha[i] + emissions[i]
         # print('Emissions@', i, emissions[i])
         # print('alpha@',i, alpha[i])
     return alpha
Exemple #8
0
 def buffered_future_mask(self, tensor):
     dim = tensor.size(0)
     # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
     if (self._future_mask.size(0) == 0
             or (not self._future_mask.device == tensor.device)
             or self._future_mask.size(0) < dim):
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1)
     return self._future_mask[:dim, :dim]
Exemple #9
0
    def local_mask(self, tensor, kernel_size, causal, tgt_len=None):
        """Locality constraint mask."""
        rows = tensor.size(0)
        cols = tensor.size(0) if tgt_len is None else tgt_len
        if causal:
            if rows == 1:
                mask = utils.fill_with_neg_inf(tensor.new(1, cols))
                mask[0, -kernel_size:] = 0
                return mask
            else:
                diag_u, diag_l = 1, kernel_size
        else:
            diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \
                else (kernel_size // 2, kernel_size // 2 + 1)
        mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)), diag_u)
        mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)), -diag_l)

        return mask1 + mask2
 def buffered_future_mask(self, tensor):
     """attend all surounding words except itself
        [[0, -inf, 0]
         [0,  0, -inf]
         [0,  0,   0]]
     The attention map is not ture diagonal since we predict y_{t+1} at time-step t
     """
     dim = tensor.size(0)
     if (not hasattr(self, "_future_mask") or self._future_mask is None
             or self._future_mask.device != tensor.device):
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
         self._future_mask = torch.tril(self._future_mask, 1)
     if self._future_mask.size(0) < dim:
         self._future_mask = torch.triu(
             utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)),
             1)
         self._future_mask = torch.tril(self._future_mask, 1)
     return self._future_mask[:dim, :dim]
Exemple #11
0
 def forward(self, x, need_attention_weights=False):
     # Attention scorees:
     B, Tt, Ts, C = x.size()
     alpha = self.w2(self.w1(x))  # B, Tt, Ts, 1
     # for every (t,j) allow first j
     mask = torch.triu(utils.fill_with_neg_inf(x.new(Ts, Ts)), 1).type_as(alpha)
     alpha = alpha.permute(0,1,3,2) + mask.unsqueeze(0).unsqueeze(0)  # B,Tt,Ts,Ts
     alpha = utils.softmax(alpha, dim=-1)
     x = torch.matmul(alpha, x)
     return x, None
 def buffered_past_mask(self, tensor):
     dim = tensor.size(0)
     if self.onnx_trace:
         a = torch._dim_arange(tensor, 0).unsqueeze(0).repeat(dim, 1)
         b = torch._dim_arange(tensor, 0).unsqueeze(1).repeat(1, dim)
         past_mask = a < b
         past_mask_neg_inf = torch.where(past_mask,
                                         torch.Tensor([float("-Inf")]),
                                         torch.Tensor([0])).type_as(tensor)
         return past_mask_neg_inf
     if not hasattr(
             self, '_past_mask'
     ) or self._past_mask is None or self._past_mask.device != tensor.device:
         self._past_mask = torch.tril(
             utils.fill_with_neg_inf(tensor.new(dim, dim)), -1)
     if self._past_mask.size(0) < dim:
         self._past_mask = torch.tril(
             utils.fill_with_neg_inf(self._past_mask.resize_(dim, dim)), -1)
     return self._past_mask[:dim, :dim]
Exemple #13
0
    def mask(self, tensor, mask_curr):
        dim = tensor.size(-1)
        half_dim = dim // 2

        add = 1 if mask_curr else 0

        ones = tensor.new_ones(half_dim, dim).byte()
        mask = ones.triu(half_dim + add) + ones.tril(-add)
        mask = utils.fill_with_neg_inf(tensor.new(mask.size())).masked_fill_(
            mask, 0)
        return mask
Exemple #14
0
 def fill_controls_emissions_grid(self, controls, emissions, indices, src_length):
     """
     Return controls (C) and emissions (E) covering all the grid
     C : Tt, N, Ts, 2
     E : Tt, N, Ts
     """
     N = controls[0].size(0)
     tgt_length = len(controls)
     Cread = controls[0].new_zeros((tgt_length, src_length, N, 1))
     Cwrite = utils.fill_with_neg_inf(torch.empty_like(Cread))
     triu_mask = torch.triu(controls[0].new_ones(tgt_length, src_length), 1).byte()
     triu_mask = triu_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, N, 1)
     Cwrite.masked_fill_(triu_mask, 0)
     C = torch.cat((Cread, Cwrite), dim=-1)
     E = utils.fill_with_neg_inf(emissions[0].new(tgt_length, src_length, N))
     for t, (subC, subE) in enumerate(zip(controls, emissions)):
         select = [indices[t]]
         C[t].index_put_(select, subC.transpose(0, 1))
         E[t].index_put_(select, subE.transpose(0, 1))
     return C.transpose(1, 2), E.transpose(1, 2)
Exemple #15
0
 def forward(self, x, need_attention_weights=False):
     # Attention scorees:
     B, Tt, Ts, C = x.size()
     alpha = self.w2(self.w1(x))  # B, Tt, Ts, 1
     mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)), self.waitk)
     alpha = utils.softmax(alpha + mask.unsqueeze(0).unsqueeze(-1), dim=2).type_as(alpha)
     x = x.permute(0,1,3,2)
     x = torch.matmul(x, alpha).squeeze(-1)
     if need_attention_weights:
         return x, alpha.squeeze(-1)
     return x, None
Exemple #16
0
 def _backward_beta(self, emissions, M):
     Tt, B, Ts = emissions.size()
     beta = utils.fill_with_neg_inf(
         torch.empty_like(emissions))  # Tt, B, Ts
     # initialization
     beta[-1] = 0
     for i in range(Tt - 2, -1, -1):
         beta[i] = torch.logsumexp(
             M[i].transpose(1, 2) +  # N, Ts, Ts
             beta[i + 1].unsqueeze(-1) +  # N, Ts, 1
             emissions[i + 1].unsqueeze(-1),  # N, Ts, 1
             dim=1)
     return beta
    def local_mask(self, tensor, kernel_size, causal, tgt_len=None):
        """Locality constraint mask."""
        #if tgt_len is None:
        rows = tensor.size(0)
        cols = tgt_len  #if tgt_len is None else tgt_len
        #else:
        #    rows = tensor.size(0)-tgt_len
        #    cols = tgt_len

        if causal:
            if rows == 1:
                mask = utils.fill_with_neg_inf(tensor.new(1, cols))
                mask[0, -kernel_size:] = 0
                return mask
            else:
                diag_u, diag_l = 1, kernel_size
        else:

            diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \
                else (kernel_size // 2, kernel_size // 2 + 1)

        print('diagonal u:')
        print(diag_u)
        print('diagonal l:')
        print(diag_l)
        mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)),
                           diag_u)

        plt.imshow(mask1)
        plt.show()
        mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)),
                           -diag_l)
        plt.imshow(mask2)
        plt.show()
        plt.imshow(mask1 + mask2)
        plt.show()
        return mask1 + mask2
Exemple #18
0
 def generate_mask(self, segment):
     segment = torch.cat(
         [segment.new(segment.size(0), 1).fill_(0), segment], dim=-1)
     doc_mask = segment.eq(0)
     bsz, dim = segment.size()
     mask = utils.fill_with_neg_inf(segment.new(dim, dim))
     enc_mask, dec_mask = [], []
     for batch in range(bsz):
         enc = torch.triu(mask.clone(), 1)
         enc[doc_mask[batch].expand_as(enc).byte()] = 0
         dec = torch.triu(mask.clone(), 0)
         dec[doc_mask[batch].expand_as(dec).byte()] = 0
         enc_mask.append(enc)
         dec_mask.append(dec)
     return torch.stack(enc_mask, 0), torch.stack(dec_mask, 0)
Exemple #19
0
 def get_transitions(self, controls):
     """
     Inputs:
         controls:  log(rho) & log(1-rho)  read/write probabilities: (Tt, N, Ts, 2)
     Returns the log-transition matrix (N, Tt, Ts, Ts)
         k->j :  p(z_t+1 = j | z_t = k) = (1-rho_tj) prod_l rho_tl
     """
     Tt, N, Ts, _ = controls.size()
     # force rho_tTx = 0
     controls[:, :, -1, 0] = - float('inf')
     controls[:, :, -1, 1] = 0
     M = utils.fill_with_neg_inf(controls.new_empty((Tt, N, Ts, Ts)))
     for k in range(Ts):
         for j in range(k, Ts):
             M[:, :, k, j] = controls[:, :, j, 1] + torch.sum(controls[:, :, k:j, 0], dim=-1)
     return M
    def buffered_future_mask(self, tensor):
        #mask for 5-gram
        dim = tensor.size(0)
        #if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
        #self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        self._future_mask = utils.fill_with_neg_inf(tensor.new(dim, dim))
        #self._future
        for i in range(dim):
            self._future_mask[i][i + 1] = 0
            self._future_mask[i + 1][i] = 0
            if (i > dim - 3):
                break
            self._future_mask[i][i + 2] = 0
            self._future_mask[i + 2][i] = 0

        return self._future_mask[:dim, :dim]
Exemple #21
0
 def mask(self, tensor):
     _, half_dim, dim = tensor.size()
     if self.onnx_trace:
         # triu and tril are not supported in onnx
         a = torch._dim_arange(tensor, 2).unsqueeze(0).repeat(half_dim, 1)
         b = torch._dim_arange(tensor, 1).unsqueeze(1).repeat(1, dim)
         mask = (a > b + half_dim).float() + (a < b).float()
         mask = torch.where(mask > 0,
                            torch.Tensor([0]).type_as(tensor),
                            torch.Tensor([float("-Inf")]).type_as(tensor))
     else:
         ones = tensor.new_ones(half_dim, dim).bool()
         mask = ones.triu(half_dim + 1) + ones.tril(-1)
         mask = utils.fill_with_neg_inf(tensor.new(
             mask.size())).masked_fill_(mask, 0)
     return mask
Exemple #22
0
    def get_attention_mask(self, x, src_len, waitk=None):
        if waitk is None:
            if self.multi_waitk:
                assert self.min_waitk <= self.max_waitk
                waitk = random.randint(min(self.min_waitk, src_len),
                                       min(src_len, self.max_waitk))
            else:
                waitk = self.waitk

        if waitk < src_len:
            encoder_attn_mask = torch.triu(
                utils.fill_with_neg_inf(x.new(x.size(0), src_len)), waitk)
            if waitk <= 0:
                encoder_attn_mask[:, 0] = 0

        else:
            encoder_attn_mask = None
        return encoder_attn_mask
Exemple #23
0
    def forward(self, src_tokens, src_lengths):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        future_mask = torch.triu(
            utils.fill_with_neg_inf(x.new(x.size(0), x.size(0))), 1)
        for layer in self.layers:
            # Make the encoder unidirectional
            x = layer(x, encoder_padding_mask, self_attn_mask=future_mask)

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }
Exemple #24
0
    def forward(self, src_tokens, src_lengths=None, mask=None, **kwargs):
        """
        Args: src_tokens (batch, src_len)
              src_lengths (batch) 
        Returns:
            dict: - **encoder_out** (src_len, batch, embed_dim)
                  - **encoder_padding_mask**  (batch, src_len)
        """
        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(src_tokens)
        if self.embed_positions is not None:
            x += self.embed_positions(src_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        if mask is None:
            mask = torch.triu(
                utils.fill_with_neg_inf(x.new(x.size(0), x.size(0))), 1)
        for layer in self.layers:
            # Make the encoder unidirectional
            x = layer(
                x,
                encoder_padding_mask,
                self_attn_mask=mask,
            )

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }
Exemple #25
0
 def forward(self, x, need_attention_weights=False):
     x = F.glu(self.linear(x), dim=-1)  # B, Tt, Ts, C
     if not need_attention_weights:
         # Maxpool
         B, Tt, Ts, C = x.size()
         mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)),
                           self.waitk)
         x, _ = (x + mask.unsqueeze(0).unsqueeze(-1)).max(dim=2)  # B, Tt, C
         return x, None
     # Output attention weights:
     if need_attention_weights:
         # x in B, Tt, Ts, C
         B, Tt, Ts, C = x.size()
         x, indices = x.max(dim=2)
         # indices in B, Tt, C with each channel selecting a source position
         # Terrible but will do:
         attn = x.new_zeros(B, Tt, Ts)
         for i in range(Ts):
             attn[:, :, i] = indices.eq(i).sum(dim=-1)
         # Normalize
         attn = attn / attn.sum(dim=-1, keepdim=True)
     return x, attn
Exemple #26
0
 def get_transitions(self, controls):
     """
     Inputs:
         controls:  log(rho) & log(1-rho)  read/write probabilities: (Tt, B, Ts, 2)
     Returns the log-transition matrix (Tt, B, Ts, Ts)
         k->j :  p(z_t+1 = j | z_t = k) = (1-rho_tj) prod_l rho_tl
     """
     Tt, N, Ts, _ = controls.size()
     # force rho_tTx = 0
     controls[:, :, -1, 0] = -float('inf')
     controls[:, :, -1, 1] = 0
     M = utils.fill_with_neg_inf(controls.new_empty((Tt, N, Ts, Ts)))
     for k in range(Ts):
         for j in range(k, Ts):
             M[:, :, k,
               j] = controls[:, :, j, 1] + torch.sum(controls[:, :, k:j, 0],
                                                     dim=-1)
     # print('Controls p(read)', torch.exp(controls[:,:,:,0]).round().data)
     # print('M(t=0)', torch.exp(M[0,0]).data)
     # print('M(t=ly)', torch.exp(M[-1,0]).data)
     # print('Sum transitions:', M.exp().sum(dim=-1))
     return M
Exemple #27
0
    def forward_one(self,
                    prev_output_tokens,
                    encoder_out=None,
                    context_size=1,
                    incremental_state=None,
                    **kwargs):

        # embed positions
        positions = self.embed_positions(
            prev_output_tokens,
            incremental_state=incremental_state,
        ) if self.embed_positions is not None else None

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None

        # encoder attn mask following the reading/writing schedule len_tgt x len_src
        encoder_states = encoder_out['encoder_out']  # len_src, B, C
        encoder_mask = encoder_out['encoder_padding_mask']
        if incremental_state is None:
            encoder_attn_mask = utils.fill_with_neg_inf(
                x.new(x.size(0), encoder_states.size(0)))
            upto = min(context_size + 1, encoder_states.size(0))
            encoder_attn_mask[:, :upto] = 0
        else:
            encoder_attn_mask = torch.triu(
                utils.fill_with_neg_inf(
                    x.new(x.size(0), encoder_states.size(0))), context_size)

        # decoder layers
        for e, layer in enumerate(self.layers):
            x, attn = layer(
                x,
                encoder_states,
                encoder_mask,
                encoder_attn_mask=encoder_attn_mask,
                incremental_state=incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None else None,
            )

        if self.layer_norm:
            x = self.layer_norm(x)

        # Project only the last token
        x = x[-1:]
        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        if self.adaptive_softmax is None:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = F.linear(x, self.embed_out)

        return x, {'attn': attn}
Exemple #28
0
    def forward(self, x, emissions, indices, src_length, src_mask):
        """
        For N sequences in the batch of max_trg_length Tt and src_length Ts
        Inputs: 
            x : decoder states [(N, #ctx, C)  x Tt]
            emissions: Emissions [(N, #ctx) x Tt]  \log p(y_t|z_t=j, ...) 

        """
        controls = [self.logsigmoid_pair(sub) for sub in x]  # [N, #ctx, 1] xTt
        controls, emissions = self.fill_controls_emissions_grid(controls, emissions, indices, src_length) #Tt, N, Ts
        Tt, N, Ts = emissions.size()
        with torch.no_grad():
            # get transition matrix:
            M = self.get_transitions(controls.clone())  # Tt, N, Ts, Ts
            # Forward
            alpha = utils.fill_with_neg_inf(torch.empty_like(emissions))
            if self.bias_emission:  # penalize large contexts:
                # print('Unbiased:', emissions[:, 0])
                emissions = emissions - self.bias_emission * torch.arange(Ts).view(1, 1, -1).type_as(emissions).to(emissions)
                # print('Biased :', emissions[:, 0])
            # initialization  t=1
            initial = utils.fill_with_neg_inf(torch.empty_like(alpha[0])) 
            initial[:, 0] = 0
            alpha[0] = emissions[0] + initial
            # induction
            for i in range(1, Tt):
                alpha[i] = torch.logsumexp(alpha[i-1].unsqueeze(-1) + M[i-1], dim=1)
                alpha[i] = alpha[i] + emissions[i]

            # Backward
            beta = torch.empty_like(alpha).fill_(-float('inf'))
            # initialization
            beta[-1] = 0
            for i in range(Tt-2, -1, -1):
                beta[i] = torch.logsumexp(M[i].transpose(1, 2) +  # N, Ts, Ts
                                          beta[i+1].unsqueeze(-1) +  # N, Ts, 1
                                          emissions[i+1].unsqueeze(-1),  # N, Ts, 1
                                          dim=1)

            # Sanity check:
            prior = torch.logsumexp(alpha[-1:], dim=-1, keepdim=True)
            # prior_1 = torch.sum(torch.exp(alpha[1]) * torch.exp(beta[1]), dim=-1)
            # prior_2 = torch.sum(torch.exp(alpha[2]) * torch.exp(beta[2]), dim=-1)
            # print('Prior with n=1:', prior_1, 'Prior with n=2', prior_2, 'Prior with n=-1:', torch.exp(prior.squeeze(-1)))
            # print('Alpha:', alpha[:, 0].exp())
            # print('Beta:', beta[:, 0].exp())

            gamma = alpha + beta - prior
            gamma = torch.exp(gamma)  # Tt, N, Ts
            ksi = alpha[:-1].unsqueeze(-1) + beta[1:].unsqueeze(-2) + emissions[1:].unsqueeze(-2) + M[:-1] - prior.unsqueeze(-1)
            ksi = torch.exp(ksi)
            # print('Sum Ksi:', ksi.sum(dim=-1).sum(dim=-1))
            # print('Sum gamma:', gamma.sum(dim=-1))

            # if self.discretize: # binarize r/w labels
                # write = gamma[1:]
                # write = write.ge(self.discretize)
                # read = 1 - write

            if self.before_after: # binarize r/w labels
                gamma = torch.cumsum(gamma, dim=-1)

            write = gamma[1:]
            read = torch.ones_like(write)
            for t in range(1, Tt):
                for j in range(Ts):
                    read[t-1, :, j] = ksi[t-1, :, :j+1, j+1:].sum(dim=-1).sum(dim=-1)
            print('Write summed:', write.sum(dim=-1))
            print('Read summed:', read.sum(dim=-1))

                # if self.normalize_rw:
                    # denom = read + write
                    # mask = denom.eq(0)
                    # read = read / denom
                    # write = write / denom
                    # read[mask] = 0
                    # write[mask] = 0

            # elif self.before_after: #
                # before = torch.cumsum(gamma, dim=-1)  # p(z_t<=j)
                # write = before[1:]
                # read = 1 - before[1:]
            # else: 
                # write = gamma[1:]
                # repartition = torch.cumsum(gamma, dim=-1)[:-1]  # q(z_t <= j) = R_tj + W_tj
                # if self.normalize_rw:
                    # write = write / (repartition + 1e-6)
                    # read = 1 - write
                # else:
                    # read = repartition - write

        return emissions, gamma, controls[:-1], read, write
Exemple #29
0
#plt.imshow(x[1, :, :, 0])
#plt.show()
print(x[2, :, :, 3])
#plt.imshow(x[2, :, :, 3])
#plt.show()

eyed=torch.eye(30,61)
#plt.imshow(eyed)
#plt.show()

rows=30
columns=61
mlen=30

tensor=torch.randn(rows, columns)
all_inf=utils.fill_with_neg_inf(tensor.new(rows, columns))
dec_attn_mask = (torch.triu(all_inf, 1 + mlen)
                 + torch.tril(all_inf, -3)).byte()[:, :]  # -1

#plt.imshow(dec_attn_mask)
#plt.show()
tensor2=torch.randn(columns, rows+columns)
rows=61
columns=30
all_inf=utils.fill_with_neg_inf(tensor.new(rows, columns))
dec_attn_mask1 = torch.triu(
    all_inf, diagonal=0)

#plt.imshow(dec_attn_mask1)
#plt.show()
Exemple #30
0
    def forward(self,
                prev_output_tokens,
                encoder_out=None,
                incremental_state=None,
                self_attn=False):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for input feeding/teacher forcing
            encoder_out (Tensor, optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
        Returns:
            tuple:
                - the last decoder layer's output of shape `(batch, tgt_len,
                  vocab)`
                - the last decoder layer's attention weights of shape `(batch,
                  tgt_len, src_len)`
        """
        incremental_state = None

        decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
        if self_attn:
            dim = prev_output_tokens.size(1)
            self_attn_mask = torch.triu(
                utils.fill_with_neg_inf(prev_output_tokens.new(dim, dim)), 1)
            self_attn_mask = self_attn_mask.to(prev_output_tokens)[:dim, :dim]
        else:
            self_attn_mask = None

        # embed positions
        positions = self.embed_positions(
            prev_output_tokens, ) if self.embed_positions is not None else None

        # embed tokens and positions
        x = self.embed_tokens(prev_output_tokens)
        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)
        # B x T x C -> T x B x C
        x = x.transpose(0, 1)
        attn = None
        inner_states = [x]

        # decoder layers
        for layer in self.layers:
            x, attn = layer(x,
                            encoder_out['encoder_out']
                            if encoder_out is not None else None,
                            encoder_out['encoder_padding_mask']
                            if encoder_out is not None else None,
                            decoder_padding_mask,
                            self_attn_mask=self_attn_mask)
            inner_states.append(x)

        if self.normalize:
            x = self.layer_norm(x)
        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        if self.adaptive_softmax is None and self.load_softmax:
            # project back to size of vocabulary
            if self.share_input_output_embed:
                x = F.linear(x, self.embed_tokens.weight)
            else:
                x = F.linear(x, self.embed_out)

        return x, {
            'attn': attn,
            'inner_states': inner_states,
            'predicted_lengths': encoder_out['predicted_lengths']
        }
 def buffered_future_mask_base(self, tensor):
     dim = tensor.size(1)
     self._future_mask = torch.triu(
         utils.fill_with_neg_inf(tensor.new(dim, dim)), 1).float()
     return self._future_mask[:dim, :dim]