def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset(root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher if args.legacy_jit: set_jit_legacy() # create model if 'inception' in args.model: model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, aux_logits=True, # ! add aux loss in_chans=3, scriptable=args.torchscript) else: model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, scriptable=args.torchscript) # ! add more layer to classifier layer if args.create_classifier_layerfc: model.global_pool, model.classifier = create_classifier_layerfc( model.num_features, model.num_classes) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) if args.has_eval_label: criterion = nn.CrossEntropyLoss().cuda() # ! don't have gold label if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map, args=args) if args.valid_labels: with open(args.valid_labels, 'r') as f: # @valid_labels is index numbering valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config[ 'interpolation'], # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43 mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing, auto_augment=args.aa, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, args=args) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() topk = AverageMeter() prediction = None # ! need to save output true_label = None model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # ! not have real label if args.has_eval_label: # ! just save true labels anyway... why not if true_label is None: true_label = target.cpu().data.numpy() else: true_label = np.concatenate( (true_label, target.cpu().data.numpy()), axis=0) if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) if isinstance(output, (tuple, list)): output = output[0] # ! some model returns both loss + aux loss if valid_labels is not None: output = output[:, valid_labels] # ! keep only valid labels ? good to eval by class. # ! save prediction, don't append too slow ... whatever ? # ! are names of files also sorted ? if prediction is None: prediction = output.cpu().data.numpy() # batchsize x label else: # stack prediction = np.concatenate( (prediction, output.cpu().data.numpy()), axis=0) if real_labels is not None: real_labels.add_result(output) if args.has_eval_label: # measure accuracy and record loss loss = criterion( output, target) # ! don't have gold standard on testset acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) topk.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.has_eval_label and (batch_idx % args.log_freq == 0): _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, topk=topk)) if not args.has_eval_label: top1a, topka = 0, 0 # just dummy, because we don't know ground labels else: if real_labels is not None: # real labels mode replaces topk values at the end top1a, topka = real_labels.get_accuracy( k=1), real_labels.get_accuracy(k=args.topk) else: top1a, topka = top1.avg, topk.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), topk=round(topka, 4), topk_err=round(100 - topka, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['topk'], results['topk_err'])) return results, prediction, true_label
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # amp_autocast = suppress # do nothing # if args.amp: # if has_native_amp: # args.native_amp = True # elif has_apex: # args.apex_amp = True # else: # _logger.warning("Neither APEX or Native Torch AMP is available.") # assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." # if args.native_amp: # amp_autocast = torch.cuda.amp.autocast # _logger.info('Validating in mixed precision with native PyTorch AMP.') # elif args.apex_amp: # _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') # else: # _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) # model = model.cuda() # if args.apex_amp: # model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) # if args.num_gpu > 1: # model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) # criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss() dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) # added for post quantization calibration calib_dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) #Also create loader for calibration dataset calib_loader = create_loader( calib_dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() print('Start calibration of quantization observers before post-quantization') model_to_quantize = copy.deepcopy(model) model_to_quantize.eval() #post training static quantization if args.quant_option == 'static': qconfig_dict = {"": torch.quantization.default_static_qconfig} model_to_quantize = copy.deepcopy(model_fp) qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')} model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # calibrate with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) #post training dynamic/weight only quantization elif args.quant_option == 'dynamic': qconfig_dict = {"": torch.quantization.default_dynamic_qconfig} # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # no calibration needed when we only have dynamici/weight_only quantization # quantize model_quantized = quantize_fx.convert_fx(model_prepared) else: _logger.warning("Invalid quantization option. Set option to default(static)") # # fusion # model_to_quantize = copy.deepcopy(model_fp) model_fused = quantize_fx.fuse_fx(model_to_quantize) model = model_fused with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non # input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # if args.no_prefetcher: # target = target.cuda() # input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output # with amp_autocast(): # output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n') _logger.info("Argument parser collected the following arguments:") for arg in vars(args): _logger.info(f" {arg}:{getattr(args, arg)}") _logger.info("\n") # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True else: _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info( f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)' ) data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset( root=args.data_dir, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() top1_fgm_ae = AverageMeter() top5_fgm_ae = AverageMeter() top1_pgd_ae = AverageMeter() top5_pgd_ae = AverageMeter() model.eval() #with torch.no_grad():# TODO Requires grad # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # TODO <--------------------- # Generate adversarial examples for current inputs input_fgm_ae = fast_gradient_method( model_fn=model, x=input, eps=args.eps, norm=np.inf, clip_min=None, clip_max=None, ) input_pgd_ae = projected_gradient_descent( model_fn=model, x=input, eps=args.eps, eps_iter=0.01, nb_iter=40, norm=np.inf, clip_min=None, clip_max=None, ) # Predict with Adversarial Examples with torch.no_grad(): with amp_autocast(): output_fgm_ae = model(input_fgm_ae) output_pgd_ae = model(input_pgd_ae) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5)) acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5)) top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0)) top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0)) top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0)) top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: raise NotImplementedError # TODO NOt modified for the adversarial examples mode # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), top1_fgm_ae=round(top1a_fgm_ae, 4), top5_fgm_ae=round(top5a_fgm_ae, 4), top1_pgd_ae=round(top1a_pgd_ae, 4), top5_pgd_ae=round(top5a_pgd_ae, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_fgm_ae'], results['top5_fgm_ae'])) _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_pgd_ae'], results['top5_pgd_ae'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, train_mode='val', fold_num=args.fold_num, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() # top5 = AverageMeter() f1_m = AverageMeter() model.eval() last_idx = len(loader) - 1 cuda = torch.device('cuda') temperature = nn.Parameter(torch.ones(1) * 1.5).to(cuda).detach().requires_grad_(True) m = nn.Sigmoid() nll_criterion = nn.CrossEntropyLoss().cuda() ece_criterion = _ECELoss().cuda() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() logits_list = [] target_list = [] for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, _ = accuracy(output.detach(), target, topk=(1, 1)) logits_list.append(output) target_list.append(target) best_f1 = 0.0 best_th = 1.0 if last_batch: logits = torch.cat(logits_list).cuda() ### targets = torch.cat(target_list).cuda() ### targets_cpu = targets.cpu().numpy() sigmoided = m(logits)[:, 1].cpu().numpy() for i in range(1000, 0, -1): th = i * 0.001 real_pred = (sigmoided >= th) * 1.0 f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze()) if f1 > best_f1: best_f1 = f1 best_th = th losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'thresh: {thresh:>7.4f} ' 'f1: {f1:>7.4f}'.format(batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, thresh=best_th, f1=best_f1)) print(best_th, best_f1) #for temp_scalilng if args.temp_scaling: # before_temperature_ece = ece_criterion(logits, targets).item() # before_temperature_nll = nll_criterion(logits, targets).item() # print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece)) # optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50) # def eval(): # unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) # loss = nll_criterion(logits/unsqueezed_temperature, targets) # loss.backward() # return loss # optimizer.step(eval) # unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) # logits = logits/unsqueezed_temperature # after_temperature_nll = nll_criterion(logits, targets).item() # after_temperature_ece = ece_criterion(logits, targets).item() # print('Optimal temperature: %.3f' % temperature.item()) # print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece)) sigmoided = m(logits)[:, 1].detach().cpu().numpy() temperature = nn.Parameter(torch.ones(1) * 11).to(cuda).detach().requires_grad_(False) logits = logits / temperature.unsqueeze(1).expand( logits.size(0), logits.size(1)) targets_cpu = targets.cpu().numpy() sigmoided = m(logits)[:, 1].detach().cpu().numpy() best_f1 = 0.0 best_th = 1.0 for i in range(1000, 0, -1): th = i * 0.001 real_pred = (sigmoided >= th) * 1.0 f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze()) if f1 > best_f1: best_f1 = f1 best_th = th print(best_th, best_f1) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, f1a = top1.avg, best_f1 results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), f1=f1a, f1_err=round(100 - f1a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['f1'], results['f1_err'])) return results