Пример #1
0
def build_dataset(args):
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=args.batch_size,
                           drop_last=True))
    train_dataloader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([  # Baseline Augmentation for small models
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]) if args.arch in ("resnet18", "resnet34") else T.Compose(
            [  # Facebook Augmentation for large models
                T.RandomResizedCrop(224),
                T.RandomHorizontalFlip(),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                T.Normalize(mean=[103.530, 116.280, 123.675],
                            std=[57.375, 57.120, 58.395]),  # BGR
                T.ToMode("CHW"),
            ]),
        num_workers=args.workers,
    )
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_dataloader = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    return train_dataloader, valid_dataloader
Пример #2
0
def build_dataset(args):
    assert not args.batch_size//args.ngpus == 0 and not 4 // args.ngpus == 0
    train_dataset = SIDDData(args.data, length=args.batch_size*args.steps_per_epoch)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset, batch_size=args.batch_size//args.ngpus, drop_last=True)
    )
    train_dataloader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=args.workers,
    )
    valid_dataset = SIDDValData(args.data)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=4//args.ngpus, drop_last=False
    )
    valid_dataloader = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        num_workers=args.workers,
    )
    return train_dataloader, valid_dataloader
Пример #3
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements
    mge.set_log_file(os.path.join(args.save, args.arch, "log.txt"))

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch)

    model = getattr(M, args.arch)()
    step_start = 0
    if args.model:
        logger.info("load weights from %s", args.model)
        model.load_state_dict(mge.load(args.model))
        step_start = int(args.model.split("-")[1].split(".")[0])

    optimizer = optim.SGD(
        get_parameters(model),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Define train and valid graph
    @jit.trace(symbolic=True)
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "train_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "train_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "train_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=args.batch_size,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(step_start, args.steps + 1):
        # Linear learning rate decay
        decay = 1.0
        decay = 1 - float(step) / args.steps if step < args.steps else 0
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.learning_rate * decay

        image, label = next(train_queue)
        time_data = time.time() - t
        image = image.astype("float32")
        label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        time_iter = time.time() - t
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info(
                "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f",
                step,
                args.learning_rate * decay,
                float(objs.__str__().split()[1]),
                1 - float(top1.__str__().split()[1]) / 100,
                1 - float(top5.__str__().split()[1]) / 100,
                time_data,
                time_iter - time_data,
                time_iter * (args.steps - step) / 3600,
            )
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step % 10000 == 0 and rank == 0 and step != 0:
            logger.info("SAVING %06d", step)
            mge.save(
                model.state_dict(),
                os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)),
            )
        if step % 50000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info(
                "TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f",
                step, _, 1 - valid_acc / 100, 1 - valid_acc5 / 100)

    mge.save(model.state_dict(),
             os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)))
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST Iter %06d: loss=%f,\tTop-1 err = %f,\tTop-5 err = %f",
                step, _, 1 - valid_acc / 100, 1 - valid_acc5 / 100)
Пример #4
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_finetune_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS

    if args.mode != "normal":
        Q.quantize_qat(model, Q.ema_fakequant_qconfig)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    if args.mode == "quantized":
        raise ValueError("mode = quantized only used during inference")
        Q.quantize(model)

    optimizer = optim.SGD(
        get_parameters(model, cfg),
        lr=cfg.LEARNING_RATE,
        momentum=cfg.MOMENTUM,
    )

    # Define train and valid graph
    @jit.trace(symbolic=True)
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "train_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "train_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "train_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=cfg.BATCH_SIZE,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            cfg.COLOR_JITTOR,
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    train_queue = iter(train_queue)
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    def adjust_learning_rate(step, epoch):
        learning_rate = cfg.LEARNING_RATE
        if cfg.SCHEDULER == "Linear":
            learning_rate *= 1 - float(step) / total_steps
        elif cfg.SCHEDULER == "Multistep":
            learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right(
                cfg.SCHEDULER_STEPS, epoch)
        else:
            raise ValueError(cfg.SCHEDULER)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        return learning_rate

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(0, total_steps):
        # Linear learning rate decay
        epoch = step // steps_per_epoch
        learning_rate = adjust_learning_rate(step, epoch)

        image, label = next(train_queue)
        image = image.astype("float32")
        label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info("TRAIN e%d %06d %f %s %s %s %s", epoch, step,
                        learning_rate, objs, top1, top5, total_time)
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step % 10000 == 0 and rank == 0:
            logger.info("SAVING %06d", step)
            mge.save(
                {
                    "step": step,
                    "state_dict": model.state_dict()
                },
                os.path.join(save_dir, "checkpoint.pkl"),
            )
        if step % 10000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)

    mge.save({
        "step": step,
        "state_dict": model.state_dict()
    }, os.path.join(save_dir, "checkpoint-final.pkl"))
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
Пример #5
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS

    if args.mode != "normal":
        quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)

    if world_size > 1:
        # Sync parameters
        dist.bcast_list_(model.parameters(), dist.WORLD)

    # Autodiff gradient manager
    gm = autodiff.GradManager().attach(
        model.parameters(),
        callbacks=dist.make_allreduce_cb("MEAN") if world_size > 1 else None,
    )

    optimizer = optim.SGD(
        get_parameters(model, cfg),
        lr=cfg.LEARNING_RATE,
        momentum=cfg.MOMENTUM,
    )

    # Define train and valid graph
    def train_func(image, label):
        with gm:
            model.train()
            logits = model(image)
            loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
            acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
            gm.backward(loss)
            optimizer.step().clear_grad()
        return loss, acc1, acc5

    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=cfg.BATCH_SIZE,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            cfg.COLOR_JITTOR,
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    train_queue = iter(train_queue)
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW")
        ]),
        num_workers=args.workers,
    )

    def adjust_learning_rate(step, epoch):
        learning_rate = cfg.LEARNING_RATE
        if cfg.SCHEDULER == "Linear":
            learning_rate *= 1 - float(step) / total_steps
        elif cfg.SCHEDULER == "Multistep":
            learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right(
                cfg.SCHEDULER_STEPS, epoch)
        else:
            raise ValueError(cfg.SCHEDULER)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        return learning_rate

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(0, total_steps):
        # Linear learning rate decay
        epoch = step // steps_per_epoch
        learning_rate = adjust_learning_rate(step, epoch)

        image, label = next(train_queue)
        image = mge.tensor(image, dtype="float32")
        label = mge.tensor(label, dtype="int32")

        n = image.shape[0]

        loss, acc1, acc5 = train_func(image, label)

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info(
                "TRAIN e%d %06d %f %s %s %s %s",
                epoch,
                step,
                learning_rate,
                objs,
                top1,
                top5,
                total_time,
            )
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step != 0 and step % 10000 == 0 and rank == 0:
            logger.info("SAVING %06d", step)
            mge.save(
                {
                    "step": step,
                    "state_dict": model.state_dict()
                },
                os.path.join(save_dir, "checkpoint.pkl"),
            )
        if step % 10000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)

    mge.save(
        {
            "step": step,
            "state_dict": model.state_dict()
        },
        os.path.join(save_dir, "checkpoint-final.pkl"),
    )
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
Пример #6
0
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.data as data
import megengine.data.transform as T
import megengine.optimizer as optim

import megengine_mimicry as mmc
import megengine_mimicry.nets.dcgan.dcgan_cifar as dcgan

dataset = mmc.datasets.load_dataset(root=None, name='cifar10')
dataloader = data.DataLoader(dataset,
                             sampler=data.Infinite(
                                 data.RandomSampler(dataset,
                                                    batch_size=64,
                                                    drop_last=True)),
                             transform=T.Compose(
                                 [T.Normalize(std=255),
                                  T.ToMode("CHW")]),
                             num_workers=4)

netG = dcgan.DCGANGeneratorCIFAR()
netD = dcgan.DCGANDiscriminatorCIFAR()
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

LOG_DIR = "./log/dcgan_example"

trainer = mmc.training.Trainer(netD=netD,
                               netG=netG,
Пример #7
0
def worker(rank, gpu_num, args):
    # using sublinear
    os.environ["MGB_COMP_GRAPH_OPT"] = "enable_sublinear_memory_opt=1;seq_opt.enable_seq_comp_node_opt=0"
    os.environ["MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER"] = '10'
    os.environ['MGB_CUDA_RESERVE_MEMORY'] = '1'
    # establish the server if is the master

    dist_port = args.port
    if rank == 0:
        dist.Server(port=dist_port)
    if gpu_num> 1:

        dist.init_process_group(
            master_ip="localhost",
            port=dist_port,
            world_size=gpu_num,
            rank=rank,
            device=rank,
        )
        logger.info("Init process group for gpu%d done", rank)

    model = network.Network()
    params = model.parameters(requires_grad=True)
    model.train()

    # Autodiff gradient manager
    gm = autodiff.GradManager().attach(
        model.parameters(),
        callbacks=allreduce_cb,
    )

    opt = optim.SGD(
        params,
        lr=cfg.basic_lr * gpu_num * cfg.batch_per_gpu,
        momentum=cfg.momentum,
        weight_decay=cfg.weight_decay,
    )

    if cfg.pretrain_weight is not None:
        weights = mge.load(cfg.pretrain_weight)
        del weights['fc.weight']
        del weights['fc.bias']
        model.resnet50.load_state_dict(weights)

    start_epoch = 0
    if args.resume_weights is not None:
        assert osp.exists(args.resume_weights)
        model_file = args.resume_weights
        print('Loading {} to initialize FPN...'.format(model_file))
        model_dict = mge.load(model_file)
        start_epoch, weights = model_dict['epoch'] + 1, model_dict['state_dict']
        model.load_state_dict(weights, strict=False)
    
    logger.info("Prepare dataset")
    # train_loader = dataset.train_dataset(rank)

    train_dataset = CrowdHuman(cfg, if_train=True)
    train_sampler = data.Infinite(data.RandomSampler(
        train_dataset, batch_size = cfg.batch_per_gpu, drop_last=True,
        world_size = gpu_num, rank = rank,))
    train_loader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        collator = train_dataset,
        num_workers=4,
    )
    
    train_loader = iter(train_loader)
    logger.info("Training...")
    for epoch_id in range(start_epoch, cfg.max_epoch):
        for param_group in opt.param_groups:
            param_group["lr"] = (
                cfg.basic_lr * gpu_num * cfg.batch_per_gpu
                * (cfg.lr_decay_rate ** bisect.bisect_right(cfg.lr_decay_sates, epoch_id))
            )

        max_steps = cfg.nr_images_epoch // (cfg.batch_per_gpu * gpu_num)
        train_one_epoch(model, gm, train_loader, opt, max_steps, rank, epoch_id, gpu_num)
        if rank == 0:
            save_path = osp.join(cfg.model_dir, 'epoch-{}.pkl'.format(epoch_id + 1))
            state_dict = model.state_dict()
            names = [k for k, _ in state_dict.items()]
            for name in names:
                if name.startswith('inputs.'):
                    del state_dict[name]

            mge.save(
                {"epoch": epoch_id, "state_dict": state_dict}, save_path,
            )
            logger.info("dump weights to %s", save_path)