Exemplo n.º 1
0
    def forward(self, x, meta):
        ar = torch.arange(self.seq_len).float().type_as(x)
        relative_pos = self.positional_encoder(ar).unsqueeze(0).repeat(
            [x.shape[0], 1, 1])

        att_mask = TriangularCausalMask(self.seq_len, self.get_device())
        seq_mask = (x[..., 0] > 0).sum(dim=1).long()
        seq_mask = LengthMask(seq_mask, max_len=self.seq_len)

        x = self.model_projection(x) * math.sqrt(self.project_dimension)

        x = x + relative_pos
        regress_embeddings = self.transformer(x, att_mask, seq_mask)
        pred = self.output_projection(regress_embeddings)

        return pred
Exemplo n.º 2
0
    def forward(self, x, lengths=None, attn_kwargs=None):
        attn_mask = TriangularCausalMask(x.size(1), device=x.device)

        if lengths is not None:
            length_mask = LengthMask(lengths, device=x.device)
        else:
            length_mask = None

        attn_kwargs = dict(attn_kwargs) if attn_kwargs else {}

        out = x
        for l in range(self.n_layer):
            # print (out.size())
            out = self.decoder_layers[l](out,
                                         attn_mask=attn_mask,
                                         length_mask=length_mask,
                                         attn_kwargs=attn_kwargs)

        return out
Exemplo n.º 3
0
    def extract_features(self, x, padding_mask=None):

        if padding_mask is not None:
            x[padding_mask] = 0

        x_conv = self.pos_conv(x.transpose(1, 2))
        x_conv = x_conv.transpose(1, 2)
        x += x_conv

        if not self.layer_norm_first:
            x = self.layer_norm(x)

        x = F.dropout(x, p=self.dropout, training=self.training)

        layer_results = []
        length_mask = None
        if padding_mask is not None and torch.any(padding_mask):
            input_lengths = (~padding_mask).sum(-1)
            length_mask = LengthMask(lengths=input_lengths.long())

        n_frozen = self.n_freeze
        attention_mask = None
        for i, layer in enumerate(self.layers):
            if i >= self.n_active:
                break
            dropout_probability = np.random.random()
            if i <= n_frozen and n_frozen > 0:
                x = layer(x, length_mask=length_mask)
                layer_results.append(x)
            else:
                if not self.training or (dropout_probability > self.layerdrop):
                    if self.gradient_checkpoint:
                        x = checkpoint(layer, x, attention_mask, length_mask)
                    else:
                        x = layer(x, length_mask=length_mask)
                    layer_results.append(x)
        return x
Exemplo n.º 4
0
    def __init__(self,
                 num_feats,
                 num_output_points,
                 lstm_layers,
                 n_layers,
                 n_heads,
                 hidden_dim,
                 ff_dim,
                 tf_depth=3,
                 dropout=0.15):
        super(SetTransformer, self).__init__()

        self.num_feats = num_feats
        self.k = num_output_points

        def dup(x):
            return (x, x) if type(x) == int else x

        self.lstm_layers = lstm_layers
        self.n_layers = dup(n_layers)
        self.n_heads = dup(n_heads)
        self.hidden_dim = dup(hidden_dim)
        self.ff_dim = dup(ff_dim)
        self.tf_depth = dup(tf_depth)

        self.d_model = [self.hidden_dim[i] * self.n_heads[i] for i in [0, 1]]

        encoder_builder_pre = TransformerEncoderBuilder.from_kwargs(
            n_layers=self.n_layers[0],
            n_heads=self.n_heads[0],
            query_dimensions=self.hidden_dim[0],
            value_dimensions=self.hidden_dim[0],
            feed_forward_dimensions=self.ff_dim[0],
            attention_type='linear',
            dropout=dropout)

        encoder_builder_post = TransformerEncoderBuilder.from_kwargs(
            n_layers=self.n_layers[1],
            n_heads=self.n_heads[1],
            query_dimensions=self.hidden_dim[1],
            value_dimensions=self.hidden_dim[1],
            feed_forward_dimensions=self.ff_dim[1],
            attention_type='linear',
            dropout=dropout)

        self.seeds = nn.Parameter(torch.normal(0, 1,
                                               (self.k, self.d_model[0])))
        self.encoder_pre = encoder_builder_pre.get()
        self.encoder_post = encoder_builder_post.get()

        self.initial_ff = nn.Linear(self.num_feats, self.d_model[0])
        # self.pos_encoding = PositionalEncoding(self.d_model[0], dropout=dropout)
        self.lstm = nn.LSTM(self.d_model[0],
                            self.d_model[0],
                            2,
                            batch_first=True,
                            bidirectional=False)
        self.attn_pooling = AttentionLayer(LinearAttention(self.d_model[0]),
                                           self.d_model[0], self.n_heads[0])
        self.final_ff = nn.Linear(self.d_model[1], self.num_feats)

        # init masks to meaningless values, doesn't matter what. these are all empty anyway.
        self.mask = FullMask(N=self.k, M=5)
        self.kl_mask = LengthMask(torch.ones(5) * 5)
        self.ql_mask = LengthMask(torch.ones(5) * self.k)
    def test_correctness_masked(self):
        N = 12
        H = 6
        L = 1000
        S = 1000
        E = 32
        k = 32
        C = 100
        I = 10
        B = 32
        for exp in range(30):
            N = np.random.randint(1, 6)
            H = np.random.randint(1, 8)
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)

            if os.getenv("VERBOSE_TESTS", ""):
                print(("Testing Masked: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device)
            K = torch.randn(N, H, S, E).to(self.device)

            lengths = np.random.randint(C, L + 1, N)
            lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device)
            lengths[0] = L
            query_lengths = LengthMask(lengths, max_len=L)
            groups, counts = cluster_queries(Q, lengths, C, I, B)

            sorted_g, sorted_gi = torch.sort(groups.view(N * H, -1), dim=-1)
            sorted_rev_gi = torch.argsort(sorted_gi, dim=-1)

            q_offset = torch.arange(N * H, device=Q.device).unsqueeze(-1) * L
            q_flat = (sorted_gi + q_offset).reshape(-1)
            s_queries = Q.reshape(-1, E).index_select(0,
                                                      q_flat).view(N, H, L, E)
            Q_grouped = aggregate(s_queries, sorted_g.view(N, H, L),
                                  1 / counts.float())

            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()
            topk_broadcast = broadcast(
                topk.float(), groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device))
            weights_sorted = torch.rand(N, H, L,
                                        k).to(self.device).requires_grad_(True)
            weights_sorted.retain_grad()

            q_rev_flat = (sorted_rev_gi + q_offset).reshape(-1)
            weights = torch.clone(
                weights_sorted.reshape(-1, k).index_select(0, q_rev_flat).view(
                    N, H, L, k))
            weights.retain_grad()

            values = torch.randn(N, H, S,
                                 E).to(self.device).requires_grad_(True)
            self._zero_grad(weights, values)
            values_selected = values[
                torch.arange(N).view(N, 1, 1, 1).to(self.device),
                torch.arange(H).view(1, H, 1, 1).to(self.device),
                topk_broadcast.long()]
            output = (weights.unsqueeze(-1) * values_selected).sum(-2)
            output = output * query_lengths.float_matrix[:, None, :, None]
            output.sum().backward()
            grad = [torch.clone(weights.grad), torch.clone(values.grad)]

            self._zero_grad(weights_sorted, values)
            self._zero_grad(weights, values)

            output_hat_sorted = clustered_sparse_weighted_average(
                weights_sorted, values, topk, groups, counts)
            output_hat = output_hat_sorted.reshape(-1, E).index_select(
                0, q_rev_flat).view(N, H, L, E)

            self.assertLess(torch.abs(output - output_hat).max(), 1e-4)
            output_hat.sum().backward()
            weights_grad_sorted = torch.clone(weights_sorted.grad)
            weights_grad = torch.clone(
                weights_grad_sorted.reshape(-1, k).index_select(
                    0, q_rev_flat).view(N, H, L, k))
            grad_hat = [weights_grad, torch.clone(values.grad)]
            for g1, g2 in zip(grad, grad_hat):
                self.assertLess(torch.abs(g1 - g2).max(), 1e-3)
 def test_masked(self):
     att = LocalAttention(16, softmax_temp=1)
     q, k, v, m1, m2, m3 = self._get_inputs(N=3, L=64, S=64, D=32)
     m2 = m3 = LengthMask(torch.tensor([8, 16, 64], dtype=torch.long))
     v_hat = att(q, k, v, m1, m2, m3)
     self.assertFalse(torch.any(torch.isnan(v_hat)))
    def test_masked_simple_grad(self):
        N = 4
        H = 2
        L = 100
        E = 64
        S = 100
        k = 32
        C = 5
        I = 5
        B = 16

        for i in range(30):
            C = np.random.randint(10, 500)
            L = np.random.randint(C, 2000)
            E = np.random.randint(10, 128)
            S = np.random.randint(100, 1000)
            k = np.random.randint(10, 64)

            if os.getenv("VERBOSE_TESTS", ""):
                print(("Testing Masked: N H L S E C k: "
                       "{} {} {} {} {} {} {}").format(N, H, L, S, E, C, k))

            Q = torch.randn(N, H, L, E).to(self.device).requires_grad_(True)
            K = torch.randn(N, H, S, E).to(self.device).requires_grad_(True)

            lengths = np.random.randint(C, L+1, N)
            lengths = torch.tensor(lengths, dtype=torch.int32).to(self.device)
            query_lengths = LengthMask(
                lengths,
                max_len=L
            )
            groups, counts = cluster_queries(Q, lengths, C, I, B)
            Q_grouped = aggregate(Q, groups, 1/counts.float())
            QK = torch.einsum("nhle,nhse->nhls", Q_grouped, K)
            _, topk = torch.topk(QK, k, dim=-1)
            topk = topk.contiguous()
            topk_broadcast = broadcast(
                topk.float(),
                groups,
                torch.ones_like(counts, dtype=torch.float32),
                torch.zeros((N, H, L, k), device=Q.device)
            )

            self._zero_grad(Q, K)
            QK_full = torch.einsum("nhle,nhse->nhls", Q, K)
            QK_selected = QK_full[
                torch.arange(N).view(N, 1, 1, 1).to(self.device),
                torch.arange(H).view(1, H, 1, 1).to(self.device),
                torch.arange(L).view(1, 1, L, 1).to(self.device),
                topk_broadcast.long()
            ]

            QK_selected = QK_selected * query_lengths.float_matrix[:, None, :, None]
            QK_selected.sum().backward()
            grad = [torch.clone(Q.grad), torch.clone(K.grad)]

            self._zero_grad(Q, K)
            QK_selected_hat = sparse_product(
                Q, K, groups, topk, counts, lengths
            )
            QK_selected_hat.sum().backward()
            grad_hat = [torch.clone(Q.grad), torch.clone(K.grad)]
            self.assertLess(
                torch.abs(QK_selected - QK_selected_hat).max(),
                1e-4
            )
            for g1, g2 in zip(grad, grad_hat):
                self.assertLess(
                    torch.abs(g1 - g2).max(),
                    1e-3
                )