def main():

    # Function to get mnist iterator given a rank
    def get_voc_iterator(rank, num_workers, net, num_shards):
        data_dir = "data-%d" % rank
        try:
            s3_client = boto3.client('s3')
            for file in [
                    'VOCtrainval_06-Nov-2007.tar', 'VOCtest_06-Nov-2007.tar',
                    'VOCtrainval_11-May-2012.tar'
            ]:
                s3_client.download_file(args.s3bucket, f'voc_tars/{file}',
                                        f'/opt/ml/code/{file}')
                with tarfile.open(filename) as tar:
                    tar.extractall(path=path)
        except:
            print('downloading from source')
            download_voc(data_dir)

        input_shape = (1, 256, 256, 3)
        batch_size = args.batch_size

        # might want to replace with mx.io.ImageDetRecordIter, this means you need data in RecordIO format
        #         train_iter = mx.io.MNISTIter(
        #             image="%s/train-images-idx3-ubyte" % data_dir,
        #             label="%s/train-labels-idx1-ubyte" % data_dir,
        #             input_shape=input_shape,
        #             batch_size=batch_size,
        #             shuffle=True,
        #             flat=False,
        #             num_parts=hvd.size(),
        #             part_index=hvd.rank()
        #         )

        train_dataset = gdata.VOCDetection(
            root=f'/opt/ml/code/data-{rank}/VOCdevkit/',
            splits=[(2007, 'trainval'), (2012, 'trainval')])
        val_dataset = gdata.VOCDetection(
            root=f'/opt/ml/code/data-{rank}/VOCdevkit/',
            splits=[(2007, 'test')])
        val_metric = VOC07MApMetric(iou_thresh=0.5,
                                    class_names=val_dataset.classes)
        im_aspect_ratio = [1.] * len(train_dataset)
        train_bfn = FasterRCNNTrainBatchify(net)
        train_sampler = gluoncv.nn.sampler.SplitSortedBucketSampler(
            im_aspect_ratio,
            batch_size,
            num_parts=hvd.size() if args.horovod else 1,
            part_index=hvd.rank() if args.horovod else 0,
            shuffle=True)
        # had issue with multi_stage=True
        train_iter = mx.gluon.data.DataLoader(train_dataset.transform(
            FasterRCNNDefaultTrainTransform(net.short,
                                            net.max_size,
                                            net,
                                            ashape=net.ashape,
                                            multi_stage=False)),
                                              batch_sampler=train_sampler,
                                              batchify_fn=train_bfn,
                                              num_workers=num_workers)

        val_bfn = Tuple(*[Append() for _ in range(3)])
        short = net.short[-1] if isinstance(net.short,
                                            (tuple, list)) else net.short
        # validation use 1 sample per device
        val_iter = mx.gluon.data.DataLoader(val_dataset.transform(
            FasterRCNNDefaultValTransform(short, net.max_size)),
                                            num_shards,
                                            False,
                                            batchify_fn=val_bfn,
                                            last_batch='keep',
                                            num_workers=num_workers)

        return train_iter, val_iter

    # Function to define neural network
    def conv_nets(model_name):
        net = model_zoo.get_model(model_name, pretrained_base=False)
        return net

    def evaluate(net, val_data, ctx, eval_metric, args):
        """Test on validation dataset."""
        clipper = gcv.nn.bbox.BBoxClipToImage()
        eval_metric.reset()
        if not args.disable_hybridization:
            # input format is differnet than training, thus rehybridization is needed.
            net.hybridize(static_alloc=args.static_alloc)
        for batch in val_data:
            batch = split_and_load(batch, ctx_list=ctx)
            det_bboxes = []
            det_ids = []
            det_scores = []
            gt_bboxes = []
            gt_ids = []
            gt_difficults = []
            for x, y, im_scale in zip(*batch):
                # get prediction results
                ids, scores, bboxes = net(x)
                det_ids.append(ids)
                det_scores.append(scores)
                # clip to image size
                det_bboxes.append(clipper(bboxes, x))
                # rescale to original resolution
                im_scale = im_scale.reshape((-1)).asscalar()
                det_bboxes[-1] *= im_scale
                # split ground truths
                gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
                gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
                gt_bboxes[-1] *= im_scale
                gt_difficults.append(
                    y.slice_axis(axis=-1, begin=5, end=6
                                 ) if y.shape[-1] > 5 else None)

            # update metric
            for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(
                    det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids,
                    gt_difficults):
                eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id,
                                   gt_diff)
        return eval_metric.get()

    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
        ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        ctx = ctx if ctx else [mx.cpu()]
    context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(
        hvd.local_rank())
    num_workers = hvd.size()

    # Build model
    model = conv_nets(args.model_name)
    model.cast(args.dtype)
    model.hybridize()

    # Initialize parameters
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    model.initialize(initializer, ctx=context)

    # Create optimizer
    optimizer_params = {
        'momentum': args.momentum,
        'learning_rate': args.lr * hvd.size()
    }
    opt = mx.optimizer.create('sgd', **optimizer_params)

    # Load training and validation data
    train_data, val_data = get_voc_iterator(hvd.rank(), num_workers, model,
                                            len(ctx))

    # Horovod: fetch and broadcast parameters
    params = model.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params, opt)

    # Create loss function and train metric
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
    # adding in new loss functions
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rcnn_smoothl1_rho)  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    metric = mx.metric.Accuracy()

    # Global training timing
    if hvd.rank() == 0:
        global_tic = time.time()

    # Train model


#     for epoch in range(args.epochs):
#         tic = time.time()
#         train_data.reset()
#         metric.reset()
#         for nbatch, batch in enumerate(train_data, start=1):
#             data = batch.data[0].as_in_context(context)
#             label = batch.label[0].as_in_context(context)
#             with autograd.record():
#                 output = model(data.astype(args.dtype, copy=False))
#                 loss = loss_fn(output, label)
#             loss.backward()
#             trainer.step(args.batch_size)
#             metric.update([label], [output])

#             if nbatch % 100 == 0:
#                 name, acc = metric.get()
#                 logging.info('[Epoch %d Batch %d] Training: %s=%f' %
#                              (epoch, nbatch, name, acc))

#         if hvd.rank() == 0:
#             elapsed = time.time() - tic
#             speed = nbatch * args.batch_size * hvd.size() / elapsed
#             logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
#                          epoch, speed, elapsed)

#         # Evaluate model accuracy
#         _, train_acc = metric.get()
#         name, val_acc = evaluate(model, val_data, context)
#         if hvd.rank() == 0:
#             logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
#                          train_acc, name, val_acc)

#     if hvd.rank()==0:
#         global_training_time =time.time() - global_tic
#         print("Global elpased time on training:{}".format(global_training_time))
#         device = context.device_type + str(num_workers)

# train from train_faster_rcnn.py
    for epoch in range(args.epochs):
        lr_decay = float(args.lr_decay)
        lr_steps = sorted(
            [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
        lr_warmup = float(args.lr_warmup)  # avoid int division
        # this simplifies dealing with all of the loss functions
        rcnn_task = ForwardBackwardTask(model,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0,
                                        amp_enabled=args.amp)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        if not args.disable_hybridization:
            model.hybridize(static_alloc=args.static_alloc)
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:  # does a learning rate reset if warming up
                # adjust based on real percentage
                if (lr_warmup != 0):
                    new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                      args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(
                batch, ctx_list=ctx
            )  # does split and load function, creates a batch per device
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(model, val_data, ctx, eval_metric,
                                             args)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(model, logger, best_map, current_map, epoch,
                        args.save_interval, args.save_prefix)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
    """Training pipeline"""
    args.kv_store = 'device' if (
        args.amp and 'nccl' in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'momentum': args.momentum
    }
    if args.amp:
        optimizer_params['multi_precision'] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params)
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted(
        [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rcnn_smoothl1_rho)  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    if args.custom_model:
        logger.info(
            'Custom model enabled. Expert Only!! Currently non-FPN model is not supported!!'
            ' Default setting is for MS-COCO.')
    logger.info(args)

    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(net,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0,
                                        amp_enabled=args.amp)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric,
                                             args)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch,
                        args.save_interval, args.save_prefix)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
    """Training pipeline"""
    args.kv_store = "device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr("grad_req", "null")
    net.collect_train_params().setattr("grad_req", "write")
    optimizer_params = {"learning_rate": args.lr, "wd": args.wd, "momentum": args.momentum}
    if args.amp:
        optimizer_params["multi_precision"] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
        )
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv,
        )

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.0)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=1.0)  # == smoothl1
    metrics = [
        mx.metric.Loss("RPN_Conf"),
        mx.metric.Loss("RPN_SmoothL1"),
        mx.metric.Loss("RCNN_CrossEntropy"),
        mx.metric.Loss("RCNN_SmoothL1"),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

    logger.info(args)

    if args.verbose:
        logger.info("Trainable parameters:")
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(
            net,
            trainer,
            rpn_cls_loss,
            rpn_box_loss,
            rcnn_cls_loss,
            rcnn_box_loss,
            mix_ratio=1.0,
            amp_enabled=args.amp,
        )
        executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        net.hybridize()

        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(
                    i / lr_warmup, args.lr_warmup_factor / args.num_gpus
                )
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            "[Epoch 0 Iteration {}] Set learning rate to {}".format(i, new_lr)
                        )
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (
                (not args.horovod or hvd.rank() == 0)
                and args.log_interval
                and not (i + 1) % args.log_interval
            ):
                msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics + metrics2])
                logger.info(
                    "[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
                        epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg
                    )
                )
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics])
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}".format(epoch, (time.time() - tic), msg)
            )
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.0
            save_params(
                net,
                logger,
                best_map,
                current_map,
                epoch,
                args.save_interval,
                os.path.join(args.sm_save, args.save_prefix),
                args,
            )