def main():
    if opt.builtin_profiler > 0:
        profiler.set_config(profile_all=True, aggregate_stats=True)
        profiler.set_state('run')
    if opt.mode == 'symbolic':
        data = mx.sym.var('data')
        out = net(data)
        softmax = mx.sym.SoftmaxOutput(out, name='softmax')
        mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()])
        kv = mx.kv.create(opt.kvstore)
        train_data, val_data = get_data_iters(dataset, batch_size, kv.num_workers, kv.rank)
        mod.fit(train_data,
                eval_data = val_data,
                num_epoch=opt.epochs,
                kvstore=kv,
                batch_end_callback = mx.callback.Speedometer(batch_size, max(1, opt.log_interval)),
                epoch_end_callback = mx.callback.do_checkpoint('image-classifier-%s'% opt.model),
                optimizer = 'sgd',
                optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'multi_precision': True},
                initializer = mx.init.Xavier(magnitude=2))
        mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
    else:
        if opt.mode == 'hybrid':
            net.hybridize()
        train(opt, context)
    if opt.builtin_profiler > 0:
        profiler.set_state('stop')
        print(profiler.dumps())
예제 #2
0
def test_continuous_profile_and_instant_marker():
    enable_profiler(True, True, True)
    python_domain = profiler.Domain('PythonDomain::test_continuous_profile')
    last_file_size = 0
    for i in range(5):
        profiler.Marker(python_domain, "StartIteration-" + str(i)).mark('process')
        print("{}...".format(i))
        test_profile_event(False)
        test_profile_counter(False)
        profiler.dump(False)
        # File size should keep increasing
        new_file_size = os.path.getsize("test_profile.json")
        assert new_file_size >= last_file_size
        last_file_size = new_file_size
    profiler.dump(False)
    debug_str = profiler.dumps()
    assert(len(debug_str) > 0)
    print(debug_str)
    profiler.set_state('stop')
예제 #3
0
def train(
    args,
    model,
    train_sampler,
    valid_samplers=None,
    rank=0,
    rel_parts=None,
    barrier=None,
):
    assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
    assert (args.rel_part == False
            ), "No need for relation partition in single process for MXNet KGE"
    logs = []

    for arg in vars(args):
        logging.info("{:20}:{}".format(arg, getattr(args, arg)))

    if len(args.gpu) > 0:
        gpu_id = (args.gpu[rank % len(args.gpu)]
                  if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0])
    else:
        gpu_id = -1

    if args.strict_rel_part:
        model.prepare_relation(mx.gpu(gpu_id))

    if mxprofiler:
        from mxnet import profiler

        profiler.set_config(
            profile_all=True,
            aggregate_stats=True,
            continuous_dump=True,
            filename="profile_output.json",
        )
    start = time.time()
    for step in range(0, args.max_step):
        pos_g, neg_g = next(train_sampler)
        args.step = step
        if step == 1 and mxprofiler:
            profiler.set_state("run")
        with mx.autograd.record():
            loss, log = model.forward(pos_g, neg_g, gpu_id)
        loss.backward()
        logs.append(log)
        model.update(gpu_id)

        if step % args.log_interval == 0:
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
                print("[Train]({}/{}) average {}: {}".format(
                    step, args.max_step, k, v))
            logs = []
            print(time.time() - start)
            start = time.time()

        if (args.valid and step % args.eval_interval == 0 and step > 1
                and valid_samplers is not None):
            start = time.time()
            test(args, model, valid_samplers, mode="Valid")
            print("test:", time.time() - start)
    if args.strict_rel_part:
        model.writeback_relation(rank, rel_parts)
    if mxprofiler:
        nd.waitall()
        profiler.set_state("stop")
        profiler.dump()
        print(profiler.dumps())
    # clear cache
    logs = []
예제 #4
0
                        context=mx.gpu(0),
                        data_names=['data'],
                        label_names=['softmax_label'])

    # allocate memory given the input data and label shapes
    mod.bind(data_shapes=train_iter.provide_data,
             label_shapes=train_iter.provide_label)
    # initialize parameters by uniform random numbers
    mod.init_params(initializer=mx.init.Uniform(scale=.1))
    # use SGD with learning rate 0.1 to train
    mod.init_optimizer(optimizer='sgd',
                       optimizer_params={'learning_rate': 0.1})
    # use accuracy as the metric
    metric = mx.metric.create('acc')

    train_iter.reset()
    metric.reset()

    profiler.set_state('run')
    for batch in train_iter:
        mod.forward(batch, is_train=True)  # compute predictions
        mod.update_metric(metric,
                          batch.label)  # accumulate prediction accuracy
        mod.backward()  # compute gradients
        mod.update()  # update parameters
        break

    mx.nd.waitall()
    profiler.set_state('stop')
    print(profiler.dumps())
예제 #5
0
def custom_operator_profiling_multiple_custom_ops(seed, mode, file_name):
    class MyAdd(mx.operator.CustomOp):
        def forward(self, is_train, req, in_data, out_data, aux):
            self.assign(out_data[0], req[0], in_data[0] + 1)

        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
            self.assign(in_grad[0], req[0], out_grad[0])

    @mx.operator.register('MyAdd1')
    class MyAdd1Prop(mx.operator.CustomOpProp):
        def __init__(self):
            super(MyAdd1Prop, self).__init__(need_top_grad=True)

        def list_arguments(self):
            return ['data']

        def list_outputs(self):
            return ['output']

        def infer_shape(self, in_shape):
            # inputs, outputs, aux
            return [in_shape[0]], [in_shape[0]], []

        def create_operator(self, ctx, shapes, dtypes):
            return MyAdd()

    @mx.operator.register('MyAdd2')
    class MyAdd2Prop(mx.operator.CustomOpProp):
        def __init__(self):
            super(MyAdd2Prop, self).__init__(need_top_grad=True)

        def list_arguments(self):
            return ['data']

        def list_outputs(self):
            return ['output']

        def infer_shape(self, in_shape):
            # inputs, outputs, aux
            return [in_shape[0]], [in_shape[0]], []

        def create_operator(self, ctx, shapes, dtypes):
            return MyAdd()

    enable_profiler(profile_filename=file_name, run=True, continuous_dump=True,\
                    aggregate_stats=True)
    # clear aggregate stats
    profiler.dumps(reset=True)
    inp = mx.nd.zeros(shape=(100, 100))
    if mode == 'imperative':
        y = mx.nd.Custom(inp, op_type='MyAdd1')
        z = mx.nd.Custom(inp, op_type='MyAdd2')
    elif mode == 'symbolic':
        a = mx.symbol.Variable('a')
        b = mx.symbol.Custom(data=a, op_type='MyAdd1')
        c = mx.symbol.Custom(data=a, op_type='MyAdd2')
        y = b.bind(mx.cpu(), {'a': inp})
        z = c.bind(mx.cpu(), {'a': inp})
        yy = y.forward()
        zz = z.forward()
    mx.nd.waitall()
    profiler.dump(False)
    debug_str = profiler.dumps(format='json')
    check_custom_operator_profiling_multiple_custom_ops_output(debug_str)
    profiler.set_state('stop')
예제 #6
0
def test_custom_operator_profiling_multiple_custom_ops_imperative(seed = None, \
        mode = 'imperative', file_name = None):
    class MyAdd(mx.operator.CustomOp):
        def forward(self, is_train, req, in_data, out_data, aux):
            self.assign(out_data[0], req[0], in_data[0] + 1)

        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
            self.assign(in_grad[0], req[0], out_grad[0])

    @mx.operator.register('MyAdd1')
    class MyAdd1Prop(mx.operator.CustomOpProp):
        def __init__(self):
            super(MyAdd1Prop, self).__init__(need_top_grad=True)

        def list_arguments(self):
            return ['data']

        def list_outputs(self):
            return ['output']

        def infer_shape(self, in_shape):
            # inputs, outputs, aux
            return [in_shape[0]], [in_shape[0]], []

        def create_operator(self, ctx, shapes, dtypes):
            return MyAdd()

    @mx.operator.register('MyAdd2')
    class MyAdd2Prop(mx.operator.CustomOpProp):
        def __init__(self):
            super(MyAdd2Prop, self).__init__(need_top_grad=True)

        def list_arguments(self):
            return ['data']

        def list_outputs(self):
            return ['output']

        def infer_shape(self, in_shape):
            # inputs, outputs, aux
            return [in_shape[0]], [in_shape[0]], []

        def create_operator(self, ctx, shapes, dtypes):
            return MyAdd()

    if file_name is None:
        file_name = 'test_custom_operator_profiling_multiple_custom_ops_imperative.json'
    enable_profiler(profile_filename = file_name, run=True, continuous_dump=True,\
                    aggregate_stats=True)
    inp = mx.nd.zeros(shape=(100, 100))
    if mode == 'imperative':
        x = inp + 1
        y = mx.nd.Custom(inp, op_type='MyAdd1')
        z = mx.nd.Custom(inp, op_type='MyAdd2')
    elif mode == 'symbolic':
        a = mx.symbol.Variable('a')
        b = a + 1
        c = mx.symbol.Custom(data=a, op_type='MyAdd1')
        d = mx.symbol.Custom(data=a, op_type='MyAdd2')
        b.bind(mx.cpu(), {'a': inp}).forward()
        c.bind(mx.cpu(), {'a': inp}).forward()
        d.bind(mx.cpu(), {'a': inp}).forward()
    mx.nd.waitall()
    profiler.dump(False)
    debug_str = profiler.dumps(format='json')
    target_dict = json.loads(debug_str)
    '''
    We are calling _plus_scalar within MyAdd1 and MyAdd2 and outside both the custom 
    operators, so in aggregate stats we should have three different kinds of 
    _plus_scalar under domains "Custom Operator" and "operator"
    '''
    assert 'Time' in target_dict and 'Custom Operator' in target_dict['Time'] \
        and 'MyAdd1::pure_python' in target_dict['Time']['Custom Operator'] \
        and 'MyAdd2::pure_python' in target_dict['Time']['Custom Operator'] \
        and 'MyAdd1::_plus_scalar' in target_dict['Time']['Custom Operator'] \
        and 'MyAdd2::_plus_scalar' in target_dict['Time']['Custom Operator'] \
        and '_plus_scalar' not in target_dict['Time']['Custom Operator'] \
        and 'operator' in target_dict['Time'] \
        and '_plus_scalar' in target_dict['Time']['operator']
    profiler.set_state('stop')
예제 #7
0
 def _save_profile(self):
     if self._profile:
         print(profiler.dumps())
         profiler.dump()
예제 #8
0
def train(args):
    np.random.seed(args.seed)
    if args.gpu:
        ctx = [mx.gpu(0)]
    else:
        ctx = [mx.cpu(0)]
    if args.dataset == "Sony":
        out_channels = 12
        scale = 2
    else:
        out_channels = 27
        scale = 3

    # load data
    train_transform = utils.Compose([
        utils.RandomCrop(args.patch_size, scale),
        utils.RandomFlipLeftRight(),
        utils.RandomFlipTopBottom(),
        utils.RandomTranspose(),
        utils.ToTensor(),
    ])
    train_dataset = data.MyDataset(args.dataset,
                                   "train",
                                   transform=train_transform)
    val_transform = utils.Compose([utils.ToTensor()])
    val_dataset = data.MyDataset(args.dataset, "val", transform=val_transform)
    train_loader = gluon.data.DataLoader(train_dataset,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         last_batch='rollover')
    val_loader = gluon.data.DataLoader(val_dataset,
                                       batch_size=1,
                                       last_batch='discard')
    unet = net.UNet(out_channels, scale)
    unet.initialize(init=initializer.Xavier(), ctx=ctx)

    # optimizer and loss
    trainer = gluon.Trainer(unet.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    l1_loss = gluon.loss.L1Loss()

    print "Start training now.."
    for i in range(args.epochs):
        total_loss = 0
        count = 0
        profiler.set_state('run')
        for batch_id, (img, gt) in enumerate(train_loader):
            batch_size = img.shape[0]
            count += batch_size
            img_list = gluon.utils.split_and_load(img[0], ctx)
            gt_list = gluon.utils.split_and_load(gt[0], ctx)
            with autograd.record():
                preds = [unet(x) for x in img_list]
                losses = []
                for ii in range(len(preds)):
                    loss = l1_loss(gt_list[ii], preds[ii])
                    losses.append(loss)
            for loss in losses:
                loss.backward()
            total_loss += sum([l.sum().asscalar() for l in losses])
            avg_loss = total_loss / count
            trainer.step(batch_size)
            metric.update(gt_list, preds)
            F.waitall()
            profiler.set_state('stop')
            print profiler.dumps()
            break
            gt_save = gt_list[0]
            output_save = preds[0]

            if (batch_id + 1) % 100 == 0:
                message = "Epoch {}: [{}/{}]: l1_loss: {:.4f}".format(
                    i + 1, count, len(train_dataset), avg_loss)
                print message
        temp = F.concat(gt_save, output_save, dim=3)
        temp = temp.asnumpy().reshape(temp.shape[2], temp.shape[3], 3)
        scipy.misc.toimage(temp * 255,
                           high=255,
                           low=0,
                           cmin=0,
                           cmax=255,
                           mode='RGB').save(args.save_model_dir +
                                            '%04d_%05d_00_train.jpg' %
                                            (i + 1, count))

        # evaluate
        batches = 0
        avg_psnr = 0.
        for img, gt in val_loader:
            batches += 1
            imgs = gluon.utils.split_and_load(img[0], ctx)
            label = gluon.utils.split_and_load(gt[0], ctx)
            outputs = []
            for x in imgs:
                outputs.append(unet(x))
            metric.update(label, outputs)
            avg_psnr += 10 * math.log10(1 / metric.get()[1])
            metric.reset()
        avg_psnr /= batches
        print('Epoch {}: validation avg psnr: {:.3f}'.format(i + 1, avg_psnr))

        # save model
        if (i + 1) % args.save_freq == 0:
            save_model_filename = "Epoch_" + str(i + 1) + ".params"
            save_model_path = os.path.join(args.save_model_dir,
                                           save_model_filename)
            unet.save_params(save_model_path)
            print("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    save_model_filename = "Final_Epoch_" + str(i + 1) + ".params"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    unet.save_params(save_model_path)
    print("\nCheckpoint, trained model saved at", save_model_path)
예제 #9
0
    def run(self):
        # Helper methods
        def get_random_lot(data_loader):
            return next(iter(data_loader))

        # Data importing, pre-processing, and loading
        num_training_examples, num_testing_examples, train_data_lot_iterator, train_data_eval_iterator, test_data = self._load_data(
        )
        # parameters calculated from loaded data
        self._num_training_examples = num_training_examples
        self._num_testing_examples = num_testing_examples
        self._hyperparams[
            'sample_fraction'] = self._lot_size / num_training_examples
        rounds_per_epoch = round(num_training_examples / self._lot_size)

        # Set up privacy accountant
        accountant = rdp_acct.anaRDPacct()  # dpacct.anaCGFAcct()
        eps_sequence = []

        # Network structure creation
        self._create_network_params()

        # Loss function
        loss_func = self._get_loss_func()

        # Optimization procedure
        trainer = self._optimizer(self._hyperparams, self._net, self._params,
                                  loss_func, self._model_ctx, accountant)

        # begin profiling if enabled
        if self._enable_mxnet_profiling:
            from mxnet import profiler
            profiler.set_config(profile_all=True,
                                aggregate_stats=True,
                                filename='profile_output.json')
            profiler.set_state('run')

        # Training sequence
        rounds = round(self._epochs * rounds_per_epoch)
        loss_sequence = []
        current_epoch_loss = mx.nd.zeros(1, ctx=self._model_ctx)
        for t in range(1, rounds + 1):
            if self._verbose and self._print_epoch_status:
                # show current epoch progress
                epoch_number = 1 + (t - 1) // rounds_per_epoch
                epoch_progress = 1 + (t - 1) % rounds_per_epoch
                printProgressBar(
                    epoch_progress,
                    rounds_per_epoch,
                    prefix='Epoch {} progress:'.format(epoch_number),
                    length=50)

            if self._run_training:
                # prepare random lot of data for DPSGD step
                data, labels = get_random_lot(train_data_lot_iterator)
                data = data.as_in_context(self._model_ctx).reshape(
                    (-1, 1, self._input_layer))
                labels = labels.as_in_context(self._model_ctx)
            else:
                data, labels = [], []

            # perform DPSGD step
            lot_mean_loss = trainer.step(
                data,
                labels,
                accumulate_privacy=self._accumulate_privacy,
                run_training=self._run_training)

            loss_sequence.append(lot_mean_loss)
            current_epoch_loss += lot_mean_loss

            # no need to continue running training if NaNs are present
            if not np.isfinite(lot_mean_loss):
                self._run_training = False
                if self._verbose: print("NaN loss on round {}.".format(t))
            if self._params_not_finite():
                self._run_training = False
                if self._verbose:
                    print("Non-finite parameters on round {}.".format(t))

            if self._accumulate_privacy and self._debugging:
                eps_sequence.append(accountant.get_eps(self._fixed_delta))

            # print some stats after an "epoch"
            if t % rounds_per_epoch == 0:
                if self._verbose:
                    print("Epoch {}  (round {})  complete.".format(
                        t / rounds_per_epoch, t))
                    if self._run_training:
                        print("mean epoch loss: {}".format(
                            current_epoch_loss.asscalar() * self._lot_size /
                            self._num_training_examples))
                        if self._compute_epoch_accuracy:
                            print("training accuracy: {}".format(
                                self._evaluate_accuracy(
                                    train_data_eval_iterator)))
                            print("testing accuracy: {}".format(
                                self._evaluate_accuracy(test_data)))
                    if self._accumulate_privacy and self._debugging:
                        print("eps used: {}\n".format(eps_sequence[-1]))
                    print()
                current_epoch_loss = mx.nd.zeros(1, ctx=self._model_ctx)

        # end profiling if enabled
        if self._enable_mxnet_profiling:
            mx.nd.waitall()
            profiler.set_state('stop')
            print(profiler.dumps())

        # Make sure we don't report a bogus number
        if self._accumulate_privacy:
            final_eps = accountant.get_eps(self._fixed_delta)
        else:
            final_eps = -1

        test_accuracy = self._evaluate_accuracy(test_data)

        if self._save_plots or self._debugging:
            self._create_and_save_plots(t, eps_sequence, loss_sequence,
                                        final_eps, test_accuracy)

        return final_eps, test_accuracy
예제 #10
0
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
    """Training pipeline"""
    print("rank:{}, training...".format(
        kv.rank)) if "perseus" in args.kv_store else None

    if args.profiler == "1":
        # profiler config
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename='profile_output_{}.json'.format(
                                kv.rank if "perseus" in args.kv_store else 0))
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'momentum': args.momentum
    }
    if args.amp:
        optimizer_params['multi_precision'] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params)
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=None,
            kvstore=kv)  #(False if args.amp else None), kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted(
        [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rcnn_smoothl1_rho)  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    if args.custom_model:
        logger.info(
            'Custom model enabled. Expert Only!! Currently non-FPN model is not supported!!'
            ' Default setting is for MS-COCO.')
    logger.info(args)

    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        rcnn_task = ForwardBackwardTask(net,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0)
        if "perseus" in args.kv_store:
            args.executor_threads = 1
        executor = Parallel(args.executor_threads, rcnn_task) if (
            not args.horovod and "perseus" not in args.kv_store) else None
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio

        if args.profiler == "1":
            # profiler 1
            profiler.set_state('run')

        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])

                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])

            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])

            trainer.step(batch_size)

            # update metrics
            if ((not args.horovod) or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                cur_rank = kv.rank if "perseus" in args.kv_store else 0
                logger.info(
                    'rank:{}, [Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'
                    .format(
                        cur_rank, epoch, i,
                        args.log_interval * batch_size / (time.time() - btic),
                        msg))
                btic = time.time()
            if i >= 100 and args.profiler == "1":
                profiler.set_state('stop')
                print(profiler.dumps())
                break

        if ((not args.horovod) and ("perseus" not in args.kv_store)) or (
                args.horovod and hvd.rank() == 0) or (
                    ("perseus" in args.kv_store) and kv.rank == 0):  # perseus
            #if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric,
                                             args)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch,
                        args.save_interval, args.save_prefix)
        mx.nd.waitall()