예제 #1
0
def test(net,
         val_data,
         data_source_needs_reset,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False):
    if not calc_flops_only:
        accuracy_metric = mx.metric.Accuracy()
        tic = time.time()
        err_val = validate1(accuracy_metric=accuracy_metric,
                            net=net,
                            val_data=val_data,
                            batch_fn=batch_fn,
                            data_source_needs_reset=data_source_needs_reset,
                            dtype=dtype,
                            ctx=ctx)
        if extended_log:
            logging.info('Test: err={err:.4f} ({err}))'.format(err=err_val))
        else:
            logging.info('Test: err={err:.4f}'.format(err=err_val))
        logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))
예제 #2
0
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data,
              data_source_needs_reset, dtype, net, trainer, lr_scheduler,
              lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing,
              num_classes, grad_clip_value, batch_size_scale, ctx):

    assert (not (mixup and label_smoothing))

    if batch_size_scale != 1:
        for p in net.collect_params().values():
            p.grad_req = 'add'

    if isinstance(ctx, mx.Context):
        ctx = [ctx]

    acc_metric_val = mx.metric.Accuracy()
    acc_metric_train = mx.metric.Accuracy()

    loss_func = gluon.loss.SoftmaxCrossEntropyLoss(
        sparse_label=(not (mixup or label_smoothing)))

    assert (type(start_epoch1) == int)
    assert (start_epoch1 >= 1)
    if start_epoch1 > 1:
        logging.info('Start training from [Epoch {}]'.format(start_epoch1))
        err_val = validate1(accuracy_metric=acc_metric_val,
                            net=net,
                            val_data=val_data,
                            batch_fn=batch_fn,
                            data_source_needs_reset=data_source_needs_reset,
                            dtype=dtype,
                            ctx=ctx)
        logging.info('[Epoch {}] validation: err={:.4f}'.format(
            start_epoch1 - 1, err_val))

    gtic = time.time()
    for epoch in range(start_epoch1 - 1, num_epochs):
        err_train, train_loss = train_epoch(
            epoch=epoch,
            net=net,
            acc_metric_train=acc_metric_train,
            train_data=train_data,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx,
            loss_func=loss_func,
            trainer=trainer,
            lr_scheduler=lr_scheduler,
            batch_size=batch_size,
            log_interval=log_interval,
            mixup=mixup,
            mixup_epoch_tail=mixup_epoch_tail,
            label_smoothing=label_smoothing,
            num_classes=num_classes,
            num_epochs=num_epochs,
            grad_clip_value=grad_clip_value,
            batch_size_scale=batch_size_scale)

        err_val = validate1(accuracy_metric=acc_metric_val,
                            net=net,
                            val_data=val_data,
                            batch_fn=batch_fn,
                            data_source_needs_reset=data_source_needs_reset,
                            dtype=dtype,
                            ctx=ctx)

        logging.info('[Epoch {}] validation: err={:.4f}'.format(
            epoch + 1, err_val))

        if lp_saver is not None:
            lp_saver_kwargs = {'net': net, 'trainer': trainer}
            lp_saver.epoch_test_end_callback(
                epoch1=(epoch + 1),
                params=[err_val, err_train, train_loss, trainer.learning_rate],
                **lp_saver_kwargs)

    logging.info('Total time cost: {:.2f} sec'.format(time.time() - gtic))
    if lp_saver is not None:
        logging.info('Best err: {:.4f} at {} epoch'.format(
            lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))