def test_topk_equals_length_attention_masked(self): d_model = 32 n_heads = 4 improved_transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer( ImprovedClusteredAttention(clusters=10, topk=20), d_model, n_heads), d_model, n_heads) for i in range(6) ]) full_transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer(FullAttention(), d_model, n_heads), d_model, n_heads) for i in range(6) ]) full_transformer = full_transformer.to("cuda") improved_transformer = improved_transformer.to("cuda") improved_transformer.load_state_dict(full_transformer.state_dict()) improved_transformer.eval() full_transformer.eval() x = torch.rand(100, 20, d_model).to("cuda") lengths = x.new_full((100, ), 20, dtype=torch.int64) lengths[1] = 5 lengths[10] = 10 length_mask = LengthMask(lengths=lengths, max_len=20) y_full = improved_transformer(x, length_mask=length_mask) y_improved = full_transformer(x, length_mask=length_mask) self.assertLess( torch.max(torch.abs(y_improved[1, :5] - y_full[1, :5])), 1e-4) self.assertLess( torch.max(torch.abs(y_improved[10, :10] - y_full[10, :10])), 1e-4)
def test_compare_with_batch(self): N = 10 L = 42 S = 100 D = 1024 E = D // 4 x = torch.rand(N, L, D) m = torch.rand(N, S, D) tests = [("full", FullAttention, FullAttention, RecurrentFullAttention, RecurrentCrossFullAttention), ("linear", partial(CausalLinearAttention, E), partial(LinearAttention, E), partial(RecurrentLinearAttention, E), partial(RecurrentCrossLinearAttention, E))] for name, a1, a2, a3, a4 in tests: dec = TransformerDecoder([ TransformerDecoderLayer(AttentionLayer(a1(), D, 4), AttentionLayer(a2(), D, 4), D) for i in range(4) ]) rdec = RecurrentTransformerDecoder([ RecurrentTransformerDecoderLayer( RecurrentAttentionLayer(a3(), D, 4), RecurrentCrossAttentionLayer(a4(), D, 4), D) for i in range(4) ]) dec.eval() rdec.eval() rdec.load_state_dict(dec.state_dict()) x_mask = TriangularCausalMask(L) x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64)) m_mask = FullMask(L, S) m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64)) y1 = dec(x, m, x_mask=x_mask, x_length_mask=x_length, memory_mask=m_mask, memory_length_mask=m_length) state = None y2 = [] for i in range(L): y2i, state = rdec(x[:, i], m, memory_length_mask=m_length, state=state) y2.append(y2i) y2 = torch.stack(y2, dim=1) self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
def test_full_attention_forward(self): d_model = 128 n_heads = 4 transformer = TransformerDecoder([ TransformerDecoderLayer( AttentionLayer(FullAttention(), d_model, n_heads), # self AttentionLayer(FullAttention(), d_model, n_heads), # cross d_model) for i in range(6) ]) x = torch.rand(10, 7, d_model) mem = torch.rand(10, 12, d_model) y = transformer(x, mem) self.assertEqual(y.shape, (10, 7, d_model))
def test_feature_map_sharing(self): x = torch.rand(3, 100, 4 * 32) f = Favor.factory(n_dims=64) att = AttentionLayer(LinearAttention(32, f), 4 * 32, 4) attn_mask = FullMask(100) lengths = FullMask(3, 100) y = att(x, x, x, attn_mask, lengths, lengths) y = att(y, y, y, attn_mask, lengths, lengths) y.sum().backward()
def test_clustered_attention_forward(self): d_model = 128 n_heads = 4 transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer(ClusteredAttention(clusters=10), d_model, n_heads), d_model, n_heads) for i in range(6) ]) x = transformer(torch.rand(100, 20, d_model)) self.assertEqual(x.shape, (100, 20, d_model))
def test_full_attention_forward(self): d_model = 128 n_heads = 4 transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer(FullAttention(), d_model, n_heads), d_model, n_heads) for i in range(6) ]) x = transformer(torch.rand(10, 7, d_model)) self.assertEqual(x.shape, (10, 7, d_model))
def test_improved_clustered_attention_forward(self): d_model = 128 n_heads = 4 transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer(ImprovedClusteredAttention(clusters=10, topk=5), d_model, n_heads), d_model, n_heads) for i in range(6) ]) x = torch.rand(100, 20, d_model) y = transformer(x) self.assertEqual(y.shape, (100, 20, d_model))
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): super().__init__() attn_fn = linear_attention(d_model) d_key = d_query = d_model // nhead self.self_attn = AttentionLayer(attn_fn, d_model, nhead, d_key, d_query) self.multihead_attn = AttentionLayer(attn_fn, d_model, nhead, d_key, d_query) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before
def test_full_attention_forward(self): d_model = 128 n_heads = 4 transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer(ClusteredAttention(clusters=10), d_model, n_heads), d_model, n_heads) for i in range(6) ]) transformer = transformer.to("cuda") x = torch.rand(100, 20, d_model).to("cuda") y = transformer(x) self.assertEqual(y.shape, (100, 20, d_model))
def test_topk_equals_length_attention(self): d_model = 32 n_heads = 4 improved_transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer( ImprovedClusteredAttention(clusters=10, topk=20), d_model, n_heads), d_model, n_heads) for i in range(6) ]) full_transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer(FullAttention(), d_model, n_heads), d_model, n_heads) for i in range(6) ]) full_transformer = full_transformer.to("cuda") improved_transformer = improved_transformer.to("cuda") improved_transformer.load_state_dict(full_transformer.state_dict()) improved_transformer.eval() full_transformer.eval() x = torch.rand(100, 20, d_model).to("cuda") y_full = improved_transformer(x) y_improved = full_transformer(x) self.assertLess(torch.max(torch.abs(y_improved - y_full)), 1e-4)
def test_improved_clustered_attention_forward(self): d_model = 128 n_heads = 4 transformer = TransformerEncoder([ TransformerEncoderLayer( AttentionLayer( ReformerAttention( chunk_size=32, rounds=4, bits=8, masked=False, ), d_model, n_heads), d_model, n_heads) for i in range(6) ]) x = torch.rand(12, 128, d_model) y = transformer(x) self.assertEqual(y.shape, (12, 128, d_model))
def __init__( self, inner_attention, embedding_dim: float = 768, ffn_embedding_dim: float = 3072, num_attention_heads: float = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, activation_fn: str = "relu", layer_norm_first: bool = False, ) -> None: super().__init__() # Initialize parameters self.embedding_dim = embedding_dim self.dropout = dropout self.activation_dropout = activation_dropout # Initialize blocks self.activation_fn = utils.get_activation_fn(activation_fn) self.self_attn = AttentionLayer( inner_attention, self.embedding_dim, num_attention_heads, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(self.activation_dropout) self.dropout3 = nn.Dropout(dropout) self.layer_norm_first = layer_norm_first # layer norm associated with the self attention layer self.self_attn_layer_norm = LayerNorm(self.embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) # layer norm associated with the position wise feed-forward NN self.final_layer_norm = LayerNorm(self.embedding_dim)
def __init__(self, n_layer, n_head, d_model, d_ff, dropout=0.1, activation='relu', favor_feature_dims=None): super(FastTransformerDecoder, self).__init__() self.n_layer = n_layer self.n_head = n_head self.d_model = d_model self.d_ff = d_ff self.dropout = dropout self.activation = activation self.favor_feature_dims = 2 * d_model // n_head \ if favor_feature_dims is None else favor_feature_dims att_builder = AttentionBuilder.from_kwargs( query_dimensions=d_model // n_head, feature_map=Favor.factory(n_dims=self.favor_feature_dims)) self.attention_layers = [ AttentionLayer(att_builder.get("causal-linear"), d_model, n_head, positional_encoder=None) for l in range(n_layer) ] self.decoder_layers = nn.ModuleList() for l in range(n_layer): self.decoder_layers.append( TransformerEncoderLayer(attention=self.attention_layers[l], d_model=d_model, d_ff=d_ff, dropout=dropout, activation=activation))
def __init__(self, n_layer, n_head, d_model, d_ff, dropout=0.1, activation='relu', favor_feature_dims=None, spe_module=None, share_pe=False, share_spe_filter=False, use_gated_filter=True, spe_module_params=None ): super(SPEFastTransformerDecoder, self).__init__() self.n_layer = n_layer self.n_head = n_head self.d_model = d_model self.d_ff = d_ff self.dropout = dropout self.activation = activation self.share_pe = share_pe self.use_gated_filter = use_gated_filter self.share_spe_filter = share_spe_filter self.spe_module = spe_module self._spe = None self._spe_filters = None if share_pe: self.spe = self.spe_module( num_heads=n_head, **(spe_module_params or {}) ) self._spe = n_layer * [self.spe] else: self.spe = nn.ModuleList([ self.spe_module( num_heads=n_head, in_features=d_model // n_head, **(spe_module_params or {}) ) for _ in range(n_layer) ]) self._spe = list(self.spe) if share_spe_filter: self.spe_filters = SPEFilter( code_shape=self._spe[0].code_shape, gated=use_gated_filter ) self._spe_filters = n_layer * [self.spe_filters] else: self.spe_filters = nn.ModuleList([ SPEFilter( code_shape=pe.code_shape, gated=use_gated_filter ) for pe in self._spe ]) self._spe_filters = list(self.spe_filters) self.favor_feature_dims = 2 * d_model // n_head \ if favor_feature_dims is None else favor_feature_dims att_builder = AttentionBuilder.from_kwargs( query_dimensions=d_model // n_head, feature_map=Favor.factory(n_dims=self.favor_feature_dims) ) self.attention_layers = [ AttentionLayer( att_builder.get("causal-linear"), d_model, n_head, positional_encoder=self._spe_filters[l].__call__ ) for l in range(n_layer) ] self.decoder_layers = nn.ModuleList() for l in range(n_layer): self.decoder_layers.append( TransformerEncoderLayer( attention=self.attention_layers[l], d_model=d_model, d_ff=d_ff, dropout=dropout, activation=activation ) )
def __init__(self, n_layer, n_head, d_model, d_ff, dropout=0.1, activation='relu', share_pe=False, share_spe_filter=False): super(FastTransformerDecoder, self).__init__() self.n_layer = n_layer self.n_head = n_head self.d_model = d_model self.d_ff = d_ff self.dropout = dropout self.activation = activation self.share_pe = share_pe self.share_spe_filter = share_spe_filter self._spe = None self._spe_filters = None if 'positional_encoder' in self._cfg: make_pe = self._cfg['positional_encoder'].bind(num_heads=n_head) if share_pe: self.spe = make_pe() # Register as a module (only once!) self._spe = n_layer * [self.spe] else: # Make an SPE encoder for each layer and register them all self.spe = nn.ModuleList([make_pe() for _ in range(n_layer)]) self._spe = list(self.spe) make_filter = self._cfg['spe_filter'].bind(spe.SPEFilter) if share_spe_filter: self.spe_filters = make_filter( code_shape=self._spe[0].code_shape) self._spe_filters = n_layer * [self.spe_filters] else: # Make a filter for each layer, register them self.spe_filters = nn.ModuleList([ make_filter(code_shape=pe.code_shape) for pe in self._spe ]) self._spe_filters = list(self.spe_filters) self.attention_layers = [ AttentionLayer( self._cfg['attention'].configure( CausalLinearAttention, query_dimensions=d_model // n_head, feature_map=self._cfg['feature_map'].configure( Favor.factory, n_dims=d_model // n_head)), d_model, n_head, # Do not register as submodules of the layer positional_encoder=(self._spe_filters[l].__call__ if self._spe_filters else None)) for l in range(n_layer) ] self.decoder_layers = nn.ModuleList() for l in range(n_layer): self.decoder_layers.append( TransformerEncoderLayer(attention=self.attention_layers[l], d_model=d_model, d_ff=d_ff, dropout=dropout, activation=activation))