Esempio n. 1
0
    def test_dims_with_size(self):
        x = dims(3)
        assert len(x) == 3 and isinstance(x[0], Dim)

        class Foo:
            pass

        y = Foo()
        z, y.x, q = dims(3)
        assert str(z) == "z"
        assert str(y.x) == "d1"
        assert str(q) == "d2"
Esempio n. 2
0
    def test_time_mm_fuse(self):
        i, j, k = dims()
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)

        for _ in range(10):
            r0 = A @ B

        for _ in range(10):
            a = A[i, k]
            b = B[k, j]
            r1 = (a * b).sum(k)

        with measure('pp'):
            for _ in range(10000):
                A @ B
        # magic_trace_stop_indicator()

        with measure('fc'):
            for _ in range(10000):
                (A[i, k] * B[k, j]).sum(k).order(i, j)

        with magic_trace('f.fxt'):
            for _ in range(10000):
                (A[i, k] * B[k, j]).sum(k).order(i, j)

        with magic_trace('p.fxt'):
            for _ in range(10000):
                A @ B

        # magic_trace_stop_indicator()

        assert torch.allclose(r1.order(i, j), r0)
Esempio n. 3
0
 def test_with_dims_split(self):
     a = torch.arange(3 * 12).view(3, 12)
     i, j, k = dims()
     k.size = 4
     r = a[i, [j, k]]
     x = r.order(i, [j, k])
     self.assertTrue(torch.allclose(a, x))
Esempio n. 4
0
    def test_mm_fuse(self):
        i, j, k = dims()
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)

        C = (A[i, k] * B[k, j]).sum(k).order(i, j)
        assert torch.allclose(C, A @ B)
Esempio n. 5
0
    def test_manual_stuff(self):

        A_ = torch.rand(3, 4)
        B_ = torch.rand(4, 5)
        i, j, k = dims()
        A = A_[i, k]
        B = B_[k, j]
        C = (A.expand(j) * B.expand(i)).sum(k)
        self.assertTrue(torch.allclose(C.order(i, j), torch.mm(A_, B_)))
        self.assertTrue(torch.allclose(torch.triu(A_, 0), triu(A_)))

        D_ = torch.randint(0, 3, (6, ))
        d = dims()
        D = D_[d]

        A.index([i], [D]).order(k, d)
Esempio n. 6
0
 def test_mm(self):
     i, j, k, q = dims()
     a = torch.rand(3, 4)
     b = torch.rand(4, 5)
     a_ = a[i, k]
     b_ = b[k, j]
     q.size = 1
     r = (a_.expand(j, q) * b_.expand(i, q)).sum(k).order(q, i, j)
Esempio n. 7
0
 def test_dim_args(self):
     a = dimlists()
     assert isinstance(a, DimList)
     a = dims()
     b = dimlists()
     assert isinstance(a, Dim)
     assert isinstance(b, DimList)
     assert str(a) == 'a'
     a, b = dims(sizes=[3, 4])
     assert a.size == 3
     assert b.size == 4
     a = dims(sizes=[3])
     b = dimlists(sizes=[4])
     assert len(b) == 4
     a = dims()
     b = dimlists(sizes=[[4, 5]])
     assert b[0].size == 4
     assert b[1].size == 5
Esempio n. 8
0
    def test_index_placement(self):
        A = torch.rand(1, 2, 3, 4)

        i, j = dims(sizes=[2, 4])

        a = A[:, i + 0, :, j + 0]
        r = a.order(i, j)

        assert torch.allclose(A.permute(1, 3, 0, 2), r)
Esempio n. 9
0
    def test_functorch(self):
        A = torch.rand(3, 4, 5)
        B = torch.rand(3, 4, 5)
        C = torch.rand(5, 2)

        i, j = dims()

        AA = torch.mm(A[i], C)  # 3, 4, 2
        BB = torch.mm(B[j], C)  # 3, 4, 2
        assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2]
Esempio n. 10
0
    def test_inplace(self):
        # some embeddings table
        embeddings = torch.zeros(10, 3)

        # some sparse updates to the embeddings
        indices = torch.arange(2) + 1
        values = torch.rand(2, 3)

        i, n, f = dims()

        embeddings[indices[i], f] += values[i, f]
Esempio n. 11
0
    def test_index(self):
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)
        i, j, k = dims()

        o, l = dims()
        o.size = 2
        r = A[i, k].index(k, [o, l])
        assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2))
        rr = r.index([o, l], k)
        assert torch.allclose(A, rr.order(i, k))
        z = dims()
        C = torch.arange(2)
        x = A[i, k].index(k, C[z]).order(i, z)
        assert torch.allclose(A[:, 0:2], x)

        C = torch.rand(3, 4, 5)
        ik = dims()
        assert torch.allclose(
            C.index((0, 2), ik).order(ik),
            C.permute(0, 2, 1).reshape(15, 4))
Esempio n. 12
0
    def test_network(self):
        if resnet18 is None:
            self.skipTest('no torchvision')
        rn = resnet18(norm_layer=lambda x: torch.nn.BatchNorm2d(
            x, track_running_stats=False))
        rn.train()
        img = torch.rand(1, 1, 2, 3, 224, 224)
        imgf = img.view(2, 3, 224, 224)

        i, j = dims()
        r = rn(img[i, j])
        r = r.order(i, j).view(2, 1000)
        r2 = rn(imgf)
        assert torch.allclose(r2, r, atol=1e-06)
Esempio n. 13
0
    def test_softmax_split(self):
        a = torch.rand(16)
        g, i = dims(sizes=[2, None])
        a2 = a[[i, g], ]

        m_b, _ = a2.max(i)
        f_b = torch.exp(a2 - m_b)
        l_b = f_b.sum(i)

        m, _ = m_b.max(g)
        c = torch.exp(m_b - m)
        f = (c * f_b).order((i, g))
        l = (c * l_b).sum(g)
        assert torch.allclose(f / l, torch.nn.functional.softmax(a, dim=0))
Esempio n. 14
0
    def test_embed(self):

        embeddings = torch.rand(8, 32)
        ids = torch.tensor([1, 0, 3, 4])

        # slow but Pythonic
        values_ = torch.empty(4, 32)
        for batch in range(ids.size(0)):
            for feature in range(embeddings.size(1)):
                values_[batch, feature] = embeddings[ids[batch], feature]

        # with torchdim, single indexing kernel
        batch, feature = dims(2)
        values = embeddings[ids[batch], feature].order(batch, feature)

        assert torch.allclose(values, values_)
Esempio n. 15
0
 def test_simple(self):
     i, j, k = dims()
     x = torch.rand(3, 4)
     z = x[i, j]
     (z + z + z + z)
     (z.order(i, j))
Esempio n. 16
0
 def test_dir(self):
     i, j = dims(sizes=[3, 3])
     dir(i <= j)
Esempio n. 17
0
 def test_stack(self):
     i, j, d = dims()
     A = torch.rand(4, 5)
     r = stack([A[i, j]], d, j)
Esempio n. 18
0
 def test_eq(self):
     i, j = dims(sizes=[3, 3])
     assert (i == j).sum((i, j)) == 3
Esempio n. 19
0
 def test_mask(self):
     a = torch.rand(5)
     i, j = dims(sizes=[a.size(0), a.size(0)])
     ((i >= j) * a[i]).sum(j).order(i)
Esempio n. 20
0
 def test_order(self):
     i, j = dims()
     A = torch.rand(3, 4, 5)
     assert torch.allclose(A[i].order(1, i), A.permute(2, 0, 1))
Esempio n. 21
0
 def test_max(self):
     ap = torch.rand(2, 3, 2)
     i, j, k = dims()
     a = ap[i, j, k]
     r, i0 = a.max(dim=k)
     self.assertTrue(torch.allclose(r.order(i, j), ap.max(2)[0]))
Esempio n. 22
0
 def test_compare_dims(self):
     i, j = dims()
     i.size = 3
     j.size = 4
     (i < j)
Esempio n. 23
0
 def test_diag(self):
     i = dims()
     A = torch.rand(4, 4)
     (A[i, i])
Esempio n. 24
0
    def test_hello(self):
        A = torch.rand(3, 4)
        B = torch.rand(4, 5)
        i, j, k = dims()

        # r = A[i]*4
        r = (A[i, k] * B[k, j]).sum(k).order(i, j)
        assert torch.allclose(r, A @ B)

        assert A.sum() == A[i].sum((0, i))
        assert A.sum() == A[i].sum((-1, i))

        assert torch.allclose(A.sum(), A[i].sum(0, keepdim=True).sum((0, i)))
        assert torch.allclose(A[i].std(i, True), A.std(0, True))

        assert torch.allclose(A[i, k].max(i)[0].order(k), A.max(0)[0])
        assert torch.allclose(A.sort(1)[0], A[i, k].sort(k)[0].order(i, k))
        # XXX - chunk changes the size of a dimension, has to take a new dimension...
        # assert torch.allclose(A.chunk(2,1)[0], A[i, k].chunk(2, k)[0].order(i, k))
        assert torch.allclose(A[i].renorm(1, i, 7).order(i), A.renorm(1, 0, 7))
        kk = dims()
        # assert torch.allclose( torch.stack([A, A], 1), stack([A[i,k], A[i, k]], kk, k).order(i, kk, k))

        k2 = dims()
        # r = cat((A[i, k], A[i,k]), k, k2)
        # assert torch.allclose(torch.cat([A, A], 1), r.order(i, k2))
        # assert k2.size == 2*k.size

        assert torch.allclose(A.expand(5, -1, -1),
                              A[i, k].expand(j).order(j, i, k))
        z = dims()
        C = torch.arange(2)
        assert torch.allclose(A[:, 0:2], A[i, k].index(k, C[z]).order(i, z))

        o, l = dims()
        o.size = 2
        r = A[i, k].index(k, (o, l))
        assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2))
        rr = r.index((o, l), k)
        assert torch.allclose(A, rr.order(i, k))

        r = i + k - 1
        r2 = torch.arange(3)[:, None] + torch.arange(4)[None, :] - 1
        assert torch.allclose(r.order(i, k), r2)

        # test with ...
        assert torch.allclose(A.T, A[..., k].order(k))

        # test with dimlist
        a_, b_ = dimlists()
        assert torch.allclose(A[i, a_].order(*a_, i), A.T)
        # test with one bound dimlist
        assert torch.allclose(A[:, a_].order(*a_), A.T)
        # test with a dimlist that will end up empty
        assert torch.allclose(A[i, b_, k].order(i, k, *b_), A)
        # test with too few things
        (A[i] + i)
        assert torch.allclose((A[i] + i).order(i),
                              A + torch.arange(3)[:, None])
        # test with too many elements
        try:
            A[1, ..., 1, 1]
            raise NotImplementedError()
        except IndexError:
            pass
        c, d = dims()
        c.size = 2
        assert torch.allclose(A[i, [c, d]].order(i, c, d), A.view(3, 2, 2))

        assert torch.allclose(A[c + 1, c + 0].order(c), A[torch.arange(2) + 1,
                                                          torch.arange(2)])
        try:
            A[..., 3, ...]
            raise NotImplementedError()
        except DimensionBindError:
            pass

        C = torch.rand(4, 7)
        c_, x, y, z = dims()

        a, b, c = C.split((3, 3, 1), dim=1)
        s = dims()
        ref = C.split((3, 3, 1), dim=1)
        t = C[s, c_].split((x, y, z), dim=c_)
        for a, b, d in zip(ref, t, (x, y, z)):
            assert torch.allclose(a, b.order(s, d))

        D = torch.rand(3, 4, 5)
        assert torch.allclose(
            D.transpose(0, 1).flatten(1, 2), D[i, k, j].order((i, j)).order(k))

        r = [id(x) for x in torch.rand_like(A[i, k]).dims]
        assert id(i) in r and id(k) in r
        r = [id(x) for x in torch.nn.functional.dropout(A[i, k]).dims]
        assert id(i) in r and id(k) in r
Esempio n. 25
0
 def test_seg(self):
     A = torch.rand(3, 4)
     i, k = dims()
     i.size = 4
     k.size = 3
     r = i + k - 1
Esempio n. 26
0
 def forward(self, input):
     ci, co = dims()
     b = dimlists()
     result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co]
     return result.order(b, co)
Esempio n. 27
0
 def test_expand(self):
     A = torch.rand(3, 4)
     i = dims()
     assert list(A[i].expand(2, 4).order(i).size()) == [3, 2, 4]
Esempio n. 28
0
    def forward(
        self,
        hidden_states,
        past_key_value=None,
    ):
        # first run the encoding linear layers for q, k, v normally
        # the meaning of a linear layer is well understood, so no need to use explicit dimensions
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)

        # introduce values that represent each dimension. dimensions are 'first class'
        # becaue they are actual python values introduced here
        batch, query_sequence, key_sequence, heads, features = dims()
        heads.size = self.num_attention_heads

        # bind the positional dimensions in k, q, and v against
        # our values. the sizes of each dimension are determined by this binding
        # and when a dimension is used twice (e.g. batch), its size against both
        # uses is checked for consistency.
        # The group (heads, features) splits apart a single positional dimension
        # into two dimensions. Since heads.size*features.size == q.size(2)
        # and we specified heads.size, features.size is inferred here.
        q = q[batch, query_sequence, [heads, features]]
        k = k[batch, key_sequence, [heads, features]]
        v = v[batch, key_sequence, [heads, features]]

        # this option allows the model to attend to not just the elements of the current sequence
        # but the previouse elements as well as additional tokens.
        if past_key_value is not None:
            extended_key_sequence = dims()
            key_past = past_key_value[0][batch, heads, key_sequence, features]
            value_past = past_key_value[1][batch, heads, key_sequence,
                                           features]
            # cat introduces a new dimension exteneded_key_sequence, becuase it is twice as long
            # as the original key_sequence
            k = cat([key_past, k], key_sequence, extended_key_sequence)
            v = cat([value_past, v], key_sequence, extended_key_sequence)
            # for the rest of the function, we will just use extended_key_sequence in lieu of
            # key_sequence
            key_sequence = extended_key_sequence

        # Take the dot product between "query" and "key" to get the raw attention scores.
        # The actual outer-product and summation are explicitly represented here,
        # and like einsum, will be pattern matched to an efficient matrix multiply op.
        attention_scores = (q * k).sum(features) / math.sqrt(features.size)

        # relative positional embeddings gave a unique embedding based on the distance between
        # key and value tokens in the sequence, e.g.
        #  0  1  2  3
        # -1  0  1  2
        # -2 -1  0  1
        # -3 -2 -1  0
        if self.position_embedding_type is not None:
            # the value of a dimension object when used as a tensor is the indices along its dimension
            # so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence)
            # with the distance between them
            distance = query_sequence - key_sequence

            assert key_sequence.size <= self.max_position_embeddings

            # we can then use that as an indirect index into the embedding table values to look up the features for that index
            # this is just a `gather` primitive op. The resulting tensor will
            # have all the dimensions of embeddeding_idx (query_sequence x key_sequence),
            # plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`).
            # this form of indirect indexing is more strainghtforward than either advanced indexing or torch.gather which both
            # have a lot of dependencies on the positions of indexing tensors.

            positional_embedding = self.distance_embedding.weight[
                self.max_position_embeddings - 1 + distance, features]

            if self.position_embedding_type == "relative_key":
                # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators
                # eventhough they are degenerate matmuls
                relative_position_scores = (q *
                                            positional_embedding).sum(features)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = (
                    q * positional_embedding).sum(features)
                relative_position_scores_key = (
                    k * positional_embedding).sum(features)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_probs = attention_scores
        # Normalize the attention scores to probabilities.
        attention_probs = softmax(attention_scores, dim=key_sequence)
        # # This is actually dropping out entire tokens to attend to, which might
        # # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = torch.nn.functional.dropout(attention_probs,
                                                      p=self.dropout_prob)

        # similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear
        # we are weighting the values v across all keys with the attention scores.
        context_layer = (attention_probs * v).sum(key_sequence)

        # finally, we convert back to a standard tensor by describing the layout of dimensions.
        # working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one.
        return context_layer.order(batch, query_sequence, [heads, features])
Esempio n. 29
0
def triu(A):
    i, j = dims()
    a = A[i, j]
    zero = torch.tensor(0, dtype=torch.float)  # XXX - torch.where is janky...
    return torch.where(i <= j, a, zero).order(i, j)