def test(net, val_data, use_cuda, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: accuracy_metric = AverageMeter() tic = time.time() err_val = validate1(accuracy_metric=accuracy_metric, net=net, val_data=val_data, use_cuda=use_cuda) if extended_log: logging.info('Test: err={err:.4f} ({err})'.format(err=err_val)) else: logging.info('Test: err={err:.4f}'.format(err=err_val)) logging.info('Time cost: {:.4f} sec'.format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info('Model: {} trainable parameters'.format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, net, optimizer, lr_scheduler, lp_saver, log_interval, use_cuda): acc_metric_val = AverageMeter() acc_metric_train = AverageMeter() L = nn.CrossEntropyLoss() if use_cuda: L = L.cuda() assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info('Start training from [Epoch {}]'.format(start_epoch1)) err_val = validate1(accuracy_metric=acc_metric_val, net=net, val_data=val_data, use_cuda=use_cuda) logging.info('[Epoch {}] validation: err={:.4f}'.format( start_epoch1 - 1, err_val)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): lr_scheduler.step() err_train, train_loss = train_epoch( epoch, acc_metric_train, net, train_data, use_cuda, L, optimizer, # lr_scheduler, batch_size, log_interval) err_val = validate1(accuracy_metric=acc_metric_val, net=net, val_data=val_data, use_cuda=use_cuda) logging.info('[Epoch {}] validation: err={:.4f}'.format( epoch + 1, err_val)) if lp_saver is not None: state = { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), } lp_saver_kwargs = {'state': state} lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=[ err_val, err_train, train_loss, optimizer.param_groups[0]['lr'] ], **lp_saver_kwargs) logging.info('Total time cost: {:.2f} sec'.format(time.time() - gtic)) if lp_saver is not None: logging.info('Best err: {:.4f} at {} epoch'.format( lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch))