def __init__(self, hidden_size, attn_span, dropout, adapt_span_params, **kargs): nn.Module.__init__(self) self.dropout = nn.Dropout(dropout) self.hidden_size = hidden_size # size of a single head self.attn_span = attn_span self.adapt_span_enabled = adapt_span_params['adapt_span_enabled'] if self.adapt_span_enabled: self.adaptive_span = AdaptiveSpan(attn_span=attn_span, **adapt_span_params, **kargs)
class SeqAttention(nn.Module): """Sequential self-attention layer. Each token will attend to its previous fixed number of steps. Note that attention doesn't include the current step itself. """ def __init__(self, hidden_size, attn_span, dropout, adapt_span_params, **kargs): nn.Module.__init__(self) self.dropout = nn.Dropout(dropout) self.hidden_size = hidden_size # size of a single head self.attn_span = attn_span self.adapt_span_enabled = adapt_span_params['adapt_span_enabled'] if self.adapt_span_enabled: self.adaptive_span = AdaptiveSpan(attn_span=attn_span, **adapt_span_params, **kargs) def forward(self, query, key, value, key_pe): # query size = B x M x H # key, value sizes = B x (M+L) x H if self.adapt_span_enabled: # [optional] trim out memory to reduce unnecessary computation key, value, key_pe = self.adaptive_span.trim_memory( query, key, value, key_pe) # compute attention from context # B x M (dest) x (M+L) (src) attn_cont = torch.matmul(query, key.transpose(-1, -2)) attn_cont = _unskew(attn_cont) # B x M x L # compute the effect of position embedding attn_pos = torch.matmul(query, key_pe) # B x M x L_pos attn = attn_cont + attn_pos attn = attn / math.sqrt(self.hidden_size) # B x M X L_pos attn = F.softmax(attn, dim=-1) if self.adapt_span_enabled: # trim attention lengths according to the learned span attn = self.adaptive_span(attn) attn = self.dropout(attn) # B x M X L_pos attn_cont = _skew(attn, 0) # B x M X (L+M) out = torch.matmul(attn_cont, value) # B x M x H return out def get_cache_size(self): if self.adapt_span_enabled: return self.adaptive_span.get_cache_size() else: return self.attn_span
def __init__(self, hidden_size, nb_heads, attn_span, dropout, adapt_span_params, pers_mem_params, **kargs): nn.Module.__init__(self) self.dropout = nn.Dropout(dropout) self.hidden_size = hidden_size # size of a single head self.attn_span = attn_span self.adapt_span_enabled = adapt_span_params['adapt_span_enabled'] if self.adapt_span_enabled: self.adaptive_span = AdaptiveSpan(attn_span=attn_span, nb_heads=nb_heads, **adapt_span_params, **kargs) self.persistent_memory = None if pers_mem_params['pers_mem_size'] > 0: self.persistent_memory = PersistentMemory( pers_mem_params['pers_mem_size'], nb_heads, hidden_size, dropout) if self.adapt_span_enabled: self.persistent_memory.adaptive_span = self.adaptive_span