def mask(self, inputs): ''' Get a self-attention mask The mask will be of shape [T x T] containing elements from the set {0, -inf} Input shape: (B x T x E) Output shape: (T x T) ''' if not self.causal: return None dim = inputs.shape[1] device = inputs.device mask_store = TransformerDecoderLayer._masks.__dict__ if device not in mask_store: mask = inputs.new_full((dim, dim), float('-inf')) mask_store[device] = triu(mask, 1, self.span, self.span) mask = mask_store[device] if mask.shape[0] < dim: mask = mask.resize_(dim, dim).fill_(float('-inf')) mask_store[device] = triu(mask, 1, self.span, self.span) mask = mask_store[device] return mask[None, :dim, :dim]
def mask(self, inputs): dim = inputs.shape[1] device = inputs.device mask_store = NPLMLayer._masks.__dict__ if device not in mask_store: mask = inputs.new_full((dim, dim), float('-inf')) mask_store[device] = triu(mask, 1, 1, 1) mask = mask_store[device] return mask[None, :dim, :dim]
def masks(self, inputs, local_window_size): ''' Stores the sum of causal mask and context mask ''' dim = inputs.shape[1] device = inputs.device masks_store = TransformerLayer._all_masks.__dict__ if device not in masks_store: causal_mask = self.mask(inputs)[0] context_mask = inputs.new_full((dim, dim), float('-inf')) context_mask = triu(context_mask, local_window_size, 1, 1).t() masks_store[device] = causal_mask + context_mask mask = masks_store[device] return mask[None, :dim, :dim]