Пример #1
0
def update_model(model_path):
    """
    Update the dumped model with test cases for new reference values.

    The model with pre-trained weights is trained for one iter with the test data attached.
    The loss and updated net state dict is dumped.

    .. code-block:: python

        from test_correctness import update_model
        update_model('mnist_model_with_test.mge') # for gpu
        update_model('mnist_model_with_test_cpu.mge') # for cpu

    """
    net = MnistNet(has_bn=True)
    checkpoint = mge.load(model_path)
    net.load_state_dict(checkpoint["net_init"])
    lr = checkpoint["sgd_lr"]
    opt = SGD(net.parameters(), lr=lr)

    data = tensor(dtype=np.float32)
    label = tensor(dtype=np.int32)
    data.set_value(checkpoint["data"])
    label.set_value(checkpoint["label"])

    opt.zero_grad()
    loss = train(data, label, net=net, opt=opt)
    opt.step()

    xpu_name = get_xpu_name()

    checkpoint.update(
        {"net_updated": net.state_dict(), "loss": loss.numpy(), "xpu": xpu_name}
    )
    mge.save(checkpoint, model_path)
Пример #2
0
def test_sgd_simple():
    data, data_shape, label, label_shape = get_input()
    mlp = MLP()
    opt = SGD(mlp.parameters(), lr=0.01, weight_decay=0.1)
    for idx in range(3):
        data.set_value(np.random.random(data_shape).astype(np.float32))
        label.set_value(np.random.randint(0, 10, label_shape))
        pred = mlp(data)
        loss = F.square_loss(pred, label.reshape(-1, 1))
        if idx % 2:
            opt.zero_grad()
        else:
            mlp.zero_grad()
        opt.backward(loss)
        grads = TensorDict()
        orig_params = TensorDict()
        for param in mlp.parameters():
            grad = F.grad(loss, param, use_virtual_grad=False)
            assertTensorClose(grad.numpy(), param.grad.numpy())
            grads[param] = np.copy(grad.numpy())
            orig_params[param] = np.copy(param.numpy())
        opt.step()
        for param in mlp.parameters():
            assertTensorClose(param.numpy(),
                              orig_params[param] * 0.999 - grads[param] * 0.01)
Пример #3
0
def train(dataloader, args):
    writer = SummaryWriter("runs")
    net = Net()
    net.train()
    optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    gm = GradManager().attach(net.parameters())

    epoch_length = len(dataloader)
    for epoch in range(args.epoch):
        for step, (batch_data, batch_label) in enumerate(dataloader):
            batch_label = batch_label.astype(np.int32)
            data, label = mge.tensor(batch_data), mge.tensor(batch_label)
            with gm:
                pred = net(data)
                loss = F.loss.cross_entropy(pred, label)
                gm.backward(loss)
            optimizer.step().clear_grad()

            if step % 50 == 0:
                print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss)))  # noqa
            writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
        if (epoch + 1) % 5 == 0:
            mge.save(
                net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
            )  # noqa
Пример #4
0
def test_sgd_momentum_static():
    _, data_shape, _, label_shape = get_input()
    mlp = MLP()
    opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9)

    @trace
    def f(data, label):
        pred = mlp(data)
        loss = F.square_loss(pred, label.reshape(-1, 1))
        opt.zero_grad()
        opt.backward(loss)

    slots = TensorDict()
    for param in mlp.parameters():
        slots[param] = np.zeros(param.shape).astype(np.float32)
    for _ in range(3):
        f(
            np.random.random(data_shape).astype(np.float32),
            np.random.randint(0, 10, label_shape).astype(np.int32),
        )
        orig_params = TensorDict()
        grads = TensorDict()
        for param in mlp.parameters():
            orig_params[param] = np.copy(param.numpy())
            grads[param] = np.copy(param.grad.numpy())
        opt.step()
        for param in mlp.parameters():
            slot = slots[param]
            orig_param = orig_params[param]
            slot *= 0.9
            slot -= param.grad.numpy() * 0.01
            assertTensorClose(param.numpy(), orig_param + slot)
Пример #5
0
def test_update_lr():
    data, data_shape, label, label_shape = get_input()
    mlp = MLP()
    opt = SGD(mlp.parameters(), lr=0.01)
    pred = mlp(data)
    loss = F.square_loss(pred, label.reshape(-1, 1))
    opt.zero_grad()
    opt.backward(loss)
    opt.step()
    for group in opt.param_groups:
        group["lr"] += 0.02
    for _ in range(3):
        data.set_value(np.random.random(data_shape).astype(np.float32))
        label.set_value(np.random.randint(0, 10, label_shape))
        pred = mlp(data)
        loss = F.square_loss(pred, label.reshape(-1, 1))
        opt.zero_grad()
        opt.backward(loss)
        for param in mlp.parameters():
            grad = F.grad(loss, param, use_virtual_grad=False)
            assertTensorClose(grad.numpy(), param.grad.numpy())
        orig_params = []
        for param in mlp.parameters():
            orig_params.append(np.copy(param.numpy()))
        opt.step()
        for param, orig_param in zip(mlp.parameters(), orig_params):
            assertTensorClose(param.numpy(),
                              orig_param - param.grad.numpy() * 0.03)
Пример #6
0
def test_compile_multi_times_static():
    return  # XXX: rewrite or remove this test
    with Graph() as cg:
        cg.set_option("eager_evaluation", False)
        data = Input("data", shape=(2, 28))
        label = Input("label", shape=(2, ), dtype=np.int32)

        mlp = MLP()
        opt = SGD(mlp.parameters(requires_grad=True), lr=0.01)

        pred0 = mlp(data)
        pred = F.softmax(pred0)
        loss = F.square_loss(pred, label.reshape(2, 1))
        opt.zero_grad()
        grads = opt.backward(loss)
        opt.step()

        f0 = compile(pred, None)
        f1 = compile([pred, loss], grads, copy=True)

        data = np.random.random((2, 28)).astype(np.float32)
        label = np.random.randint(0, 10, (2, )).astype(np.float32)
        out0 = f0(data=data)
        out1 = f1(data=data, label=label)
        assertTensorClose(out0[0], out1[0])

        _ = compile([pred, loss], grads, copy=False)
        with pytest.raises(mgb.MegBrainError):
            f0(data=data)
Пример #7
0
def test_correctness_parampack():
    net1 = XORNet()
    net2 = XORNet()
    params1 = net1.parameters()
    params2 = net2.parameters()
    for param1, param2 in zip(params1, params2):
        param1.set_value(param2.numpy())
    net1 = ParamPack(net1,
                     nr_ignore_first=0,
                     max_size_per_group=10,
                     max_nr_params_per_group=100)
    opt1 = SGD(net1.parameters(requires_grad=True),
               lr=0.01,
               momentum=0.9,
               weight_decay=5e-4)

    opt2 = SGD(net2.parameters(requires_grad=True),
               lr=0.01,
               momentum=0.9,
               weight_decay=5e-4)

    @trace(symbolic=False)
    def train1(data, label):
        pred = net1(data)
        opt1.zero_grad()
        loss = cross_entropy_with_softmax(pred, label)
        opt1.backward(loss)
        return loss

    @trace(symbolic=False)
    def train2(data, label):
        pred = net2(data)
        opt2.zero_grad()
        loss = cross_entropy_with_softmax(pred, label)
        opt2.backward(loss)
        return loss

    @trace(symbolic=False)
    def infer1(data):
        return net1(data)

    @trace(symbolic=False)
    def infer2(data):
        return net2(data)

    train_dataset = minibatch_generator()

    for data, label in itertools.islice(train_dataset, 2000):
        train1(data, label)
        opt1.step()

        train2(data, label)
        opt2.step()

    data, _ = next(train_dataset)
    pred1 = infer1(data).numpy()
    pred2 = infer2(data).numpy()
    assert np.allclose(pred1, pred2)
Пример #8
0
def run_train(
    model_path,
    use_jit,
    use_symbolic,
    sublinear_memory_config=None,
    max_err=None,
    use_adaptive_pooling=False,
):

    """
    Load the model with test cases and run the training for one iter.
    The loss and updated weights are compared with reference value to verify the correctness.

    Dump a new file with updated result by calling update_model
    if you think the test fails due to numerical rounding errors instead of bugs.
    Please think twice before you do so.

    """
    net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling)
    checkpoint = mge.load(model_path)
    net.load_state_dict(checkpoint["net_init"])
    lr = checkpoint["sgd_lr"]
    opt = SGD(net.parameters(), lr=lr)
    gm = ad.GradManager().attach(net.parameters())

    data = Tensor(checkpoint["data"], dtype=np.float32)
    label = Tensor(checkpoint["label"], dtype=np.int32)

    if max_err is None:
        max_err = 1e-5

    train_func = train
    if use_jit:
        train_func = jit.trace(
            train_func,
            symbolic=use_symbolic,
            sublinear_memory_config=sublinear_memory_config,
        )

    opt.clear_grad()
    loss = train_func(data, label, net, opt, gm)
    opt.step()

    np.testing.assert_allclose(loss.numpy(), checkpoint["loss"], atol=max_err)

    for param, param_ref in zip(
        net.state_dict().items(), checkpoint["net_updated"].items()
    ):
        assert param[0] == param_ref[0]
        if "bn" in param[0]:
            ref = param_ref[1].reshape(param[1].shape)
            np.testing.assert_allclose(param[1], ref, atol=max_err)
        else:
            np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
Пример #9
0
def worker(master_ip, master_port, world_size, rank, dev, trace):
    import megengine.distributed as dist
    import megengine.functional as F
    from megengine import is_cuda_available
    from megengine import jit
    from megengine.module import Linear, Module
    from megengine.optimizer import SGD

    if not is_cuda_available():
        return

    class MLP(Module):
        def __init__(self):
            super().__init__()
            self.fc0 = Linear(3 * 224 * 224, 500)
            self.fc1 = Linear(500, 10)

        def forward(self, x):
            x = self.fc0(x)
            x = F.relu(x)
            x = self.fc1(x)
            return x

    dist.init_process_group(master_ip=master_ip,
                            master_port=3456,
                            world_size=world_size,
                            rank=rank,
                            dev=dev)
    net = MLP()

    opt = SGD(net.parameters(requires_grad=True), lr=0.02)

    data = np.random.random((64, 3 * 224 * 224)).astype(np.float32)
    label = np.random.randint(0, 10, size=(64, )).astype(np.int32)

    jit.trace.enabled = trace

    @jit.trace()
    def train_func(data, label):
        pred = net(data)
        loss = F.cross_entropy_with_softmax(pred, label)
        opt.backward(loss)
        return loss

    for i in range(5):
        opt.zero_grad()
        loss = train_func(data, label)
        opt.step()
def test_training_converge_with_swap_and_drop():
    _set_swap_flag(True)
    _set_drop_flag(True)
    old_buffer_length = get_option("buffer_length")
    set_option("buffer_length", 0)
    net = XORNet()
    opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    gm = ad.GradManager().attach(net.parameters())

    def train(data, label):
        with gm:
            pred = net(data)
            loss = F.nn.cross_entropy(pred, label)
            gm.backward(loss)
        return loss

    def infer(data):
        return net(data)

    train_dataset = minibatch_generator()
    losses = []

    for data, label in itertools.islice(train_dataset, 2000):
        data = Tensor(data, dtype=np.float32)
        label = Tensor(label, dtype=np.int32)
        opt.clear_grad()
        loss = train(data, label)
        opt.step()
        losses.append(loss.numpy())

    assert np.mean(
        losses[-100:]) < 0.1, "Final training Loss must be low enough"

    ngrid = 10
    x = np.linspace(-1.0, 1.0, ngrid)
    xx, yy = np.meshgrid(x, x)
    xx = xx.reshape((ngrid * ngrid, 1))
    yy = yy.reshape((ngrid * ngrid, 1))
    data = np.concatenate((xx, yy), axis=1).astype(np.float32)

    pred = infer(Tensor(data)).numpy()
    precision = calculate_precision(data, pred)
    assert precision == 1.0, "Test precision must be high enough, get {}".format(
        precision)

    _set_swap_flag(False)
    _set_drop_flag(False)
    set_option("buffer_length", old_buffer_length)
Пример #11
0
def run_test(
    model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
):

    """
    Load the model with test cases and run the training for one iter.
    The loss and updated weights are compared with reference value to verify the correctness.

    Dump a new file with updated result by calling update_model
    if you think the test fails due to numerical rounding errors instead of bugs.
    Please think twice before you do so.

    """
    net = MnistNet(has_bn=True)
    checkpoint = mge.load(model_path)
    net.load_state_dict(checkpoint["net_init"])
    lr = checkpoint["sgd_lr"]
    opt = SGD(net.parameters(), lr=lr)

    data = tensor(dtype=np.float32)
    label = tensor(dtype=np.int32)
    data.set_value(checkpoint["data"])
    label.set_value(checkpoint["label"])

    if max_err is None:
        max_err = 1e-5

    train_func = train
    if use_jit:
        train_func = jit.trace(
            train_func,
            symbolic=use_symbolic,
            sublinear_memory_config=sublinear_memory_config,
        )

    opt.zero_grad()
    loss = train_func(data, label, net=net, opt=opt)
    opt.step()

    assertTensorClose(loss.numpy(), checkpoint["loss"], max_err=max_err)

    for param, param_ref in zip(
        net.state_dict().items(), checkpoint["net_updated"].items()
    ):
        assert param[0] == param_ref[0]
        assertTensorClose(param[1], param_ref[1], max_err=max_err)
Пример #12
0
def test_training_converge(test_traced_module):
    net = XORNet()
    if test_traced_module:
        inp = Tensor(np.random.random((14, 2)))
        net = trace_module(net, inp)
    opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    gm = ad.GradManager().attach(net.parameters())

    @trace(symbolic=False)
    def train(data, label):
        with gm:
            pred = net(data)
            loss = F.nn.cross_entropy(pred, label)
            gm.backward(loss)
            optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0)
        return loss

    def infer(data):
        return net(data)

    train_dataset = minibatch_generator()
    losses = []

    for data, label in itertools.islice(train_dataset, 2000):
        data = Tensor(data, dtype=np.float32)
        label = Tensor(label, dtype=np.int32)
        opt.clear_grad()
        loss = train(data, label)
        optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1)
        opt.step()
        losses.append(loss.numpy())
    assert (np.mean(losses[-100:]) <
            0.1), "Final training Loss must be low enough, get {}".format(
                np.mean(losses[-100:]))

    ngrid = 10
    x = np.linspace(-1.0, 1.0, ngrid)
    xx, yy = np.meshgrid(x, x)
    xx = xx.reshape((ngrid * ngrid, 1))
    yy = yy.reshape((ngrid * ngrid, 1))
    data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
    pred = infer(data)
    precision = calculate_precision(data.numpy(), pred.numpy())
    assert precision == 1.0, "Test precision must be high enough, get {}".format(
        precision)
Пример #13
0
def test_static_graph_parampack():
    net = XORNet()
    net = ParamPack(net,
                    nr_ignore_first=0,
                    max_size_per_group=10,
                    max_nr_params_per_group=100)
    opt = SGD(net.parameters(requires_grad=True),
              lr=0.01,
              momentum=0.9,
              weight_decay=5e-4)

    @trace(symbolic=True)
    def train(data, label):
        pred = net(data)
        opt.zero_grad()
        loss = cross_entropy_with_softmax(pred, label)
        opt.backward(loss)
        return loss

    @trace(symbolic=True)
    def infer(data):
        return net(data)

    train_dataset = minibatch_generator()
    losses = []

    for data, label in itertools.islice(train_dataset, 2000):
        loss = train(data, label)
        loss = loss[0][0]
        opt.step()
        losses.append(loss.numpy())

    assert np.mean(
        losses[-100:]) < 0.1, "Final training Loss must be low enough"

    ngrid = 10
    x = np.linspace(-1.0, 1.0, ngrid)
    xx, yy = np.meshgrid(x, x)
    xx = xx.reshape((ngrid * ngrid, 1))
    yy = yy.reshape((ngrid * ngrid, 1))
    data = np.concatenate((xx, yy), axis=1).astype(np.float32)

    pred = infer(data).numpy()
    assert calculate_precision(
        data, pred) == 1.0, "Test precision must be high enough"
Пример #14
0
def test_optimizer_serialization():
    data, data_shape, label, label_shape = get_input()
    mlp = MLP()
    opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9)
    slots = TensorDict()
    for param in mlp.parameters():
        slots[param] = np.zeros(param.shape).astype(np.float32)

    pred = mlp(data)
    loss = F.square_loss(pred, label.reshape(-1, 1))
    opt.zero_grad()
    opt.backward(loss)
    opt.step()
    for param in mlp.parameters():
        slot = slots[param]
        slot *= 0.9
        slot -= param.grad.numpy() * 0.01

    with BytesIO() as fout:
        save(opt.state_dict(), fout)
        fout.seek(0)
        state_dict = load(fout)
        opt1 = SGD(mlp.parameters(), lr=0.02, momentum=0.8)
        opt1.load_state_dict(state_dict)

        data.set_value(np.random.random(data_shape).astype(np.float32))
        label.set_value(np.random.randint(0, 10, label_shape))
        pred = mlp(data)
        loss = F.square_loss(pred, label.reshape(-1, 1))
        opt1.zero_grad()
        opt1.backward(loss)
        orig_params = TensorDict()
        for param in mlp.parameters():
            orig_params[param] = np.copy(param.numpy())
        opt1.step()
        for param in mlp.parameters():
            orig_param = orig_params[param]
            slot = slots[param]
            slot *= 0.9
            slot -= param.grad.numpy() * 0.01
            assertTensorClose(param.numpy(), orig_param + slot)
Пример #15
0
def update_model(model_path):
    """
    Update the dumped model with test cases for new reference values.

    The model with pre-trained weights is trained for one iter with the test data attached.
    The loss and updated net state dict is dumped.

    .. code-block:: python

        from test_dp_correctness import update_model
        update_model('mnist_model_with_test.mge') # for gpu
        update_model('mnist_model_with_test_cpu.mge') # for cpu

    """
    net = MnistNet(has_bn=True)
    checkpoint = mge.load(model_path)
    net.load_state_dict(checkpoint["net_init"])
    lr = checkpoint["sgd_lr"]
    opt = SGD(net.parameters(), lr=lr)

    gm = ad.GradManager().attach(
        net.parameters(),
        callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)])

    data = Tensor(checkpoint["data"], dtype=np.float32)
    label = Tensor(checkpoint["label"], dtype=np.int32)

    opt.clear_grad()
    loss = train(data, label, net=net, opt=opt)
    opt.step()

    xpu_name = get_xpu_name()

    checkpoint.update({
        "net_updated": net.state_dict(),
        "loss": loss.numpy(),
        "xpu": xpu_name
    })
    mge.serialization.save(checkpoint, model_path)
Пример #16
0
def test_compile_multi_times_eager():
    return  # XXX: rewrite or remove this test
    data = Input("data", shape=(2, 28))
    label = Input("label", shape=(2, ), dtype=np.int32)

    mlp = MLP()
    opt = SGD(mlp.parameters(requires_grad=True), lr=0.01)

    pred0 = mlp(data)
    pred = F.softmax(pred0)
    loss = F.square_loss(pred, label.reshape(2, 1))
    opt.zero_grad()
    grads = opt.backward(loss)
    opt.step()

    f0 = compile(pred, None)
    f1 = compile([pred, loss], grads, copy=False)
    for _ in range(3):
        data = np.random.random((2, 28)).astype(np.float32)
        label = np.random.randint(0, 10, (2, )).astype(np.float32)
        out0 = f0(data=data)
        out1 = f1(data=data, label=label)
        assertTensorClose(out0[0], out1[0])
Пример #17
0
def test_training_converge():
    net = XORNet()
    opt = SGD(net.parameters(requires_grad=True),
              lr=0.01,
              momentum=0.9,
              weight_decay=5e-4)

    @trace
    def train(data, label):
        pred = net(data)
        opt.zero_grad()
        loss = cross_entropy_with_softmax(pred, label)
        opt.backward(loss)
        return loss

    @trace
    def infer(data):
        return net(data)

    train_dataset = minibatch_generator()
    losses = []

    for data, label in itertools.islice(train_dataset, 2000):
        # opt.zero_grad()
        loss = train(data, label)
        loss = loss[0][0]
        opt.step()
        losses.append(loss.numpy())

    assert np.mean(
        losses[-100:]) < 0.1, "Final training Loss must be low enough"

    data, _ = next(train_dataset)
    pred = infer(data).numpy()
    assert calculate_precision(
        data, pred) > 0.95, "Test precision must be high enough"
Пример #18
0
def test_release_memory():
    mnist_datasets = load_mnist_datasets()
    data_train, label_train = mnist_datasets["train"]

    batch_size = 15000
    data_shape = (batch_size, 1, 28, 28)
    label_shape = (batch_size, )

    data = nn.Input("data", shape=data_shape, dtype=np.float32)
    label = nn.Input("label",
                     shape=label_shape,
                     dtype=np.int32,
                     value=np.zeros(label_shape))

    net = MnistNet()
    opt = SGD(net.parameters(), lr=0.01)

    pred = F.softmax(net(data))
    loss = F.cross_entropy(pred, label)

    opt.zero_grad()
    opt.backward(loss)
    add_updates = opt.step()

    mge.graph._default_graph.get_default().clear_device_memory()

    f = mge.graph.compile(loss, add_updates)

    for _ in range(3):
        train_loss = 0.0
        for i in range(0, data_train.shape[0], batch_size):
            opt.zero_grad()
            data = data_train[i:i + batch_size, :, :, :]
            label = label_train[i:i + batch_size]
            loss = f(data=data, label=label)[0]
            train_loss += loss[0]
Пример #19
0
def run_perf(
    batch_size=64,
    warm_up=True,
    dump_prof=None,
    opt_level=2,
    conv_fastrun=False,
    run_step=True,
    track_bn_stats=True,
    warm_up_iter=20,
    run_iter=100,
    num_gpu=None,
    device=0,
    server=None,
    port=None,
    scale_batch_size=False,
    eager=False,
):

    if conv_fastrun:
        set_conv_execution_strategy("PROFILE")

    if num_gpu:
        dist.init_process_group(args.server, args.port, num_gpu, device,
                                device)
        if scale_batch_size:
            batch_size = batch_size // num_gpu
        print("Run with data parallel, batch size = {} per GPU".format(
            batch_size))

    data = tensor(np.random.randn(batch_size, 3, 224, 224).astype("float32"))
    label = tensor(np.random.randint(1000, size=[
        batch_size,
    ], dtype=np.int32))

    net = Resnet50(track_bn_stats=track_bn_stats)
    opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

    def train_func(data, label):
        logits = net(data)
        loss = F.cross_entropy_with_softmax(logits, label)

        if num_gpu:
            loss = loss / num_gpu

        opt.zero_grad()
        opt.backward(loss)
        return loss

    train_func = trace(
        train_func,
        symbolic=(not eager),
        opt_level=opt_level,
        profiling=not (dump_prof is None),
    )

    if warm_up:
        print("Warm up ...")
        for _ in range(warm_up_iter):
            opt.zero_grad()
            train_func(data, label)
            if run_step:
                opt.step()
    print_gpu_usage()
    print("Running train ...")
    start = time.time()
    for _ in range(run_iter):
        opt.zero_grad()
        train_func(data, label)
        if run_step:
            opt.step()

    time_used = time.time() - start

    if dump_prof:
        with open(dump_prof, "w") as fout:
            json.dump(train_func.get_profile(), fout, indent=2)

    return time_used / run_iter
Пример #20
0
from network import LeNet
from data import train_dataloader

net = LeNet()

optimizer = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

gm = GradManager().attach(net.parameters())

net.train()
total_epochs = 5
for epoch in range(total_epochs):
    total_loss = 0
    for step, (batch_data, batch_label) in enumerate(train_dataloader):
        batch_data = mge.tensor(batch_data)
        batch_label = mge.tensor(batch_label).astype(np.int32)

        with gm:
            pred = net(batch_data)
            loss = F.loss.cross_entropy(pred, batch_label)
            gm.backward(loss)
        optimizer.step().clear_grad()

        total_loss += loss.numpy().item()
        if step % 100 == 0:
            logger.info("epoch: {}, iter: {}, loss {}".format(
                epoch, step, total_loss / len(train_dataloader)))
    logger.info("epoch: {}, loss {}".format(epoch, total_loss /
                                            len(train_dataloader)))

mge.save(net.state_dict(), 'mnist_net.mge')