def check(): a = jt.random((5, 1)) b = jt.numpy_code( a.shape, a.dtype, [a], forward_code, [backward_code], ) assert numpy.allclose(b.data, (a + a).data) da = jt.grad(b, a) one = numpy.ones(a.shape) assert numpy.allclose(da.data, one * 2.0)
def test_batchnorm_backward(self): mpi = jt.compile_extern.mpi data = np.random.rand(30, 3, 10, 10).astype("float32") global_x = jt.array(data) x = jt.array(data[mpi.world_rank() * 10:(mpi.world_rank() + 1) * 10, ...]) bn1 = nn.BatchNorm(3, sync=True) bn2 = FakeMpiBatchNorm(3) y1 = bn1(x) y2 = bn2(x, global_x) gs1 = jt.grad(y1, bn1.parameters()) gs2 = jt.grad(y2, bn2.parameters()) assert np.allclose(y1.data, y2.data, atol=1e-5), (mpi.world_rank(), y1.data, y2.data, y1.data - y2.data) for i in range(len(gs1)): assert np.allclose(gs1[i].data, gs2[i].data, rtol=1e-3), (mpi.world_rank(), gs1[i].data, gs2[i].data, gs1[i].data - gs2[i].data)
def compute_gradient_penalty(D, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" # Random weight term for interpolation between real and fake samples alpha = jt.array( np.random.random((real_samples.size(0), 1, 1, 1)).astype(np.float32)) # Get random interpolation between real and fake samples interpolates = alpha * real_samples + ((1 - alpha) * fake_samples) d_interpolates, _ = D(interpolates) # Get gradient w.r.t. interpolates gradients = jt.grad(d_interpolates, interpolates) gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty
def test1(self): class MyFunc(Function): def execute(self, x): return x + 1 def grad(self, grad): return grad - 2 a = jt.ones(1) func = MyFunc() b = func(a) da = jt.grad(b, a) assert da.data == -1
def test_broadcast(self): data = jt.random([5, 5]) if mpi.world_rank() == 0: x = data else: x = jt.zeros([5, 5]) y = x.mpi_broadcast(0) assert np.allclose(y.data, data.data) g = jt.grad(y, x) if mpi.world_rank() == 0: assert np.allclose(g.data, np.ones([5, 5]) * n) else: assert np.allclose(g.data, np.zeros([5, 5]))
def test2(self): class MyFunc(Function): def execute(self, x): self.x = x return x+1 def grad(self, grad): return (grad-2) * self.x a = jt.ones(1) * 10 func = MyFunc() b = func(a) da = jt.grad(b, a) assert da.data == -10
def check(xshape, wshape, stride=1, padding=0, dilation=1): with jt.log_capture_scope(use_cuda=1, enable_tuner=1, log_v=0, log_vprefix="op.cc=100" ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = conv_oihw(x, w, stride, padding, dilation) mask = jt.random(y.shape) loss = mask * y dx, dw = jt.grad(loss, [x, w]) jt.sync([y, loss, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0): cy = conv_oihw(x, w, stride, padding, dilation) closs = mask * cy cdx, cdw = jt.grad(closs, [x, w]) jt.sync([cy, closs, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") assert len(logs)==3 and "oihw" in logs[0][0], logs assert np.allclose(y.data, cy.data) assert np.allclose(dx.data, cdx.data) assert np.allclose(dw.data, cdw.data)
def test_cumprod_cpu(self): for i in range(1,6): for j in range(i): x = np.random.rand(*((10,)*i)) x_jt = jt.array(x) y_jt = jt.cumprod(x_jt, j).sqr() g_jt = jt.grad(y_jt.sum(), x_jt) x_tc = Variable(torch.from_numpy(x), requires_grad=True) y_tc = torch.cumprod(x_tc, j)**2 y_tc.sum().backward() g_tc = x_tc.grad assert np.allclose(y_jt.numpy(), y_tc.data) assert np.allclose(g_jt.numpy(), g_tc.data)
def adam(model, loss, lr=3e-4, betas=[0.9, 0.999], eps=1e-8): ps = jt.find_vars(model) gs = jt.grad(loss, ps) with jt.var_scope('_'.join([model, 'adam']), unique=True): adam_step = jt.make_var([1], init=jt.zeros) adam_step += 1 for p,g in zip(ps,gs): m = jt.make_var(p.shape, init=jt.zeros) v = jt.make_var(p.shape, init=jt.zeros) m.assign(betas[0] * m + (1-betas[0]) * g) v.assign(betas[1] * v + (1-betas[1]) * g * g) step_size = lr * jt.sqrt(1-betas[1]**adam_step) / (1-betas[0] ** adam_step) p -= m * step_size / (jt.sqrt(v) + eps)
def test_all_reduce(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random([5, 5]) y = x.mpi_all_reduce() assert np.allclose(y.data, (x * n).data) g = jt.grad(y, x) assert np.allclose(g.data, np.ones([5, 5]) * n) logs = find_log_with_re( raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)") assert len(logs) == 2, len(logs)
def test_grad_not_match_error(self): class MyFunc(Function): def execute(self, x, y): self.x = x self.y = y return x*y def grad(self, grad): return (grad-2) * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() c = func(a, b) expect_error(lambda: jt.grad(c, [a, b]))
def pre_step(self, loss): """ something should be done before step, such as calc gradients, mpi sync, and so on. Example:: class MyOptimizer(Optimizer): def step(self, loss): self.post_step(loss) ... """ # clean prev grads params = [] params_has_grad = [] for pg in self.param_groups: for p in pg['params']: params.append(p) if not p.is_stop_grad(): params_has_grad.append(p) # get gradient grads = jt.grad(loss, params_has_grad) # sync grads and model if in mpi if jt.in_mpi: for g in grads: g.assign(g.mpi_all_reduce("mean")) if self.n_step % self.param_sync_iter == 0: for p in params: p.assign(p.mpi_broadcast()) self.n_step += 1 # set up grads in param_groups pid = 0 for pg in self.param_groups: if "grads" not in pg: pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ] pg_grads = pg["grads"] for i, p in enumerate(pg['params']): if not p.is_stop_grad(): # accumulate grad and stop grad of grad g = grads[pid].stop_grad() if not self.__zero_grad: g = g + pg_grads[i] pg_grads[i].update(g) pid += 1 self.__zero_grad = False
def test_pinv(self): def check_pinv(a): w = anp.linalg.pinv(a) return w for i in range(50): x = jt.random((2, 2, 4, 3)) c_a = x.data mx = jt.linalg.pinv(x) tx = check_pinv(c_a) np.allclose(mx.data, tx) jx = jt.grad(mx, x) check_grad = jacobian(check_pinv) gx = check_grad(c_a) np.allclose(gx, jx.data)
def test_matmul_grad(self): np.random.seed(0) for i in range(10): a = np.random.rand(2, 3).astype("float32") b = np.random.rand(3, 4).astype("float32") out, (da, db) = ngrad(lambda vars: np.matmul(vars[0], vars[1]).sum(), [a, b], 1e-1) ja = jt.array(a) jb = jt.array(b) jc = ja.matmul(jb) jda, jdb = jt.grad(jc, [ja, jb]) assert ((da - jda.data) < 1e-5).all(), (da, jda.data, da - jda.data) assert ((db - jdb.data) < 1e-5).all(), (db - jdb.data)
def check(xshape, wshape, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), group=1): with jt.flag_scope(use_cuda=1): x = jt.random(xshape) w = jt.random(wshape) y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation) with jt.flag_scope(use_cuda=1): # y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group) y = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation) masky = jt.rand_like(y) dx, dw = jt.grad(masky * y, [x, w]) dx2, dw2 = jt.grad(masky * y2, [x, w]) np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4) np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4) np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): if nhwc: test_func = test_nhwc else: test_func = test_nchw if use_cuda == 1: op_name = "cudnn_conv" else: op_name = "mkl_conv" with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=1, log_vprefix="op.cc=100,exe=1000,conv_t=1000", compile_options={"test": 244}) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation) loss = y.mean() dx, dw = jt.grad(loss, [x, w]) jt.sync([y, loss, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test": 233}): cy = test_func(x, w, stride, padding, dilation) closs = cy.mean() cdx, cdw = jt.grad(closs, [x, w]) jt.sync([cy, closs, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") assert len(logs) == 3 and "oihw" in logs[0][0], (logs) assert np.allclose(y.data, cy.data, 1e-3) assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data) assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(), np.abs(dx.data - cdx.data).max())
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): assert nhwc == 0 test_func = test_nchw # only check cudnn with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, log_v=10, log_vprefix="op.cc=100,conv_tuner=1000" ) as raw_log: x = jt.random(xshape) w = jt.random(wshape) y = test_func(x, w, stride, padding, dilation, groups) y.sync() dx, dw = jt.grad(y, [x, w]) jt.sync([y, dx, dw]) with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}): cy = test_func(x, w, stride, padding, dilation, groups) cdx, cdw = jt.grad(cy, [x, w]) jt.sync([cy, cdx, cdw]) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)") assert len(logs)==3 assert np.allclose(y.data, cy.data) assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max()) assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
def test_longest_dis_fuse(self): x = jt.array(np.random.rand(1, 3, 224, 224).astype(np.float32)) loss = jt.sum(resnet_fake(x)) ps = jt.find_vars('resnet_fake') gs = jt.grad(loss, ps) jt.sync(gs) # assert not alloc big tensor g = jt.dump_all_graphs() for s in g.nodes_info: if not s.startswith("Var"): continue shape = s.split("[")[1].split("]")[0].split(",") ptr = s.split("(")[1].split(")")[0].split(",")[-1] if ptr != '0': assert len(shape) <= 5, s
def test_reduce(self): with jt.log_capture_scope(enable_tuner=1, log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=1000") as raw_log: x = jt.random([5, 5]) y = x.mpi_reduce(root=0) y_ = y.data x_ = (x * n).data if mpi.world_rank() == 0: assert np.allclose(y_, x_) g = jt.grad(y, x) assert np.allclose(g.data, np.ones([5, 5])) logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)") assert len(logs) == 1, len(logs)
def test_multi_grads_none(self): class MyFunc(Function): def execute(self, x, y): self.x = x self.y = y return x*y def grad(self, grad): return (grad-2) * self.y, None a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() c = func(a, b) da, db = jt.grad(c, [a, b]) assert da.data == -4 assert db.data == 0
def test_multi_grads_multi_out(self): class MyFunc(Function): def execute(self, x, y): self.x = x self.y = y return x*y, x/y def grad(self, grad0, grad1): return grad0 * self.y, grad1 * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc() c,d = func(a, b) da, db = jt.grad(c+d*3, [a, b]) assert da.data == 4 assert db.data == 9
def step(self, loss): ps = self.parameters gs = jt.grad(loss, ps) for p, g, v in zip(ps, gs, self.values): dp = p * self.weight_decay + g v.assign(self.momentum * v + dp * (1 - self.dampening)) if self.nesterov: p -= (dp + self.momentum * v) * self.lr else: p -= v * self.lr # detach with the prev graph to reduce memory consumption p.detach_inplace() # sync all no grad parameters, such as # moving_mean and moving_var in batch_norm # sync such parameters to reduce memory consumption jt.sync(self.no_grad_parameters)
def test_erfinv(self): from scipy import special y = np.linspace(-1.0, 1.0, num=10) x = special.erfinv(y) y2 = jt.array(y) x2 = jt.erfinv(y2) np.testing.assert_allclose(y.data, y2.data) y = np.linspace(-0.9, 0.9, num=10) x = special.erfinv(y) y2 = jt.array(y) x2 = jt.erfinv(y2) np.testing.assert_allclose(y.data, y2.data) d = jt.grad(x2, y2) _, (dn, ) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8) np.testing.assert_allclose(d.data, dn, atol=1e-6, rtol=1e-6)
def test_grad(self): ops = ["abs", "negative", "log", "exp", "sqrt"] a = [1.1, 2.2, 3.3, 4.4] for op in ops: if op == "abs": b = np.array(a + [ -1, ]) else: b = np.array(a) func = lambda x: eval(f"np.{op}(x[0]).sum()") x, (da, ) = ngrad(func, [b], 1e-8) ja = jt.array(b) jb = eval(f"jt.{op}(ja)") jda = jt.grad(jb, ja) assert (np.abs(jda.data - da) < 1e-5).all(), (jda.data, da, op)
def test(): class MyFunc(Function): def execute(self, x, z, y): self.x = x self.y = y return x*y, "test", x/y def grad(self, grad0, _, grad1): assert _ is None res = (grad0 * self.y, None, grad1 * self.x) return res a = jt.array(3.0) b = jt.array(4.0) c,_,d = MyFunc()(a, "a", b) g = jt.grad(c+d*3, [a, b]) jt.sync(g)
def test(): class MyFunc(Function): def execute(self, x, z, y): self.x = x.name("x") self.y = y.name("y") return x*y, "test", x/y def grad(self, grad0, _, grad1): assert _ is None res = (grad0 * self.y, None, grad1 * self.x) return res a = jt.array(3.0).name('a') b = jt.array(4.0).name('b') c,_,d = MyFunc()(a, "a", b) c.name('c'), d.name('d') g = jt.grad(c+d*3, [a, b])
def test_cuda(self): a = jt.random([100000]) b = jt.random([100000]) c = jt.code(a.shape, a.dtype, [a, b], cuda_header=''' namespace jittor { __global__ static void kernel1(@ARGS_DEF) { @PRECALC int i = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i=0; i<in0shape0; i++) @out(i) = @in0(i)*@in1(i); } __global__ static void kernel2(@ARGS_DEF) { @PRECALC int i = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i=0; i<in0shape0; i++) @out(i) = @dout(i)*@in1(i); } __global__ static void kernel3(@ARGS_DEF) { @PRECALC int i = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; for (int i=0; i<in0shape0; i++) @out(i) = @dout(i)*@in0(i); } } ''', cuda_src=''' kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS); ''', cuda_grad_src=[ ''' kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS); ''', ''' kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS); ''' ]) da, db = jt.grad(c, [a, b]) assert np.allclose(c.data, a.data * b.data), (c.data, a.data * b.data) assert np.allclose(da.data, b.data) assert np.allclose(db.data, a.data)
def test_conv_transpose_grad(self): N,H,W,C = 1,5,5,2 Kh, Kw, Kc = 3, 3, 2 x = jt.random([N,H,W,C]) w = jt.random([Kh,Kw,C,Kc]) y, yy = conv_transpose(x, w) mask = jt.random(y.shape) loss = (y*mask).sum() dx, dw = jt.grad(loss, [x, w]) jdx, jdw = jt.fetch_sync([dx, dw]) check_fused(len(x.shape)) nmask = mask.data _, (ndx, ndw) = ngrad(lambda args: \ (conv_transpose_naive(args[0], args[1])*nmask).sum(), [np.float64(x.data), np.float64(w.data)], 1e-7) assert np.allclose(ndx, jdx), (ndx, jdx, ndx-jdx) assert np.allclose(ndw, jdw), (ndw, jdw)
def test_grad(self): ops = ["+", "-", "*", "/", "**"] np.random.seed(3) a = np.random.rand(10) b = np.random.rand(10) c = np.random.rand(10) for op in ops: func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()") x, grads = ngrad(func, [a, b, c], 1e-8) ja = jt.array(a).name("ja") jb = jt.array(b).name("jb") jc = jt.array(c).name("jc") jx = eval(f"(ja{op}jb)*jc") jgrads = jt.grad(jx, [ja, jb, jc]) for jd, nd in zip(jgrads, grads): assert (np.abs(jd.data - nd) < 1e-4).all(), f"\n{jd.data}\n{nd}"
def test_stop_fuse2(self): with jt.profile_scope() as report: a = jt.float32(0).stop_fuse() c = jt.float32(0).stop_fuse() bs = [c] for i in range(2000): b = jt.float32(i) * 2 * c bs.append(b) a += b a = a * 2 dbs = jt.grad(a, bs) jt.sync(dbs + [a]) for a in report[1:]: assert len(a[0].split("opkey")) < 8