Ejemplo n.º 1
0
 def check():
     a = jt.random((5, 1))
     b = jt.numpy_code(
         a.shape,
         a.dtype,
         [a],
         forward_code,
         [backward_code],
     )
     assert numpy.allclose(b.data, (a + a).data)
     da = jt.grad(b, a)
     one = numpy.ones(a.shape)
     assert numpy.allclose(da.data, one * 2.0)
Ejemplo n.º 2
0
    def test_batchnorm_backward(self):
        mpi = jt.compile_extern.mpi
        data = np.random.rand(30, 3, 10, 10).astype("float32")
        global_x = jt.array(data)
        x = jt.array(data[mpi.world_rank() * 10:(mpi.world_rank() + 1) * 10,
                          ...])

        bn1 = nn.BatchNorm(3, sync=True)
        bn2 = FakeMpiBatchNorm(3)
        y1 = bn1(x)
        y2 = bn2(x, global_x)
        gs1 = jt.grad(y1, bn1.parameters())
        gs2 = jt.grad(y2, bn2.parameters())

        assert np.allclose(y1.data, y2.data,
                           atol=1e-5), (mpi.world_rank(), y1.data, y2.data,
                                        y1.data - y2.data)
        for i in range(len(gs1)):
            assert np.allclose(gs1[i].data, gs2[i].data,
                               rtol=1e-3), (mpi.world_rank(), gs1[i].data,
                                            gs2[i].data,
                                            gs1[i].data - gs2[i].data)
Ejemplo n.º 3
0
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = jt.array(
        np.random.random((real_samples.size(0), 1, 1, 1)).astype(np.float32))
    # Get random interpolation between real and fake samples
    interpolates = alpha * real_samples + ((1 - alpha) * fake_samples)
    d_interpolates, _ = D(interpolates)
    # Get gradient w.r.t. interpolates
    gradients = jt.grad(d_interpolates, interpolates)
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
    return gradient_penalty
Ejemplo n.º 4
0
    def test1(self):
        class MyFunc(Function):
            def execute(self, x):
                return x + 1

            def grad(self, grad):
                return grad - 2

        a = jt.ones(1)
        func = MyFunc()
        b = func(a)
        da = jt.grad(b, a)
        assert da.data == -1
Ejemplo n.º 5
0
 def test_broadcast(self):
     data = jt.random([5, 5])
     if mpi.world_rank() == 0:
         x = data
     else:
         x = jt.zeros([5, 5])
     y = x.mpi_broadcast(0)
     assert np.allclose(y.data, data.data)
     g = jt.grad(y, x)
     if mpi.world_rank() == 0:
         assert np.allclose(g.data, np.ones([5, 5]) * n)
     else:
         assert np.allclose(g.data, np.zeros([5, 5]))
Ejemplo n.º 6
0
    def test2(self):
        class MyFunc(Function):
            def execute(self, x):
                self.x = x
                return x+1

            def grad(self, grad):
                return (grad-2) * self.x
        a = jt.ones(1) * 10
        func = MyFunc()
        b = func(a)
        da = jt.grad(b, a)
        assert da.data == -10
Ejemplo n.º 7
0
        def check(xshape, wshape, stride=1, padding=0, dilation=1):
            with jt.log_capture_scope(use_cuda=1, enable_tuner=1,
                log_v=0, log_vprefix="op.cc=100"
            ) as raw_log:
                x = jt.random(xshape)
                w = jt.random(wshape)
                y = conv_oihw(x, w, stride, padding, dilation)
                mask = jt.random(y.shape)
                loss = mask * y
                dx, dw = jt.grad(loss, [x, w])
                jt.sync([y, loss, dx, dw])

            with jt.flag_scope(use_cuda=0, enable_tuner=0):
                cy = conv_oihw(x, w, stride, padding, dilation)
                closs = mask * cy
                cdx, cdw = jt.grad(closs, [x, w])
                jt.sync([cy, closs, cdx, cdw])
            logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)")
            assert len(logs)==3 and "oihw" in logs[0][0], logs
            assert np.allclose(y.data, cy.data)
            assert np.allclose(dx.data, cdx.data)
            assert np.allclose(dw.data, cdw.data)
Ejemplo n.º 8
0
 def test_cumprod_cpu(self):
     for i in range(1,6):
         for j in range(i):
             x = np.random.rand(*((10,)*i))
             x_jt = jt.array(x)
             y_jt = jt.cumprod(x_jt, j).sqr()
             g_jt = jt.grad(y_jt.sum(), x_jt)
             x_tc = Variable(torch.from_numpy(x), requires_grad=True)
             y_tc = torch.cumprod(x_tc, j)**2
             y_tc.sum().backward()
             g_tc = x_tc.grad
             assert np.allclose(y_jt.numpy(), y_tc.data)
             assert np.allclose(g_jt.numpy(), g_tc.data)
Ejemplo n.º 9
0
def adam(model, loss, lr=3e-4, betas=[0.9, 0.999], eps=1e-8):
    ps = jt.find_vars(model)
    gs = jt.grad(loss, ps)
    with jt.var_scope('_'.join([model, 'adam']), unique=True):
        adam_step = jt.make_var([1], init=jt.zeros)
        adam_step += 1
        for p,g in zip(ps,gs):
            m = jt.make_var(p.shape, init=jt.zeros)
            v = jt.make_var(p.shape, init=jt.zeros)
            
            m.assign(betas[0] * m + (1-betas[0]) * g)
            v.assign(betas[1] * v + (1-betas[1]) * g * g)
            step_size = lr * jt.sqrt(1-betas[1]**adam_step) / (1-betas[0] ** adam_step)
            p -= m * step_size / (jt.sqrt(v) + eps)
Ejemplo n.º 10
0
    def test_all_reduce(self):
        with jt.log_capture_scope(enable_tuner=1,
                                  log_silent=1,
                                  log_v=1,
                                  log_vprefix="op.cc=100,exe=1000") as raw_log:
            x = jt.random([5, 5])
            y = x.mpi_all_reduce()
            assert np.allclose(y.data, (x * n).data)
            g = jt.grad(y, x)
            assert np.allclose(g.data, np.ones([5, 5]) * n)

        logs = find_log_with_re(
            raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
        assert len(logs) == 2, len(logs)
Ejemplo n.º 11
0
    def test_grad_not_match_error(self):
        class MyFunc(Function):
            def execute(self, x, y):
                self.x = x
                self.y = y
                return x*y

            def grad(self, grad):
                return (grad-2) * self.x
        a = jt.array(3.0)
        b = jt.array(4.0)
        func = MyFunc()
        c = func(a, b)
        expect_error(lambda: jt.grad(c, [a, b]))
Ejemplo n.º 12
0
    def pre_step(self, loss):
        """ something should be done before step, such as calc gradients, mpi sync, and so on.

        Example::

            class MyOptimizer(Optimizer):
                def step(self, loss):
                    self.post_step(loss)
                    ...
        """
        # clean prev grads
        params = []
        params_has_grad = []
        for pg in self.param_groups:
            for p in pg['params']:
                params.append(p)
                if not p.is_stop_grad():
                    params_has_grad.append(p)

        # get gradient
        grads = jt.grad(loss, params_has_grad)

        # sync grads and model if in mpi
        if jt.in_mpi:
            for g in grads:
                g.assign(g.mpi_all_reduce("mean"))
            if self.n_step % self.param_sync_iter == 0:
                for p in params:
                    p.assign(p.mpi_broadcast())
        self.n_step += 1

        # set up grads in param_groups
        pid = 0
        for pg in self.param_groups:
            if "grads" not in pg:
                pg["grads"] = [
                    jt.zeros_like(p).stop_grad().stop_fuse()
                    for p in pg['params']
                ]
            pg_grads = pg["grads"]
            for i, p in enumerate(pg['params']):
                if not p.is_stop_grad():
                    # accumulate grad and stop grad of grad
                    g = grads[pid].stop_grad()
                    if not self.__zero_grad:
                        g = g + pg_grads[i]
                    pg_grads[i].update(g)
                    pid += 1
        self.__zero_grad = False
Ejemplo n.º 13
0
    def test_pinv(self):
        def check_pinv(a):
            w = anp.linalg.pinv(a)
            return w

        for i in range(50):
            x = jt.random((2, 2, 4, 3))
            c_a = x.data
            mx = jt.linalg.pinv(x)
            tx = check_pinv(c_a)
            np.allclose(mx.data, tx)
            jx = jt.grad(mx, x)
            check_grad = jacobian(check_pinv)
            gx = check_grad(c_a)
            np.allclose(gx, jx.data)
Ejemplo n.º 14
0
 def test_matmul_grad(self):
     np.random.seed(0)
     for i in range(10):
         a = np.random.rand(2, 3).astype("float32")
         b = np.random.rand(3, 4).astype("float32")
         out, (da,
               db) = ngrad(lambda vars: np.matmul(vars[0], vars[1]).sum(),
                           [a, b], 1e-1)
         ja = jt.array(a)
         jb = jt.array(b)
         jc = ja.matmul(jb)
         jda, jdb = jt.grad(jc, [ja, jb])
         assert ((da - jda.data) < 1e-5).all(), (da, jda.data,
                                                 da - jda.data)
         assert ((db - jdb.data) < 1e-5).all(), (db - jdb.data)
Ejemplo n.º 15
0
        def check(xshape,
                  wshape,
                  stride=(1, 1, 1),
                  padding=(0, 0, 0),
                  dilation=(1, 1, 1),
                  group=1):
            with jt.flag_scope(use_cuda=1):
                x = jt.random(xshape)
                w = jt.random(wshape)

            y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group,
                                        dilation)

            with jt.flag_scope(use_cuda=1):
                # y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group)
                y = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0,
                                           group, dilation)
                masky = jt.rand_like(y)
                dx, dw = jt.grad(masky * y, [x, w])

            dx2, dw2 = jt.grad(masky * y2, [x, w])
            np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4)
            np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4)
            np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
Ejemplo n.º 16
0
def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
    if nhwc:
        test_func = test_nhwc
    else:
        test_func = test_nchw
    if use_cuda == 1:
        op_name = "cudnn_conv"
    else:
        op_name = "mkl_conv"

    with jt.log_capture_scope(use_cuda=use_cuda,
                              enable_tuner=1,
                              log_v=1,
                              log_vprefix="op.cc=100,exe=1000,conv_t=1000",
                              compile_options={"test": 244}) as raw_log:
        x = jt.random(xshape)
        w = jt.random(wshape)
        y = test_func(x, w, stride, padding, dilation)
        loss = y.mean()
        dx, dw = jt.grad(loss, [x, w])
        jt.sync([y, loss, dx, dw])
    with jt.flag_scope(use_cuda=0,
                       enable_tuner=0,
                       compile_options={"test": 233}):
        cy = test_func(x, w, stride, padding, dilation)
        closs = cy.mean()
        cdx, cdw = jt.grad(closs, [x, w])
        jt.sync([cy, closs, cdx, cdw])
    logs = find_log_with_re(raw_log,
                            "(Jit op key (not )?found: " + op_name + ".*)")
    assert len(logs) == 3 and "oihw" in logs[0][0], (logs)
    assert np.allclose(y.data, cy.data, 1e-3)
    assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data)
    assert np.allclose(dx.data, cdx.data,
                       1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(),
                               np.abs(dx.data - cdx.data).max())
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc):
    assert nhwc == 0
    test_func = test_nchw

    # only check cudnn
    with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
        log_v=10, log_vprefix="op.cc=100,conv_tuner=1000"
    ) as raw_log:
        x = jt.random(xshape)
        w = jt.random(wshape)
        y = test_func(x, w, stride, padding, dilation, groups)
        y.sync()
        dx, dw = jt.grad(y, [x, w])
        jt.sync([y, dx, dw])
    with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}):
        cy = test_func(x, w, stride, padding, dilation, groups)
        cdx, cdw = jt.grad(cy, [x, w])
        jt.sync([cy, cdx, cdw])

    logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)")
    assert len(logs)==3
    assert np.allclose(y.data, cy.data)
    assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max())
    assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max())
Ejemplo n.º 18
0
 def test_longest_dis_fuse(self):
     x = jt.array(np.random.rand(1, 3, 224, 224).astype(np.float32))
     loss = jt.sum(resnet_fake(x))
     ps = jt.find_vars('resnet_fake')
     gs = jt.grad(loss, ps)
     jt.sync(gs)
     # assert not alloc big tensor
     g = jt.dump_all_graphs()
     for s in g.nodes_info:
         if not s.startswith("Var"):
             continue
         shape = s.split("[")[1].split("]")[0].split(",")
         ptr = s.split("(")[1].split(")")[0].split(",")[-1]
         if ptr != '0':
             assert len(shape) <= 5, s
Ejemplo n.º 19
0
 def test_reduce(self):
     with jt.log_capture_scope(enable_tuner=1,
                               log_silent=1,
                               log_v=1,
                               log_vprefix="op.cc=100,exe=1000") as raw_log:
         x = jt.random([5, 5])
         y = x.mpi_reduce(root=0)
         y_ = y.data
         x_ = (x * n).data
         if mpi.world_rank() == 0:
             assert np.allclose(y_, x_)
         g = jt.grad(y, x)
         assert np.allclose(g.data, np.ones([5, 5]))
     logs = find_log_with_re(raw_log,
                             "(Jit op key (not )?found: nccl_reduce.*)")
     assert len(logs) == 1, len(logs)
Ejemplo n.º 20
0
    def test_multi_grads_none(self):
        class MyFunc(Function):
            def execute(self, x, y):
                self.x = x
                self.y = y
                return x*y

            def grad(self, grad):
                return (grad-2) * self.y, None
        a = jt.array(3.0)
        b = jt.array(4.0)
        func = MyFunc()
        c = func(a, b)
        da, db = jt.grad(c, [a, b])
        assert da.data == -4
        assert db.data == 0
Ejemplo n.º 21
0
    def test_multi_grads_multi_out(self):
        class MyFunc(Function):
            def execute(self, x, y):
                self.x = x
                self.y = y
                return x*y, x/y

            def grad(self, grad0, grad1):
                return grad0 * self.y, grad1 * self.x
        a = jt.array(3.0)
        b = jt.array(4.0)
        func = MyFunc()
        c,d = func(a, b)
        da, db = jt.grad(c+d*3, [a, b])
        assert da.data == 4
        assert db.data == 9
Ejemplo n.º 22
0
 def step(self, loss):
     ps = self.parameters
     gs = jt.grad(loss, ps)
     for p, g, v in zip(ps, gs, self.values):
         dp = p * self.weight_decay + g
         v.assign(self.momentum * v + dp * (1 - self.dampening))
         if self.nesterov:
             p -= (dp + self.momentum * v) * self.lr
         else:
             p -= v * self.lr
         # detach with the prev graph to reduce memory consumption
         p.detach_inplace()
     # sync all no grad parameters, such as
     # moving_mean and moving_var in batch_norm
     # sync such parameters to reduce memory consumption
     jt.sync(self.no_grad_parameters)
Ejemplo n.º 23
0
    def test_erfinv(self):
        from scipy import special
        y = np.linspace(-1.0, 1.0, num=10)
        x = special.erfinv(y)
        y2 = jt.array(y)
        x2 = jt.erfinv(y2)
        np.testing.assert_allclose(y.data, y2.data)

        y = np.linspace(-0.9, 0.9, num=10)
        x = special.erfinv(y)
        y2 = jt.array(y)
        x2 = jt.erfinv(y2)
        np.testing.assert_allclose(y.data, y2.data)
        d = jt.grad(x2, y2)
        _, (dn, ) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8)
        np.testing.assert_allclose(d.data, dn, atol=1e-6, rtol=1e-6)
Ejemplo n.º 24
0
 def test_grad(self):
     ops = ["abs", "negative", "log", "exp", "sqrt"]
     a = [1.1, 2.2, 3.3, 4.4]
     for op in ops:
         if op == "abs":
             b = np.array(a + [
                 -1,
             ])
         else:
             b = np.array(a)
         func = lambda x: eval(f"np.{op}(x[0]).sum()")
         x, (da, ) = ngrad(func, [b], 1e-8)
         ja = jt.array(b)
         jb = eval(f"jt.{op}(ja)")
         jda = jt.grad(jb, ja)
         assert (np.abs(jda.data - da) < 1e-5).all(), (jda.data, da, op)
Ejemplo n.º 25
0
        def test():
            class MyFunc(Function):
                def execute(self, x, z, y):
                    self.x = x
                    self.y = y
                    return x*y, "test", x/y

                def grad(self, grad0, _, grad1):
                    assert _ is None
                    res = (grad0 * self.y, None, grad1 * self.x)
                    return res
            a = jt.array(3.0)
            b = jt.array(4.0)
            c,_,d = MyFunc()(a, "a", b)
            g = jt.grad(c+d*3, [a, b])
            jt.sync(g)
Ejemplo n.º 26
0
        def test():
            class MyFunc(Function):
                def execute(self, x, z, y):
                    self.x = x.name("x")
                    self.y = y.name("y")
                    return x*y, "test", x/y

                def grad(self, grad0, _, grad1):
                    assert _ is None
                    res = (grad0 * self.y, None, grad1 * self.x)
                    return res
            a = jt.array(3.0).name('a')
            b = jt.array(4.0).name('b')
            c,_,d = MyFunc()(a, "a", b)
            c.name('c'), d.name('d')
            g = jt.grad(c+d*3, [a, b])
Ejemplo n.º 27
0
    def test_cuda(self):
        a = jt.random([100000])
        b = jt.random([100000])
        c = jt.code(a.shape,
                    a.dtype, [a, b],
                    cuda_header='''
            namespace jittor {
            __global__ static void kernel1(@ARGS_DEF) {
                @PRECALC
                int i = threadIdx.x + blockIdx.x * blockDim.x;
                int stride = blockDim.x * gridDim.x;
                for (int i=0; i<in0shape0; i++)
                    @out(i) = @in0(i)*@in1(i);
            }

            __global__ static void kernel2(@ARGS_DEF) {
                @PRECALC
                int i = threadIdx.x + blockIdx.x * blockDim.x;
                int stride = blockDim.x * gridDim.x;
                for (int i=0; i<in0shape0; i++)
                    @out(i) = @dout(i)*@in1(i);
            }

            __global__ static void kernel3(@ARGS_DEF) {
                @PRECALC
                int i = threadIdx.x + blockIdx.x * blockDim.x;
                int stride = blockDim.x * gridDim.x;
                for (int i=0; i<in0shape0; i++)
                    @out(i) = @dout(i)*@in0(i);
            }

            }
            ''',
                    cuda_src='''
                kernel1<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
            ''',
                    cuda_grad_src=[
                        '''
                kernel2<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
            ''', '''
                kernel3<<<(in0shape0-1)/1024+1, 1024>>>(@ARGS);
            '''
                    ])
        da, db = jt.grad(c, [a, b])
        assert np.allclose(c.data, a.data * b.data), (c.data, a.data * b.data)
        assert np.allclose(da.data, b.data)
        assert np.allclose(db.data, a.data)
Ejemplo n.º 28
0
 def test_conv_transpose_grad(self):
     N,H,W,C = 1,5,5,2
     Kh, Kw, Kc = 3, 3, 2
     x = jt.random([N,H,W,C])
     w = jt.random([Kh,Kw,C,Kc])
     y, yy = conv_transpose(x, w)
     mask = jt.random(y.shape)
     loss = (y*mask).sum()
     dx, dw = jt.grad(loss, [x, w])
     jdx, jdw = jt.fetch_sync([dx, dw])
     check_fused(len(x.shape))
     nmask = mask.data
     _, (ndx, ndw) = ngrad(lambda args: \
         (conv_transpose_naive(args[0], args[1])*nmask).sum(),
         [np.float64(x.data), np.float64(w.data)], 1e-7)
     assert np.allclose(ndx, jdx), (ndx, jdx, ndx-jdx)
     assert np.allclose(ndw, jdw), (ndw, jdw)
Ejemplo n.º 29
0
 def test_grad(self):
     ops = ["+", "-", "*", "/", "**"]
     np.random.seed(3)
     a = np.random.rand(10)
     b = np.random.rand(10)
     c = np.random.rand(10)
     for op in ops:
         func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()")
         x, grads = ngrad(func, [a, b, c], 1e-8)
         ja = jt.array(a).name("ja")
         jb = jt.array(b).name("jb")
         jc = jt.array(c).name("jc")
         jx = eval(f"(ja{op}jb)*jc")
         jgrads = jt.grad(jx, [ja, jb, jc])
         for jd, nd in zip(jgrads, grads):
             assert (np.abs(jd.data - nd) <
                     1e-4).all(), f"\n{jd.data}\n{nd}"
Ejemplo n.º 30
0
    def test_stop_fuse2(self):
        with jt.profile_scope() as report:
            a = jt.float32(0).stop_fuse()
            c = jt.float32(0).stop_fuse()
            bs = [c]
            for i in range(2000):
                b = jt.float32(i) * 2 * c
                bs.append(b)
                a += b

            a = a * 2

            dbs = jt.grad(a, bs)
            jt.sync(dbs + [a])

        for a in report[1:]:
            assert len(a[0].split("opkey")) < 8