Ejemplo n.º 1
0
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
    """Training pipeline"""
    args.kv_store = 'device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    for k, v in net.collect_params('.*bias').items():
        v.wd_mult = 0.0
    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum, }
    if args.clip_gradient > 0.0:
        optimizer_params['clip_gradient'] = args.clip_gradient
    if args.amp:
        optimizer_params['multi_precision'] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params
        )
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=args.rcnn_smoothl1_rho)  # == smoothl1
    rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1'),
               mx.metric.Loss('RCNN_Mask')]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    rcnn_mask_metric = MaskAccMetric()
    rcnn_fgmask_metric = MaskFGAccMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric,
                rcnn_acc_metric, rcnn_bbox_metric,
                rcnn_mask_metric, rcnn_fgmask_metric]
    async_eval_processes = []
    logger.info(args)

    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    base_lr = trainer.learning_rate
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
                                        rcnn_box_loss, rcnn_mask_loss, args.amp)
        executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        train_data_iter = iter(train_data)
        next_data_batch = next(train_data_iter)
        next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
        for i in range(len(train_data)):
            batch = next_data_batch
            if i + epoch * len(train_data) <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter((i + epoch * len(train_data)) / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info('[Epoch {} Iteration {}] Set learning rate to {}'
                                    .format(epoch, i, new_lr))
                    trainer.set_learning_rate(new_lr)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            try:
                # prefetch next batch
                next_data_batch = next(train_data_iter)
                next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
            except StopIteration:
                pass

            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
                batch_speed = args.log_interval * args.batch_size / (time.time() - btic)
                speed.append(batch_speed)
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                    epoch, i, batch_speed, msg))
                btic = time.time()
        if speed:
            avg_batch_speed = sum(speed) / len(speed)
        # validate and save params
        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, Speed: {:.3f} samples/sec, {}'.format(
                epoch, (time.time() - tic), avg_batch_speed, msg))
        if not (epoch + 1) % args.val_interval:
            # consider reduce the frequency of validation to save time
            validate(net, val_data, async_eval_processes, ctx, eval_metric, logger, epoch, best_map,
                     args)
        elif (not args.horovod) or hvd.rank() == 0:
            current_map = 0.
            save_params(net, logger, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
    for thread in async_eval_processes:
        thread.join()
Ejemplo n.º 2
0
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
    """Training pipeline"""
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'momentum': args.momentum
    }
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params)
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted(
        [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        rcnn_task = ForwardBackwardTask(net,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        print(len(train_data))
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric,
                                             args)
                map_name_train, mean_ap_train = validate(
                    net, train_data, ctx, eval_metric, args)
                if isinstance(map_name, list):
                    val_msg = '\n'.join([
                        '{}={}'.format(k, v)
                        for k, v in zip(map_name, mean_ap)
                    ])
                    train_msg = '\n'.join([
                        '{}={}'.format(k, v)
                        for k, v in zip(map_name_train, mean_ap_train)
                    ])
                    current_map = float(mean_ap[-1])
                else:
                    val_msg = '{}={}'.format(map_name, mean_ap)
                    train_msg = '{}={}'.format(map_name_train, mean_ap_train)
                    current_map = mean_ap
                logger.info('[Epoch {}] Validation: {}'.format(epoch, val_msg))
                logger.info('[Epoch {}] Train: {}'.format(epoch, train_msg))
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch,
                        args.save_interval,
                        os.path.join(args.model_dir, 'fastrcnn'))
        executor.__del__()
Ejemplo n.º 3
0
    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1'), ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
Ejemplo n.º 4
0
def get_faster_rcnn_metrics():
    return [mx.metric.Loss('RPN_Conf'), mx.metric.Loss('RPN_SmoothL1'),
            mx.metric.Loss('RCNN_CrossEntropy'), mx.metric.Loss('RCNN_SmoothL1')], \
           [RPNAccMetric(), RPNL1LossMetric(), RCNNAccMetric(), RCNNL1LossMetric()]
Ejemplo n.º 5
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    """Training pipeline"""
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    for k, v in net.collect_params('.*beta|.*bias').items():
        v.wd_mult = 0.0

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
                        net.collect_train_params(), # fix batchnorm, fix first stage, etc...
                        'sgd',
                        {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})
    else:
        trainer = gluon.Trainer(
                    net.collect_train_params(), # fix batchnorm, fix first stage, etc...
                    'sgd',
                    {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
                    update_on_kvstore=(False if args.amp else None))

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
    rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1'),
               mx.metric.Loss('RCNN_Mask')]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    rcnn_mask_metric = MaskAccMetric()
    rcnn_fgmask_metric = MaskFGAccMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric,
                rcnn_acc_metric, rcnn_bbox_metric,
                rcnn_mask_metric, rcnn_fgmask_metric]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        base_lr = trainer.learning_rate
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            batch_size = len(batch[0])
            losses = []
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            with autograd.record():
                for data, label, gt_mask, rpn_cls_targets, rpn_box_targets, rpn_box_masks in zip(
                        *batch):
                    gt_label = label[:, :, 4:5]
                    gt_box = label[:, :, :4]
                    cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors = net(
                        data, gt_box)
                    # losses of rpn
                    rpn_score = rpn_score.squeeze(axis=-1)
                    num_rpn_pos = (rpn_cls_targets >= 0).sum()
                    rpn_loss1 = rpn_cls_loss(rpn_score, rpn_cls_targets,
                                             rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
                    rpn_loss2 = rpn_box_loss(rpn_box, rpn_box_targets,
                                             rpn_box_masks) * rpn_box.size / num_rpn_pos
                    # rpn overall loss, use sum rather than average
                    rpn_loss = rpn_loss1 + rpn_loss2
                    # generate targets for rcnn
                    cls_targets, box_targets, box_masks = net.target_generator(roi, samples,
                                                                               matches, gt_label,
                                                                               gt_box)
                    # losses of rcnn
                    num_rcnn_pos = (cls_targets >= 0).sum()
                    rcnn_loss1 = rcnn_cls_loss(cls_pred, cls_targets,
                                               cls_targets >= 0) * cls_targets.size / \
                                 cls_targets.shape[0] / num_rcnn_pos
                    rcnn_loss2 = rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
                                 box_pred.shape[0] / num_rcnn_pos
                    rcnn_loss = rcnn_loss1 + rcnn_loss2
                    # generate targets for mask
                    mask_targets, mask_masks = net.mask_target(roi, gt_mask, matches, cls_targets)
                    # loss of mask
                    mask_loss = rcnn_mask_loss(mask_pred, mask_targets, mask_masks) * \
                                mask_targets.size / mask_targets.shape[0] / mask_masks.sum()
                    # overall losses
                    losses.append(rpn_loss.sum() + rcnn_loss.sum() + mask_loss.sum())
                    if (not args.horovod or hvd.rank() == 0):
                        metric_losses[0].append(rpn_loss1.sum())
                        metric_losses[1].append(rpn_loss2.sum())
                        metric_losses[2].append(rcnn_loss1.sum())
                        metric_losses[3].append(rcnn_loss2.sum())
                        metric_losses[4].append(mask_loss.sum())
                        add_losses[0].append([[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]])
                        add_losses[1].append([[rpn_box_targets, rpn_box_masks], [rpn_box]])
                        add_losses[2].append([[cls_targets], [cls_pred]])
                        add_losses[3].append([[box_targets, box_masks], [box_pred]])
                        add_losses[4].append([[mask_targets, mask_masks], [mask_pred]])
                        add_losses[5].append([[mask_targets, mask_masks], [mask_pred]])
                if args.amp:
                    with amp.scale_loss(losses, trainer) as scaled_losses:
                        autograd.backward(scaled_losses)
                else:
                    autograd.backward(losses)
                if (not args.horovod or hvd.rank() == 0):
                    for metric, record in zip(metrics, metric_losses):
                        metric.update(0, record)
                    for metric, records in zip(metrics2, add_losses):
                        for pred in records:
                            metric.update(pred[0], pred[1])
            trainer.step(batch_size)
            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval and not (i + 1) % args.log_interval:
                msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                    epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg))
                btic = time.time()
        # validate and save params
        if (not args.horovod or hvd.rank() == 0):
            msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)
def main():

    # Function to get mnist iterator given a rank
    def get_voc_iterator(rank, num_workers, net, num_shards):
        data_dir = "data-%d" % rank
        try:
            s3_client = boto3.client('s3')
            for file in [
                    'VOCtrainval_06-Nov-2007.tar', 'VOCtest_06-Nov-2007.tar',
                    'VOCtrainval_11-May-2012.tar'
            ]:
                s3_client.download_file(args.s3bucket, f'voc_tars/{file}',
                                        f'/opt/ml/code/{file}')
                with tarfile.open(filename) as tar:
                    tar.extractall(path=path)
        except:
            print('downloading from source')
            download_voc(data_dir)

        input_shape = (1, 256, 256, 3)
        batch_size = args.batch_size

        # might want to replace with mx.io.ImageDetRecordIter, this means you need data in RecordIO format
        #         train_iter = mx.io.MNISTIter(
        #             image="%s/train-images-idx3-ubyte" % data_dir,
        #             label="%s/train-labels-idx1-ubyte" % data_dir,
        #             input_shape=input_shape,
        #             batch_size=batch_size,
        #             shuffle=True,
        #             flat=False,
        #             num_parts=hvd.size(),
        #             part_index=hvd.rank()
        #         )

        train_dataset = gdata.VOCDetection(
            root=f'/opt/ml/code/data-{rank}/VOCdevkit/',
            splits=[(2007, 'trainval'), (2012, 'trainval')])
        val_dataset = gdata.VOCDetection(
            root=f'/opt/ml/code/data-{rank}/VOCdevkit/',
            splits=[(2007, 'test')])
        val_metric = VOC07MApMetric(iou_thresh=0.5,
                                    class_names=val_dataset.classes)
        im_aspect_ratio = [1.] * len(train_dataset)
        train_bfn = FasterRCNNTrainBatchify(net)
        train_sampler = gluoncv.nn.sampler.SplitSortedBucketSampler(
            im_aspect_ratio,
            batch_size,
            num_parts=hvd.size() if args.horovod else 1,
            part_index=hvd.rank() if args.horovod else 0,
            shuffle=True)
        # had issue with multi_stage=True
        train_iter = mx.gluon.data.DataLoader(train_dataset.transform(
            FasterRCNNDefaultTrainTransform(net.short,
                                            net.max_size,
                                            net,
                                            ashape=net.ashape,
                                            multi_stage=False)),
                                              batch_sampler=train_sampler,
                                              batchify_fn=train_bfn,
                                              num_workers=num_workers)

        val_bfn = Tuple(*[Append() for _ in range(3)])
        short = net.short[-1] if isinstance(net.short,
                                            (tuple, list)) else net.short
        # validation use 1 sample per device
        val_iter = mx.gluon.data.DataLoader(val_dataset.transform(
            FasterRCNNDefaultValTransform(short, net.max_size)),
                                            num_shards,
                                            False,
                                            batchify_fn=val_bfn,
                                            last_batch='keep',
                                            num_workers=num_workers)

        return train_iter, val_iter

    # Function to define neural network
    def conv_nets(model_name):
        net = model_zoo.get_model(model_name, pretrained_base=False)
        return net

    def evaluate(net, val_data, ctx, eval_metric, args):
        """Test on validation dataset."""
        clipper = gcv.nn.bbox.BBoxClipToImage()
        eval_metric.reset()
        if not args.disable_hybridization:
            # input format is differnet than training, thus rehybridization is needed.
            net.hybridize(static_alloc=args.static_alloc)
        for batch in val_data:
            batch = split_and_load(batch, ctx_list=ctx)
            det_bboxes = []
            det_ids = []
            det_scores = []
            gt_bboxes = []
            gt_ids = []
            gt_difficults = []
            for x, y, im_scale in zip(*batch):
                # get prediction results
                ids, scores, bboxes = net(x)
                det_ids.append(ids)
                det_scores.append(scores)
                # clip to image size
                det_bboxes.append(clipper(bboxes, x))
                # rescale to original resolution
                im_scale = im_scale.reshape((-1)).asscalar()
                det_bboxes[-1] *= im_scale
                # split ground truths
                gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
                gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
                gt_bboxes[-1] *= im_scale
                gt_difficults.append(
                    y.slice_axis(axis=-1, begin=5, end=6
                                 ) if y.shape[-1] > 5 else None)

            # update metric
            for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(
                    det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids,
                    gt_difficults):
                eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id,
                                   gt_diff)
        return eval_metric.get()

    # Initialize Horovod
    hvd.init()

    # Horovod: pin context to local rank
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
        ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        ctx = ctx if ctx else [mx.cpu()]
    context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(
        hvd.local_rank())
    num_workers = hvd.size()

    # Build model
    model = conv_nets(args.model_name)
    model.cast(args.dtype)
    model.hybridize()

    # Initialize parameters
    initializer = mx.init.Xavier(rnd_type='gaussian',
                                 factor_type="in",
                                 magnitude=2)
    model.initialize(initializer, ctx=context)

    # Create optimizer
    optimizer_params = {
        'momentum': args.momentum,
        'learning_rate': args.lr * hvd.size()
    }
    opt = mx.optimizer.create('sgd', **optimizer_params)

    # Load training and validation data
    train_data, val_data = get_voc_iterator(hvd.rank(), num_workers, model,
                                            len(ctx))

    # Horovod: fetch and broadcast parameters
    params = model.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    # Horovod: create DistributedTrainer, a subclass of gluon.Trainer
    trainer = hvd.DistributedTrainer(params, opt)

    # Create loss function and train metric
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
    # adding in new loss functions
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rpn_smoothl1_rho)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(
        rho=args.rcnn_smoothl1_rho)  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    metric = mx.metric.Accuracy()

    # Global training timing
    if hvd.rank() == 0:
        global_tic = time.time()

    # Train model


#     for epoch in range(args.epochs):
#         tic = time.time()
#         train_data.reset()
#         metric.reset()
#         for nbatch, batch in enumerate(train_data, start=1):
#             data = batch.data[0].as_in_context(context)
#             label = batch.label[0].as_in_context(context)
#             with autograd.record():
#                 output = model(data.astype(args.dtype, copy=False))
#                 loss = loss_fn(output, label)
#             loss.backward()
#             trainer.step(args.batch_size)
#             metric.update([label], [output])

#             if nbatch % 100 == 0:
#                 name, acc = metric.get()
#                 logging.info('[Epoch %d Batch %d] Training: %s=%f' %
#                              (epoch, nbatch, name, acc))

#         if hvd.rank() == 0:
#             elapsed = time.time() - tic
#             speed = nbatch * args.batch_size * hvd.size() / elapsed
#             logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
#                          epoch, speed, elapsed)

#         # Evaluate model accuracy
#         _, train_acc = metric.get()
#         name, val_acc = evaluate(model, val_data, context)
#         if hvd.rank() == 0:
#             logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
#                          train_acc, name, val_acc)

#     if hvd.rank()==0:
#         global_training_time =time.time() - global_tic
#         print("Global elpased time on training:{}".format(global_training_time))
#         device = context.device_type + str(num_workers)

# train from train_faster_rcnn.py
    for epoch in range(args.epochs):
        lr_decay = float(args.lr_decay)
        lr_steps = sorted(
            [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
        lr_warmup = float(args.lr_warmup)  # avoid int division
        # this simplifies dealing with all of the loss functions
        rcnn_task = ForwardBackwardTask(model,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0,
                                        amp_enabled=args.amp)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        if not args.disable_hybridization:
            model.hybridize(static_alloc=args.static_alloc)
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:  # does a learning rate reset if warming up
                # adjust based on real percentage
                if (lr_warmup != 0):
                    new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                      args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(
                batch, ctx_list=ctx
            )  # does split and load function, creates a batch per device
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(model, val_data, ctx, eval_metric,
                                             args)
                val_msg = '\n'.join(
                    ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(
                    epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(model, logger, best_map, current_map, epoch,
                        args.save_interval, args.save_prefix)
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
    """Training pipeline"""
    args.kv_store = "device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr("grad_req", "null")
    net.collect_train_params().setattr("grad_req", "write")
    optimizer_params = {"learning_rate": args.lr, "wd": args.wd, "momentum": args.momentum}
    if args.amp:
        optimizer_params["multi_precision"] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
        )
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv,
        )

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.0)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=1.0)  # == smoothl1
    metrics = [
        mx.metric.Loss("RPN_Conf"),
        mx.metric.Loss("RPN_SmoothL1"),
        mx.metric.Loss("RCNN_CrossEntropy"),
        mx.metric.Loss("RCNN_SmoothL1"),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

    logger.info(args)

    if args.verbose:
        logger.info("Trainable parameters:")
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(
            net,
            trainer,
            rpn_cls_loss,
            rpn_box_loss,
            rcnn_cls_loss,
            rcnn_box_loss,
            mix_ratio=1.0,
            amp_enabled=args.amp,
        )
        executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        net.hybridize()

        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(
                    i / lr_warmup, args.lr_warmup_factor / args.num_gpus
                )
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            "[Epoch 0 Iteration {}] Set learning rate to {}".format(i, new_lr)
                        )
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (
                (not args.horovod or hvd.rank() == 0)
                and args.log_interval
                and not (i + 1) % args.log_interval
            ):
                msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics + metrics2])
                logger.info(
                    "[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
                        epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg
                    )
                )
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics])
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}".format(epoch, (time.time() - tic), msg)
            )
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.0
            save_params(
                net,
                logger,
                best_map,
                current_map,
                epoch,
                args.save_interval,
                os.path.join(args.sm_save, args.save_prefix),
                args,
            )
Ejemplo n.º 8
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    net.collect_params.reset_ctx(ctx)
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params.setattr('grad_req', 'write')
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_train_params(), 'sgd',
                                         {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})
    else:
        trainer = gluon.Trainer(net.collect_train_params(), 'sgd',
                                {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
                                update_on_kvstore=(False if args.amp else None), kvstore=kv)
    if args.amp:
        amp.init_trainer(trainer)

    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)

    # losses, 以下4个loss是rcnn_task 里面要用到
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()

    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1')]
    # metrics: [rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss]
    # metric_losses: [[rpn_cls_loss], [rpn_box_loss], [rcnn_cls_loss], [rcnn_box_loss]]
    # metric.update(0, record)

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

    # logger set_up
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]

    for epoch in range(args.start_epoch, args.epochs):
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)

        rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss,
                                        rcnn_cls_loss, rcnn_box_loss, mix_ratio=1.0)
        executor = Parallel(1 if args.horovod else args.executor_threads, rcnn_task)
        # executor 这一句什么意思

        if args.mixup:
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio =0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0

        # 调整学习率
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info('[Epoch {}] Set learning rate to {}'.format(epoch, new_lr))

        for metric in metrics:
            metric.reset()

        tic = time.time()  # 记录一次循环的时间
        btic = time.time()  # 记录每一个batch的时间
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio

        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info('[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            batch_size = len(batch[0])
            metric_losses = [[] for _ in metrics]  # metrics: [rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss]
            add_losses = [[] for _ in metrics2]  # metrics2 : [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

            for data in zip(*batch):
                executor.put(data)

            for j in range(len(ctx)):
                result = executor.get()
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])

            for metric, record in zip(metrics, metric_losses):
                # metrics: [rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss]
                # metric_losses: [[rpn_cls_loss], [rpn_box_loss], [rcnn_cls_loss], [rcnn_box_loss]]
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
            # metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]
            # add_losses: [[rpn_acc_metric], [rpn_bbox_metric], [rcnn_acc_metric], [rcnn_bbox_metric]]
            # rpn_acc_metric: [[rpn_label, rpn_weight], [rpn_cls_logits]]
                for pred in records:
                    # update(label, preds)
                    # label: [rpn_label, rpn_weight]
                    # preds: [rpn_cls_logits]
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            if (not args.horovod or hvd.rank() == 0) and args.log_interval and not (i + 1) % args.log_interval:
                msg = ','.join(['{} ={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                    epoch, i, args.batch_size * args.log_interval / (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(['{} ={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0):
                # 每循环args.val_interval或者args.save_interval次
                # 就需要使用验证集来测试一次,得到current_map
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = "\n".join('{}={}'.format(k, v) for k, v in zip(map_name, mean_ap))
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])  # mean_ap的最后一个数据就是mAP
            else:
                current_map = 0
            save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)
        executor.__del__()
Ejemplo n.º 9
0
def train(net, train_loader, val_loader, eval_metric, ctx, cfg):
    kv = mx.kvstore.create(cfg["train"]["kv_store"])
    net.collect_params().setattr('grad_req', 'null')
    # 需要训练的参数的train_pattern在构造时输入到网络中
    net.collect_train_params().setattr('grad_req', 'write')
    optimizer_params = {
        "learning_rate": cfg["train"]["lr"],
        "wd": cfg["train"]["wd"],
        "momentum": cfg["train"]["momentum"],
    }
    trainer = gluon.Trainer(net.collect_train_params(),
                            "sgd",
                            optimizer_params,
                            kvstore=kv)

    lr_decay = float(cfg["train"]["lr_decay"])
    lr_steps = sorted([float(ls) for ls in cfg["train"]["lr_decay_epoch"]])
    lr_warmup = float(cfg["train"]["lr_warmup_iteration"])

    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(
        rho=cfg["train"]["rpn_smoothl1_rho"])
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(
        rho=cfg["train"]["rcnn_smoothl1_rho"])

    metrics = [
        mx.metric.Loss("RPN_Conf"),
        mx.metric.Loss("RPN_SmoothL1"),
        mx.metric.Loss("RCNN_CrossEntropy"),
        mx.metric.Loss("RCNN_SmoothL1"),
        mx.metric.Loss('RPN_GT_Recall'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    data_prepare_time = Record("data_prepare_time")
    data_distributed_time = Record("data_distributed_time")
    net_forward_backward_time = Record("net_forward_backward_time")

    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = cfg["train"]["save_prefix"] + "_train.log"
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(cfg)

    if cfg["verbose"]:
        logger.info("Trainable parameters:")
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(
        cfg["train"]['start_epoch']))
    best_map = [0]
    for epoch in range(cfg["train"]["start_epoch"], cfg["train"]["epochs"]):
        rcnn_task = ForwardBackwardTask(net,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0,
                                        amp_enabled=None)
        # 多线程执行运算操作
        # 每个线程处理一部分数据,这部分数据可以在任意gpus上
        # 线程数量和gpus数量无关
        # 在多gpus训练中,每个设备都要调用一次forward-backward操作
        # 每次调用都需要等待执行结果
        # 使用Parallel,每次调用都不需要等待执行结果,
        # 实现并行运行
        # 也可以理解为后台线程进行运算(如果Parallel设置成主线程不运算)
        # cfg["executor_threads"]设置成1,
        # Parallel的代码:
        # 前cfg["executor_threads"]次put()的数据在主线程运算,
        # 后面再调用put()在后台线程运算
        # 目的是在第一次迭代中模型在主线程初始化
        # 调用cfg["executor_threads"]次是为了保证模型在不同设备上正常进行初始化?
        executor = Parallel(cfg["train"]["executor_threads"], rcnn_task)
        mix_ratio = 1.0
        if not cfg["train"]["disable_hybridization"]:
            net.hybridize(static_alloc=cfg["train"]["static_alloc"])

        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))

        for metric in metrics:
            metric.reset()
        for metric in metrics2:
            metric.reset()

        data_prepare_time.reset()
        data_distributed_time.reset()
        net_forward_backward_time.reset()

        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio

        before_data_prepare_point = time.time()
        for i, batch in enumerate(train_loader):
            data_prepare_time.update(None,
                                     time.time() - before_data_prepare_point)
            if epoch == 0 and i <= lr_warmup:
                new_lr = base_lr * get_lr_at_iter(
                    i / lr_warmup, cfg["train"]["lr_warmup_factor"])
                if new_lr != trainer.learning_rate:
                    if i % cfg["train"]["log_interval"] == 0:
                        logger.info(
                            "[Epoch 0 Iteration {}] Set learning rate to {}".
                            format(i, new_lr))
                trainer.set_learning_rate(new_lr)
            before_data_distributed_point = time.time()
            # img, label, cls_targets, box_targets, box_masks = batch
            batch = split_and_load(batch, ctx_list=ctx)  # 分发数据
            data_distributed_time.update(
                None,
                time.time() - before_data_distributed_point)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]

            before_net_forward_backward_point = time.time()
            for data in zip(*batch):
                executor.put(data)  #
            for j in range(len(ctx)):
                result = executor.get()
                for k in range(len(metric_losses)):
                    metric_losses[k].append(result[k])
                for k in range(len(add_losses)):
                    add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)  # 把所有loss放到一起
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(cfg["dataset"]["batch_size_per_device"] * len(ctx))
            net_forward_backward_time.update(
                None,
                time.time() - before_net_forward_backward_point)

            if cfg["train"]["log_interval"] and not (
                    i + 1) % cfg["train"]["log_interval"]:
                msg = ",".join([
                    "{}={:.3f}".format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    "[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".
                    format(
                        epoch, i, cfg["train"]["log_interval"] *
                        cfg["dataset"]["batch_size_per_device"] * len(ctx) /
                        (time.time() - btic), msg))
                time_msg = ",".join([
                    "{}={}".format(*metric.get()) for metric in [
                        data_prepare_time, data_distributed_time,
                        net_forward_backward_time
                    ]
                ])
                logger.info("[Epoch {}][Batch {}], {}".format(
                    epoch, i, time_msg))
                btic = time.time()

            before_data_prepare_point = time.time()

        msg = ",".join(
            ["{}={:.3f}".format(*metric.get()) for metric in metrics])
        logger.info("[Epoch {}] Training cost: {:.3f}, {}".format(
            epoch, (time.time() - tic), msg))
        if not (epoch + 1) % cfg["train"]["val_interval"]:
            map_name, mean_ap = validate(net, val_loader, ctx, eval_metric,
                                         cfg)
            val_msg = "\n".join(
                ["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
            logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
            current_map = float(mean_ap[-1])
        else:
            current_map = 0.
        save_params(net, logger, best_map, current_map, epoch,
                    cfg["train"]["save_interval"], cfg["train"]["save_prefix"])