def test(net, val_data, data_source_needs_reset, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: accuracy_metric = mx.metric.Accuracy() tic = time.time() err_val = validate1(accuracy_metric=accuracy_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) if extended_log: logging.info('Test: err={err:.4f} ({err}))'.format(err=err_val)) else: logging.info('Test: err={err:.4f}'.format(err=err_val)) logging.info('Time cost: {:.4f} sec'.format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, data_source_needs_reset, dtype, net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, ctx): assert (not (mixup and label_smoothing)) if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = 'add' if isinstance(ctx, mx.Context): ctx = [ctx] acc_metric_val = mx.metric.Accuracy() acc_metric_train = mx.metric.Accuracy() loss_func = gluon.loss.SoftmaxCrossEntropyLoss( sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info('Start training from [Epoch {}]'.format(start_epoch1)) err_val = validate1(accuracy_metric=acc_metric_val, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) logging.info('[Epoch {}] validation: err={:.4f}'.format( start_epoch1 - 1, err_val)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): err_train, train_loss = train_epoch( epoch=epoch, net=net, acc_metric_train=acc_metric_train, train_data=train_data, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) err_val = validate1(accuracy_metric=acc_metric_val, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) logging.info('[Epoch {}] validation: err={:.4f}'.format( epoch + 1, err_val)) if lp_saver is not None: lp_saver_kwargs = {'net': net, 'trainer': trainer} lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=[err_val, err_train, train_loss, trainer.learning_rate], **lp_saver_kwargs) logging.info('Total time cost: {:.2f} sec'.format(time.time() - gtic)) if lp_saver is not None: logging.info('Best err: {:.4f} at {} epoch'.format( lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))