def test(net, val_data, batch_fn, use_rec, dtype, ctx, calc_weight_count=False, extended_log=False): acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) tic = time.time() err_top1_val, err_top5_val = validate(acc_top1=acc_top1, acc_top5=acc_top5, net=net, val_data=val_data, batch_fn=batch_fn, use_rec=use_rec, dtype=dtype, ctx=ctx) if calc_weight_count: weight_count = calc_net_weight_count(net) logging.info('Model: {} trainable parameters'.format(weight_count)) if extended_log: logging.info( 'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'. format(top1=err_top1_val, top5=err_top5_val)) else: logging.info('Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format( top1=err_top1_val, top5=err_top5_val)) logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))
def test(net, test_data, batch_fn, data_source_needs_reset, metric, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: tic = time.time() validate(metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) accuracy_msg = report_accuracy(metric=metric, extended_log=extended_log) logging.info("Test: {}".format(accuracy_msg)) logging.info("Time cost: {:.4f} sec".format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def test(net, val_data, batch_fn, data_source_needs_reset, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) tic = time.time() err_top1_val, err_top5_val = validate( acc_top1=acc_top1, acc_top5=acc_top5, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) if extended_log: logging.info( 'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})' .format(top1=err_top1_val, top5=err_top5_val)) else: logging.info( 'Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format( top1=err_top1_val, top5=err_top5_val)) logging.info('Time cost: {:.4f} sec'.format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, teacher_net, discrim_net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, loss_metrics, loss_func, discrim_loss_func, ctx): """ Main procedure for training model. Parameters: ---------- batch_size : int Training batch size. num_epochs : int Number of training epochs. start_epoch1 : int Number of starting epoch (1-based). train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (training subset). val_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (validation subset). batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. net : HybridBlock Model. teacher_net : HybridBlock or None Teacher model. discrim_net : HybridBlock or None MEALv2 discriminator model. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. lp_saver : TrainLogParamSaver Model/trainer state saver. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. val_metric : EvalMetric Metric object instance (validation subset). train_metric : EvalMetric Metric object instance (training subset). loss_metrics : list of EvalMetric Metric object instances (loss values). loss_func : Loss Loss object instance. discrim_loss_func : Loss or None MEALv2 adversarial loss function. ctx : Context MXNet context. """ if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] # loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_epoch(epoch=epoch, net=net, teacher_net=teacher_net, discrim_net=discrim_net, train_metric=train_metric, loss_metrics=loss_metrics, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, discrim_loss_func=discrim_loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type( val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type( train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [loss_metrics[0].get()[1], trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))
def calc_model_accuracy(net, test_data, batch_fn, data_source_needs_reset, metric, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): """ Main test routine. Parameters: ---------- net : HybridBlock Model. test_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator. batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). metric : EvalMetric Metric object instance. dtype : str Base data type for tensors. ctx : Context MXNet context. input_image_size : tuple of 2 ints Spatial size of the expected input image. in_channels : int Number of input channels. calc_weight_count : bool, default False Whether to calculate count of weights. calc_flops : bool, default False Whether to calculate FLOPs. calc_flops_only : bool, default True Whether to only calculate FLOPs without testing. extended_log : bool, default False Whether to log more precise accuracy values. Returns: ------- list of floats Accuracy values. """ if not calc_flops_only: tic = time.time() validate(metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) accuracy_msg = report_accuracy(metric=metric, extended_log=extended_log) logging.info("Test: {}".format(accuracy_msg)) logging.info("Time cost: {:.4f} sec".format(time.time() - tic)) acc_values = metric.get()[1] acc_values = acc_values if type(acc_values) == list else [acc_values] else: acc_values = [] if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6)) return acc_values
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, ctx): if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] loss_func = gluon.loss.SoftmaxCrossEntropyLoss( sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_loss = train_epoch( epoch=epoch, net=net, train_metric=train_metric, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate(metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format( epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type( val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type( train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [train_loss, trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, ctx): assert (not (mixup and label_smoothing)) if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = 'add' if isinstance(ctx, mx.Context): ctx = [ctx] acc_top1_val = mx.metric.Accuracy() acc_top5_val = mx.metric.TopKAccuracy(5) acc_top1_train = mx.metric.Accuracy() loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info('Start training from [Epoch {}]'.format(start_epoch1)) err_top1_val, err_top5_val = validate( acc_top1=acc_top1_val, acc_top5=acc_top5_val, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) logging.info('[Epoch {}] validation: err-top1={:.4f}\terr-top5={:.4f}'.format( start_epoch1 - 1, err_top1_val, err_top5_val)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): err_top1_train, train_loss = train_epoch( epoch=epoch, net=net, acc_top1_train=acc_top1_train, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) err_top1_val, err_top5_val = validate( acc_top1=acc_top1_val, acc_top5=acc_top5_val, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) logging.info('[Epoch {}] validation: err-top1={:.4f}\terr-top5={:.4f}'.format( epoch + 1, err_top1_val, err_top5_val)) if lp_saver is not None: lp_saver_kwargs = {'net': net, 'trainer': trainer} lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=[err_top1_val, err_top1_train, err_top5_val, train_loss, trainer.learning_rate], **lp_saver_kwargs) logging.info('Total time cost: {:.2f} sec'.format(time.time() - gtic)) if lp_saver is not None: logging.info('Best err-top5: {:.4f} at {} epoch'.format( lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))