def check(xshape, wshape, stride=1, padding=0, dilation=1): with jt.log_capture_scope( use_cuda=1, enable_tuner=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = conv(x, w, stride, padding) mask = jt.random(y.shape) loss = mask * y dx, dw = jt.grad(loss, [x, w]) jt.sync([y, loss, dx, dw]) # fails when enable_tuner=1, something wrong with mkl_conv_backward_x maybe. with jt.flag_scope(use_cuda=0, enable_tuner=0): cy = conv(x, w, stride, padding) closs = mask * cy cdx, cdw = jt.grad(closs, [x, w]) jt.sync([cy, closs, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") assert len(logs) == 3 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data) assert np.allclose(dx.data, cdx.data, 1e-2) assert np.allclose(dw.data, cdw.data, 1e-2)
def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): if nhwc: test_func = test_nhwc else: test_func = test_nchw if use_cuda == 1: op_name = "cudnn_conv" else: op_name = "mkl_conv" with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=1, log_vprefix="op.cc=1000,exe=1000,conv_t=1000", compile_options={"test":244} ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation) loss = y.mean() dx, dw = jt.grad(loss, [x, w]) jt.sync([y, loss, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}): cy = test_func(x, w, stride, padding, dilation) closs = cy.mean() cdx, cdw = jt.grad(closs, [x, w]) jt.sync([cy, closs, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") assert len(logs)==3 and "oihw" in logs[0][0], (logs) assert np.allclose(y.data, cy.data, 1e-3) assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data) assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(), np.abs(dx.data - cdx.data).max())
def check_forward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): if nhwc: test_func = test_nhwc else: test_func = test_nchw if use_cuda == 1: op_name = "cudnn_conv" else: op_name = "mkl_conv" with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=0, log_vprefix="op.cc=100,conv_tuner=1000", compile_options={"test": 266}) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test": 255}): cy = test_func(x, w, stride, padding, dilation) cy.sync() logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") assert len(logs) == 1 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data)
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): assert nhwc == 0 test_func = test_nchw # only check cudnn with jt.log_capture_scope( use_cuda=use_cuda, enable_tuner=1, log_v=10, log_vprefix="op.cc=100,conv_tuner=1000") as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation, groups) y.sync() dx, dw = jt.grad(y, [x, w]) jt.sync([y, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test": 233}): cy = test_func(x, w, stride, padding, dilation, groups) cdx, cdw = jt.grad(cy, [x, w]) jt.sync([cy, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)") assert len(logs) == 3 assert np.allclose(y.data, cy.data) assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max()) assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
def check(xshape, wshape, stride=1, padding=0, dilation=1): with jt.log_capture_scope(use_cuda=1, enable_tuner=1, log_v=0, log_vprefix="op.cc=100" ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = conv_oihw(x, w, stride, padding, dilation) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=1): cy = conv_oihw(x, w, stride, padding, dilation) cy.sync() logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") assert len(logs)==1 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data), np.abs(y.data-cy.data).max()
def test_all_reduce(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random([5, 5]) y = x.mpi_all_reduce() assert np.allclose(y.data, (x * n).data) g = jt.grad(y, x) assert np.allclose(g.data, np.ones([5, 5]) * n) logs = find_log_with_re( raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)") assert len(logs) == 2, len(logs)
def test_reduce(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random([5, 5]) y = x.mpi_reduce(root=0) y_ = y.data x_ = (x * n).data if mpi.world_rank() == 0: assert np.allclose(y_, x_) g = jt.grad(y, x) assert np.allclose(g.data, np.ones([5, 5])) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)") assert len(logs) == 1, len(logs)
def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): assert nhwc == 0 test_func = test_nchw # only check cudnn with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=10, log_vprefix="op.cc=100,conv_tuner=1000" ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation, groups) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=0): cy = test_func(x, w, stride, padding, dilation, groups) cy.sync() logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)") assert len(logs)==1 assert np.allclose(y.data, cy.data)
def test_broadcast(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: data = jt.random([5, 5]) if mpi.world_rank() == 0: x = data else: x = jt.zeros([5, 5]) y = x.mpi_broadcast(0) assert np.allclose(y.data, data.data) g = jt.grad(y.sum(), x) g_ = g.data if mpi.world_rank() == 0: assert np.allclose(g_, np.ones([5, 5]) * n) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)") assert len(logs) == 1, len(logs)
def check(data_shape, weights_shape, stride=1, dilation=1): N, C, H, W = data_shape i, o, h, w = weights_shape img = np.random.rand(N, C, H, W).astype("float32") weights = np.random.rand(i, o, h, w).astype("float32") m1 = jt.nn.ConvTranspose(i, o, h, stride=stride, dilation=dilation, bias=False) m2 = torch.nn.ConvTranspose2d(i, o, h, stride=stride, dilation=dilation, bias=False) m1.weight.data = weights m2.weight.data = torch.Tensor(weights) x = jt.array(img) # out1 = m1(x) out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False) mask = jt.random(out1.shape) out1 = out1 * mask tx = torch.Tensor(img) tx.requires_grad = True out2 = m2(tx) * torch.Tensor(mask.data) with jt.log_capture_scope( log_silent=1, log_vprefix="var_re=0,conv=0,op.cc=100") as logs: assert np.allclose(out1.data, out2.data) dx, dw = jt.grad(out1, [x, m1.weight]) jt.sync([dx, dw]) out2.sum().backward() assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3) assert np.allclose(dx.data, tx.grad.numpy()) assert len(find_log_with_re(logs, "conv")) == 3
def test_resnet(self): self.setup_seed(1) loss_list = [] acc_list = [] mnist_net = MnistNet() global prev prev = time.time() SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) self.train_loader.endless = True for data, target in self.train_loader: batch_id = self.train_loader.batch_id epoch_id = self.train_loader.epoch_id # train step with jt.log_capture_scope( log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=10", ) as logs: output = mnist_net(data) loss = nn.cross_entropy_loss(output, target) SGD.step(loss) def callback(epoch_id, batch_id, loss, output, target): # print train info global prev pred = np.argmax(output, axis=1) acc = np.mean(target == pred) loss_list.append(loss[0]) acc_list.append(acc) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' .format(epoch_id, batch_id, 600, 1. * batch_id / 6.0, loss[0], acc, time.time() - prev)) # prev = time.time() jt.fetch(epoch_id, batch_id, loss, output, target, callback) log_conv = find_log_with_re( logs, "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") log_matmul = find_log_with_re( logs, "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") if batch_id > 2: assert len(log_conv) == 59 and len(log_matmul) == 6, ( len(log_conv), len(log_matmul)) mem_used = jt.flags.stat_allocator_total_alloc_byte \ -jt.flags.stat_allocator_total_free_byte # assert mem_used < 4e9, mem_used # TODO: why bigger? assert mem_used < 5.6e9, mem_used # example log: # Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000 # Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000 # Train Epoch: 0 [2/100 (2%)] Loss: 3.473594 Acc: 0.100000 # Train Epoch: 0 [3/100 (3%)] Loss: 3.131615 Acc: 0.200000 # Train Epoch: 0 [4/100 (4%)] Loss: 2.524094 Acc: 0.230000 # Train Epoch: 0 [5/100 (5%)] Loss: 7.780025 Acc: 0.080000 # Train Epoch: 0 [6/100 (6%)] Loss: 3.890721 Acc: 0.160000 # Train Epoch: 0 [7/100 (7%)] Loss: 6.370137 Acc: 0.140000 # Train Epoch: 0 [8/100 (8%)] Loss: 11.390827 Acc: 0.150000 # Train Epoch: 0 [9/100 (9%)] Loss: 21.598564 Acc: 0.080000 # Train Epoch: 0 [10/100 (10%)] Loss: 23.369165 Acc: 0.130000 # Train Epoch: 0 [20/100 (20%)] Loss: 4.804510 Acc: 0.100000 # Train Epoch: 0 [30/100 (30%)] Loss: 3.393924 Acc: 0.110000 # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 if jt.in_mpi: assert jt.core.number_of_lived_vars( ) < 8100, jt.core.number_of_lived_vars() else: assert jt.core.number_of_lived_vars( ) < 7000, jt.core.number_of_lived_vars() if self.train_loader.epoch_id >= 2: break jt.sync_all(True) assert np.mean(loss_list[-50:]) < 0.5 assert np.mean(acc_list[-50:]) > 0.8