Ejemplo n.º 1
0
def test(net,
         test_data,
         batch_fn,
         data_source_needs_reset,
         metric,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False):
    if not calc_flops_only:
        tic = time.time()
        validate(metric=metric,
                 net=net,
                 val_data=test_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        accuracy_msg = report_accuracy(metric=metric,
                                       extended_log=extended_log)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))
Ejemplo n.º 2
0
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data,
              batch_fn, data_source_needs_reset, dtype, net, teacher_net,
              discrim_net, trainer, lr_scheduler, lp_saver, log_interval,
              mixup, mixup_epoch_tail, label_smoothing, num_classes,
              grad_clip_value, batch_size_scale, val_metric, train_metric,
              loss_metrics, loss_func, discrim_loss_func, ctx):
    """
    Main procedure for training model.

    Parameters:
    ----------
    batch_size : int
        Training batch size.
    num_epochs : int
        Number of training epochs.
    start_epoch1 : int
        Number of starting epoch (1-based).
    train_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator (training subset).
    val_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator (validation subset).
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    dtype : str
        Base data type for tensors.
    net : HybridBlock
        Model.
    teacher_net : HybridBlock or None
        Teacher model.
    discrim_net : HybridBlock or None
        MEALv2 discriminator model.
    trainer : Trainer
        Trainer.
    lr_scheduler : LRScheduler
        Learning rate scheduler.
    lp_saver : TrainLogParamSaver
        Model/trainer state saver.
    log_interval : int
        Batch count period for logging.
    mixup : bool
        Whether to use mixup.
    mixup_epoch_tail : int
        Number of epochs without mixup at the end of training.
    label_smoothing : bool
        Whether to use label-smoothing.
    num_classes : int
        Number of model classes.
    grad_clip_value : float
        Threshold for gradient clipping.
    batch_size_scale : int
        Manual batch-size increasing factor.
    val_metric : EvalMetric
        Metric object instance (validation subset).
    train_metric : EvalMetric
        Metric object instance (training subset).
    loss_metrics : list of EvalMetric
        Metric object instances (loss values).
    loss_func : Loss
        Loss object instance.
    discrim_loss_func : Loss or None
        MEALv2 adversarial loss function.
    ctx : Context
        MXNet context.
    """
    if batch_size_scale != 1:
        for p in net.collect_params().values():
            p.grad_req = "add"

    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    # loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing)))

    assert (type(start_epoch1) == int)
    assert (start_epoch1 >= 1)
    if start_epoch1 > 1:
        logging.info("Start training from [Epoch {}]".format(start_epoch1))
        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            start_epoch1 - 1, val_accuracy_msg))

    gtic = time.time()
    for epoch in range(start_epoch1 - 1, num_epochs):
        train_epoch(epoch=epoch,
                    net=net,
                    teacher_net=teacher_net,
                    discrim_net=discrim_net,
                    train_metric=train_metric,
                    loss_metrics=loss_metrics,
                    train_data=train_data,
                    batch_fn=batch_fn,
                    data_source_needs_reset=data_source_needs_reset,
                    dtype=dtype,
                    ctx=ctx,
                    loss_func=loss_func,
                    discrim_loss_func=discrim_loss_func,
                    trainer=trainer,
                    lr_scheduler=lr_scheduler,
                    batch_size=batch_size,
                    log_interval=log_interval,
                    mixup=mixup,
                    mixup_epoch_tail=mixup_epoch_tail,
                    label_smoothing=label_smoothing,
                    num_classes=num_classes,
                    num_epochs=num_epochs,
                    grad_clip_value=grad_clip_value,
                    batch_size_scale=batch_size_scale)

        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            epoch + 1, val_accuracy_msg))

        if lp_saver is not None:
            lp_saver_kwargs = {"net": net, "trainer": trainer}
            val_acc_values = val_metric.get()[1]
            train_acc_values = train_metric.get()[1]
            val_acc_values = val_acc_values if type(
                val_acc_values) == list else [val_acc_values]
            train_acc_values = train_acc_values if type(
                train_acc_values) == list else [train_acc_values]
            lp_saver.epoch_test_end_callback(
                epoch1=(epoch + 1),
                params=(val_acc_values + train_acc_values +
                        [loss_metrics[0].get()[1], trainer.learning_rate]),
                **lp_saver_kwargs)

    logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic))
    if lp_saver is not None:
        opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind)
        logging.info("Best {}: {:.4f} at {} epoch".format(
            opt_metric_name, lp_saver.best_eval_metric_value,
            lp_saver.best_eval_metric_epoch))
Ejemplo n.º 3
0
def train_epoch(epoch, net, teacher_net, discrim_net, train_metric,
                loss_metrics, train_data, batch_fn, data_source_needs_reset,
                dtype, ctx, loss_func, discrim_loss_func, trainer,
                lr_scheduler, batch_size, log_interval, mixup,
                mixup_epoch_tail, label_smoothing, num_classes, num_epochs,
                grad_clip_value, batch_size_scale):
    """
    Train model on particular epoch.

    Parameters:
    ----------
    epoch : int
        Epoch number.
    net : HybridBlock
        Model.
    teacher_net : HybridBlock or None
        Teacher model.
    discrim_net : HybridBlock or None
        MEALv2 discriminator model.
    train_metric : EvalMetric
        Metric object instance.
    loss_metric : list of EvalMetric
        Metric object instances (loss values).
    train_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator.
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    dtype : str
        Base data type for tensors.
    ctx : Context
        MXNet context.
    loss_func : Loss
        Loss function.
    discrim_loss_func : Loss or None
        MEALv2 adversarial loss function.
    trainer : Trainer
        Trainer.
    lr_scheduler : LRScheduler
        Learning rate scheduler.
    batch_size : int
        Training batch size.
    log_interval : int
        Batch count period for logging.
    mixup : bool
        Whether to use mixup.
    mixup_epoch_tail : int
        Number of epochs without mixup at the end of training.
    label_smoothing : bool
        Whether to use label-smoothing.
    num_classes : int
        Number of model classes.
    num_epochs : int
        Number of training epochs.
    grad_clip_value : float
        Threshold for gradient clipping.
    batch_size_scale : int
        Manual batch-size increasing factor.

    Returns:
    -------
    float
        Loss value.
    """
    labels_list_inds = None
    batch_size_extend_count = 0
    tic = time.time()
    if data_source_needs_reset:
        train_data.reset()
    train_metric.reset()
    for m in loss_metrics:
        m.reset()

    i = 0
    btic = time.time()
    for i, batch in enumerate(train_data):
        data_list, labels_list = batch_fn(batch, ctx)

        labels_one_hot = False
        if teacher_net is not None:
            labels_list = [
                teacher_net(x.astype(dtype,
                                     copy=False)).softmax(axis=-1).mean(axis=1)
                for x in data_list
            ]
            labels_list_inds = [y.argmax(axis=-1) for y in labels_list]
            labels_one_hot = True

        if label_smoothing and not (teacher_net is not None):
            eta = 0.1
            on_value = 1 - eta + eta / num_classes
            off_value = eta / num_classes
            if not labels_one_hot:
                labels_list_inds = labels_list
                labels_list = [
                    y.one_hot(depth=num_classes,
                              on_value=on_value,
                              off_value=off_value) for y in labels_list
                ]
                labels_one_hot = True
        if mixup:
            if not labels_one_hot:
                labels_list_inds = labels_list
                labels_list = [
                    y.one_hot(depth=num_classes) for y in labels_list
                ]
                labels_one_hot = True
            if epoch < num_epochs - mixup_epoch_tail:
                alpha = 1
                lam = np.random.beta(alpha, alpha)
                data_list = [lam * x + (1 - lam) * x[::-1] for x in data_list]
                labels_list = [
                    lam * y + (1 - lam) * y[::-1] for y in labels_list
                ]

        with ag.record():
            outputs_list = [
                net(x.astype(dtype, copy=False)) for x in data_list
            ]
            loss_list = [
                loss_func(yhat, y.astype(dtype, copy=False))
                for yhat, y in zip(outputs_list, labels_list)
            ]

            if discrim_net is not None:
                d_pred_list = [
                    discrim_net(yhat.astype(dtype, copy=False).softmax())
                    for yhat in outputs_list
                ]
                d_label_list = [
                    discrim_net(y.astype(dtype, copy=False))
                    for y in labels_list
                ]
                d_loss_list = [
                    discrim_loss_func(yhat, y)
                    for yhat, y in zip(d_pred_list, d_label_list)
                ]
                loss_list = [z + dz for z, dz in zip(loss_list, d_loss_list)]

        for loss in loss_list:
            loss.backward()
        lr_scheduler.update(i, epoch)

        if grad_clip_value is not None:
            grads = [
                v.grad(ctx[0]) for v in net.collect_params().values()
                if v._grad is not None
            ]
            gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value)

        if batch_size_scale == 1:
            trainer.step(batch_size)
        else:
            if (i + 1) % batch_size_scale == 0:
                batch_size_extend_count = 0
                trainer.step(batch_size * batch_size_scale)
                for p in net.collect_params().values():
                    p.zero_grad()
            else:
                batch_size_extend_count += 1

        train_metric.update(
            labels=(labels_list if not labels_one_hot else labels_list_inds),
            preds=outputs_list)
        loss_metrics[0].update(labels=None, preds=loss_list)
        if (discrim_net is not None) and (len(loss_metrics) > 1):
            loss_metrics[1].update(labels=None, preds=d_loss_list)

        if log_interval and not (i + 1) % log_interval:
            speed = batch_size * log_interval / (time.time() - btic)
            btic = time.time()
            train_accuracy_msg = report_accuracy(metric=train_metric)
            loss_accuracy_msg = report_accuracy(metric=loss_metrics[0])
            if (discrim_net is not None) and (len(loss_metrics) > 1):
                dloss_accuracy_msg = report_accuracy(metric=loss_metrics[1])
                logging.info(
                    "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\t{}\t{}\tlr={:.5f}"
                    .format(epoch + 1, i, speed, train_accuracy_msg,
                            loss_accuracy_msg, dloss_accuracy_msg,
                            trainer.learning_rate))
            else:
                logging.info(
                    "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\t{}\tlr={:.5f}"
                    .format(epoch + 1, i, speed, train_accuracy_msg,
                            loss_accuracy_msg, trainer.learning_rate))

    if (batch_size_scale != 1) and (batch_size_extend_count > 0):
        trainer.step(batch_size * batch_size_extend_count)
        for p in net.collect_params().values():
            p.zero_grad()

    throughput = int(batch_size * (i + 1) / (time.time() - tic))
    logging.info(
        "[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format(
            epoch + 1, throughput,
            time.time() - tic))

    train_accuracy_msg = report_accuracy(metric=train_metric)
    loss_accuracy_msg = report_accuracy(metric=loss_metrics[0])
    if (discrim_net is not None) and (len(loss_metrics) > 1):
        dloss_accuracy_msg = report_accuracy(metric=loss_metrics[1])
        logging.info("[Epoch {}] training: {}\t{}\t{}".format(
            epoch + 1, train_accuracy_msg, loss_accuracy_msg,
            dloss_accuracy_msg))
    else:
        logging.info("[Epoch {}] training: {}\t{}".format(
            epoch + 1, train_accuracy_msg, loss_accuracy_msg))
Ejemplo n.º 4
0
def calc_model_accuracy(net,
                        test_data,
                        batch_fn,
                        data_source_needs_reset,
                        metric,
                        dtype,
                        ctx,
                        input_image_size,
                        in_channels,
                        calc_weight_count=False,
                        calc_flops=False,
                        calc_flops_only=True,
                        extended_log=False):
    """
    Main test routine.

    Parameters:
    ----------
    net : HybridBlock
        Model.
    test_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator.
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    metric : EvalMetric
        Metric object instance.
    dtype : str
        Base data type for tensors.
    ctx : Context
        MXNet context.
    input_image_size : tuple of 2 ints
        Spatial size of the expected input image.
    in_channels : int
        Number of input channels.
    calc_weight_count : bool, default False
        Whether to calculate count of weights.
    calc_flops : bool, default False
        Whether to calculate FLOPs.
    calc_flops_only : bool, default True
        Whether to only calculate FLOPs without testing.
    extended_log : bool, default False
        Whether to log more precise accuracy values.

    Returns:
    -------
    list of floats
        Accuracy values.
    """
    if not calc_flops_only:
        tic = time.time()
        validate(metric=metric,
                 net=net,
                 val_data=test_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        accuracy_msg = report_accuracy(metric=metric,
                                       extended_log=extended_log)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))
        acc_values = metric.get()[1]
        acc_values = acc_values if type(acc_values) == list else [acc_values]
    else:
        acc_values = []

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))

    return acc_values
Ejemplo n.º 5
0
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data,
              batch_fn, data_source_needs_reset, dtype, net, trainer,
              lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail,
              label_smoothing, num_classes, grad_clip_value, batch_size_scale,
              val_metric, train_metric, ctx):

    if batch_size_scale != 1:
        for p in net.collect_params().values():
            p.grad_req = "add"

    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    loss_func = gluon.loss.SoftmaxCrossEntropyLoss(
        sparse_label=(not (mixup or label_smoothing)))

    assert (type(start_epoch1) == int)
    assert (start_epoch1 >= 1)
    if start_epoch1 > 1:
        logging.info("Start training from [Epoch {}]".format(start_epoch1))
        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            start_epoch1 - 1, val_accuracy_msg))

    gtic = time.time()
    for epoch in range(start_epoch1 - 1, num_epochs):
        train_loss = train_epoch(
            epoch=epoch,
            net=net,
            train_metric=train_metric,
            train_data=train_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx,
            loss_func=loss_func,
            trainer=trainer,
            lr_scheduler=lr_scheduler,
            batch_size=batch_size,
            log_interval=log_interval,
            mixup=mixup,
            mixup_epoch_tail=mixup_epoch_tail,
            label_smoothing=label_smoothing,
            num_classes=num_classes,
            num_epochs=num_epochs,
            grad_clip_value=grad_clip_value,
            batch_size_scale=batch_size_scale)

        validate(metric=val_metric,
                 net=net,
                 val_data=val_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        val_accuracy_msg = report_accuracy(metric=val_metric)
        logging.info("[Epoch {}] validation: {}".format(
            epoch + 1, val_accuracy_msg))

        if lp_saver is not None:
            lp_saver_kwargs = {"net": net, "trainer": trainer}
            val_acc_values = val_metric.get()[1]
            train_acc_values = train_metric.get()[1]
            val_acc_values = val_acc_values if type(
                val_acc_values) == list else [val_acc_values]
            train_acc_values = train_acc_values if type(
                train_acc_values) == list else [train_acc_values]
            lp_saver.epoch_test_end_callback(
                epoch1=(epoch + 1),
                params=(val_acc_values + train_acc_values +
                        [train_loss, trainer.learning_rate]),
                **lp_saver_kwargs)

    logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic))
    if lp_saver is not None:
        opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind)
        logging.info("Best {}: {:.4f} at {} epoch".format(
            opt_metric_name, lp_saver.best_eval_metric_value,
            lp_saver.best_eval_metric_epoch))
Ejemplo n.º 6
0
def train_epoch(epoch, net, train_metric, train_data, batch_fn,
                data_source_needs_reset, dtype, ctx, loss_func, trainer,
                lr_scheduler, batch_size, log_interval, mixup,
                mixup_epoch_tail, label_smoothing, num_classes, num_epochs,
                grad_clip_value, batch_size_scale):

    labels_list_inds = None
    batch_size_extend_count = 0
    tic = time.time()
    if data_source_needs_reset:
        train_data.reset()
    train_metric.reset()
    train_loss = 0.0

    btic = time.time()
    for i, batch in enumerate(train_data):
        data_list, labels_list = batch_fn(batch, ctx)

        if label_smoothing:
            eta = 0.1
            on_value = 1 - eta + eta / num_classes
            off_value = eta / num_classes
            labels_list_inds = labels_list
            labels_list = [
                Y.one_hot(depth=num_classes,
                          on_value=on_value,
                          off_value=off_value) for Y in labels_list
            ]
        if mixup:
            if not label_smoothing:
                labels_list_inds = labels_list
                labels_list = [
                    Y.one_hot(depth=num_classes) for Y in labels_list
                ]
            if epoch < num_epochs - mixup_epoch_tail:
                alpha = 1
                lam = np.random.beta(alpha, alpha)
                data_list = [lam * X + (1 - lam) * X[::-1] for X in data_list]
                labels_list = [
                    lam * Y + (1 - lam) * Y[::-1] for Y in labels_list
                ]

        with ag.record():
            outputs_list = [
                net(X.astype(dtype, copy=False)) for X in data_list
            ]
            loss_list = [
                loss_func(yhat, y.astype(dtype, copy=False))
                for yhat, y in zip(outputs_list, labels_list)
            ]
        for loss in loss_list:
            loss.backward()
        lr_scheduler.update(i, epoch)

        if grad_clip_value is not None:
            grads = [
                v.grad(ctx[0]) for v in net.collect_params().values()
                if v._grad is not None
            ]
            gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value)

        if batch_size_scale == 1:
            trainer.step(batch_size)
        else:
            if (i + 1) % batch_size_scale == 0:
                batch_size_extend_count = 0
                trainer.step(batch_size * batch_size_scale)
                for p in net.collect_params().values():
                    p.zero_grad()
            else:
                batch_size_extend_count += 1

        train_loss += sum([loss.mean().asscalar()
                           for loss in loss_list]) / len(loss_list)

        train_metric.update(
            labels=(labels_list
                    if not (mixup or label_smoothing) else labels_list_inds),
            preds=outputs_list)

        if log_interval and not (i + 1) % log_interval:
            speed = batch_size * log_interval / (time.time() - btic)
            btic = time.time()
            train_accuracy_msg = report_accuracy(metric=train_metric)
            logging.info(
                "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\tlr={:.5f}"
                .format(epoch + 1, i, speed, train_accuracy_msg,
                        trainer.learning_rate))

    if (batch_size_scale != 1) and (batch_size_extend_count > 0):
        trainer.step(batch_size * batch_size_extend_count)
        for p in net.collect_params().values():
            p.zero_grad()

    throughput = int(batch_size * (i + 1) / (time.time() - tic))
    logging.info(
        "[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format(
            epoch + 1, throughput,
            time.time() - tic))

    train_loss /= (i + 1)
    train_accuracy_msg = report_accuracy(metric=train_metric)
    logging.info("[Epoch {}] training: {}\tloss={:.4f}".format(
        epoch + 1, train_accuracy_msg, train_loss))

    return train_loss
Ejemplo n.º 7
0
def train_epoch(epoch, net, train_metric, train_data, batch_fn,
                data_source_needs_reset, dtype, ctx, loss_func, trainer,
                lr_scheduler, batch_size, log_interval, mixup,
                mixup_epoch_tail, label_smoothing, num_classes, num_epochs,
                grad_clip_value, batch_size_scale):
    """
    Train model on particular epoch.

    Parameters:
    ----------
    epoch : int
        Epoch number.
    net : HybridBlock
        Model.
    train_metric : EvalMetric
        Metric object instance.
    train_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator.
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    dtype : str
        Base data type for tensors.
    ctx : Context
        MXNet context.
    loss_func : Loss
        Loss function.
    trainer : Trainer
        Trainer.
    lr_scheduler : LRScheduler
        Learning rate scheduler.
    batch_size : int
        Training batch size.
    log_interval : int
        Batch count period for logging.
    mixup : bool
        Whether to use mixup.
    mixup_epoch_tail : int
        Number of epochs without mixup at the end of training.
    label_smoothing : bool
        Whether to use label-smoothing.
    num_classes : int
        Number of model classes.
    num_epochs : int
        Number of training epochs.
    grad_clip_value : float
        Threshold for gradient clipping.
    batch_size_scale : int
        Manual batch-size increasing factor.

    Returns
    -------
    float
        Loss value.
    """
    labels_list_inds = None
    batch_size_extend_count = 0
    tic = time.time()
    if data_source_needs_reset:
        train_data.reset()
    train_metric.reset()
    train_loss = 0.0

    btic = time.time()
    for i, batch in enumerate(train_data):
        data_list, labels_list = batch_fn(batch, ctx)
        # print("--> {}".format(labels_list[0][0].asscalar()))
        # import cv2
        # img = data_list[0][0].transpose((1, 2, 0)).asnumpy().copy()[:, :, [2, 1, 0]]
        # img -= img.min()
        # img *= 255 / img.max()
        # cv2.imshow(winname="img", mat=img.astype(np.uint8))
        # cv2.waitKey()

        if label_smoothing:
            eta = 0.1
            on_value = 1 - eta + eta / num_classes
            off_value = eta / num_classes
            labels_list_inds = labels_list
            labels_list = [
                Y.one_hot(depth=num_classes,
                          on_value=on_value,
                          off_value=off_value) for Y in labels_list
            ]
        if mixup:
            if not label_smoothing:
                labels_list_inds = labels_list
                labels_list = [
                    Y.one_hot(depth=num_classes) for Y in labels_list
                ]
            if epoch < num_epochs - mixup_epoch_tail:
                alpha = 1
                lam = np.random.beta(alpha, alpha)
                data_list = [lam * X + (1 - lam) * X[::-1] for X in data_list]
                labels_list = [
                    lam * Y + (1 - lam) * Y[::-1] for Y in labels_list
                ]

        with ag.record():
            outputs_list = [
                net(X.astype(dtype, copy=False)) for X in data_list
            ]
            loss_list = [
                loss_func(yhat, y.astype(dtype, copy=False))
                for yhat, y in zip(outputs_list, labels_list)
            ]
        for loss in loss_list:
            loss.backward()
        lr_scheduler.update(i, epoch)

        if grad_clip_value is not None:
            grads = [
                v.grad(ctx[0]) for v in net.collect_params().values()
                if v._grad is not None
            ]
            gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value)

        if batch_size_scale == 1:
            trainer.step(batch_size)
        else:
            if (i + 1) % batch_size_scale == 0:
                batch_size_extend_count = 0
                trainer.step(batch_size * batch_size_scale)
                for p in net.collect_params().values():
                    p.zero_grad()
            else:
                batch_size_extend_count += 1

        train_loss += sum([loss.mean().asscalar()
                           for loss in loss_list]) / len(loss_list)

        train_metric.update(
            labels=(labels_list
                    if not (mixup or label_smoothing) else labels_list_inds),
            preds=outputs_list)

        if log_interval and not (i + 1) % log_interval:
            speed = batch_size * log_interval / (time.time() - btic)
            btic = time.time()
            train_accuracy_msg = report_accuracy(metric=train_metric)
            logging.info(
                "Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\tlr={:.5f}"
                .format(epoch + 1, i, speed, train_accuracy_msg,
                        trainer.learning_rate))

    if (batch_size_scale != 1) and (batch_size_extend_count > 0):
        trainer.step(batch_size * batch_size_extend_count)
        for p in net.collect_params().values():
            p.zero_grad()

    throughput = int(batch_size * (i + 1) / (time.time() - tic))
    logging.info(
        "[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format(
            epoch + 1, throughput,
            time.time() - tic))

    train_loss /= (i + 1)
    train_accuracy_msg = report_accuracy(metric=train_metric)
    logging.info("[Epoch {}] training: {}\tloss={:.4f}".format(
        epoch + 1, train_accuracy_msg, train_loss))

    return train_loss