Пример #1
0
            msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch, args.save_interval,
                        args.save_prefix)
        executor.__del__()


if __name__ == '__main__':
    import sys

    sys.setrecursionlimit(1100)
    args = parse_args()
    # fix seed for mxnet, numpy and python builtin random generator.
    gutils.random.seed(args.seed)

    if args.amp:
        amp.init()

    # training contexts
    if args.horovod:
        ctx = [mx.gpu(hvd.local_rank())]
    else:
Пример #2
0
def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
    """Training pipeline"""
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params().setattr('grad_req', 'write')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'momentum': args.momentum
    }
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params)
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(
            ),  # fix batchnorm, fix first stage, etc...
            'sgd',
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv)

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted(
        [float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(
        from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
    metrics = [
        mx.metric.Loss('RPN_Conf'),
        mx.metric.Loss('RPN_SmoothL1'),
        mx.metric.Loss('RCNN_CrossEntropy'),
        mx.metric.Loss('RCNN_SmoothL1'),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [
        rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric
    ]

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)
        rcnn_task = ForwardBackwardTask(net,
                                        trainer,
                                        rpn_cls_loss,
                                        rpn_box_loss,
                                        rcnn_cls_loss,
                                        rcnn_box_loss,
                                        mix_ratio=1.0)
        executor = Parallel(args.executor_threads,
                            rcnn_task) if not args.horovod else None
        if args.mixup:
            # TODO(zhreshold) only support evenly mixup now, target generator needs to be modified otherwise
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio = 0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(
                epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        print(len(train_data))
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup,
                                                  args.lr_warmup_factor)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            '[Epoch 0 Iteration {}] Set learning rate to {}'.
                            format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (not args.horovod or hvd.rank() == 0) and args.log_interval \
                    and not (i + 1) % args.log_interval:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get())
                    for metric in metrics + metrics2
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(
                        epoch, i, args.log_interval * args.batch_size /
                        (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric,
                                             args)
                map_name_train, mean_ap_train = validate(
                    net, train_data, ctx, eval_metric, args)
                if isinstance(map_name, list):
                    val_msg = '\n'.join([
                        '{}={}'.format(k, v)
                        for k, v in zip(map_name, mean_ap)
                    ])
                    train_msg = '\n'.join([
                        '{}={}'.format(k, v)
                        for k, v in zip(map_name_train, mean_ap_train)
                    ])
                    current_map = float(mean_ap[-1])
                else:
                    val_msg = '{}={}'.format(map_name, mean_ap)
                    train_msg = '{}={}'.format(map_name_train, mean_ap_train)
                    current_map = mean_ap
                logger.info('[Epoch {}] Validation: {}'.format(epoch, val_msg))
                logger.info('[Epoch {}] Train: {}'.format(epoch, train_msg))
            else:
                current_map = 0.
            save_params(net, logger, best_map, current_map, epoch,
                        args.save_interval,
                        os.path.join(args.model_dir, 'fastrcnn'))
        executor.__del__()
Пример #3
0
def train(net, train_data, val_data, eval_metric, ctx, args):
    net.collect_params.reset_ctx(ctx)
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr('grad_req', 'null')
    net.collect_train_params.setattr('grad_req', 'write')
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(net.collect_train_params(), 'sgd',
                                         {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})
    else:
        trainer = gluon.Trainer(net.collect_train_params(), 'sgd',
                                {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
                                update_on_kvstore=(False if args.amp else None), kvstore=kv)
    if args.amp:
        amp.init_trainer(trainer)

    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
    lr_warmup = float(args.lr_warmup)

    # losses, 以下4个loss是rcnn_task 里面要用到
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss()

    metrics = [mx.metric.Loss('RPN_Conf'),
               mx.metric.Loss('RPN_SmoothL1'),
               mx.metric.Loss('RCNN_CrossEntropy'),
               mx.metric.Loss('RCNN_SmoothL1')]
    # metrics: [rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss]
    # metric_losses: [[rpn_cls_loss], [rpn_box_loss], [rcnn_cls_loss], [rcnn_box_loss]]
    # metric.update(0, record)

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

    # logger set_up
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    if args.verbose:
        logger.info('Trainable parameters:')
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]

    for epoch in range(args.start_epoch, args.epochs):
        mix_ratio = 1.0
        if not args.disable_hybridization:
            net.hybridize(static_alloc=args.static_alloc)

        rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss,
                                        rcnn_cls_loss, rcnn_box_loss, mix_ratio=1.0)
        executor = Parallel(1 if args.horovod else args.executor_threads, rcnn_task)
        # executor 这一句什么意思

        if args.mixup:
            train_data._dataset._data.set_mixup(np.random.uniform, 0.5, 0.5)
            mix_ratio =0.5
            if epoch >= args.epochs - args.no_mixup_epochs:
                train_data._dataset._data.set_mixup(None)
                mix_ratio = 1.0

        # 调整学习率
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info('[Epoch {}] Set learning rate to {}'.format(epoch, new_lr))

        for metric in metrics:
            metric.reset()

        tic = time.time()  # 记录一次循环的时间
        btic = time.time()  # 记录每一个batch的时间
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio

        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                new_lr = base_lr * get_lr_at_iter(i / lr_warmup)
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info('[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            batch_size = len(batch[0])
            metric_losses = [[] for _ in metrics]  # metrics: [rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss]
            add_losses = [[] for _ in metrics2]  # metrics2 : [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

            for data in zip(*batch):
                executor.put(data)

            for j in range(len(ctx)):
                result = executor.get()
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])

            for metric, record in zip(metrics, metric_losses):
                # metrics: [rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss]
                # metric_losses: [[rpn_cls_loss], [rpn_box_loss], [rcnn_cls_loss], [rcnn_box_loss]]
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
            # metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]
            # add_losses: [[rpn_acc_metric], [rpn_bbox_metric], [rcnn_acc_metric], [rcnn_bbox_metric]]
            # rpn_acc_metric: [[rpn_label, rpn_weight], [rpn_cls_logits]]
                for pred in records:
                    # update(label, preds)
                    # label: [rpn_label, rpn_weight]
                    # preds: [rpn_cls_logits]
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            if (not args.horovod or hvd.rank() == 0) and args.log_interval and not (i + 1) % args.log_interval:
                msg = ','.join(['{} ={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                    epoch, i, args.batch_size * args.log_interval / (time.time() - btic), msg))
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ','.join(['{} ={:.3f}'.format(*metric.get()) for metric in metrics])
            logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
                epoch, (time.time() - tic), msg))
            if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0):
                # 每循环args.val_interval或者args.save_interval次
                # 就需要使用验证集来测试一次,得到current_map
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = "\n".join('{}={}'.format(k, v) for k, v in zip(map_name, mean_ap))
                logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
                current_map = float(mean_ap[-1])  # mean_ap的最后一个数据就是mAP
            else:
                current_map = 0
            save_params(net, logger, best_map, current_map, epoch, args.save_interval, args.save_prefix)
        executor.__del__()