Example #1
0
    def mask(self, inputs):
        '''
        Get a self-attention mask

        The mask will be of shape [T x T] containing elements from the set {0, -inf}

        Input shape:  (B x T x E)
        Output shape: (T x T)
        '''
        if not self.causal:
            return None

        dim = inputs.shape[1]
        device = inputs.device
        mask_store = TransformerDecoderLayer._masks.__dict__
        if device not in mask_store:
            mask = inputs.new_full((dim, dim), float('-inf'))
            mask_store[device] = triu(mask, 1, self.span, self.span)

        mask = mask_store[device]
        if mask.shape[0] < dim:
            mask = mask.resize_(dim, dim).fill_(float('-inf'))
            mask_store[device] = triu(mask, 1, self.span, self.span)
            mask = mask_store[device]

        return mask[None, :dim, :dim]
Example #2
0
    def mask(self, inputs):
        dim = inputs.shape[1]
        device = inputs.device
        mask_store = NPLMLayer._masks.__dict__
        if device not in mask_store:
            mask = inputs.new_full((dim, dim), float('-inf'))
            mask_store[device] = triu(mask, 1, 1, 1)

        mask = mask_store[device]
        return mask[None, :dim, :dim]
Example #3
0
    def masks(self, inputs, local_window_size):
        '''
        Stores the sum of causal mask and context mask
        '''
        dim = inputs.shape[1]
        device = inputs.device

        masks_store = TransformerLayer._all_masks.__dict__
        if device not in masks_store:
            causal_mask = self.mask(inputs)[0]
            context_mask = inputs.new_full((dim, dim), float('-inf'))
            context_mask = triu(context_mask, local_window_size, 1, 1).t()
            masks_store[device] = causal_mask + context_mask

        mask = masks_store[device]
        return mask[None, :dim, :dim]