def test_linear(self): def _test_linear(x): # create in tinygrad layer = (Tensor.uniform(in_dim, out_dim), Tensor.zeros(out_dim)) z = x.linear(*layer) # create in torch with torch.no_grad(): torch_layer = torch.nn.Linear(in_dim, out_dim).eval() torch_layer.weight[:] = torch.tensor(layer[0].data.T, dtype=torch.float32) torch_layer.bias[:] = torch.tensor(layer[1].data, dtype=torch.float32) torch_x = torch.tensor(x.cpu().data, dtype=torch.float32) torch_z = torch_layer(torch_x) # test np.testing.assert_allclose(z.data, torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) BS, T, in_dim, out_dim = 4, 2, 8, 16 _test_linear(Tensor.randn(BS, in_dim)) _test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims
def test_batchnorm2d(self, training=False): sz = 4 # create in tinygrad bn = BatchNorm2D(sz, eps=1e-5, training=training, track_running_stats=training) bn.weight = Tensor.randn(sz) bn.bias = Tensor.randn(sz) bn.running_mean = Tensor.randn(sz) bn.running_var = Tensor.randn(sz) bn.running_var.data[bn.running_var.data < 0] = 0 # create in torch with torch.no_grad(): tbn = torch.nn.BatchNorm2d(sz).eval() tbn.training = training tbn.weight[:] = torch.tensor(bn.weight.data) tbn.bias[:] = torch.tensor(bn.bias.data) tbn.running_mean[:] = torch.tensor(bn.running_mean.data) tbn.running_var[:] = torch.tensor(bn.running_var.data) np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5) np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5) # trial inn = Tensor.randn(2, sz, 3, 3) # in tinygrad outt = bn(inn) # in torch toutt = tbn(torch.tensor(inn.cpu().data)) # close np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5) np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
def test_mnist(self): # https://keras.io/examples/vision/mnist_convnet/ conv = 3 inter_chan, out_chan = 32, 64 # ****** torch baseline ******* torch.backends.mkldnn.enabled = False conv = 3 inter_chan, out_chan = 32, 64 c1 = torch.randn(inter_chan, 1, conv, conv, requires_grad=True) c2 = torch.randn(out_chan, inter_chan, conv, conv, requires_grad=True) l1 = torch.randn(out_chan * 5 * 5, 10, requires_grad=True) c2d = torch.nn.functional.conv2d mp = torch.nn.MaxPool2d((2, 2)) lsm = torch.nn.LogSoftmax(dim=1) with torch.autograd.profiler.profile(record_shapes=True) as tprof: cnt = 5 fpt, bpt = 0.0, 0.0 for i in range(cnt): et0 = time.time() x = torch.randn(128, 1, 28, 28, requires_grad=True) x = mp(c2d(x, c1).relu()) x = mp(c2d(x, c2).relu()) x = x.reshape(x.shape[0], -1) out = lsm(x.matmul(l1)) out = out.mean() et1 = time.time() out.backward() et2 = time.time() fpt += (et1 - et0) bpt += (et2 - et1) fpt_baseline = (fpt * 1000 / cnt) bpt_baseline = (bpt * 1000 / cnt) print("torch forward pass: %.3f ms" % fpt_baseline) print("torch backward pass: %.3f ms" % bpt_baseline) print(tprof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) # ****** tinygrad compare ******* c1 = Tensor(c1.detach().numpy()) c2 = Tensor(c2.detach().numpy()) l1 = Tensor(l1.detach().numpy()) cnt = 5 fpt, bpt = 0.0, 0.0 for i in range(1 + cnt): et0 = time.time() x = Tensor.randn(128, 1, 28, 28) x = x.conv2d(c1).relu().avg_pool2d() x = x.conv2d(c2).relu().max_pool2d() x = x.reshape(Tensor(np.array((x.shape[0], -1)))) out = x.dot(l1).logsoftmax() out = out.mean() et1 = time.time() out.backward() et2 = time.time() if i == 0: pr = start_profile() else: fpt += (et1 - et0) bpt += (et2 - et1) stop_profile(pr, sort='time') fpt = (fpt * 1000 / cnt) bpt = (bpt * 1000 / cnt) print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt / fpt_baseline, fpt_baseline)) print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt / bpt_baseline, bpt_baseline))
def profile_conv(bs, chans, conv, cnt=100): img = Tensor.zeros(bs, 1, 28, 28) conv = Tensor.randn(chans, 1, conv, conv) for i in range(cnt): out = img.conv2d(conv)
def __init__(self, in_size: int, out_size: int): self.weight = Tensor.randn(in_size, out_size) self.bias = Tensor.zeros(1, out_size)