Ejemplo n.º 1
0
def test(net,
         val_data,
         batch_fn,
         use_rec,
         dtype,
         ctx,
         calc_weight_count=False,
         extended_log=False):
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    tic = time.time()
    err_top1_val, err_top5_val = validate(acc_top1=acc_top1,
                                          acc_top5=acc_top5,
                                          net=net,
                                          val_data=val_data,
                                          batch_fn=batch_fn,
                                          use_rec=use_rec,
                                          dtype=dtype,
                                          ctx=ctx)
    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        logging.info('Model: {} trainable parameters'.format(weight_count))
    if extended_log:
        logging.info(
            'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'.
            format(top1=err_top1_val, top5=err_top5_val))
    else:
        logging.info('Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format(
            top1=err_top1_val, top5=err_top5_val))
    logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))
Ejemplo n.º 2
0
def test(net,
         test_data,
         batch_fn,
         data_source_needs_reset,
         metric,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False):
    if not calc_flops_only:
        tic = time.time()
        validate(metric=metric,
                 net=net,
                 val_data=test_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        accuracy_msg = report_accuracy(metric=metric,
                                       extended_log=extended_log)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))
Ejemplo n.º 3
0
def test(net,
         val_data,
         batch_fn,
         data_source_needs_reset,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False):
    if not calc_flops_only:
        acc_top1 = mx.metric.Accuracy()
        acc_top5 = mx.metric.TopKAccuracy(5)
        tic = time.time()
        err_top1_val, err_top5_val = validate(
            acc_top1=acc_top1,
            acc_top5=acc_top5,
            net=net,
            val_data=val_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx)
        if extended_log:
            logging.info(
                'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'
                .format(top1=err_top1_val, top5=err_top5_val))
        else:
            logging.info(
                'Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format(
                    top1=err_top1_val, top5=err_top5_val))
        logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))
Ejemplo n.º 4
0
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data,
              batch_fn, data_source_needs_reset, dtype, net, teacher_net,
              discrim_net, trainer, lr_scheduler, lp_saver, log_interval,
              mixup, mixup_epoch_tail, label_smoothing, num_classes,
              grad_clip_value, batch_size_scale, val_metric, train_metric,
              loss_metrics, loss_func, discrim_loss_func, ctx):
    """
    Main procedure for training model.

    Parameters:
    ----------
    batch_size : int
        Training batch size.
    num_epochs : int
        Number of training epochs.
    start_epoch1 : int
        Number of starting epoch (1-based).
    train_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator (training subset).
    val_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator (validation subset).
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    dtype : str
        Base data type for tensors.
    net : HybridBlock
        Model.
    teacher_net : HybridBlock or None
        Teacher model.
    discrim_net : HybridBlock or None
        MEALv2 discriminator model.
    trainer : Trainer
        Trainer.
    lr_scheduler : LRScheduler
        Learning rate scheduler.
    lp_saver : TrainLogParamSaver
        Model/trainer state saver.
    log_interval : int
        Batch count period for logging.
    mixup : bool
        Whether to use mixup.
    mixup_epoch_tail : int
        Number of epochs without mixup at the end of training.
    label_smoothing : bool
        Whether to use label-smoothing.
    num_classes : int
        Number of model classes.
    grad_clip_value : float
        Threshold for gradient clipping.
    batch_size_scale : int
        Manual batch-size increasing factor.
    val_metric : EvalMetric
        Metric object instance (validation subset).
    train_metric : EvalMetric
        Metric object instance (training subset).
    loss_metrics : list of EvalMetric
        Metric object instances (loss values).
    loss_func : Loss
        Loss object instance.
    discrim_loss_func : Loss or None
        MEALv2 adversarial loss function.
    ctx : Context
        MXNet context.
    """
    if batch_size_scale != 1:
        for p in net.collect_params().values():
            p.grad_req = "add"

    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    # loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing)))

    assert (type(start_epoch1) == int)
    assert (start_epoch1 >= 1)
    if start_epoch1 > 1:
        logging.info("Start training from [Epoch {}]".format(start_epoch1))
        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            start_epoch1 - 1, val_accuracy_msg))

    gtic = time.time()
    for epoch in range(start_epoch1 - 1, num_epochs):
        train_epoch(epoch=epoch,
                    net=net,
                    teacher_net=teacher_net,
                    discrim_net=discrim_net,
                    train_metric=train_metric,
                    loss_metrics=loss_metrics,
                    train_data=train_data,
                    batch_fn=batch_fn,
                    data_source_needs_reset=data_source_needs_reset,
                    dtype=dtype,
                    ctx=ctx,
                    loss_func=loss_func,
                    discrim_loss_func=discrim_loss_func,
                    trainer=trainer,
                    lr_scheduler=lr_scheduler,
                    batch_size=batch_size,
                    log_interval=log_interval,
                    mixup=mixup,
                    mixup_epoch_tail=mixup_epoch_tail,
                    label_smoothing=label_smoothing,
                    num_classes=num_classes,
                    num_epochs=num_epochs,
                    grad_clip_value=grad_clip_value,
                    batch_size_scale=batch_size_scale)

        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            epoch + 1, val_accuracy_msg))

        if lp_saver is not None:
            lp_saver_kwargs = {"net": net, "trainer": trainer}
            val_acc_values = val_metric.get()[1]
            train_acc_values = train_metric.get()[1]
            val_acc_values = val_acc_values if type(
                val_acc_values) == list else [val_acc_values]
            train_acc_values = train_acc_values if type(
                train_acc_values) == list else [train_acc_values]
            lp_saver.epoch_test_end_callback(
                epoch1=(epoch + 1),
                params=(val_acc_values + train_acc_values +
                        [loss_metrics[0].get()[1], trainer.learning_rate]),
                **lp_saver_kwargs)

    logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic))
    if lp_saver is not None:
        opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind)
        logging.info("Best {}: {:.4f} at {} epoch".format(
            opt_metric_name, lp_saver.best_eval_metric_value,
            lp_saver.best_eval_metric_epoch))
Ejemplo n.º 5
0
def calc_model_accuracy(net,
                        test_data,
                        batch_fn,
                        data_source_needs_reset,
                        metric,
                        dtype,
                        ctx,
                        input_image_size,
                        in_channels,
                        calc_weight_count=False,
                        calc_flops=False,
                        calc_flops_only=True,
                        extended_log=False):
    """
    Main test routine.

    Parameters:
    ----------
    net : HybridBlock
        Model.
    test_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator.
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    metric : EvalMetric
        Metric object instance.
    dtype : str
        Base data type for tensors.
    ctx : Context
        MXNet context.
    input_image_size : tuple of 2 ints
        Spatial size of the expected input image.
    in_channels : int
        Number of input channels.
    calc_weight_count : bool, default False
        Whether to calculate count of weights.
    calc_flops : bool, default False
        Whether to calculate FLOPs.
    calc_flops_only : bool, default True
        Whether to only calculate FLOPs without testing.
    extended_log : bool, default False
        Whether to log more precise accuracy values.

    Returns:
    -------
    list of floats
        Accuracy values.
    """
    if not calc_flops_only:
        tic = time.time()
        validate(metric=metric,
                 net=net,
                 val_data=test_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        accuracy_msg = report_accuracy(metric=metric,
                                       extended_log=extended_log)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))
        acc_values = metric.get()[1]
        acc_values = acc_values if type(acc_values) == list else [acc_values]
    else:
        acc_values = []

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))

    return acc_values
Ejemplo n.º 6
0
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data,
              batch_fn, data_source_needs_reset, dtype, net, trainer,
              lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail,
              label_smoothing, num_classes, grad_clip_value, batch_size_scale,
              val_metric, train_metric, ctx):

    if batch_size_scale != 1:
        for p in net.collect_params().values():
            p.grad_req = "add"

    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    loss_func = gluon.loss.SoftmaxCrossEntropyLoss(
        sparse_label=(not (mixup or label_smoothing)))

    assert (type(start_epoch1) == int)
    assert (start_epoch1 >= 1)
    if start_epoch1 > 1:
        logging.info("Start training from [Epoch {}]".format(start_epoch1))
        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            start_epoch1 - 1, val_accuracy_msg))

    gtic = time.time()
    for epoch in range(start_epoch1 - 1, num_epochs):
        train_loss = train_epoch(
            epoch=epoch,
            net=net,
            train_metric=train_metric,
            train_data=train_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx,
            loss_func=loss_func,
            trainer=trainer,
            lr_scheduler=lr_scheduler,
            batch_size=batch_size,
            log_interval=log_interval,
            mixup=mixup,
            mixup_epoch_tail=mixup_epoch_tail,
            label_smoothing=label_smoothing,
            num_classes=num_classes,
            num_epochs=num_epochs,
            grad_clip_value=grad_clip_value,
            batch_size_scale=batch_size_scale)

        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            epoch + 1, val_accuracy_msg))

        if lp_saver is not None:
            lp_saver_kwargs = {"net": net, "trainer": trainer}
            val_acc_values = val_metric.get()[1]
            train_acc_values = train_metric.get()[1]
            val_acc_values = val_acc_values if type(
                val_acc_values) == list else [val_acc_values]
            train_acc_values = train_acc_values if type(
                train_acc_values) == list else [train_acc_values]
            lp_saver.epoch_test_end_callback(
                epoch1=(epoch + 1),
                params=(val_acc_values + train_acc_values +
                        [train_loss, trainer.learning_rate]),
                **lp_saver_kwargs)

    logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic))
    if lp_saver is not None:
        opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind)
        logging.info("Best {}: {:.4f} at {} epoch".format(
            opt_metric_name, lp_saver.best_eval_metric_value,
            lp_saver.best_eval_metric_epoch))
Ejemplo n.º 7
0
def train_net(batch_size,
              num_epochs,
              start_epoch1,
              train_data,
              val_data,
              batch_fn,
              data_source_needs_reset,
              dtype,
              net,
              trainer,
              lr_scheduler,
              lp_saver,
              log_interval,
              mixup,
              mixup_epoch_tail,
              label_smoothing,
              num_classes,
              grad_clip_value,
              batch_size_scale,
              ctx):

    assert (not (mixup and label_smoothing))

    if batch_size_scale != 1:
        for p in net.collect_params().values():
            p.grad_req = 'add'

    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    acc_top1_val = mx.metric.Accuracy()
    acc_top5_val = mx.metric.TopKAccuracy(5)
    acc_top1_train = mx.metric.Accuracy()

    loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing)))

    assert (type(start_epoch1) == int)
    assert (start_epoch1 >= 1)
    if start_epoch1 > 1:
        logging.info('Start training from [Epoch {}]'.format(start_epoch1))
        err_top1_val, err_top5_val = validate(
            acc_top1=acc_top1_val,
            acc_top5=acc_top5_val,
            net=net,
            val_data=val_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx)
        logging.info('[Epoch {}] validation: err-top1={:.4f}\terr-top5={:.4f}'.format(
            start_epoch1 - 1, err_top1_val, err_top5_val))

    gtic = time.time()
    for epoch in range(start_epoch1 - 1, num_epochs):
        err_top1_train, train_loss = train_epoch(
            epoch=epoch,
            net=net,
            acc_top1_train=acc_top1_train,
            train_data=train_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx,
            loss_func=loss_func,
            trainer=trainer,
            lr_scheduler=lr_scheduler,
            batch_size=batch_size,
            log_interval=log_interval,
            mixup=mixup,
            mixup_epoch_tail=mixup_epoch_tail,
            label_smoothing=label_smoothing,
            num_classes=num_classes,
            num_epochs=num_epochs,
            grad_clip_value=grad_clip_value,
            batch_size_scale=batch_size_scale)

        err_top1_val, err_top5_val = validate(
            acc_top1=acc_top1_val,
            acc_top5=acc_top5_val,
            net=net,
            val_data=val_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx)

        logging.info('[Epoch {}] validation: err-top1={:.4f}\terr-top5={:.4f}'.format(
            epoch + 1, err_top1_val, err_top5_val))

        if lp_saver is not None:
            lp_saver_kwargs = {'net': net, 'trainer': trainer}
            lp_saver.epoch_test_end_callback(
                epoch1=(epoch + 1),
                params=[err_top1_val, err_top1_train, err_top5_val, train_loss, trainer.learning_rate],
                **lp_saver_kwargs)

    logging.info('Total time cost: {:.2f} sec'.format(time.time() - gtic))
    if lp_saver is not None:
        logging.info('Best err-top5: {:.4f} at {} epoch'.format(
            lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))