def check(self, model, std_log): x = jt.random([100, 64, 128, 128]) with jt.log_capture_scope( log_silent=1, log_v=0, log_vprefix="atomic_tuner_pass.cc=100", ) as logs: y = model(x).numpy() with jt.log_capture_scope( log_v=0, exclude_pass="******", ) as logs2: y_std = model(x).numpy() err = np.max(y_std - y) / (np.mean(y_std) + 1e-6) assert err < 1e-5 log_move = find_log_with_re(logs, "atomictuner: move .* to loop .*") assert len(log_move) == len(std_log), (len(log_move), len(std_log)) for st in log_move: sidx = -1 for j in range(len(std_log)): if std_log[j] == st: sidx = j break assert sidx != -1 std_log[sidx] = "matched"
def test_backward_once_cuda(self): with jt.flag_scope(use_cuda=1): np.random.seed(0) jt.set_seed(3) model = Model2() n = 1 batch_size = 50 def get_data(n): for i in range(n): x = np.random.rand(batch_size, 1) y = x * x yield jt.float32(x), jt.float32(y) for i, (x, y) in enumerate(get_data(n)): pred_y = model(x).name("pred_y") with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: jt.sync_all() logs = find_log_with_re( logs, "Jit op key (not )?found: (cublas)_matmul.*") assert (len(logs) == 1) with jt.log_capture_scope( log_silent=1, log_v=0, log_vprefix="op.cc=100,exe=1000") as logs_b: gs = jt.grad(pred_y, x) gs2 = jt.grad(pred_y, model.linear1.weight) jt.sync_all() logs_b = find_log_with_re( logs_b, "Jit op key (not )?found: (cublas)_matmul.*") assert len(logs_b) == 2, len(logs_b) jt.clean()
def check(a): perms = list(permutations(range(a.ndim))) + [None] for perm in perms: with jt.log_capture_scope(log_silent=1, log_v=0, log_vprefix="op.cc=100") as raw_log: if perm: x = np.transpose(a, perm) y = jt.transpose(a, perm).data else: x = np.transpose(a) y = jt.transpose(a).data self.assertEqual(x.shape, y.shape) logs = find_log_with_re( raw_log, "(Jit op key (not )?found: " + "cutt_transpose" + ".*)") if perm is None: continue last = -1 in_order = True for i in range(len(perm)): if a.shape[perm[i]] == 1: continue if last != -1 and last > perm[i]: in_order = False break last = perm[i] if not in_order: assert len(logs) == 1 assert (x == y).all(), f"\n{x}\n{y}"
def check(self, h, w, cs, rs, pa, rtp, dim): a = jt.random([h, w]) a.data with jt.log_capture_scope( log_v=0, log_vprefix="tuner_manager=100", # this value is used for force compile compile_options={"test_reduce_tuner": 1}) as logs: amean = jt.mean(a, dims=[dim], keepdims=1) a2mean = jt.mean(a * a, dims=[dim], keepdims=1) norm_aa = (a - amean.broadcast_var(a)) / ( jt.sqrt(a2mean - amean * amean).broadcast_var(a)) norm_aa.data logs = find_log_with_re( logs, "Run tuner reduce: confidence\\((20)\\) candidates\\((.*)\\)$") assert len(logs) == 1, logs assert logs[0][0] == "20", "confidence of reorder should be 20" candidates = simple_parser(logs[0][1]) assert candidates == { "order0": [ 0, ], "order1": [ 1, ], "order2": [ 0, ], "split1": [ 2048, ], }
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): assert nhwc == 0 test_func = test_nchw # only check cudnn with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=10, log_vprefix="conv_tuner.cc=1000") as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation, groups) dx, dw = jt.grad(y, [x, w]) jt.sync([y, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test": 233}): cy = test_func(x, w, stride, padding, dilation, groups) cdx, cdw = jt.grad(cy, [x, w]) jt.sync([cy, cdx, cdw]) assert np.allclose(y.data, cy.data) assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max()) assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): if nhwc: test_func = test_nhwc else: test_func = test_nchw if use_cuda == 1: op_name = "cudnn_conv" else: op_name = "mkl_conv" with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=1, log_vprefix="op.cc=1000,exe=1000,conv_t=1000", compile_options={"test":244} ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation) loss = y.mean() dx, dw = jt.grad(loss, [x, w]) jt.sync([y, loss, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}): cy = test_func(x, w, stride, padding, dilation) closs = cy.mean() cdx, cdw = jt.grad(closs, [x, w]) jt.sync([cy, closs, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") assert len(logs)==3 and "oihw" in logs[0][0], (logs) assert np.allclose(y.data, cy.data, 1e-3) assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data) assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(), np.abs(dx.data - cdx.data).max())
def check(xshape, wshape, stride=1, padding=0, dilation=1): with jt.log_capture_scope( use_cuda=1, enable_tuner=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = conv(x, w, stride, padding) mask = jt.random(y.shape) loss = mask * y dx, dw = jt.grad(loss, [x, w]) jt.sync([y, loss, dx, dw]) # fails when enable_tuner=1, something wrong with mkl_conv_backward_x maybe. with jt.flag_scope(use_cuda=0, enable_tuner=0): cy = conv(x, w, stride, padding) closs = mask * cy cdx, cdw = jt.grad(closs, [x, w]) jt.sync([cy, closs, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") assert len(logs) == 3 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data) assert np.allclose(dx.data, cdx.data, 1e-2) assert np.allclose(dw.data, cdw.data, 1e-2)
def check(xshape, wshape, stride, pad): a = np.random.rand(*xshape).astype(np.float32) b = np.random.rand(*wshape).astype(np.float32) c = jt.mkl_ops.mkl_conv(a, b, stride, stride, pad, pad, 1, 1, xformat="acdb", wformat="hwio").data a_jt = jt.array(a) b_jt = jt.array(b) with jt.flag_scope(enable_tuner=0, compile_options={"test_mkl_conv": uid[0]}): c_jt = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data with jt.log_capture_scope( enable_tuner=1, compile_options={"test_mkl_conv": uid[0] + 1}, log_v=0, log_vprefix="tuner_manager=100,conv_tuner=1000", ) as raw_logs: c_jt_tune = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data uid[0] += 2 assert np.max(c_jt - c) < 1e-4 and np.max(c_jt_tune - c) < 1e-4 logs = find_log_with_re( raw_logs, "Run tuner conv: confidence\\((.*)\\) candidates\\((.*)\\)$") assert len(logs) == 1, raw_logs assert logs[0][0] == '20' assert simple_parser(logs[0][1]) == {'relay0': [1, 0]}
def check(shape, slices, i_to_vs="", i_to_o="", o_shape=""): # print(slices) x = jt.random(shape) with jt.log_capture_scope(log_vprefix="getitem=999") as logs: a = x.getitem(slices) a.sync() b = x.data[slices] bshape = b.shape if len(b.shape) else (1, ) assert a.shape == bshape, (a.shape, bshape) s = logs[-1]['msg'] assert "i_to_vs: " + i_to_vs in s assert "i_to_o: " + i_to_o in s assert "o_shape: " + o_shape in s aa = a.numpy() assert (aa == b).all(), (aa, b) y = x.numpy() v = jt.random(a.shape) z = x.setitem(slices, v) y[slices] = v.data assert (z.data == y).all(), (z.data, y, v.data, x.data) # test_setitem broadcast adim = len(a.shape) for mask in range(1 << adim): new_shape = list(a.shape) for i in range(adim): if (mask >> i) & 1: new_shape[i] = 1 y = x.numpy() v = jt.random(new_shape) z = x.setitem(slices, v) y[slices] = v.data assert (z.data == y).all(), (z.data, y, v.data, x.data)
def check_cub_argsort(shape, dim, descending = False): with jt.log_capture_scope( log_silent=1, log_v=0, log_vprefix="op.cc=100" ) as raw_log: x = jt.random(shape) y, y_key = jt.argsort(x, dim=dim, descending=descending) v = [] for i in range(len(shape)): if (i == dim): v.append(y) else: v.append(jt.index(shape, dim=i)) yk = jt.reindex(x, v) yk_ = yk.data y_key_ = y_key.data logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_argsort" + ".*)") assert len(logs)==1 x__ = x.data if descending: x__ = -x__ yk__ = np.sort(x__, axis=dim) if descending: yk__ = -yk__ assert np.allclose(y_key_, yk__) assert np.allclose(yk_, yk__)
def test_forward(self): a = np.random.rand(1, 3, 224, 224).astype(np.float32) b = np.random.rand(64, 3, 7, 7).astype(np.float32) c = jt.mkl_ops.mkl_conv(a, b, 2, 2, 3, 3).data a_jt = jt.array(a) b_jt = jt.array(b) with jt.flag_scope(enable_tuner=0, compile_options={"test_mkl_conv": 1}): c_jt = conv(a_jt, b_jt, 3, 2).data with jt.log_capture_scope( enable_tuner=1, compile_options={"test_mkl_conv": 2}, log_v=0, log_vprefix="tuner_manager=100,conv_tuner=1000", ) as raw_logs: c_jt_tune = conv(a_jt, b_jt, 3, 2).data assert np.max(c_jt - c) < 1e-4 and np.max(c_jt_tune - c) < 1e-4 logs = find_log_with_re( raw_logs, "Run tuner conv: confidence\\((.*)\\) candidates\\((.*)\\)$") assert len(logs) == 1 assert logs[0][0] == '20' assert simple_parser(logs[0][1]) == {'relay0': [1, 0]}
def test(self): a = jt.ones((8, 8, 8)) a.data with jt.log_capture_scope(log_v=0, log_vprefix="tuner_manager=100") as logs: b = a + a b.data logs = find_log_with_re( logs, "Run tuner reorder: confidence\\((.*)\\) candidates\\((.*)\\)$") assert len(logs) == 1 assert logs[0][0] == "1", "confidence of reorder should be 1" candidates = simple_parser(logs[0][1]) assert candidates == { "order0": [ 0, ], "order1": [ 0, 1, ], "order2": [ 0, 1, 2, ] }
def test_matmul2(s1, s2, t1, t2, dtype='float32'): if (not t1) and (not t2): tp = 0 if (t1) and (not t2): tp = 1 if (not t1) and (t2): tp = 2 if (dtype.startswith('float')): a = jt.random(s1, dtype=dtype) b = jt.random(s2, dtype=dtype) else: a = jt.random(s1) b = jt.random(s2) a = (a * 2000 - 1000).cast(dtype) b = (b * 2000 - 1000).cast(dtype) c = matmul2(a, b, tp) if t1: a_ = a.data.transpose() else: a_ = a.data if t2: b_ = b.data.transpose() else: b_ = b.data c_ = np.matmul(a_, b_) with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: c__ = c.data assert np.allclose(c_, c__) logs = find_log_with_re( logs, "Jit op key (not )?found: (mkl)|(cublas)_matmul.*") if (dtype.startswith('float')): if jt.flags.use_cuda or dtype == 'float32': assert (len(logs) == 1)
def test_searcher(self): a = jt.ones((80, 80, 80)) a.data global gid gid += 1 with jt.log_capture_scope(log_v=0, log_vprefix="jit_searcher=1000", jit_search_kernel=1, compile_options={ "compile_shape": 1, "test_reorder_tuner": gid }) as logs: b = a + a b.data ls = find_log_with_re(logs, "Choices") assert len(ls) == 6, (ls, logs) ls = find_log_with_re(logs, "Best choices\\(.*\\): (.*)$") assert len(ls) == 1 best = simple_parser(ls[0]) assert best == { "compile_shape": 1, "order0": 0, "order1": 0, "order2": 0, "test_reorder_tuner": gid }
def check_forward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): if nhwc: test_func = test_nhwc else: test_func = test_nchw if use_cuda == 1: op_name = "cudnn_conv" else: op_name = "mkl_conv" with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=0, log_vprefix="op.cc=100,conv_tuner=1000", compile_options={"test": 266}) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test": 255}): cy = test_func(x, w, stride, padding, dilation) cy.sync() logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") assert len(logs) == 1 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data)
def test_float64(self): jt.set_seed(3) with jt.log_capture_scope(log_silent=1, log_v=0, log_vprefix="op.cc=100") as raw_log: t = jt.random([5, 5], dtype='float64') t.data logs = find_log_with_re( raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)") assert len(logs) == 1
def test_matmul(s1, s2): a = jt.random(s1) b = jt.random(s2) c = jt.nn.matmul(a, b) c_ = np.matmul(a.data, b.data) with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: c__ = c.data assert np.allclose(c_, c__) logs = find_log_with_re( logs, "Jit op key (not )?found: (mkl)|(cublas)_matmul.*") assert (len(logs) == 1)
def check(self, model, std_log): x = jt.random([100, 64, 128, 128]) with jt.log_capture_scope( # log_silent=1, log_v=0, log_vprefix="atomic=100,data=100", ) as logs: y = model(x).numpy() with jt.log_capture_scope( log_v=0, exclude_pass="******", # new options to force recompile compile_options={"test_atomic_tuner": 1}) as logs2: y_std = model(x).numpy() err = np.max(y_std - y) / (np.mean(y_std) + 1e-6) assert err < 1e-5, (err) log_move = find_log_with_re(logs, "atomictuner: move .* to loop .*") assert len(log_move) == len(std_log), (len(log_move), len(std_log)) assert sorted(log_move) == sorted(std_log)
def test_with_split(self): a = jt.ones((8, 8, 8)) a.data global gid gid += 1 with jt.log_capture_scope(log_v=0, log_vprefix="tuner_manager=100", compile_options={ "split0": 4, "split1": 4, "split2": 4, "test_reorder_tuner": gid }) as logs: b = a + a b.data logs = find_log_with_re( logs, "Run tuner reorder: confidence\\((.*)\\) candidates\\((.*)\\)$") assert len(logs) == 1 assert logs[0][0] == "1", "confidence of reorder should be 1" candidates = simple_parser(logs[0][1]) assert candidates == { "order0": [ 0, ], "order1": [ 0, 1, ], "order2": [ 0, 1, 2, ], "order3": [ 0, 1, 2, ], "order4": [ 0, 1, 2, ], "order5": [ 0, 1, 2, ], }, candidates
def test_all_reduce(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random([5, 5]) y = x.mpi_all_reduce() assert np.allclose(y.data, (x * n).data) g = jt.grad(y, x) assert np.allclose(g.data, np.ones([5, 5]) * n) logs = find_log_with_re( raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)") assert len(logs) == 2, len(logs)
def check(xshape, wshape, stride=1, padding=0, dilation=1): with jt.log_capture_scope(use_cuda=1, enable_tuner=1, log_v=0, log_vprefix="op.cc=100" ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = conv_oihw(x, w, stride, padding, dilation) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=1): cy = conv_oihw(x, w, stride, padding, dilation) cy.sync() logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") assert len(logs)==1 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data), np.abs(y.data-cy.data).max()
def test_vgg(self): self.setup_seed(1) loss_list = [] acc_list = [] mnist_net = MnistNet() SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) for batch_idx, (data, target) in enumerate(self.train_loader): output = mnist_net(data) loss = nn.cross_entropy_loss(output, target) # train step with jt.log_capture_scope( log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=10", ) as logs: SGD.step(loss) def callback(loss, output, target, batch_idx): # print train info pred = np.argmax(output, axis=1) acc = np.sum(target == pred) / self.batch_size loss_list.append(loss[0]) acc_list.append(acc) print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}' .format(0, batch_idx, 100, 1. * batch_idx, loss[0], acc)) jt.fetch(batch_idx, loss, output, target, callback) log_conv = find_log_with_re( logs, "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") log_matmul = find_log_with_re( logs, "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") if batch_idx: assert len(log_conv) == 38 and len(log_matmul) == 12, ( len(log_conv), len(log_matmul)) mem_used = jt.flags.stat_allocator_total_alloc_byte \ -jt.flags.stat_allocator_total_free_byte assert mem_used < 11e9, mem_used assert jt.core.number_of_lived_vars() < 3500 if (np.mean(loss_list[-50:]) < 0.2): break assert np.mean(loss_list[-50:]) < 0.2
def test_reduce(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random([5, 5]) y = x.mpi_reduce(root=0) y_ = y.data x_ = (x * n).data if mpi.world_rank() == 0: assert np.allclose(y_, x_) g = jt.grad(y, x) assert np.allclose(g.data, np.ones([5, 5])) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)") assert len(logs) == 1, len(logs)
def test(shape, op1, op2): n = 753.1 a = jt.random(shape) b = jt.random(shape) c = op1(a, n) d = op2(c, b) with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs: d__ = d.data logs = find_log_with_re( logs, "Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32") assert (len(logs) == 1), logs a_ = a.data b_ = b.data d_ = op2(op1(a_, n), b_) assert (np.allclose(d_, d__, atol=1e-4))
def test_backward_nhwc_hwio(self): n,c,H,W = 2,3,5,5 o,i,h,w = 4,c,3,3 a = np.random.rand(n,H,W,c).astype(np.float32) b = np.random.rand(h,w,i,o).astype(np.float32) da = np.random.rand(n,H,W,o).astype(np.float32) dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data a_jt = jt.array(a) b_jt = jt.array(b) with jt.flag_scope( enable_tuner=0, compile_options={"test_mkl_conv":1} ): c_jt = conv_nhwc_hwio(a_jt, b_jt, 1, 1) * da gs=jt.grad(c_jt,[a_jt,b_jt]) gs.append(c_jt) jt.fetch_sync(gs) dx_jt=gs[0].data dw_jt=gs[1].data with jt.log_capture_scope( log_v=10, log_vprefix="tuner_manager=100,var_relay=100", enable_tuner=1, compile_options={"test_mkl_conv":2} ) as rawlogs: gs_tune=jt.grad(c_jt,[a_jt,b_jt]) jt.fetch_sync(gs_tune) dx_jt_tune=gs_tune[0].data dw_jt_tune=gs_tune[1].data logs = find_log_with_re(rawlogs, "Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$") assert len(logs) == 2 assert logs[0][0] == "20", "confidence of reorder should be 20" candidates = simple_parser(logs[0][1]) assert candidates == {"relay0":[1,0]}, candidates # assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)") assert len(logs)==2 assert "@relay_op" in logs[0] assert "@relay_op" in logs[1] assert np.max(dx_jt_tune-dx)<1e-5 and np.max(dw_jt_tune-dw)<1e-5 assert np.max(dx_jt-dx)<1e-5 and np.max(dw_jt-dw)<1e-5
def check(self, h, w, cs, rs, pa, rtp, dim): a = jt.random([h, w]) a.sync() with jt.log_capture_scope( log_v=0, log_vprefix="tuner_manager=100", # this value is used for force compile compile_options={"test_new_fused_op": 1}) as logs: amean = jt.mean(a, dims=[dim], keepdims=1) a2mean = jt.mean(a * a, dims=[dim], keepdims=1) norm_aa = (a - amean.broadcast_var(a)) / ( jt.sqrt(a2mean - amean * amean).broadcast_var(a)) norm_aa.sync() logs = find_log_with_re( logs, "Run tuner reduce: confidence\\((.*)\\) candidates\\((.*)\\)$") assert len(logs) == 3, logs
def test_broadcast(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: data = jt.random([5, 5]) if mpi.world_rank() == 0: x = data else: x = jt.zeros([5, 5]) y = x.mpi_broadcast(0) assert np.allclose(y.data, data.data) g = jt.grad(y.sum(), x) g_ = g.data if mpi.world_rank() == 0: assert np.allclose(g_, np.ones([5, 5]) * n) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)") assert len(logs) == 1, len(logs)
def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): assert nhwc == 0 test_func = test_nchw # only check cudnn with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=10, log_vprefix="op.cc=100,conv_tuner=1000" ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation, groups) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=0): cy = test_func(x, w, stride, padding, dilation, groups) cy.sync() logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)") assert len(logs)==1 assert np.allclose(y.data, cy.data)
def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): assert nhwc == 0 test_func = test_nchw # only check cudnn with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=10, log_vprefix="conv_tuner.cc=1000") as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation, groups) y.sync() with jt.flag_scope(use_cuda=0, enable_tuner=0): cy = test_func(x, w, stride, padding, dilation, groups) cy.sync() assert np.allclose(y.data, cy.data)
def check(data_shape, weights_shape, stride=1, dilation=1): N, C, H, W = data_shape i, o, h, w = weights_shape img = np.random.rand(N, C, H, W).astype("float32") weights = np.random.rand(i, o, h, w).astype("float32") m1 = jt.nn.ConvTranspose(i, o, h, stride=stride, dilation=dilation, bias=False) m2 = torch.nn.ConvTranspose2d(i, o, h, stride=stride, dilation=dilation, bias=False) m1.weight.data = weights m2.weight.data = torch.Tensor(weights) x = jt.array(img) # out1 = m1(x) out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False) mask = jt.random(out1.shape) out1 = out1 * mask tx = torch.Tensor(img) tx.requires_grad = True out2 = m2(tx) * torch.Tensor(mask.data) with jt.log_capture_scope( log_silent=1, log_vprefix="var_re=0,conv=0,op.cc=100") as logs: assert np.allclose(out1.data, out2.data) dx, dw = jt.grad(out1, [x, m1.weight]) jt.sync([dx, dw]) out2.sum().backward() assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3) assert np.allclose(dx.data, tx.grad.numpy()) assert len(find_log_with_re(logs, "conv")) == 3