Esempio n. 1
0
        def check(xshape, wshape, stride=1, padding=0, dilation=1):
            with jt.log_capture_scope(
                    use_cuda=1,
                    enable_tuner=1,
                    log_v=1,
                    log_vprefix="op.cc=100,exe=1000") as raw_log:
                x = jt.random(xshape)
                w = jt.random(wshape)
                y = conv(x, w, stride, padding)
                mask = jt.random(y.shape)
                loss = mask * y
                dx, dw = jt.grad(loss, [x, w])
                jt.sync([y, loss, dx, dw])

            # fails when enable_tuner=1, something wrong with mkl_conv_backward_x maybe.
            with jt.flag_scope(use_cuda=0, enable_tuner=0):
                cy = conv(x, w, stride, padding)
                closs = mask * cy
                cdx, cdw = jt.grad(closs, [x, w])
                jt.sync([cy, closs, cdx, cdw])
            logs = find_log_with_re(raw_log,
                                    "(Jit op key (not )?found: cudnn_conv.*)")
            assert len(logs) == 3 and "oihw" in logs[0][0], logs
            assert np.allclose(y.data, cy.data)
            assert np.allclose(dx.data, cdx.data, 1e-2)
            assert np.allclose(dw.data, cdw.data, 1e-2)
Esempio n. 2
0
def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
    if nhwc:
        test_func = test_nhwc
    else:
        test_func = test_nchw
    if use_cuda == 1:
        op_name = "cudnn_conv"
    else:
        op_name = "mkl_conv"

    with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
        log_v=1, log_vprefix="op.cc=1000,exe=1000,conv_t=1000", compile_options={"test":244}
    ) as raw_log:
        x = jt.random(xshape)
        w = jt.random(wshape)
        y = test_func(x, w, stride, padding, dilation)
        loss = y.mean()
        dx, dw = jt.grad(loss, [x, w])
        jt.sync([y, loss, dx, dw])
    with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}):
        cy = test_func(x, w, stride, padding, dilation)
        closs = cy.mean()
        cdx, cdw = jt.grad(closs, [x, w])
        jt.sync([cy, closs, cdx, cdw])
    logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)")
    assert len(logs)==3 and "oihw" in logs[0][0], (logs)
    assert np.allclose(y.data, cy.data, 1e-3)
    assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data)
    assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(), np.abs(dx.data - cdx.data).max())
Esempio n. 3
0
def check_forward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
    if nhwc:
        test_func = test_nhwc
    else:
        test_func = test_nchw
    if use_cuda == 1:
        op_name = "cudnn_conv"
    else:
        op_name = "mkl_conv"
    with jt.log_capture_scope(use_cuda=use_cuda,
                              enable_tuner=1,
                              log_v=0,
                              log_vprefix="op.cc=100,conv_tuner=1000",
                              compile_options={"test": 266}) as raw_log:
        x = jt.random(xshape)
        w = jt.random(wshape)
        y = test_func(x, w, stride, padding, dilation)
        y.sync()
    with jt.flag_scope(use_cuda=0,
                       enable_tuner=0,
                       compile_options={"test": 255}):
        cy = test_func(x, w, stride, padding, dilation)
        cy.sync()
    logs = find_log_with_re(raw_log,
                            "(Jit op key (not )?found: " + op_name + ".*)")
    assert len(logs) == 1 and "oihw" in logs[0][0], logs
    assert np.allclose(y.data, cy.data)
Esempio n. 4
0
def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda,
                   nhwc):
    assert nhwc == 0
    test_func = test_nchw

    # only check cudnn
    with jt.log_capture_scope(
            use_cuda=use_cuda,
            enable_tuner=1,
            log_v=10,
            log_vprefix="op.cc=100,conv_tuner=1000") as raw_log:
        x = jt.random(xshape)
        w = jt.random(wshape)
        y = test_func(x, w, stride, padding, dilation, groups)
        y.sync()
        dx, dw = jt.grad(y, [x, w])
        jt.sync([y, dx, dw])
    with jt.flag_scope(use_cuda=0,
                       enable_tuner=0,
                       compile_options={"test": 233}):
        cy = test_func(x, w, stride, padding, dilation, groups)
        cdx, cdw = jt.grad(cy, [x, w])
        jt.sync([cy, cdx, cdw])

    logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)")
    assert len(logs) == 3
    assert np.allclose(y.data, cy.data)
    assert np.allclose(dw.data,
                       cdw.data, 1e-3), (dw.data, cdw.data,
                                         np.abs(dw.data - cdw.data).max())
    assert np.allclose(dx.data,
                       cdx.data, 1e-3), (dx.data, cdx.data,
                                         np.abs(dx.data - cdx.data).max())
Esempio n. 5
0
 def check(xshape, wshape, stride=1, padding=0, dilation=1):
     with jt.log_capture_scope(use_cuda=1, enable_tuner=1,
         log_v=0, log_vprefix="op.cc=100"
     ) as raw_log:
         x = jt.random(xshape)
         w = jt.random(wshape)
         y = conv_oihw(x, w, stride, padding, dilation)
         y.sync()
     with jt.flag_scope(use_cuda=0, enable_tuner=1):
         cy = conv_oihw(x, w, stride, padding, dilation)
         cy.sync()
     logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)")
     assert len(logs)==1 and "oihw" in logs[0][0], logs
     assert np.allclose(y.data, cy.data), np.abs(y.data-cy.data).max()
Esempio n. 6
0
    def test_all_reduce(self):
        with jt.log_capture_scope(enable_tuner=1,
                                  log_silent=1,
                                  log_v=1,
                                  log_vprefix="op.cc=100,exe=1000") as raw_log:
            x = jt.random([5, 5])
            y = x.mpi_all_reduce()
            assert np.allclose(y.data, (x * n).data)
            g = jt.grad(y, x)
            assert np.allclose(g.data, np.ones([5, 5]) * n)

        logs = find_log_with_re(
            raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
        assert len(logs) == 2, len(logs)
Esempio n. 7
0
 def test_reduce(self):
     with jt.log_capture_scope(enable_tuner=1,
                               log_silent=1,
                               log_v=1,
                               log_vprefix="op.cc=100,exe=1000") as raw_log:
         x = jt.random([5, 5])
         y = x.mpi_reduce(root=0)
         y_ = y.data
         x_ = (x * n).data
         if mpi.world_rank() == 0:
             assert np.allclose(y_, x_)
         g = jt.grad(y, x)
         assert np.allclose(g.data, np.ones([5, 5]))
     logs = find_log_with_re(raw_log,
                             "(Jit op key (not )?found: nccl_reduce.*)")
     assert len(logs) == 1, len(logs)
def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc):
    assert nhwc == 0
    test_func = test_nchw

    # only check cudnn
    with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1,
        log_v=10, log_vprefix="op.cc=100,conv_tuner=1000"
    ) as raw_log:
        x = jt.random(xshape)
        w = jt.random(wshape)
        y = test_func(x, w, stride, padding, dilation, groups)
        y.sync()
    with jt.flag_scope(use_cuda=0, enable_tuner=0):
        cy = test_func(x, w, stride, padding, dilation, groups)
        cy.sync()

    logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)")
    assert len(logs)==1
    assert np.allclose(y.data, cy.data)
Esempio n. 9
0
 def test_broadcast(self):
     with jt.log_capture_scope(enable_tuner=1,
                               log_silent=1,
                               log_v=1,
                               log_vprefix="op.cc=100,exe=1000") as raw_log:
         data = jt.random([5, 5])
         if mpi.world_rank() == 0:
             x = data
         else:
             x = jt.zeros([5, 5])
         y = x.mpi_broadcast(0)
         assert np.allclose(y.data, data.data)
         g = jt.grad(y.sum(), x)
         g_ = g.data
         if mpi.world_rank() == 0:
             assert np.allclose(g_, np.ones([5, 5]) * n)
     logs = find_log_with_re(raw_log,
                             "(Jit op key (not )?found: nccl_broadcast.*)")
     assert len(logs) == 1, len(logs)
 def check(data_shape, weights_shape, stride=1, dilation=1):
     N, C, H, W = data_shape
     i, o, h, w = weights_shape
     img = np.random.rand(N, C, H, W).astype("float32")
     weights = np.random.rand(i, o, h, w).astype("float32")
     m1 = jt.nn.ConvTranspose(i,
                              o,
                              h,
                              stride=stride,
                              dilation=dilation,
                              bias=False)
     m2 = torch.nn.ConvTranspose2d(i,
                                   o,
                                   h,
                                   stride=stride,
                                   dilation=dilation,
                                   bias=False)
     m1.weight.data = weights
     m2.weight.data = torch.Tensor(weights)
     x = jt.array(img)
     # out1 = m1(x)
     out1 = jt.nn.conv_transpose2d(x,
                                   m1.weight,
                                   stride=stride,
                                   dilation=dilation,
                                   bias=False)
     mask = jt.random(out1.shape)
     out1 = out1 * mask
     tx = torch.Tensor(img)
     tx.requires_grad = True
     out2 = m2(tx) * torch.Tensor(mask.data)
     with jt.log_capture_scope(
             log_silent=1,
             log_vprefix="var_re=0,conv=0,op.cc=100") as logs:
         assert np.allclose(out1.data, out2.data)
         dx, dw = jt.grad(out1, [x, m1.weight])
         jt.sync([dx, dw])
         out2.sum().backward()
         assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3)
         assert np.allclose(dx.data, tx.grad.numpy())
     assert len(find_log_with_re(logs, "conv")) == 3
Esempio n. 11
0
    def test_resnet(self):
        self.setup_seed(1)
        loss_list = []
        acc_list = []
        mnist_net = MnistNet()
        global prev
        prev = time.time()
        SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum,
                     self.weight_decay)
        self.train_loader.endless = True

        for data, target in self.train_loader:
            batch_id = self.train_loader.batch_id
            epoch_id = self.train_loader.epoch_id

            # train step
            with jt.log_capture_scope(
                    log_silent=1,
                    log_v=1,
                    log_vprefix="op.cc=100,exe=10",
            ) as logs:
                output = mnist_net(data)
                loss = nn.cross_entropy_loss(output, target)
                SGD.step(loss)

                def callback(epoch_id, batch_id, loss, output, target):
                    # print train info
                    global prev
                    pred = np.argmax(output, axis=1)
                    acc = np.mean(target == pred)
                    loss_list.append(loss[0])
                    acc_list.append(acc)
                    print(
                        'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
                        .format(epoch_id, batch_id, 600, 1. * batch_id / 6.0,
                                loss[0], acc,
                                time.time() - prev))
                    # prev = time.time()

                jt.fetch(epoch_id, batch_id, loss, output, target, callback)

            log_conv = find_log_with_re(
                logs, "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
            log_matmul = find_log_with_re(
                logs, "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
            if batch_id > 2:
                assert len(log_conv) == 59 and len(log_matmul) == 6, (
                    len(log_conv), len(log_matmul))

            mem_used = jt.flags.stat_allocator_total_alloc_byte \
                -jt.flags.stat_allocator_total_free_byte
            # assert mem_used < 4e9, mem_used
            # TODO: why bigger?
            assert mem_used < 5.6e9, mem_used
            # example log:
            # Train Epoch: 0 [0/100 (0%)]     Loss: 2.352903  Acc: 0.110000
            # Train Epoch: 0 [1/100 (1%)]     Loss: 2.840830  Acc: 0.080000
            # Train Epoch: 0 [2/100 (2%)]     Loss: 3.473594  Acc: 0.100000
            # Train Epoch: 0 [3/100 (3%)]     Loss: 3.131615  Acc: 0.200000
            # Train Epoch: 0 [4/100 (4%)]     Loss: 2.524094  Acc: 0.230000
            # Train Epoch: 0 [5/100 (5%)]     Loss: 7.780025  Acc: 0.080000
            # Train Epoch: 0 [6/100 (6%)]     Loss: 3.890721  Acc: 0.160000
            # Train Epoch: 0 [7/100 (7%)]     Loss: 6.370137  Acc: 0.140000
            # Train Epoch: 0 [8/100 (8%)]     Loss: 11.390827 Acc: 0.150000
            # Train Epoch: 0 [9/100 (9%)]     Loss: 21.598564 Acc: 0.080000
            # Train Epoch: 0 [10/100 (10%)]   Loss: 23.369165 Acc: 0.130000
            # Train Epoch: 0 [20/100 (20%)]   Loss: 4.804510  Acc: 0.100000
            # Train Epoch: 0 [30/100 (30%)]   Loss: 3.393924  Acc: 0.110000
            # Train Epoch: 0 [40/100 (40%)]   Loss: 2.286762  Acc: 0.130000
            # Train Epoch: 0 [50/100 (50%)]   Loss: 2.055014  Acc: 0.290000

            if jt.in_mpi:
                assert jt.core.number_of_lived_vars(
                ) < 8100, jt.core.number_of_lived_vars()
            else:
                assert jt.core.number_of_lived_vars(
                ) < 7000, jt.core.number_of_lived_vars()
            if self.train_loader.epoch_id >= 2:
                break

        jt.sync_all(True)
        assert np.mean(loss_list[-50:]) < 0.5
        assert np.mean(acc_list[-50:]) > 0.8