예제 #1
0
def test_profile_tune_pause_resume():
    enable_profiler()
    profiler.pause()
    # "test_profile_task" should *not* show up in tuning analysis
    test_profile_task()
    profiler.resume()
    # "test_profile_event" should show up in tuning analysis
    test_profile_event()
    profiler.pause()
    profiler.set_state('stop')
예제 #2
0
def test_profile_tune_pause_resume():
    enable_profiler('test_profile_tune_pause_resume.json')
    profiler.pause()
    # "test_profile_task" should *not* show up in tuning analysis
    test_profile_task()
    profiler.resume()
    # "test_profile_event" should show up in tuning analysis
    test_profile_event()
    profiler.pause()
    profiler.set_state('stop')
예제 #3
0
def check_ln_speed(nbatch, nchannel, eps, nrepeat):
    fwd_check_eps = 1E-1 if dtype == np.float16 else 1E-4
    bwd_check_eps = 1E-1 if dtype == np.float16 else 1E-3
    B, C = nbatch, nchannel
    for _ in range(2):
        in_data = mx.nd.random.normal(shape=(B, C), ctx=ctx, dtype=dtype)
        out_data = in_data * in_data
        npy_out_data = out_data.asnumpy()
    mx.nd.waitall()
    fwd_time = 0
    bwd_time = 0
    if args.profile:
        profiler.set_state('run')
        profiler.pause()
    for i in range(nrepeat + 1):
        in_data = mx.nd.random.normal(shape=(B, C), ctx=ctx, dtype=dtype)
        ograd = mx.nd.random.normal(shape=(B, C), ctx=ctx, dtype=dtype)
        nd_gamma = mx.nd.ones(shape=(C, ), ctx=ctx, dtype=dtype)
        nd_beta = mx.nd.zeros(shape=(C, ), ctx=ctx, dtype=dtype)
        npy_in_data = in_data.asnumpy().astype(np.float64)
        gt_out = (npy_in_data - npy_in_data.mean(axis=-1, keepdims=True)) \
                 / np.sqrt(npy_in_data.var(axis=-1, keepdims=True) + eps)
        gt_in_data_grad, gt_gamma_grad, gt_beta_grad = \
            npy_ln_grad(npy_in_data, ograd.asnumpy().astype(np.float64), eps, nd_gamma.asnumpy().astype(np.float64))
        mx.nd.waitall()
        in_data.attach_grad()
        nd_gamma.attach_grad()
        nd_beta.attach_grad()
        _no_use = nd_gamma.asnumpy()
        _no_use = nd_beta.asnumpy()
        mx.nd.waitall()
        # Profile Forward + Backward
        with mx.autograd.record():
            mx.nd.waitall()
            if args.profile and i > 0:
                profiler.resume()
            start = time.time()
            out_data, mean_val, std_val = mx.nd.LayerNorm(in_data,
                                                          gamma=nd_gamma,
                                                          beta=nd_beta,
                                                          axis=-1,
                                                          eps=eps,
                                                          output_mean_var=True)
            out_data.wait_to_read()
            if i > 0:
                fwd_time += time.time() - start
            mx.nd.waitall()
            start = time.time()
            out_data.backward(ograd)
            mx.nd.waitall()
            if args.profile and i > 0:
                profiler.pause()
            if i > 0:
                bwd_time += time.time() - start
        mx_in_data_grad = in_data.grad.asnumpy()
        mx_gamma_grad = nd_gamma.grad.asnumpy()
        mx_beta_grad = nd_beta.grad.asnumpy()
        npt.assert_allclose(mean_val.asnumpy()[:, 0],
                            npy_in_data.mean(axis=-1).astype(dtype),
                            fwd_check_eps, fwd_check_eps)
        npt.assert_allclose(
            std_val.asnumpy()[:, 0],
            np.sqrt(npy_in_data.var(axis=-1) + eps).astype(dtype),
            fwd_check_eps, fwd_check_eps)
        npt.assert_allclose(out_data.asnumpy(), gt_out.astype(dtype),
                            fwd_check_eps, fwd_check_eps)
        for i in range(B):
            npt.assert_allclose(mx_in_data_grad[i, :],
                                gt_in_data_grad[i, :].astype(dtype),
                                fwd_check_eps, fwd_check_eps)
        npt.assert_allclose(mx_gamma_grad, gt_gamma_grad.astype(dtype),
                            bwd_check_eps, bwd_check_eps)
        npt.assert_allclose(mx_beta_grad, gt_beta_grad.astype(dtype),
                            bwd_check_eps, bwd_check_eps)
    if args.profile:
        profiler.set_state('stop')
    return fwd_time / nrepeat * 1000000, bwd_time / nrepeat * 1000000
예제 #4
0
                    filename="profile_mx_mnist.json")

# Train model
for epoch in range(args.epochs):
    tic = time.time()
    train_data.reset()
    metric.reset()
    for nbatch, batch in enumerate(train_data, start=1):
        # Start and pause profiling
        if nbatch == 100:
            if epoch == 0:
                profiler.set_state('run')
            else:
                profiler.resume()
        elif nbatch == 200:
            profiler.pause()

        data = batch.data[0].as_in_context(context)
        label = batch.label[0].as_in_context(context)
        with autograd.record():
            output = model(data.astype(args.dtype, copy=False))
            loss = loss_fn(output, label)
        loss.backward()
        trainer.step(args.batch_size)
        metric.update([label], [output])

        if nbatch % 100 == 0:
            name, acc = metric.get()
            logging.info('[Epoch %d Batch %d] Training: %s=%f' %
                         (epoch, nbatch, name, acc))