Exemplo n.º 1
0
def demo_net(config):
    import json
    from utils.evaluate import evaluate_coco
    import tqdm
    import os

    ctx_list = [mx.gpu(x) for x in config.gpus]
    num_anchors = len(config.retinanet.network.SCALES) * len(
        config.retinanet.network.RATIOS)
    neck = PyramidNeckRetinaNet(
        feature_dim=config.network.fpn_neck_feature_dim)
    backbone = build_backbone(config,
                              neck=neck,
                              **config.network.BACKBONE.kwargs)
    net = RetinaNetFPNNet(backbone, config.dataset.NUM_CLASSES, num_anchors)
    net.collect_params().load(config.val.params_file)
    net.collect_params().reset_ctx(ctx_list)
    results = {}
    results["results"] = []
    for x, y, names in os.walk(
            os.path.join(config.dataset.dataset_path, "val2017")):
        for name in tqdm.tqdm(names):
            one_img = {}
            one_img["filename"] = os.path.basename(name)
            one_img["rects"] = []
            preds = inference_one_image(config, net, ctx_list[0],
                                        os.path.join(x, name))
            for i in range(len(preds)):
                one_rect = {}
                xmin, ymin, xmax, ymax = preds[i][:4]
                one_rect["xmin"] = int(np.round(xmin))
                one_rect["ymin"] = int(np.round(ymin))
                one_rect["xmax"] = int(np.round(xmax))
                one_rect["ymax"] = int(np.round(ymax))
                one_rect["confidence"] = float(preds[i][4])
                one_rect["label"] = int(preds[i][5])
                one_img["rects"].append(one_rect)
            results["results"].append(one_img)
    save_path = 'results.json'
    json.dump(results, open(save_path, "wt"))
    evaluate_coco(json_label=os.path.join(
        config.dataset.dataset_path, "annotations/instances_val2017.json"),
                  json_predict=save_path)
Exemplo n.º 2
0
def train_net(config):
    mx.random.seed(3)
    np.random.seed(3)
    try:
        import torch
        torch.random.manual_seed(3)
    except ImportError as e:
        logging.info("setting torch seed failed.")
        logging.exception(e)
    ctx_list = [mx.gpu(x) for x in config.gpus]
    num_anchors = len(config.retinanet.network.SCALES) * len(
        config.retinanet.network.RATIOS)
    neck = PyramidNeckRetinaNet(
        feature_dim=config.network.fpn_neck_feature_dim)
    backbone = build_backbone(config,
                              neck=neck,
                              **config.network.BACKBONE.kwargs)
    net = RetinaNetFPNNet(backbone, config.dataset.NUM_CLASSES, num_anchors)

    # Resume parameters.
    resume = None
    if resume is not None:
        params_coco = mx.nd.load(resume)
        for k in params_coco:
            params_coco[k.replace("arg:", "").replace("aux:",
                                                      "")] = params_coco.pop(k)
        params = net.collect_params()

        for k in params.keys():
            try:
                params[k]._load_init(params_coco[k.replace('resnet0_', '')],
                                     ctx=mx.cpu())
                print("success load {}".format(k))
            except Exception as e:
                logging.exception(e)

    if config.TRAIN.resume is not None:
        net.collect_params().load(config.TRAIN.resume)
        logging.info("loaded resume from {}".format(config.TRAIN.resume))

    # Initialize parameters
    params = net.collect_params()
    from utils.initializer import KaMingUniform
    for key in params.keys():
        if params[key]._data is None:
            default_init = mx.init.Zero(
            ) if "bias" in key or "offset" in key else KaMingUniform()
            default_init.set_verbosity(True)
            if params[key].init is not None and hasattr(
                    params[key].init, "set_verbosity"):
                params[key].init.set_verbosity(True)
                params[key].initialize(init=params[key].init,
                                       default_init=params[key].init)
            else:
                params[key].initialize(default_init=default_init)

    net.collect_params().reset_ctx(list(set(ctx_list)))

    if config.TRAIN.aspect_grouping:
        if config.dataset.dataset_type == "coco":
            from data.bbox.mscoco import COCODetection
            base_train_dataset = COCODetection(
                root=config.dataset.dataset_path,
                splits=("instances_train2017", ),
                h_flip=config.TRAIN.FLIP,
                transform=None,
                use_crowd=False,
                skip_empty=True,
                min_object_area=1)
        elif config.dataset.dataset_type == "voc":
            from data.bbox.voc import VOCDetection
            base_train_dataset = VOCDetection(root=config.dataset.dataset_path,
                                              splits=((2007, 'trainval'),
                                                      (2012, 'trainval')),
                                              preload_label=False)
        else:
            assert False
        train_dataset = AspectGroupingDataset(
            base_train_dataset,
            config,
            target_generator=RetinaNetTargetGenerator(config))
        train_loader = mx.gluon.data.DataLoader(dataset=train_dataset,
                                                batch_size=1,
                                                batchify_fn=batch_fn,
                                                num_workers=8,
                                                last_batch="discard",
                                                shuffle=True,
                                                thread_pool=False)
    else:
        assert False
    params_all = net.collect_params()
    params_to_train = {}
    params_fixed_prefix = config.network.FIXED_PARAMS
    for p in params_all.keys():
        ignore = False
        if params_all[p].grad_req == "null" and "running" not in p:
            ignore = True
            logging.info(
                "ignore {} because its grad req is set to null.".format(p))
        if params_fixed_prefix is not None:
            import re
            for f in params_fixed_prefix:
                if re.match(f, str(p)) is not None:
                    ignore = True
                    params_all[p].grad_req = 'null'
                    logging.info(
                        "{} is ignored when training because it matches {}.".
                        format(p, f))
        if not ignore and params_all[p].grad_req != "null":
            params_to_train[p] = params_all[p]
    lr_steps = [len(train_loader) * int(x) for x in config.TRAIN.lr_step]
    logging.info(lr_steps)
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(
        step=lr_steps,
        warmup_mode="constant",
        factor=.1,
        base_lr=config.TRAIN.lr,
        warmup_steps=config.TRAIN.warmup_step,
        warmup_begin_lr=config.TRAIN.warmup_lr)

    trainer = mx.gluon.Trainer(
        params_to_train,  # fix batchnorm, fix first stage, etc...
        'sgd',
        {
            'wd': config.TRAIN.wd,
            'momentum': config.TRAIN.momentum,
            'clip_gradient': None,
            'lr_scheduler': lr_scheduler
        })
    # trainer = mx.gluon.Trainer(
    #     params_to_train,  # fix batchnorm, fix first stage, etc...
    #     'adam', {"learning_rate": 1e-4})
    # Please note that the GPU devices of the trainer states when saving must be same with that when loading.
    if config.TRAIN.trainer_resume is not None:
        trainer.load_states(config.TRAIN.trainer_resume)
        logging.info("loaded trainer states from {}.".format(
            config.TRAIN.trainer_resume))

    metric_loss_loc = mx.metric.Loss(name="loss_loc")
    metric_loss_cls = mx.metric.Loss(name="loss_cls")
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [metric_loss_loc, metric_loss_cls]:
        eval_metrics.add(child_metric)
    net.hybridize(static_alloc=True, static_shape=False)
    for ctx in ctx_list:
        pad = lambda x: int(np.ceil(x / 32) * 32)
        with ag.record():
            y_hat = net(
                mx.nd.random.randn(
                    config.TRAIN.batch_size // len(ctx_list),
                    int(pad(config.TRAIN.image_max_long_size + 64)),
                    int(pad(config.TRAIN.image_short_size + 64)),
                    3,
                    ctx=ctx))
            y_hats = []
            for x in y_hat:
                for xx in x:
                    y_hats.append(xx)
            ag.backward(y_hats)
            del x
            del xx
            del y_hat
            del y_hats
        net.collect_params().zero_grad()
    mx.nd.waitall()
    for epoch in range(config.TRAIN.begin_epoch, config.TRAIN.end_epoch):

        mx.nd.waitall()
        for nbatch, data_batch in enumerate(
                tqdm.tqdm(train_loader, total=len(train_loader),
                          unit_scale=1)):
            data_list = mx.gluon.utils.split_and_load(data_batch[0][0],
                                                      ctx_list=ctx_list,
                                                      batch_axis=0)
            gt_bbpxe_list = mx.gluon.utils.split_and_load(data_batch[0][1],
                                                          ctx_list=ctx_list,
                                                          batch_axis=0)
            losses_loc = []
            losses_cls = []
            num_pos_list = []
            for data, gt_bboxes in zip(data_list, gt_bbpxe_list):
                # targets: (2, 86, num_anchors, h x w)
                with ag.record():
                    fpn_predictions = net(data)
                # Generate targets
                targets = []
                for stride, base_size, (loc_pred, cls_pred) in zip(
                        config.retinanet.network.FPN_STRIDES,
                        config.retinanet.network.BASE_SIZES, fpn_predictions):
                    op_kwargs = {
                        "stride": stride,
                        "base_size": base_size,
                        "negative_iou_threshold":
                        config.TRAIN.negative_iou_threshold,
                        "positive_iou_threshold":
                        config.TRAIN.positive_iou_threshold,
                        "ratios": config.retinanet.network.RATIOS,
                        "scales": config.retinanet.network.SCALES,
                        "bbox_norm_coef":
                        config.retinanet.network.bbox_norm_coef,
                    }
                    loc_targets, cls_targets, loc_masks, cls_masks = mobula.op.RetinaNetTargetGenerator(
                        data.detach(), loc_pred.detach(), cls_pred.detach(),
                        gt_bboxes.detach(), **op_kwargs)
                    targets.append(
                        [loc_targets, cls_targets, loc_masks, cls_masks])
                num_pos = mx.nd.ElementWiseSum(
                    *[x[2].sum() / 4 for x in targets])
                num_pos_list.append(num_pos)

                def smooth_l1(pred, target, beta):
                    diff = (pred - target).abs()
                    loss = mx.nd.where(diff < beta, 0.5 * diff * diff / beta,
                                       diff - 0.5 * beta)
                    return loss

                def l1_loss(pred, target):
                    diff = (pred - target).abs()
                    return diff

                with ag.record():
                    losses_loc_per_device = []
                    losses_cls_per_device = []
                    for (loc_pred,
                         cls_pred), (loc_targets, cls_targets, loc_masks,
                                     cls_masks) in zip(fpn_predictions,
                                                       targets):
                        # loss_loc = smooth_l1(loc_pred, loc_targets, beta=1.0/9) * loc_masks
                        loss_loc = l1_loss(loc_pred, loc_targets) * loc_masks
                        loss_cls = mobula.op.FocalLoss(
                            alpha=.25,
                            gamma=2,
                            logits=cls_pred,
                            targets=cls_targets.detach()) * cls_masks
                        losses_loc_per_device.append(loss_loc)
                        losses_cls_per_device.append(loss_cls)
                    loss_loc_sum_all_level = mx.nd.ElementWiseSum(
                        *[x.sum() for x in losses_loc_per_device])
                    loss_cls_sum_all_level = mx.nd.ElementWiseSum(
                        *[x.sum() for x in losses_cls_per_device])
                    losses_loc.append(loss_loc_sum_all_level)
                    losses_cls.append(loss_cls_sum_all_level)
            num_pos_per_batch_across_all_devices = sum(
                [x.sum().asscalar() for x in num_pos_list])
            with ag.record():
                for i in range(len(losses_loc)):
                    losses_loc[i] = losses_loc[
                        i] / num_pos_per_batch_across_all_devices
                for i in range(len(losses_cls)):
                    losses_cls[i] = losses_cls[
                        i] / num_pos_per_batch_across_all_devices

            ag.backward(losses_loc + losses_cls)
            # Since the num_pos is the total number of positive number of a mini-batch,
            # the normalizing coefficient should be 1 here.
            trainer.step(1)
            metric_loss_loc.update(
                None, mx.nd.array([sum([x.asscalar() for x in losses_loc])]))
            metric_loss_cls.update(
                None, mx.nd.array([sum([x.asscalar() for x in losses_cls])]))

            if trainer.optimizer.num_update % config.TRAIN.log_interval == 0:
                msg = "Epoch={},Step={},lr={}, ".format(
                    epoch, trainer.optimizer.num_update, trainer.learning_rate)
                msg += ','.join([
                    '{}={:.3f}'.format(w, v)
                    for w, v in zip(*eval_metrics.get())
                ])
                logging.info(msg)
                eval_metrics.reset()

            if trainer.optimizer.num_update % 5000 == 0:
                save_path = os.path.join(
                    config.TRAIN.log_path,
                    "{}-{}.params".format(epoch, trainer.optimizer.num_update))
                net.collect_params().save(save_path)
                logging.info("Saved checkpoint to {}".format(save_path))
                trainer_path = save_path + "-trainer.states"
                trainer.save_states(trainer_path)
        save_path = os.path.join(config.TRAIN.log_path,
                                 "{}.params".format(epoch))
        net.collect_params().save(save_path)
        logging.info("Saved checkpoint to {}".format(save_path))
        trainer_path = save_path + "-trainer.states"
        trainer.save_states(trainer_path)
Exemplo n.º 3
0
def train_net(config):
    mx.random.seed(3)
    np.random.seed(3)

    if config.TRAIN.USE_FP16:
        from mxnet.contrib import amp
        amp.init()
    if config.use_hvd:
        import horovod.mxnet as hvd

    ctx_list = [mx.gpu(x) for x in config.gpus]
    from utils.blocks import FrozenBatchNorm2d
    neck = PyramidNeckFCOS(feature_dim=config.network.fpn_neck_feature_dim)
    backbone = build_backbone(config,
                              neck=neck,
                              norm_layer=FrozenBatchNorm2d,
                              **config.network.BACKBONE.kwargs)
    net = FCOSFPNNet(backbone, config.dataset.NUM_CLASSES)

    # Resume parameters.
    resume = None
    if resume is not None:
        params_coco = mx.nd.load(resume)
        for k in params_coco:
            params_coco[k.replace("arg:", "").replace("aux:",
                                                      "")] = params_coco.pop(k)
        params = net.collect_params()

        for k in params.keys():
            try:
                params[k]._load_init(params_coco[k.replace('resnet0_', '')],
                                     ctx=mx.cpu())
                print("success load {}".format(k))
            except Exception as e:
                logging.exception(e)

    if config.TRAIN.resume is not None:
        net.collect_params().load(config.TRAIN.resume)
        logging.info("loaded resume from {}".format(config.TRAIN.resume))

    # Initialize parameters
    params = net.collect_params()
    from utils.initializer import KaMingUniform
    for key in params.keys():
        if params[key]._data is None:
            default_init = mx.init.Zero(
            ) if "bias" in key or "offset" in key else KaMingUniform()
            default_init.set_verbosity(True)
            if params[key].init is not None and hasattr(
                    params[key].init, "set_verbosity"):
                params[key].init.set_verbosity(True)
                params[key].initialize(init=params[key].init,
                                       default_init=params[key].init)
            else:
                params[key].initialize(default_init=default_init)
    params = net.collect_params()
    # for p_name, p in params.items():
    #     if p_name.endswith(('_bias')):
    #         p.wd_mult = 0
    #         p.lr_mult = 2
    #         logging.info("set wd_mult of {} to {}.".format(p_name, p.wd_mult))
    #         logging.info("set lr_mult of {} to {}.".format(p_name, p.lr_mult))

    net.collect_params().reset_ctx(list(set(ctx_list)))

    if config.dataset.dataset_type == "coco":
        from data.bbox.mscoco import COCODetection
        base_train_dataset = COCODetection(root=config.dataset.dataset_path,
                                           splits=("instances_train2017", ),
                                           h_flip=config.TRAIN.FLIP,
                                           transform=None,
                                           use_crowd=False)
    elif config.dataset.dataset_type == "voc":
        from data.bbox.voc import VOCDetection
        base_train_dataset = VOCDetection(root=config.dataset.dataset_path,
                                          splits=((2007, 'trainval'),
                                                  (2012, 'trainval')),
                                          preload_label=False)
    else:
        assert False
    train_dataset = AspectGroupingDataset(
        base_train_dataset,
        config,
        target_generator=FCOSTargetGenerator(config))

    if config.use_hvd:

        class SplitDataset(object):
            def __init__(self, da, local_size, local_rank):
                self.da = da
                self.local_size = local_size
                self.locak_rank = local_rank

            def __len__(self):
                return len(self.da) // self.local_size

            def __getitem__(self, idx):
                return self.da[idx * self.local_size + self.locak_rank]

        train_dataset = SplitDataset(train_dataset,
                                     local_size=hvd.local_size(),
                                     local_rank=hvd.local_rank())

    train_loader = mx.gluon.data.DataLoader(dataset=train_dataset,
                                            batch_size=1,
                                            num_workers=8,
                                            last_batch="discard",
                                            shuffle=True,
                                            thread_pool=False,
                                            batchify_fn=batch_fn)

    params_all = net.collect_params()
    params_to_train = {}
    params_fixed_prefix = config.network.FIXED_PARAMS
    for p in params_all.keys():
        ignore = False
        if params_all[p].grad_req == "null" and "running" not in p:
            ignore = True
            logging.info(
                "ignore {} because its grad req is set to null.".format(p))
        if params_fixed_prefix is not None:
            import re
            for f in params_fixed_prefix:
                if re.match(f, str(p)) is not None:
                    ignore = True
                    params_all[p].grad_req = 'null'
                    logging.info(
                        "{} is ignored when training because it matches {}.".
                        format(p, f))
        if not ignore and params_all[p].grad_req != "null":
            params_to_train[p] = params_all[p]
    lr_steps = [len(train_loader) * int(x) for x in config.TRAIN.lr_step]
    logging.info(lr_steps)
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(
        step=lr_steps,
        warmup_mode="constant",
        factor=.1,
        base_lr=config.TRAIN.lr,
        warmup_steps=config.TRAIN.warmup_step,
        warmup_begin_lr=config.TRAIN.warmup_lr)
    if config.use_hvd:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            params_to_train, 'sgd', {
                'wd': config.TRAIN.wd,
                'momentum': config.TRAIN.momentum,
                'clip_gradient': None,
                'lr_scheduler': lr_scheduler,
                'multi_precision': True,
            })
    else:
        trainer = mx.gluon.Trainer(
            params_to_train,  # fix batchnorm, fix first stage, etc...
            'sgd',
            {
                'wd': config.TRAIN.wd,
                'momentum': config.TRAIN.momentum,
                'clip_gradient': None,
                'lr_scheduler': lr_scheduler,
                'multi_precision': True,
            },
            update_on_kvstore=(False if config.TRAIN.USE_FP16 else None),
            kvstore=mx.kvstore.create('local'))
    if config.TRAIN.USE_FP16:
        amp.init_trainer(trainer)
    # trainer = mx.gluon.Trainer(
    #     params_to_train,  # fix batchnorm, fix first stage, etc...
    #     'adam', {"learning_rate": 4e-4})
    # Please note that the GPU devices of the trainer states when saving must be same with that when loading.
    if config.TRAIN.trainer_resume is not None:
        trainer.load_states(config.TRAIN.trainer_resume)
        logging.info("loaded trainer states from {}.".format(
            config.TRAIN.trainer_resume))

    metric_loss_loc = mx.metric.Loss(name="loss_loc")
    metric_loss_cls = mx.metric.Loss(name="loss_cls")
    metric_loss_center = mx.metric.Loss(name="loss_center")
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [metric_loss_loc, metric_loss_cls, metric_loss_center]:
        eval_metrics.add(child_metric)

    net.hybridize(static_alloc=True, static_shape=False)
    for ctx in ctx_list:
        with ag.record():
            pad = lambda x: int(np.ceil(x / 32) * 32)
            _ = net(
                mx.nd.random.randn(
                    config.TRAIN.batch_size // len(ctx_list),
                    int(pad(config.TRAIN.image_max_long_size + 32)),
                    int(pad(config.TRAIN.image_short_size + 32)),
                    3,
                    ctx=ctx))
        ag.backward(_)
        del _
        net.collect_params().zero_grad()
    mx.nd.waitall()

    while trainer.optimizer.num_update <= config.TRAIN.end_epoch * len(
            train_loader):
        epoch = trainer.optimizer.num_update // len(train_loader)
        for data_batch in tqdm.tqdm(
                train_loader
        ) if not config.use_hvd or hvd.local_rank() == 0 else train_loader:
            if config.use_hvd:
                data_list = [data_batch[0].as_in_context(ctx_list[0])]
                targets_list = [data_batch[1].as_in_context(ctx_list[0])]
            else:
                if isinstance(data_batch[0], mx.nd.NDArray):
                    data_list = mx.gluon.utils.split_and_load(
                        mx.nd.array(data_batch[0]),
                        ctx_list=ctx_list,
                        batch_axis=0)
                    targets_list = mx.gluon.utils.split_and_load(
                        mx.nd.array(data_batch[1]),
                        ctx_list=ctx_list,
                        batch_axis=0)
                else:
                    data_list = mx.gluon.utils.split_and_load(
                        mx.nd.array(data_batch[0][0]),
                        ctx_list=ctx_list,
                        batch_axis=0)
                    targets_list = mx.gluon.utils.split_and_load(
                        mx.nd.array(data_batch[0][1]),
                        ctx_list=ctx_list,
                        batch_axis=0)

            losses_loc = []
            losses_center_ness = []
            losses_cls = []

            n_workers = hvd.local_size() if config.use_hvd else len(ctx_list)
            num_pos = data_batch[0][1][:, 0].sum() / n_workers
            num_pos_denominator = mx.nd.maximum(num_pos,
                                                mx.nd.ones_like(num_pos))
            centerness_sum = data_batch[0][1][:, 5].sum() / n_workers
            centerness_sum_denominator = mx.nd.maximum(
                centerness_sum, mx.nd.ones_like(centerness_sum))

            with ag.record():
                for data, targets in zip(data_list, targets_list):
                    num_pos_denominator_ctx = num_pos_denominator.as_in_context(
                        data.context)
                    centerness_sum_denominator_ctx = centerness_sum_denominator.as_in_context(
                        data.context)
                    loc_preds, cls_preds = net(data)
                    iou_loss = mobula.op.IoULoss(loc_preds[:, :4],
                                                 targets[:, 1:5],
                                                 axis=1)
                    iou_loss = iou_loss * targets[:, 5:
                                                  6] / centerness_sum_denominator_ctx
                    # iou_loss = IoULoss()(loc_preds[:, :4].exp(), targets[:, 1:5]) * targets[:, 5] / centerness_sum_denominator_ctx
                    loss_center = mobula.op.BCELoss(
                        loc_preds[:, 4],
                        targets[:, 5]) * targets[:,
                                                 0] / num_pos_denominator_ctx
                    loss_cls = mobula.op.FocalLoss(
                        alpha=.25,
                        gamma=2,
                        logits=cls_preds,
                        targets=targets[:, 6:]) / num_pos_denominator_ctx
                    loss_total = loss_center.sum() + iou_loss.sum(
                    ) + loss_cls.sum()
                    if config.TRAIN.USE_FP16:
                        with amp.scale_loss(loss_total,
                                            trainer) as scaled_losses:
                            ag.backward(scaled_losses)
                    else:
                        loss_total.backward()
                    losses_loc.append(iou_loss)
                    losses_center_ness.append(loss_center)
                    losses_cls.append(loss_cls)

            trainer.step(n_workers)
            if not config.use_hvd or hvd.local_rank() == 0:
                for l in losses_loc:
                    metric_loss_loc.update(None, l.sum())
                for l in losses_center_ness:
                    metric_loss_center.update(None, l.sum())
                for l in losses_cls:
                    metric_loss_cls.update(None, l.sum())
                if trainer.optimizer.num_update % config.TRAIN.log_interval == 0:  #
                    msg = "Epoch={},Step={},lr={}, ".format(
                        epoch, trainer.optimizer.num_update,
                        trainer.learning_rate)
                    msg += ','.join([
                        '{}={:.3f}'.format(w, v)
                        for w, v in zip(*eval_metrics.get())
                    ])
                    logging.info(msg)
                    eval_metrics.reset()
                if trainer.optimizer.num_update % 5000 == 0:
                    save_path = os.path.join(
                        config.TRAIN.log_path,
                        "{}-{}.params".format(epoch,
                                              trainer.optimizer.num_update))
                    net.collect_params().save(save_path)
                    logging.info("Saved checkpoint to {}".format(save_path))
                    trainer_path = save_path + "-trainer.states"
                    trainer.save_states(trainer_path)

        if not config.use_hvd or hvd.local_rank() == 0:
            save_path = os.path.join(config.TRAIN.log_path,
                                     "{}.params".format(epoch))
            net.collect_params().save(save_path)
            logging.info("Saved checkpoint to {}".format(save_path))
            trainer_path = save_path + "-trainer.states"
            trainer.save_states(trainer_path)