def test_dims_with_size(self): x = dims(3) assert len(x) == 3 and isinstance(x[0], Dim) class Foo: pass y = Foo() z, y.x, q = dims(3) assert str(z) == "z" assert str(y.x) == "d1" assert str(q) == "d2"
def test_time_mm_fuse(self): i, j, k = dims() A = torch.rand(3, 4) B = torch.rand(4, 5) for _ in range(10): r0 = A @ B for _ in range(10): a = A[i, k] b = B[k, j] r1 = (a * b).sum(k) with measure('pp'): for _ in range(10000): A @ B # magic_trace_stop_indicator() with measure('fc'): for _ in range(10000): (A[i, k] * B[k, j]).sum(k).order(i, j) with magic_trace('f.fxt'): for _ in range(10000): (A[i, k] * B[k, j]).sum(k).order(i, j) with magic_trace('p.fxt'): for _ in range(10000): A @ B # magic_trace_stop_indicator() assert torch.allclose(r1.order(i, j), r0)
def test_with_dims_split(self): a = torch.arange(3 * 12).view(3, 12) i, j, k = dims() k.size = 4 r = a[i, [j, k]] x = r.order(i, [j, k]) self.assertTrue(torch.allclose(a, x))
def test_mm_fuse(self): i, j, k = dims() A = torch.rand(3, 4) B = torch.rand(4, 5) C = (A[i, k] * B[k, j]).sum(k).order(i, j) assert torch.allclose(C, A @ B)
def test_manual_stuff(self): A_ = torch.rand(3, 4) B_ = torch.rand(4, 5) i, j, k = dims() A = A_[i, k] B = B_[k, j] C = (A.expand(j) * B.expand(i)).sum(k) self.assertTrue(torch.allclose(C.order(i, j), torch.mm(A_, B_))) self.assertTrue(torch.allclose(torch.triu(A_, 0), triu(A_))) D_ = torch.randint(0, 3, (6, )) d = dims() D = D_[d] A.index([i], [D]).order(k, d)
def test_mm(self): i, j, k, q = dims() a = torch.rand(3, 4) b = torch.rand(4, 5) a_ = a[i, k] b_ = b[k, j] q.size = 1 r = (a_.expand(j, q) * b_.expand(i, q)).sum(k).order(q, i, j)
def test_dim_args(self): a = dimlists() assert isinstance(a, DimList) a = dims() b = dimlists() assert isinstance(a, Dim) assert isinstance(b, DimList) assert str(a) == 'a' a, b = dims(sizes=[3, 4]) assert a.size == 3 assert b.size == 4 a = dims(sizes=[3]) b = dimlists(sizes=[4]) assert len(b) == 4 a = dims() b = dimlists(sizes=[[4, 5]]) assert b[0].size == 4 assert b[1].size == 5
def test_index_placement(self): A = torch.rand(1, 2, 3, 4) i, j = dims(sizes=[2, 4]) a = A[:, i + 0, :, j + 0] r = a.order(i, j) assert torch.allclose(A.permute(1, 3, 0, 2), r)
def test_functorch(self): A = torch.rand(3, 4, 5) B = torch.rand(3, 4, 5) C = torch.rand(5, 2) i, j = dims() AA = torch.mm(A[i], C) # 3, 4, 2 BB = torch.mm(B[j], C) # 3, 4, 2 assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2]
def test_inplace(self): # some embeddings table embeddings = torch.zeros(10, 3) # some sparse updates to the embeddings indices = torch.arange(2) + 1 values = torch.rand(2, 3) i, n, f = dims() embeddings[indices[i], f] += values[i, f]
def test_index(self): A = torch.rand(3, 4) B = torch.rand(4, 5) i, j, k = dims() o, l = dims() o.size = 2 r = A[i, k].index(k, [o, l]) assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2)) rr = r.index([o, l], k) assert torch.allclose(A, rr.order(i, k)) z = dims() C = torch.arange(2) x = A[i, k].index(k, C[z]).order(i, z) assert torch.allclose(A[:, 0:2], x) C = torch.rand(3, 4, 5) ik = dims() assert torch.allclose( C.index((0, 2), ik).order(ik), C.permute(0, 2, 1).reshape(15, 4))
def test_network(self): if resnet18 is None: self.skipTest('no torchvision') rn = resnet18(norm_layer=lambda x: torch.nn.BatchNorm2d( x, track_running_stats=False)) rn.train() img = torch.rand(1, 1, 2, 3, 224, 224) imgf = img.view(2, 3, 224, 224) i, j = dims() r = rn(img[i, j]) r = r.order(i, j).view(2, 1000) r2 = rn(imgf) assert torch.allclose(r2, r, atol=1e-06)
def test_softmax_split(self): a = torch.rand(16) g, i = dims(sizes=[2, None]) a2 = a[[i, g], ] m_b, _ = a2.max(i) f_b = torch.exp(a2 - m_b) l_b = f_b.sum(i) m, _ = m_b.max(g) c = torch.exp(m_b - m) f = (c * f_b).order((i, g)) l = (c * l_b).sum(g) assert torch.allclose(f / l, torch.nn.functional.softmax(a, dim=0))
def test_embed(self): embeddings = torch.rand(8, 32) ids = torch.tensor([1, 0, 3, 4]) # slow but Pythonic values_ = torch.empty(4, 32) for batch in range(ids.size(0)): for feature in range(embeddings.size(1)): values_[batch, feature] = embeddings[ids[batch], feature] # with torchdim, single indexing kernel batch, feature = dims(2) values = embeddings[ids[batch], feature].order(batch, feature) assert torch.allclose(values, values_)
def test_simple(self): i, j, k = dims() x = torch.rand(3, 4) z = x[i, j] (z + z + z + z) (z.order(i, j))
def test_dir(self): i, j = dims(sizes=[3, 3]) dir(i <= j)
def test_stack(self): i, j, d = dims() A = torch.rand(4, 5) r = stack([A[i, j]], d, j)
def test_eq(self): i, j = dims(sizes=[3, 3]) assert (i == j).sum((i, j)) == 3
def test_mask(self): a = torch.rand(5) i, j = dims(sizes=[a.size(0), a.size(0)]) ((i >= j) * a[i]).sum(j).order(i)
def test_order(self): i, j = dims() A = torch.rand(3, 4, 5) assert torch.allclose(A[i].order(1, i), A.permute(2, 0, 1))
def test_max(self): ap = torch.rand(2, 3, 2) i, j, k = dims() a = ap[i, j, k] r, i0 = a.max(dim=k) self.assertTrue(torch.allclose(r.order(i, j), ap.max(2)[0]))
def test_compare_dims(self): i, j = dims() i.size = 3 j.size = 4 (i < j)
def test_diag(self): i = dims() A = torch.rand(4, 4) (A[i, i])
def test_hello(self): A = torch.rand(3, 4) B = torch.rand(4, 5) i, j, k = dims() # r = A[i]*4 r = (A[i, k] * B[k, j]).sum(k).order(i, j) assert torch.allclose(r, A @ B) assert A.sum() == A[i].sum((0, i)) assert A.sum() == A[i].sum((-1, i)) assert torch.allclose(A.sum(), A[i].sum(0, keepdim=True).sum((0, i))) assert torch.allclose(A[i].std(i, True), A.std(0, True)) assert torch.allclose(A[i, k].max(i)[0].order(k), A.max(0)[0]) assert torch.allclose(A.sort(1)[0], A[i, k].sort(k)[0].order(i, k)) # XXX - chunk changes the size of a dimension, has to take a new dimension... # assert torch.allclose(A.chunk(2,1)[0], A[i, k].chunk(2, k)[0].order(i, k)) assert torch.allclose(A[i].renorm(1, i, 7).order(i), A.renorm(1, 0, 7)) kk = dims() # assert torch.allclose( torch.stack([A, A], 1), stack([A[i,k], A[i, k]], kk, k).order(i, kk, k)) k2 = dims() # r = cat((A[i, k], A[i,k]), k, k2) # assert torch.allclose(torch.cat([A, A], 1), r.order(i, k2)) # assert k2.size == 2*k.size assert torch.allclose(A.expand(5, -1, -1), A[i, k].expand(j).order(j, i, k)) z = dims() C = torch.arange(2) assert torch.allclose(A[:, 0:2], A[i, k].index(k, C[z]).order(i, z)) o, l = dims() o.size = 2 r = A[i, k].index(k, (o, l)) assert torch.allclose(r.order(i, o, l), A.view(-1, 2, 2)) rr = r.index((o, l), k) assert torch.allclose(A, rr.order(i, k)) r = i + k - 1 r2 = torch.arange(3)[:, None] + torch.arange(4)[None, :] - 1 assert torch.allclose(r.order(i, k), r2) # test with ... assert torch.allclose(A.T, A[..., k].order(k)) # test with dimlist a_, b_ = dimlists() assert torch.allclose(A[i, a_].order(*a_, i), A.T) # test with one bound dimlist assert torch.allclose(A[:, a_].order(*a_), A.T) # test with a dimlist that will end up empty assert torch.allclose(A[i, b_, k].order(i, k, *b_), A) # test with too few things (A[i] + i) assert torch.allclose((A[i] + i).order(i), A + torch.arange(3)[:, None]) # test with too many elements try: A[1, ..., 1, 1] raise NotImplementedError() except IndexError: pass c, d = dims() c.size = 2 assert torch.allclose(A[i, [c, d]].order(i, c, d), A.view(3, 2, 2)) assert torch.allclose(A[c + 1, c + 0].order(c), A[torch.arange(2) + 1, torch.arange(2)]) try: A[..., 3, ...] raise NotImplementedError() except DimensionBindError: pass C = torch.rand(4, 7) c_, x, y, z = dims() a, b, c = C.split((3, 3, 1), dim=1) s = dims() ref = C.split((3, 3, 1), dim=1) t = C[s, c_].split((x, y, z), dim=c_) for a, b, d in zip(ref, t, (x, y, z)): assert torch.allclose(a, b.order(s, d)) D = torch.rand(3, 4, 5) assert torch.allclose( D.transpose(0, 1).flatten(1, 2), D[i, k, j].order((i, j)).order(k)) r = [id(x) for x in torch.rand_like(A[i, k]).dims] assert id(i) in r and id(k) in r r = [id(x) for x in torch.nn.functional.dropout(A[i, k]).dims] assert id(i) in r and id(k) in r
def test_seg(self): A = torch.rand(3, 4) i, k = dims() i.size = 4 k.size = 3 r = i + k - 1
def forward(self, input): ci, co = dims() b = dimlists() result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co] return result.order(b, co)
def test_expand(self): A = torch.rand(3, 4) i = dims() assert list(A[i].expand(2, 4).order(i).size()) == [3, 2, 4]
def forward( self, hidden_states, past_key_value=None, ): # first run the encoding linear layers for q, k, v normally # the meaning of a linear layer is well understood, so no need to use explicit dimensions q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) # introduce values that represent each dimension. dimensions are 'first class' # becaue they are actual python values introduced here batch, query_sequence, key_sequence, heads, features = dims() heads.size = self.num_attention_heads # bind the positional dimensions in k, q, and v against # our values. the sizes of each dimension are determined by this binding # and when a dimension is used twice (e.g. batch), its size against both # uses is checked for consistency. # The group (heads, features) splits apart a single positional dimension # into two dimensions. Since heads.size*features.size == q.size(2) # and we specified heads.size, features.size is inferred here. q = q[batch, query_sequence, [heads, features]] k = k[batch, key_sequence, [heads, features]] v = v[batch, key_sequence, [heads, features]] # this option allows the model to attend to not just the elements of the current sequence # but the previouse elements as well as additional tokens. if past_key_value is not None: extended_key_sequence = dims() key_past = past_key_value[0][batch, heads, key_sequence, features] value_past = past_key_value[1][batch, heads, key_sequence, features] # cat introduces a new dimension exteneded_key_sequence, becuase it is twice as long # as the original key_sequence k = cat([key_past, k], key_sequence, extended_key_sequence) v = cat([value_past, v], key_sequence, extended_key_sequence) # for the rest of the function, we will just use extended_key_sequence in lieu of # key_sequence key_sequence = extended_key_sequence # Take the dot product between "query" and "key" to get the raw attention scores. # The actual outer-product and summation are explicitly represented here, # and like einsum, will be pattern matched to an efficient matrix multiply op. attention_scores = (q * k).sum(features) / math.sqrt(features.size) # relative positional embeddings gave a unique embedding based on the distance between # key and value tokens in the sequence, e.g. # 0 1 2 3 # -1 0 1 2 # -2 -1 0 1 # -3 -2 -1 0 if self.position_embedding_type is not None: # the value of a dimension object when used as a tensor is the indices along its dimension # so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence) # with the distance between them distance = query_sequence - key_sequence assert key_sequence.size <= self.max_position_embeddings # we can then use that as an indirect index into the embedding table values to look up the features for that index # this is just a `gather` primitive op. The resulting tensor will # have all the dimensions of embeddeding_idx (query_sequence x key_sequence), # plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`). # this form of indirect indexing is more strainghtforward than either advanced indexing or torch.gather which both # have a lot of dependencies on the positions of indexing tensors. positional_embedding = self.distance_embedding.weight[ self.max_position_embeddings - 1 + distance, features] if self.position_embedding_type == "relative_key": # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators # eventhough they are degenerate matmuls relative_position_scores = (q * positional_embedding).sum(features) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = ( q * positional_embedding).sum(features) relative_position_scores_key = ( k * positional_embedding).sum(features) attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key attention_probs = attention_scores # Normalize the attention scores to probabilities. attention_probs = softmax(attention_scores, dim=key_sequence) # # This is actually dropping out entire tokens to attend to, which might # # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = torch.nn.functional.dropout(attention_probs, p=self.dropout_prob) # similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear # we are weighting the values v across all keys with the attention scores. context_layer = (attention_probs * v).sum(key_sequence) # finally, we convert back to a standard tensor by describing the layout of dimensions. # working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one. return context_layer.order(batch, query_sequence, [heads, features])
def triu(A): i, j = dims() a = A[i, j] zero = torch.tensor(0, dtype=torch.float) # XXX - torch.where is janky... return torch.where(i <= j, a, zero).order(i, j)