def buffered_mask(self, tensor): dim = tensor.size(-1) if self._mask is None: self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._mask.size(0) < dim: self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1) return self._mask[:dim, :dim]
def buffered_future_mask(self, tensor): dim = tensor.size(0) if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim]
def buffered_future_mask(self, tensor): """Cached future mask.""" dim = tensor.size(0) #pylint: disable=access-member-before-definition, attribute-defined-outside-init if not hasattr( self, '_future_mask' ) or self._future_mask is None or self._future_mask.device != tensor.device: self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) return self._future_mask[:dim, :dim]
def forward(self, x, need_attention_weights=False): if not need_attention_weights: # Maxpool B, Tt, Ts, C = x.size() mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)), self.waitk) # print('Mask (%d, %d):' % (Tt, Ts), mask) # for t in range(Tt): # ctx = min((t // 1 * 1) + self.waitk, Ts) # print('z_%d = %d' % (t, ctx)) x, _ = (x + mask.unsqueeze(0).unsqueeze(-1)).max(dim=2) # B, Tt, C return x, None # Output attention weights: if need_attention_weights: # x in B, Tt, Ts, C B, Tt, Ts, C = x.size() x, indices = x.max(dim=2) # indices in B, Tt, C with each channel selecting a source position # Terrible but will do: attn = x.new_zeros(B, Tt, Ts) for i in range(Ts): attn[:, :, i] = indices.eq(i).sum(dim=-1) # Normalize attn = attn / attn.sum(dim=-1, keepdim=True) return x, attn
def buffered_future_mask_short(self, tensor, line): dim = tensor.size(1) self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) for i in range(line, dim): self._future_mask[i] = [float('-inf')] * dim return self._future_mask
def mask(self, tensor): dim = tensor.size(-1) half_dim = dim // 2 ones = tensor.new_ones(half_dim, dim).byte() mask = ones.triu(half_dim + 1) + ones.tril(-1) mask = utils.fill_with_neg_inf(tensor.new(mask.size())).masked_fill_( mask, 0) return mask
def _forward_alpha(self, emissions, M): Tt, B, Ts = emissions.size() alpha = utils.fill_with_neg_inf( torch.empty_like(emissions)) # Tt, B, Ts # initialization t=1 # initial = torch.empty_like(alpha[0]).fill_(-math.log(Ts)) # log(1/Ts) initial = utils.fill_with_neg_inf(torch.empty_like(alpha[0])) initial[:, 0] = 0 alpha[0] = emissions[0] + initial # print('Initialize alpha:', alpha[0]) # induction for i in range(1, Tt): alpha[i] = torch.logsumexp(alpha[i - 1].unsqueeze(-1) + M[i - 1], dim=1) alpha[i] = alpha[i] + emissions[i] # print('Emissions@', i, emissions[i]) # print('alpha@',i, alpha[i]) return alpha
def buffered_future_mask(self, tensor): dim = tensor.size(0) # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. if (self._future_mask.size(0) == 0 or (not self._future_mask.device == tensor.device) or self._future_mask.size(0) < dim): self._future_mask = torch.triu( utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1) return self._future_mask[:dim, :dim]
def local_mask(self, tensor, kernel_size, causal, tgt_len=None): """Locality constraint mask.""" rows = tensor.size(0) cols = tensor.size(0) if tgt_len is None else tgt_len if causal: if rows == 1: mask = utils.fill_with_neg_inf(tensor.new(1, cols)) mask[0, -kernel_size:] = 0 return mask else: diag_u, diag_l = 1, kernel_size else: diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \ else (kernel_size // 2, kernel_size // 2 + 1) mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)), diag_u) mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)), -diag_l) return mask1 + mask2
def buffered_future_mask(self, tensor): """attend all surounding words except itself [[0, -inf, 0] [0, 0, -inf] [0, 0, 0]] The attention map is not ture diagonal since we predict y_{t+1} at time-step t """ dim = tensor.size(0) if (not hasattr(self, "_future_mask") or self._future_mask is None or self._future_mask.device != tensor.device): self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) self._future_mask = torch.tril(self._future_mask, 1) if self._future_mask.size(0) < dim: self._future_mask = torch.triu( utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) self._future_mask = torch.tril(self._future_mask, 1) return self._future_mask[:dim, :dim]
def forward(self, x, need_attention_weights=False): # Attention scorees: B, Tt, Ts, C = x.size() alpha = self.w2(self.w1(x)) # B, Tt, Ts, 1 # for every (t,j) allow first j mask = torch.triu(utils.fill_with_neg_inf(x.new(Ts, Ts)), 1).type_as(alpha) alpha = alpha.permute(0,1,3,2) + mask.unsqueeze(0).unsqueeze(0) # B,Tt,Ts,Ts alpha = utils.softmax(alpha, dim=-1) x = torch.matmul(alpha, x) return x, None
def buffered_past_mask(self, tensor): dim = tensor.size(0) if self.onnx_trace: a = torch._dim_arange(tensor, 0).unsqueeze(0).repeat(dim, 1) b = torch._dim_arange(tensor, 0).unsqueeze(1).repeat(1, dim) past_mask = a < b past_mask_neg_inf = torch.where(past_mask, torch.Tensor([float("-Inf")]), torch.Tensor([0])).type_as(tensor) return past_mask_neg_inf if not hasattr( self, '_past_mask' ) or self._past_mask is None or self._past_mask.device != tensor.device: self._past_mask = torch.tril( utils.fill_with_neg_inf(tensor.new(dim, dim)), -1) if self._past_mask.size(0) < dim: self._past_mask = torch.tril( utils.fill_with_neg_inf(self._past_mask.resize_(dim, dim)), -1) return self._past_mask[:dim, :dim]
def mask(self, tensor, mask_curr): dim = tensor.size(-1) half_dim = dim // 2 add = 1 if mask_curr else 0 ones = tensor.new_ones(half_dim, dim).byte() mask = ones.triu(half_dim + add) + ones.tril(-add) mask = utils.fill_with_neg_inf(tensor.new(mask.size())).masked_fill_( mask, 0) return mask
def fill_controls_emissions_grid(self, controls, emissions, indices, src_length): """ Return controls (C) and emissions (E) covering all the grid C : Tt, N, Ts, 2 E : Tt, N, Ts """ N = controls[0].size(0) tgt_length = len(controls) Cread = controls[0].new_zeros((tgt_length, src_length, N, 1)) Cwrite = utils.fill_with_neg_inf(torch.empty_like(Cread)) triu_mask = torch.triu(controls[0].new_ones(tgt_length, src_length), 1).byte() triu_mask = triu_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, N, 1) Cwrite.masked_fill_(triu_mask, 0) C = torch.cat((Cread, Cwrite), dim=-1) E = utils.fill_with_neg_inf(emissions[0].new(tgt_length, src_length, N)) for t, (subC, subE) in enumerate(zip(controls, emissions)): select = [indices[t]] C[t].index_put_(select, subC.transpose(0, 1)) E[t].index_put_(select, subE.transpose(0, 1)) return C.transpose(1, 2), E.transpose(1, 2)
def forward(self, x, need_attention_weights=False): # Attention scorees: B, Tt, Ts, C = x.size() alpha = self.w2(self.w1(x)) # B, Tt, Ts, 1 mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)), self.waitk) alpha = utils.softmax(alpha + mask.unsqueeze(0).unsqueeze(-1), dim=2).type_as(alpha) x = x.permute(0,1,3,2) x = torch.matmul(x, alpha).squeeze(-1) if need_attention_weights: return x, alpha.squeeze(-1) return x, None
def _backward_beta(self, emissions, M): Tt, B, Ts = emissions.size() beta = utils.fill_with_neg_inf( torch.empty_like(emissions)) # Tt, B, Ts # initialization beta[-1] = 0 for i in range(Tt - 2, -1, -1): beta[i] = torch.logsumexp( M[i].transpose(1, 2) + # N, Ts, Ts beta[i + 1].unsqueeze(-1) + # N, Ts, 1 emissions[i + 1].unsqueeze(-1), # N, Ts, 1 dim=1) return beta
def local_mask(self, tensor, kernel_size, causal, tgt_len=None): """Locality constraint mask.""" #if tgt_len is None: rows = tensor.size(0) cols = tgt_len #if tgt_len is None else tgt_len #else: # rows = tensor.size(0)-tgt_len # cols = tgt_len if causal: if rows == 1: mask = utils.fill_with_neg_inf(tensor.new(1, cols)) mask[0, -kernel_size:] = 0 return mask else: diag_u, diag_l = 1, kernel_size else: diag_u, diag_l = ((kernel_size + 1) // 2, (kernel_size + 1) // 2) if kernel_size % 2 == 1 \ else (kernel_size // 2, kernel_size // 2 + 1) print('diagonal u:') print(diag_u) print('diagonal l:') print(diag_l) mask1 = torch.triu(utils.fill_with_neg_inf(tensor.new(rows, cols)), diag_u) plt.imshow(mask1) plt.show() mask2 = torch.tril(utils.fill_with_neg_inf(tensor.new(rows, cols)), -diag_l) plt.imshow(mask2) plt.show() plt.imshow(mask1 + mask2) plt.show() return mask1 + mask2
def generate_mask(self, segment): segment = torch.cat( [segment.new(segment.size(0), 1).fill_(0), segment], dim=-1) doc_mask = segment.eq(0) bsz, dim = segment.size() mask = utils.fill_with_neg_inf(segment.new(dim, dim)) enc_mask, dec_mask = [], [] for batch in range(bsz): enc = torch.triu(mask.clone(), 1) enc[doc_mask[batch].expand_as(enc).byte()] = 0 dec = torch.triu(mask.clone(), 0) dec[doc_mask[batch].expand_as(dec).byte()] = 0 enc_mask.append(enc) dec_mask.append(dec) return torch.stack(enc_mask, 0), torch.stack(dec_mask, 0)
def get_transitions(self, controls): """ Inputs: controls: log(rho) & log(1-rho) read/write probabilities: (Tt, N, Ts, 2) Returns the log-transition matrix (N, Tt, Ts, Ts) k->j : p(z_t+1 = j | z_t = k) = (1-rho_tj) prod_l rho_tl """ Tt, N, Ts, _ = controls.size() # force rho_tTx = 0 controls[:, :, -1, 0] = - float('inf') controls[:, :, -1, 1] = 0 M = utils.fill_with_neg_inf(controls.new_empty((Tt, N, Ts, Ts))) for k in range(Ts): for j in range(k, Ts): M[:, :, k, j] = controls[:, :, j, 1] + torch.sum(controls[:, :, k:j, 0], dim=-1) return M
def buffered_future_mask(self, tensor): #mask for 5-gram dim = tensor.size(0) #if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: #self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) self._future_mask = utils.fill_with_neg_inf(tensor.new(dim, dim)) #self._future for i in range(dim): self._future_mask[i][i + 1] = 0 self._future_mask[i + 1][i] = 0 if (i > dim - 3): break self._future_mask[i][i + 2] = 0 self._future_mask[i + 2][i] = 0 return self._future_mask[:dim, :dim]
def mask(self, tensor): _, half_dim, dim = tensor.size() if self.onnx_trace: # triu and tril are not supported in onnx a = torch._dim_arange(tensor, 2).unsqueeze(0).repeat(half_dim, 1) b = torch._dim_arange(tensor, 1).unsqueeze(1).repeat(1, dim) mask = (a > b + half_dim).float() + (a < b).float() mask = torch.where(mask > 0, torch.Tensor([0]).type_as(tensor), torch.Tensor([float("-Inf")]).type_as(tensor)) else: ones = tensor.new_ones(half_dim, dim).bool() mask = ones.triu(half_dim + 1) + ones.tril(-1) mask = utils.fill_with_neg_inf(tensor.new( mask.size())).masked_fill_(mask, 0) return mask
def get_attention_mask(self, x, src_len, waitk=None): if waitk is None: if self.multi_waitk: assert self.min_waitk <= self.max_waitk waitk = random.randint(min(self.min_waitk, src_len), min(src_len, self.max_waitk)) else: waitk = self.waitk if waitk < src_len: encoder_attn_mask = torch.triu( utils.fill_with_neg_inf(x.new(x.size(0), src_len)), waitk) if waitk <= 0: encoder_attn_mask[:, 0] = 0 else: encoder_attn_mask = None return encoder_attn_mask
def forward(self, src_tokens, src_lengths): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` Returns: dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers future_mask = torch.triu( utils.fill_with_neg_inf(x.new(x.size(0), x.size(0))), 1) for layer in self.layers: # Make the encoder unidirectional x = layer(x, encoder_padding_mask, self_attn_mask=future_mask) if self.normalize: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T }
def forward(self, src_tokens, src_lengths=None, mask=None, **kwargs): """ Args: src_tokens (batch, src_len) src_lengths (batch) Returns: dict: - **encoder_out** (src_len, batch, embed_dim) - **encoder_padding_mask** (batch, src_len) """ # embed tokens and positions x = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x += self.embed_positions(src_tokens) x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None # encoder layers if mask is None: mask = torch.triu( utils.fill_with_neg_inf(x.new(x.size(0), x.size(0))), 1) for layer in self.layers: # Make the encoder unidirectional x = layer( x, encoder_padding_mask, self_attn_mask=mask, ) if self.normalize: x = self.layer_norm(x) return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T }
def forward(self, x, need_attention_weights=False): x = F.glu(self.linear(x), dim=-1) # B, Tt, Ts, C if not need_attention_weights: # Maxpool B, Tt, Ts, C = x.size() mask = torch.triu(utils.fill_with_neg_inf(x.new(Tt, Ts)), self.waitk) x, _ = (x + mask.unsqueeze(0).unsqueeze(-1)).max(dim=2) # B, Tt, C return x, None # Output attention weights: if need_attention_weights: # x in B, Tt, Ts, C B, Tt, Ts, C = x.size() x, indices = x.max(dim=2) # indices in B, Tt, C with each channel selecting a source position # Terrible but will do: attn = x.new_zeros(B, Tt, Ts) for i in range(Ts): attn[:, :, i] = indices.eq(i).sum(dim=-1) # Normalize attn = attn / attn.sum(dim=-1, keepdim=True) return x, attn
def get_transitions(self, controls): """ Inputs: controls: log(rho) & log(1-rho) read/write probabilities: (Tt, B, Ts, 2) Returns the log-transition matrix (Tt, B, Ts, Ts) k->j : p(z_t+1 = j | z_t = k) = (1-rho_tj) prod_l rho_tl """ Tt, N, Ts, _ = controls.size() # force rho_tTx = 0 controls[:, :, -1, 0] = -float('inf') controls[:, :, -1, 1] = 0 M = utils.fill_with_neg_inf(controls.new_empty((Tt, N, Ts, Ts))) for k in range(Ts): for j in range(k, Ts): M[:, :, k, j] = controls[:, :, j, 1] + torch.sum(controls[:, :, k:j, 0], dim=-1) # print('Controls p(read)', torch.exp(controls[:,:,:,0]).round().data) # print('M(t=0)', torch.exp(M[0,0]).data) # print('M(t=ly)', torch.exp(M[-1,0]).data) # print('Sum transitions:', M.exp().sum(dim=-1)) return M
def forward_one(self, prev_output_tokens, encoder_out=None, context_size=1, incremental_state=None, **kwargs): # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, ) if self.embed_positions is not None else None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] if positions is not None: positions = positions[:, -1:] # embed tokens and positions x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None # encoder attn mask following the reading/writing schedule len_tgt x len_src encoder_states = encoder_out['encoder_out'] # len_src, B, C encoder_mask = encoder_out['encoder_padding_mask'] if incremental_state is None: encoder_attn_mask = utils.fill_with_neg_inf( x.new(x.size(0), encoder_states.size(0))) upto = min(context_size + 1, encoder_states.size(0)) encoder_attn_mask[:, :upto] = 0 else: encoder_attn_mask = torch.triu( utils.fill_with_neg_inf( x.new(x.size(0), encoder_states.size(0))), context_size) # decoder layers for e, layer in enumerate(self.layers): x, attn = layer( x, encoder_states, encoder_mask, encoder_attn_mask=encoder_attn_mask, incremental_state=incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, ) if self.layer_norm: x = self.layer_norm(x) # Project only the last token x = x[-1:] # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) if self.adaptive_softmax is None: # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) return x, {'attn': attn}
def forward(self, x, emissions, indices, src_length, src_mask): """ For N sequences in the batch of max_trg_length Tt and src_length Ts Inputs: x : decoder states [(N, #ctx, C) x Tt] emissions: Emissions [(N, #ctx) x Tt] \log p(y_t|z_t=j, ...) """ controls = [self.logsigmoid_pair(sub) for sub in x] # [N, #ctx, 1] xTt controls, emissions = self.fill_controls_emissions_grid(controls, emissions, indices, src_length) #Tt, N, Ts Tt, N, Ts = emissions.size() with torch.no_grad(): # get transition matrix: M = self.get_transitions(controls.clone()) # Tt, N, Ts, Ts # Forward alpha = utils.fill_with_neg_inf(torch.empty_like(emissions)) if self.bias_emission: # penalize large contexts: # print('Unbiased:', emissions[:, 0]) emissions = emissions - self.bias_emission * torch.arange(Ts).view(1, 1, -1).type_as(emissions).to(emissions) # print('Biased :', emissions[:, 0]) # initialization t=1 initial = utils.fill_with_neg_inf(torch.empty_like(alpha[0])) initial[:, 0] = 0 alpha[0] = emissions[0] + initial # induction for i in range(1, Tt): alpha[i] = torch.logsumexp(alpha[i-1].unsqueeze(-1) + M[i-1], dim=1) alpha[i] = alpha[i] + emissions[i] # Backward beta = torch.empty_like(alpha).fill_(-float('inf')) # initialization beta[-1] = 0 for i in range(Tt-2, -1, -1): beta[i] = torch.logsumexp(M[i].transpose(1, 2) + # N, Ts, Ts beta[i+1].unsqueeze(-1) + # N, Ts, 1 emissions[i+1].unsqueeze(-1), # N, Ts, 1 dim=1) # Sanity check: prior = torch.logsumexp(alpha[-1:], dim=-1, keepdim=True) # prior_1 = torch.sum(torch.exp(alpha[1]) * torch.exp(beta[1]), dim=-1) # prior_2 = torch.sum(torch.exp(alpha[2]) * torch.exp(beta[2]), dim=-1) # print('Prior with n=1:', prior_1, 'Prior with n=2', prior_2, 'Prior with n=-1:', torch.exp(prior.squeeze(-1))) # print('Alpha:', alpha[:, 0].exp()) # print('Beta:', beta[:, 0].exp()) gamma = alpha + beta - prior gamma = torch.exp(gamma) # Tt, N, Ts ksi = alpha[:-1].unsqueeze(-1) + beta[1:].unsqueeze(-2) + emissions[1:].unsqueeze(-2) + M[:-1] - prior.unsqueeze(-1) ksi = torch.exp(ksi) # print('Sum Ksi:', ksi.sum(dim=-1).sum(dim=-1)) # print('Sum gamma:', gamma.sum(dim=-1)) # if self.discretize: # binarize r/w labels # write = gamma[1:] # write = write.ge(self.discretize) # read = 1 - write if self.before_after: # binarize r/w labels gamma = torch.cumsum(gamma, dim=-1) write = gamma[1:] read = torch.ones_like(write) for t in range(1, Tt): for j in range(Ts): read[t-1, :, j] = ksi[t-1, :, :j+1, j+1:].sum(dim=-1).sum(dim=-1) print('Write summed:', write.sum(dim=-1)) print('Read summed:', read.sum(dim=-1)) # if self.normalize_rw: # denom = read + write # mask = denom.eq(0) # read = read / denom # write = write / denom # read[mask] = 0 # write[mask] = 0 # elif self.before_after: # # before = torch.cumsum(gamma, dim=-1) # p(z_t<=j) # write = before[1:] # read = 1 - before[1:] # else: # write = gamma[1:] # repartition = torch.cumsum(gamma, dim=-1)[:-1] # q(z_t <= j) = R_tj + W_tj # if self.normalize_rw: # write = write / (repartition + 1e-6) # read = 1 - write # else: # read = repartition - write return emissions, gamma, controls[:-1], read, write
#plt.imshow(x[1, :, :, 0]) #plt.show() print(x[2, :, :, 3]) #plt.imshow(x[2, :, :, 3]) #plt.show() eyed=torch.eye(30,61) #plt.imshow(eyed) #plt.show() rows=30 columns=61 mlen=30 tensor=torch.randn(rows, columns) all_inf=utils.fill_with_neg_inf(tensor.new(rows, columns)) dec_attn_mask = (torch.triu(all_inf, 1 + mlen) + torch.tril(all_inf, -3)).byte()[:, :] # -1 #plt.imshow(dec_attn_mask) #plt.show() tensor2=torch.randn(columns, rows+columns) rows=61 columns=30 all_inf=utils.fill_with_neg_inf(tensor.new(rows, columns)) dec_attn_mask1 = torch.triu( all_inf, diagonal=0) #plt.imshow(dec_attn_mask1) #plt.show()
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, self_attn=False): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing encoder_out (Tensor, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` Returns: tuple: - the last decoder layer's output of shape `(batch, tgt_len, vocab)` - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)` """ incremental_state = None decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) if self_attn: dim = prev_output_tokens.size(1) self_attn_mask = torch.triu( utils.fill_with_neg_inf(prev_output_tokens.new(dim, dim)), 1) self_attn_mask = self_attn_mask.to(prev_output_tokens)[:dim, :dim] else: self_attn_mask = None # embed positions positions = self.embed_positions( prev_output_tokens, ) if self.embed_positions is not None else None # embed tokens and positions x = self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) if positions is not None: x += positions x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C x = x.transpose(0, 1) attn = None inner_states = [x] # decoder layers for layer in self.layers: x, attn = layer(x, encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, decoder_padding_mask, self_attn_mask=self_attn_mask) inner_states.append(x) if self.normalize: x = self.layer_norm(x) # T x B x C -> B x T x C x = x.transpose(0, 1) if self.project_out_dim is not None: x = self.project_out_dim(x) if self.adaptive_softmax is None and self.load_softmax: # project back to size of vocabulary if self.share_input_output_embed: x = F.linear(x, self.embed_tokens.weight) else: x = F.linear(x, self.embed_out) return x, { 'attn': attn, 'inner_states': inner_states, 'predicted_lengths': encoder_out['predicted_lengths'] }
def buffered_future_mask_base(self, tensor): dim = tensor.size(1) self._future_mask = torch.triu( utils.fill_with_neg_inf(tensor.new(dim, dim)), 1).float() return self._future_mask[:dim, :dim]