def forward(self, x, meta): ar = torch.arange(self.seq_len).float().type_as(x) relative_pos = self.positional_encoder(ar).unsqueeze(0).repeat( [x.shape[0], 1, 1]) att_mask = TriangularCausalMask(self.seq_len, self.get_device()) seq_mask = (x[..., 0] > 0).sum(dim=1).long() seq_mask = LengthMask(seq_mask, max_len=self.seq_len) x = self.model_projection(x) * math.sqrt(self.project_dimension) x = x + relative_pos regress_embeddings = self.transformer(x, att_mask, seq_mask) pred = self.output_projection(regress_embeddings) return pred
def forward(self, x, lengths=None, attn_kwargs=None): attn_mask = TriangularCausalMask(x.size(1), device=x.device) if lengths is not None: length_mask = LengthMask(lengths, device=x.device) else: length_mask = None attn_kwargs = dict(attn_kwargs) if attn_kwargs else {} out = x for l in range(self.n_layer): # print (out.size()) out = self.decoder_layers[l](out, attn_mask=attn_mask, length_mask=length_mask, attn_kwargs=attn_kwargs) return out
def extract_features(self, x, padding_mask=None): if padding_mask is not None: x[padding_mask] = 0 x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) x += x_conv if not self.layer_norm_first: x = self.layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) layer_results = [] length_mask = None if padding_mask is not None and torch.any(padding_mask): input_lengths = (~padding_mask).sum(-1) length_mask = LengthMask(lengths=input_lengths.long()) n_frozen = self.n_freeze attention_mask = None for i, layer in enumerate(self.layers): if i >= self.n_active: break dropout_probability = np.random.random() if i <= n_frozen and n_frozen > 0: x = layer(x, length_mask=length_mask) layer_results.append(x) else: if not self.training or (dropout_probability > self.layerdrop): if self.gradient_checkpoint: x = checkpoint(layer, x, attention_mask, length_mask) else: x = layer(x, length_mask=length_mask) layer_results.append(x) return x
def __init__(self, num_feats, num_output_points, lstm_layers, n_layers, n_heads, hidden_dim, ff_dim, tf_depth=3, dropout=0.15): super(SetTransformer, self).__init__() self.num_feats = num_feats self.k = num_output_points def dup(x): return (x, x) if type(x) == int else x self.lstm_layers = lstm_layers self.n_layers = dup(n_layers) self.n_heads = dup(n_heads) self.hidden_dim = dup(hidden_dim) self.ff_dim = dup(ff_dim) self.tf_depth = dup(tf_depth) self.d_model = [self.hidden_dim[i] * self.n_heads[i] for i in [0, 1]] encoder_builder_pre = TransformerEncoderBuilder.from_kwargs( n_layers=self.n_layers[0], n_heads=self.n_heads[0], query_dimensions=self.hidden_dim[0], value_dimensions=self.hidden_dim[0], feed_forward_dimensions=self.ff_dim[0], attention_type='linear', dropout=dropout) encoder_builder_post = TransformerEncoderBuilder.from_kwargs( n_layers=self.n_layers[1], n_heads=self.n_heads[1], query_dimensions=self.hidden_dim[1], value_dimensions=self.hidden_dim[1], feed_forward_dimensions=self.ff_dim[1], attention_type='linear', dropout=dropout) self.seeds = nn.Parameter(torch.normal(0, 1, (self.k, self.d_model[0]))) self.encoder_pre = encoder_builder_pre.get() self.encoder_post = encoder_builder_post.get() self.initial_ff = nn.Linear(self.num_feats, self.d_model[0]) # self.pos_encoding = PositionalEncoding(self.d_model[0], dropout=dropout) self.lstm = nn.LSTM(self.d_model[0], self.d_model[0], 2, batch_first=True, bidirectional=False) self.attn_pooling = AttentionLayer(LinearAttention(self.d_model[0]), self.d_model[0], self.n_heads[0]) self.final_ff = nn.Linear(self.d_model[1], self.num_feats) # init masks to meaningless values, doesn't matter what. these are all empty anyway. self.mask = FullMask(N=self.k, M=5) self.kl_mask = LengthMask(torch.ones(5) * 5) self.ql_mask = LengthMask(torch.ones(5) * self.k)
def test_correctness_masked(self): N = 12 H = 6 L = 1000 S = 1000 E = 32 k = 32 C = 100 I = 10 B = 32 for exp in range(30): N = np.random.randint(1, 6) H = np.random.randint(1, 8) C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VERBOSE_TESTS", ""): print(("Testing Masked: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device) K = torch.randn(N, H, S, E).to(self.device) lengths = np.random.randint(C, L + 1, N) lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device) lengths[0] = L query_lengths = LengthMask(lengths, max_len=L) groups, counts = cluster_queries(Q, lengths, C, I, B) sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1) sorted_rev_gi = torch.argsort(sorted_gi, dim=-1) q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L q_flat = (sorted_gi + q_offset).reshape(-1) s_queries = Q.reshape(-1, E).index_select(0, q_flat).view(N, H, L, E) Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L), 1 / counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device)) weights_sorted = torch.rand(N, H, L, k).to(self.device).requires_grad_(True) weights_sorted.retain_grad() q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1) weights = torch.clone( weights_sorted.reshape(-1, k).index_select(0, q_rev_flat).view( N, H, L, k)) weights.retain_grad() values = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) self._zero_grad(weights, values) values_selected = values[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), topk_broadcast.long()] output = (weights.unsqueeze(-1) * values_selected).sum(-2) output = output * query_lengths.float_matrix[:, None, :, None] output.sum().backward() grad = [torch.clone(weights.grad), torch.clone(values.grad)] self._zero_grad(weights_sorted, values) self._zero_grad(weights, values) output_hat_sorted = clustered_sparse_weighted_average( weights_sorted, values, topk, groups, counts) output_hat = output_hat_sorted.reshape(-1, E).index_select( 0, q_rev_flat).view(N, H, L, E) self.assertLess(torch.abs(output - output_hat).max(), 1e-4) output_hat.sum().backward() weights_grad_sorted = torch.clone(weights_sorted.grad) weights_grad = torch.clone( weights_grad_sorted.reshape(-1, k).index_select( 0, q_rev_flat).view(N, H, L, k)) grad_hat = [weights_grad, torch.clone(values.grad)] for g1, g2 in zip(grad, grad_hat): self.assertLess(torch.abs(g1 - g2).max(), 1e-3)
def test_masked(self): att = LocalAttention(16, softmax_temp=1) q, k, v, m1, m2, m3 = self._get_inputs(N=3, L=64, S=64, D=32) m2 = m3 = LengthMask(torch.tensor([8, 16, 64], dtype=torch.long)) v_hat = att(q, k, v, m1, m2, m3) self.assertFalse(torch.any(torch.isnan(v_hat)))
def test_masked_simple_grad(self): N = 4 H = 2 L = 100 E = 64 S = 100 k = 32 C = 5 I = 5 B = 16 for i in range(30): C = np.random.randint(10, 500) L = np.random.randint(C, 2000) E = np.random.randint(10, 128) S = np.random.randint(100, 1000) k = np.random.randint(10, 64) if os.getenv("VERBOSE_TESTS", ""): print(("Testing Masked: N H L S E C k: " "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k)) Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True) K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True) lengths = np.random.randint(C, L+1, N) lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device) query_lengths = LengthMask( lengths, max_len=L ) groups, counts = cluster_queries(Q, lengths, C, I, B) Q_grouped = aggregate(Q, groups, 1/counts.float()) QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K) _, topk = torch.topk(QK, k, dim=-1) topk = topk.contiguous() topk_broadcast = broadcast( topk.float(), groups, torch.ones_like(counts, dtype=torch.float32), torch.zeros((N, H, L, k), device=Q.device) ) self._zero_grad(Q, K) QK_full = torch.einsum("nhle,nhse->nhls", Q, K) QK_selected = QK_full[ torch.arange(N).view(N, 1, 1, 1).to(self.device), torch.arange(H).view(1, H, 1, 1).to(self.device), torch.arange(L).view(1, 1, L, 1).to(self.device), topk_broadcast.long() ] QK_selected = QK_selected * query_lengths.float_matrix[:, None, :, None] QK_selected.sum().backward() grad = [torch.clone(Q.grad), torch.clone(K.grad)] self._zero_grad(Q, K) QK_selected_hat = sparse_product( Q, K, groups, topk, counts, lengths ) QK_selected_hat.sum().backward() grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)] self.assertLess( torch.abs(QK_selected - QK_selected_hat).max(), 1e-4 ) for g1, g2 in zip(grad, grad_hat): self.assertLess( torch.abs(g1 - g2).max(), 1e-3 )