Пример #1
0
 def __init__(self, hidden_size, attn_span,
              dropout, adapt_span_params, **kargs):
     nn.Module.__init__(self)
     self.dropout = nn.Dropout(dropout)
     self.hidden_size = hidden_size # size of a single head
     self.attn_span = attn_span
     self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
     if self.adapt_span_enabled:
         self.adaptive_span = AdaptiveSpan(attn_span=attn_span,
                                           **adapt_span_params, **kargs)
Пример #2
0
class SeqAttention(nn.Module):
    """Sequential self-attention layer.
    Each token will attend to its previous fixed number of steps.
    Note that attention doesn't include the current step itself.
    """
    def __init__(self, hidden_size, attn_span, dropout, adapt_span_params,
                 **kargs):
        nn.Module.__init__(self)
        self.dropout = nn.Dropout(dropout)
        self.hidden_size = hidden_size  # size of a single head
        self.attn_span = attn_span
        self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
        if self.adapt_span_enabled:
            self.adaptive_span = AdaptiveSpan(attn_span=attn_span,
                                              **adapt_span_params,
                                              **kargs)

    def forward(self, query, key, value, key_pe):
        # query size = B x M x H
        # key, value sizes = B x (M+L) x H

        if self.adapt_span_enabled:
            # [optional] trim out memory to reduce unnecessary computation
            key, value, key_pe = self.adaptive_span.trim_memory(
                query, key, value, key_pe)

        # compute attention from context
        # B x M (dest) x (M+L) (src)
        attn_cont = torch.matmul(query, key.transpose(-1, -2))
        attn_cont = _unskew(attn_cont)  # B x M x L

        # compute the effect of position embedding
        attn_pos = torch.matmul(query, key_pe)  # B x M x L_pos
        attn = attn_cont + attn_pos

        attn = attn / math.sqrt(self.hidden_size)  # B x M X L_pos
        attn = F.softmax(attn, dim=-1)

        if self.adapt_span_enabled:
            # trim attention lengths according to the learned span
            attn = self.adaptive_span(attn)
        attn = self.dropout(attn)  # B x M X L_pos

        attn_cont = _skew(attn, 0)  # B x M X (L+M)
        out = torch.matmul(attn_cont, value)  # B x M x H

        return out

    def get_cache_size(self):
        if self.adapt_span_enabled:
            return self.adaptive_span.get_cache_size()
        else:
            return self.attn_span
Пример #3
0
    def __init__(self, hidden_size, nb_heads, attn_span, dropout,
                 adapt_span_params, pers_mem_params, **kargs):
        nn.Module.__init__(self)
        self.dropout = nn.Dropout(dropout)
        self.hidden_size = hidden_size  # size of a single head
        self.attn_span = attn_span
        self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
        if self.adapt_span_enabled:
            self.adaptive_span = AdaptiveSpan(attn_span=attn_span,
                                              nb_heads=nb_heads,
                                              **adapt_span_params,
                                              **kargs)

        self.persistent_memory = None
        if pers_mem_params['pers_mem_size'] > 0:
            self.persistent_memory = PersistentMemory(
                pers_mem_params['pers_mem_size'], nb_heads, hidden_size,
                dropout)
            if self.adapt_span_enabled:
                self.persistent_memory.adaptive_span = self.adaptive_span