コード例 #1
0
def test_basic():
    x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
    w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
    b = mge.tensor(-1.0)

    gm = GradManager().attach([w, b])
    gm.record()

    p = F.matmul(x, w)
    y = p + b

    gm.backward(y)
    gm.release()  # is not necessary
    np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
    np.testing.assert_equal(b.grad.numpy(), [1])

    w.grad = None
    b.grad = None
    with gm:
        p = F.matmul(x, w)
        y = p + b
        gm.backward(y)

    np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
    np.testing.assert_equal(b.grad.numpy(), [1])
コード例 #2
0
def train(dataloader, args):
    writer = SummaryWriter("runs")
    net = Net()
    net.train()
    optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    gm = GradManager().attach(net.parameters())

    epoch_length = len(dataloader)
    for epoch in range(args.epoch):
        for step, (batch_data, batch_label) in enumerate(dataloader):
            batch_label = batch_label.astype(np.int32)
            data, label = mge.tensor(batch_data), mge.tensor(batch_label)
            with gm:
                pred = net(data)
                loss = F.loss.cross_entropy(pred, label)
                gm.backward(loss)
            optimizer.step().clear_grad()

            if step % 50 == 0:
                print("epoch:{}, iter:{}, loss:{}".format(epoch + 1, step, float(loss)))  # noqa
            writer.add_scalar("loss", float(loss), epoch * epoch_length + step)
        if (epoch + 1) % 5 == 0:
            mge.save(
                net.state_dict(), os.path.join(gettempdir(), f"mnist_net_e{epoch + 1}.pkl"),
            )  # noqa
コード例 #3
0
ファイル: test_functional.py プロジェクト: mozre/MegEngine
 def test_dropout_with_shape(shape, rate):
     data = tensor(np.ones(shape, dtype=np.float32))
     gm = GradManager().attach([data])
     with gm:
         out = F.nn.dropout(data, rate, training=True)
         gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
         assert not out.numpy().all()
         np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7,
                                    1e-7)
コード例 #4
0
 def jvp(inp, expr):
     with GradManager() as gm:
         with GradManager().attach([inp]) as gm2:
             oup = expr(inp)
             oup_grad = F.zeros_like(oup)
             gm.attach(oup_grad)
             gm2.backward(oup, oup_grad)
         gm.backward(inp.grad)
     return oup, oup_grad.grad
コード例 #5
0
ファイル: test_functional.py プロジェクト: mozre/MegEngine
 def test_multiple_dropout(shape, rate):
     data = tensor(np.ones(shape, dtype=np.float32))
     gm = GradManager().attach([data])
     with gm:
         out1 = F.nn.dropout(data, rate, training=True)
         out2 = F.nn.dropout(out1, rate, training=True)
         out3 = F.nn.dropout(out2, rate, training=True)
         gm.backward(out3, tensor(np.ones(shape, dtype=np.float32)))
         np.testing.assert_allclose(out3.numpy(), data.grad.numpy(), 1e-7,
                                    1e-7)
コード例 #6
0
ファイル: test_tracing.py プロジェクト: mozre/MegEngine
def test_dump_backward_graph():
    x0 = tensor(np.random.randn(3, 4))
    x1 = tensor(np.random.randn(3, 4))

    gm = GradManager().attach(x0)

    @trace(symbolic=True, capture_as_const=True)
    def f(x0, x1):
        with gm:
            y = x0 * x1
            gm.backward(y, F.ones_like(y))
            dx0 = x0.grad
        return y, dx0

    y, dx0 = f(x0, x1)
    np.testing.assert_equal(dx0.numpy(), x1)

    file = io.BytesIO()
    f.dump(file, optimize_for_inference=False)
    file.seek(0)

    infer_cg = cgtools.GraphInference(file)
    results = list((infer_cg.run(x0, x1)).values())

    np.testing.assert_equal(results[0], y)
    np.testing.assert_equal(results[1], dx0)
コード例 #7
0
ファイル: train.py プロジェクト: lhaippp/GyroFlow
def train_and_evaluate(model, manager):
    rank = dist.get_rank()

    # reload weights from restore_file if specified
    if args.restore_file is not None:
        manager.load_checkpoints()

    world_size = dist.get_world_size()
    if world_size > 1:
        dist.bcast_list_(model.parameters())
        dist.bcast_list_(model.buffers())

    gm = GradManager().attach(
        model.parameters(),
        callbacks=dist.make_allreduce_cb("SUM") if world_size > 1 else None,
    )

    for epoch in range(manager.params.num_epochs):
        # compute number of batches in one epoch (one full pass over the training set)
        train(model, manager, gm)

        # Evaluate for one epoch on validation set
        evaluate(model, manager)

        # Save best model weights accroding to the params.major_metric
        if rank == 0:
            manager.check_best_save_last_checkpoints(latest_freq=5)
コード例 #8
0
def train(dataloader, net, opt):
    logger.info("***** Running training *****")
    logger.info("batch size = %d", args.train_batch_size)
    sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0

    gm = GradManager().attach(net.parameters())

    for _, batch in enumerate(tqdm(dataloader, desc="Iteration")):
        input_ids, input_mask, segment_ids, label_ids = tuple(
            mge.tensor(t) for t in batch)
        batch_size = input_ids.shape[0]
        loss, logits, label_ids = net_train(input_ids,
                                            segment_ids,
                                            input_mask,
                                            label_ids,
                                            gm=gm,
                                            net=net)
        opt.step().clear_grad()
        sum_loss += loss.mean().item()
        sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
        total_examples += batch_size
        total_steps += 1

    result = {
        "train_loss": sum_loss / total_steps,
        "train_accuracy": sum_accuracy / total_examples,
    }

    logger.info("***** Train results *****")
    for key in sorted(result.keys()):
        logger.info("%s = %s", key, str(result[key]))
コード例 #9
0
def test_output_copy_trace():
    class Simple(Module):
        def __init__(self):
            super().__init__()
            self.a = Parameter([1.0], dtype=np.float32)

        def forward(self, x):
            x = x * self.a
            # will result into a copy of output in grad
            x = F.exp(x)
            return x

    ys = {False: [], True: []}

    for symbolic in [False, True]:
        net = Simple()
        gm = GradManager().attach(net.parameters())
        opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
        data = tensor(np.arange(4).reshape(2, 2), dtype="float32")

        @trace(symbolic=symbolic)
        def train_func(d):
            with gm:
                loss = net(d)
                gm.backward(loss)
                opt.step().clear_grad()
            return loss

        for i in range(3):
            y = train_func(data).numpy()
            ys[symbolic].append(y)

    for i in range(3):
        np.testing.assert_equal(ys[False][i], ys[True][i])
コード例 #10
0
    def worker():
        rank = dist.get_rank()
        size = dist.get_world_size()
        x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32)
        m = M.Linear(rank * 2 + 2, rank * 2 + 4)
        gm = GradManager().attach(m.parameters())
        opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)

        def train_func(x):
            with gm:
                if rank != 0:
                    x = dist.functional.remote_recv(rank - 1,
                                                    shape=(1, rank * 2 + 2),
                                                    dtype=np.float32)
                y = m(x)
                if rank != size - 1:
                    dist.functional.remote_send(y, dest_rank=rank + 1)
                    gm.backward()
                else:
                    y = y.mean()
                    gm.backward(y)
                opt.step().clear_grad()

        train_funcs = [
            train_func,
            trace(symbolic=False)(train_func),
            trace(symbolic=True)(train_func),
        ]

        for func in train_funcs:
            for i in range(3):
                func(x)
コード例 #11
0
 def func():
     with GradManager().attach(m.parameters()) as gm:
         if dist.get_rank() == 0:
             y = m(x)
         else:
             y = x
         y = F.distributed.broadcast(y)
         gm.backward(y)
コード例 #12
0
 def func():
     with GradManager().attach(m.parameters()) as gm:
         y = m(x)
         y = F.distributed.reduce_sum(y)
         if dist.get_rank() == 0:
             loss = (2 * y + 1).mean()
             gm.backward(loss)
         else:
             gm.backward()
コード例 #13
0
def worker(args):
    current_network = import_from_file(args.file)

    model = current_network.Net(current_network.Cfg())
    model.train()

    if dist.get_rank() == 0:
        logger.info(get_config_info(model.cfg))
        logger.info(repr(model))

    params_with_grad = []
    for name, param in model.named_parameters():
        if "bottom_up.conv1" in name and model.cfg.backbone_freeze_at >= 1:
            continue
        if "bottom_up.layer1" in name and model.cfg.backbone_freeze_at >= 2:
            continue
        params_with_grad.append(param)

    opt = SGD(
        params_with_grad,
        lr=model.cfg.basic_lr * args.batch_size,
        momentum=model.cfg.momentum,
        weight_decay=model.cfg.weight_decay * dist.get_world_size(),
    )

    gm = GradManager()
    if dist.get_world_size() > 1:
        gm.attach(params_with_grad,
                  callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)])
    else:
        gm.attach(params_with_grad)

    if args.weight_file is not None:
        # model.backbone.bottom_up.load_state_dict(weights, strict=False)
        logger.info("Loading Base-Pretrain weights...")
        weights = mge.load(args.weight_file)
        weight_new = {k: v for k, v in weights.items() if 'pred_' not in k}
        model.load_state_dict(weight_new, strict=False)

    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters(), dist.WORLD)  # sync parameters

    if dist.get_rank() == 0:
        logger.info("Prepare dataset")
    train_loader = iter(
        build_dataloader(args.batch_size, args.dataset_dir, model.cfg))

    for epoch in range(model.cfg.max_epoch):
        train_one_epoch(model, train_loader, opt, gm, epoch, args)
        if dist.get_rank() == 0:
            save_path = "logs/{}/epoch_{}.pkl".format(
                os.path.basename(args.file).split(".")[0], epoch)
            mge.save(
                {
                    "epoch": epoch,
                    "state_dict": model.state_dict()
                },
                save_path,
            )
            logger.info("dump weights to %s", save_path)
コード例 #14
0
def test_regression_1762():
    x = F.ones((10, 10, 3, 3))

    conv = M.Conv2d(10, 10, kernel_size=3, padding=1)

    t_shape = (1, 10, 1, 1)
    weight = mge.Parameter(np.ones(t_shape, dtype=np.float32))
    bias = mge.Parameter(np.zeros(t_shape, dtype=np.float32))

    gm = GradManager()
    gm.attach(list(conv.parameters()) + [weight, bias])

    with gm:
        out1 = conv(x)

        out2 = F.batch_norm(
            out1,
            None,
            None,
            weight,
            bias,
            training=True,
        )

        # Weird error only occur when this action is placed after BN
        # Op type is not relevant
        loss = out1 + 1
        gm.backward(loss)
コード例 #15
0
def test_no_dependency():
    x = mge.tensor(3)

    w = mge.Parameter(1.0)
    w_no_dep = mge.Parameter(1.0)
    gm = GradManager()
    gm.attach(w)
    gm.attach(w_no_dep)

    with gm:
        out1 = x * w
        out2 = w_no_dep * out1
        gm.backward(out1.sum())

    assert w.grad is not None
    assert w_no_dep.grad is None
コード例 #16
0
def test_grad_manager_group():
    x_np = np.random.rand(10).astype("float32")
    x = mge.tensor(x_np)

    gm = GradManager().attach([x])
    gm2 = GradManager().attach([x])

    with gm | gm2:
        y = F.cos(x)
        gm.backward(y)
        gm2.backward(y)
    np.testing.assert_almost_equal(x.grad.numpy(),
                                   -2 * np.sin(x_np),
                                   decimal=5)

    x.grad = None
コード例 #17
0
def test_attach_in_with_block():
    a = mge.Parameter([1.0])
    gm = GradManager()
    with gm:
        b = a * 3
        gm.attach(b)
        c = b + 1
        gm.backward(c)
    assert int(b.grad.numpy()) == 1
コード例 #18
0
def test_grad_manager_visibility_by_order():
    x_np = np.random.rand(10).astype("float32")
    x = mge.tensor(x_np)

    gm = GradManager().attach([x])
    gm2 = GradManager().attach([x])

    with gm2:
        with gm:
            y = F.cos(x)
            gm2.backward(y)
            np.testing.assert_almost_equal(x.grad.numpy(),
                                           -np.sin(x_np),
                                           decimal=5)
            gm.backward(x.grad)

    np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
コード例 #19
0
def test_dy():
    x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
    w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
    b = mge.tensor(-1.0)

    gm = GradManager().attach([w, b])

    def get_grad(grad, dy, idx):
        if isinstance(dy, (list, tuple)):
            return np.array(grad) * dy[idx]
        else:
            return np.array(grad) * dy

    # dy's shape should be the same as y's
    dy = mge.tensor(2.5).reshape(1, 1)
    w.grad = None
    b.grad = None
    with gm:
        p = F.matmul(x, w)
        y = p + b
        gm.backward(y, dy=dy)

    np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]] * dy.numpy())
    np.testing.assert_equal(b.grad.numpy(), [1] * dy.numpy())
コード例 #20
0
def test_attached_tensors():
    w1 = mge.Parameter(2.0)
    w2 = mge.Parameter(2.0)
    gm = GradManager()

    def check(expected):
        actual = gm.attached_tensors()
        assert len(expected) == len(actual)
        for exp, act in zip(expected, actual):
            assert exp is act

    gm.attach(w1)
    check([w1])
    gm.attach(w2)
    check([w1, w2])
    gm.attach(w1)
    check([w1, w2])
コード例 #21
0
def test_2nd_grad_with_custom_gradient():
    class MySin(Function):
        def forward(self, x):
            self.inp = x
            x = mge.Tensor(x.numpy())
            y = F.sin(x)
            return y

        def backward(self, dy):
            dx = F.cos(self.inp) * dy
            return dx

    class MyCos(Function):
        def forward(self, x):
            self.inp = x
            x = mge.Tensor(x.numpy())
            y = F.cos(x)
            return y

        def backward(self, dy):
            if dy is None:
                return None
            dx = -MySin()(self.inp) * dy
            return dx

    x_np = np.random.rand(10).astype("float32")
    x = mge.tensor(x_np)

    gm = GradManager().attach([x])
    gm2 = GradManager().attach([x])

    with gm:
        with gm2:
            y = MyCos()(x)
            gm2.backward(y)
        np.testing.assert_almost_equal(x.grad.numpy(),
                                       -np.sin(x_np),
                                       decimal=5)
        gm.backward(x.grad)
    np.testing.assert_almost_equal(x.grad.numpy(),
                                   -np.sin(x_np) - np.cos(x_np),
                                   decimal=5)
コード例 #22
0
def test_empty_grad_in_backward():
    x = mge.Parameter(F.full(100, 0.5))
    y = mge.Parameter(F.ones(100))

    gm = GradManager()
    gm.attach([x, y])

    with gm:
        z = F.where(x > 0.7, x, y)
        loss = z.sum()
        gm.backward(loss)
        assert np.all(x.grad.numpy() == 0)
        assert np.all(y.grad.numpy() == 1)
コード例 #23
0
def test_attach_temporary():
    w = mge.Parameter(2.0)
    gm = GradManager()
    gm.attach(w)

    def cb(x, g):
        assert x is ref()
        cb.called = True

    for i in range(3):
        with gm:
            cb.called = False
            x = mge.Tensor(i, dtype="float32")
            gm.attach(x, callbacks=cb)
            ref = weakref.ref(x)
            y = x * w
            gm.backward(y)
            assert cb.called
        del x
        assert ref() is None
コード例 #24
0
def test_elemwise_fuse_in_grad(trace_mode):
    w = Parameter(np.ones([4, 6]), dtype="float32")

    gm = GradManager().attach(w)
    opt = optim.SGD([w], lr=0.01, momentum=0.9, weight_decay=5e-4)

    # explicitly declare opt_level as 2
    @trace(symbolic=trace_mode, opt_level=2)
    def f():
        with gm:
            wm = F.sum(w**2, axis=1)**0.5
            loss = wm.mean()
            gm.backward(loss)
            opt.step().clear_grad()
        return loss

    for i in range(3):
        y = f()
        y.numpy()
コード例 #25
0
ファイル: test_dtr.py プロジェクト: mozre/MegEngine
def run_dtr_resnet1202():
    batch_size = 8
    resnet1202 = ResNet(BasicBlock, [200, 200, 200])
    opt = optim.SGD(resnet1202.parameters(),
                    lr=0.05,
                    momentum=0.9,
                    weight_decay=1e-4)
    gm = GradManager().attach(resnet1202.parameters())

    def train_func(data, label, *, net, gm):
        net.train()
        with gm:
            pred = net(data)
            loss = F.loss.cross_entropy(pred, label)
            gm.backward(loss)
        return pred, loss

    _, free_mem = mge.device.get_mem_status_bytes()
    tensor_mem = free_mem - (2**30)
    if tensor_mem > 0:
        x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32)
    else:
        x = np.ones((1, ), dtype=np.float32)
    t = mge.tensor(x)

    mge.dtr.enable()
    mge.dtr.enable_sqrt_sampling = True

    data = np.random.randn(batch_size, 3, 32, 32).astype("float32")
    label = np.random.randint(0, 10, size=(batch_size, )).astype("int32")
    for _ in range(2):
        opt.clear_grad()
        _, loss = train_func(mge.tensor(data),
                             mge.tensor(label),
                             net=resnet1202,
                             gm=gm)
        opt.step()
        loss.item()

    t.numpy()
    mge.dtr.disable()
    mge._exit(0)
コード例 #26
0
    def f():
        gm = GradManager()
        scaler = GradScaler()

        x = mge.tensor(1.0)
        for _ in range(3):
            with gm:
                y = x + 1
                gm.attach(y)
                loss = y + 1
                scaler.backward(gm, loss, unscale_grad=False)
            np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor)
            scaler.unscale(gm.attached_tensors())
            np.testing.assert_equal(y.grad.numpy(), 1)
        # test handle None elements
        scaler.unscale(gm.attached_tensors())
コード例 #27
0
ファイル: test_tracing.py プロジェクト: wenming2014/MegEngine
def test_output_copy_trace():
    class Simple(Module):
        def __init__(self):
            super().__init__()
            self.a = Parameter([1.0], dtype=np.float32)

        def forward(self, x):
            x = x * self.a
            # will result into a copy of output in grad
            x = F.exp(x)
            return x

    net = Simple()

    gm = GradManager().attach(net.parameters())
    opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
    data = tensor(np.arange(4).reshape(2, 2), dtype="float32")

    @trace(symbolic=False)
    def train_f1(d):
        with gm:
            loss = net(d)
            gm.backward(loss)
            opt.step().clear_grad()
        return loss

    @trace(symbolic=True)
    def train_f2(d):
        with gm:
            loss = net(d)
            gm.backward(loss)
            opt.step().clear_grad()
        return loss

        for i in range(2):
            y1 = train_f1(data).numpy()
            y2 = train_f2(data).numpy()
            np.testing.assert_equal(y1, y2)
コード例 #28
0
ファイル: train.py プロジェクト: zzh7982/Models
def worker(args):
    current_network = import_from_file(args.file)

    model = current_network.Net(current_network.Cfg())
    model.train()

    if dist.get_rank() == 0:
        logger.info(get_config_info(model.cfg))
        logger.info(repr(model))

    backbone_params = []
    head_params = []
    for name, param in model.named_parameters():
        if "backbone" in name:
            backbone_params.append(param)
        else:
            head_params.append(param)

    opt = SGD(
        [
            {
                "params": backbone_params,
                "lr": model.cfg.learning_rate * 0.1
            },
            {
                "params": head_params
            },
        ],
        lr=model.cfg.learning_rate,
        momentum=model.cfg.momentum,
        weight_decay=model.cfg.weight_decay * dist.get_world_size(),
    )

    gm = GradManager()
    if dist.get_world_size() > 1:
        gm.attach(model.parameters(),
                  callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)])
    else:
        gm.attach(model.parameters())

    cur_epoch = 0
    if args.resume is not None:
        pretrained = mge.load(args.resume)
        cur_epoch = pretrained["epoch"] + 1
        model.load_state_dict(pretrained["state_dict"])
        opt.load_state_dict(pretrained["opt"])
        if dist.get_rank() == 0:
            logger.info("load success: epoch %d", cur_epoch)

    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters(), dist.WORLD)  # sync parameters

    if dist.get_rank() == 0:
        logger.info("Prepare dataset")
    train_loader = iter(
        build_dataloader(model.cfg.batch_size, args.dataset_dir, model.cfg))

    for epoch in range(cur_epoch, model.cfg.max_epoch):
        train_one_epoch(model, train_loader, opt, gm, epoch)
        if dist.get_rank() == 0:
            save_path = "log-of-{}/epoch_{}.pkl".format(
                os.path.basename(args.file).split(".")[0], epoch)
            mge.save(
                {
                    "epoch": epoch,
                    "state_dict": model.state_dict(),
                    "opt": opt.state_dict()
                }, save_path)
            logger.info("dump weights to %s", save_path)
コード例 #29
0
def build_gradmanager(module):
    world_size = dist.get_world_size()
    gm = GradManager().attach(
        module.parameters(),
        callbacks=dist.make_allreduce_cb("SUM") if world_size > 1 else None)
    return gm
コード例 #30
0
ファイル: train.py プロジェクト: zzh7982/Models
def worker(master_ip, port, rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))
        dist.init_process_group(
            master_ip=master_ip,
            port=port,
            world_size=world_size,
            rank=rank,
            device=rank,
        )

    model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0], cfg.input_shape[1])
    save_dir = os.path.join(args.save, model_name)

    model = getattr(kpm, args.arch)()
    model.train()
    start_epoch = 0
    if args.resume is not None:
        file = mge.load(args.resume)
        model.load_state_dict(file["state_dict"])
        start_epoch = file["epoch"]

    optimizer = optim.Adam(
        model.parameters(), lr=cfg.initial_lr, weight_decay=cfg.weight_decay
    )

    gm = GradManager()
    if dist.get_world_size() > 1:
        gm.attach(
            model.parameters(), callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)],
        )
    else:
        gm.attach(model.parameters())

    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters(), dist.WORLD)  # sync parameters

    # Build train datasets
    logger.info("preparing dataset..")
    ann_file = os.path.join(
        cfg.data_root, "annotations", "person_keypoints_train2017.json"
    )
    train_dataset = COCOJoints(
        cfg.data_root,
        ann_file,
        image_set="train2017",
        order=("image", "keypoints", "boxes", "info"),
    )
    logger.info("Num of Samples: {}".format(len(train_dataset)))
    train_sampler = data.RandomSampler(
        train_dataset, batch_size=cfg.batch_size, drop_last=True
    )

    transforms = [
        T.Normalize(mean=cfg.img_mean, std=cfg.img_std),
        RandomHorizontalFlip(0.5, keypoint_flip_order=cfg.keypoint_flip_order)
    ]

    if cfg.half_body_transform:
        transforms.append(
            HalfBodyTransform(
                cfg.upper_body_ids, cfg.lower_body_ids, cfg.prob_half_body
            )
        )
    if cfg.extend_boxes:
        transforms.append(
            ExtendBoxes(cfg.x_ext, cfg.y_ext, cfg.input_shape[1] / cfg.input_shape[0])
        )

    transforms += [
        RandomBoxAffine(
            degrees=cfg.rotate_range,
            scale=cfg.scale_range,
            output_shape=cfg.input_shape,
            rotate_prob=cfg.rotation_prob,
            scale_prob=cfg.scale_prob,
        )
    ]
    transforms += [T.ToMode()]

    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=args.workers,
        transform=T.Compose(transforms=transforms, order=train_dataset.order,),
        collator=HeatmapCollator(
            cfg.input_shape,
            cfg.output_shape,
            cfg.keypoint_num,
            cfg.heat_thr,
            cfg.heat_kernels if args.multi_scale_supervision else cfg.heat_kernels[-1:],
            cfg.heat_range,
        ),
    )

    # Start training
    for epoch in range(start_epoch, cfg.epochs):
        loss = train(model, train_queue, optimizer, gm, epoch=epoch)
        logger.info("Epoch %d Train %.6f ", epoch, loss)

        if rank == 0 and epoch % cfg.save_freq == 0:  # save checkpoint
            mge.save(
                {"epoch": epoch + 1, "state_dict": model.state_dict()},
                os.path.join(save_dir, "epoch_{}.pkl".format(epoch)),
            )