Beispiel #1
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(root=args.data,
                             name=args.dataset,
                             split=args.split,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          top5=round(top5a, 4),
                          top5_err=round(100 - top5a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    if args.legacy_jit:
        set_jit_legacy()

    # create model
    if 'inception' in args.model:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            aux_logits=True,  # ! add aux loss
            in_chans=3,
            scriptable=args.torchscript)
    else:
        model = create_model(args.model,
                             pretrained=args.pretrained,
                             num_classes=args.num_classes,
                             in_chans=3,
                             scriptable=args.torchscript)

    # ! add more layer to classifier layer
    if args.create_classifier_layerfc:
        model.global_pool, model.classifier = create_classifier_layerfc(
            model.num_features, model.num_classes)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    if args.has_eval_label:
        criterion = nn.CrossEntropyLoss().cuda()  # ! don't have gold label

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map,
                          args=args)

    if args.valid_labels:
        with open(args.valid_labels,
                  'r') as f:  # @valid_labels is index numbering
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']

    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config[
            'interpolation'],  # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        auto_augment=args.aa,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        args=args)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    topk = AverageMeter()

    prediction = None  # ! need to save output
    true_label = None

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for batch_idx, (input,
                        target) in enumerate(loader):  # ! not have real label

            if args.has_eval_label:  # ! just save true labels anyway... why not
                if true_label is None: true_label = target.cpu().data.numpy()
                else:
                    true_label = np.concatenate(
                        (true_label, target.cpu().data.numpy()), axis=0)

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]  # ! some model returns both loss + aux loss

            if valid_labels is not None:
                output = output[:,
                                valid_labels]  # ! keep only valid labels ? good to eval by class.

            # ! save prediction, don't append too slow ... whatever ?
            # ! are names of files also sorted ?
            if prediction is None:
                prediction = output.cpu().data.numpy()  # batchsize x label
            else:  # stack
                prediction = np.concatenate(
                    (prediction, output.cpu().data.numpy()), axis=0)

            if real_labels is not None:
                real_labels.add_result(output)

            if args.has_eval_label:
                # measure accuracy and record loss
                loss = criterion(
                    output, target)  # ! don't have gold standard on testset
                acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                topk.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if args.has_eval_label and (batch_idx % args.log_freq == 0):
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        topk=topk))

    if not args.has_eval_label:
        top1a, topka = 0, 0  # just dummy, because we don't know ground labels
    else:
        if real_labels is not None:
            # real labels mode replaces topk values at the end
            top1a, topka = real_labels.get_accuracy(
                k=1), real_labels.get_accuracy(k=args.topk)
        else:
            top1a, topka = top1.avg, topk.avg

    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          topk=round(topka, 4),
                          topk_err=round(100 - topka, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['topk'],
        results['topk_err']))

    return results, prediction, true_label
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
#    amp_autocast = suppress  # do nothing
#   if args.amp:
#        if has_native_amp:
#            args.native_amp = True
#        elif has_apex:
#            args.apex_amp = True
#        else:
#            _logger.warning("Neither APEX or Native Torch AMP is available.")
#    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
#    if args.native_amp:
#        amp_autocast = torch.cuda.amp.autocast
#        _logger.info('Validating in mixed precision with native PyTorch AMP.')
#   elif args.apex_amp:
#        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
#    else:
#        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

#    model = model.cuda()
#    if args.apex_amp:
#        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

#    if args.num_gpu > 1:
#        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

#   criterion = nn.CrossEntropyLoss().cuda()
    criterion = nn.CrossEntropyLoss()
    dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    # added for post quantization calibration

    calib_dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)
        

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    #Also create loader for calibration dataset
    calib_loader = create_loader(
        calib_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    print('Start calibration of quantization observers before post-quantization')
    model_to_quantize = copy.deepcopy(model)
    model_to_quantize.eval()

    #post training static quantization
    if args.quant_option == 'static':
        qconfig_dict = {"": torch.quantization.default_static_qconfig} 
        model_to_quantize = copy.deepcopy(model_fp)
        qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
        model_to_quantize.eval()
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # calibrate 
        with torch.no_grad():
            # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
            input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
            model(input)
            end = time.time()
            for batch_idx, (input, target) in enumerate(loader):

                if args.channels_last:
                    input = input.contiguous(memory_format=torch.channels_last)

                if valid_labels is not None:
                    output = output[:, valid_labels]
                loss = criterion(output, target)

                if real_labels is not None:
                    real_labels.add_result(output)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                top5.update(acc5.item(), input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if batch_idx % args.log_freq == 0:
                    _logger.info(
                        'Test: [{0:>4d}/{1}]  '
                        'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                        'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                        'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                            batch_idx, len(loader), batch_time=batch_time,
                            rate_avg=input.size(0) / batch_time.avg,
                            loss=losses, top1=top1, top5=top5))        
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)           
    #post training dynamic/weight only quantization    
    elif args.quant_option == 'dynamic':    
        qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # no calibration needed when we only have dynamici/weight_only quantization
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)       
    else:
        _logger.warning("Invalid quantization option. Set option to default(static)")
    #
    # fusion
    #
    model_to_quantize = copy.deepcopy(model_fp)
    model_fused = quantize_fx.fuse_fx(model_to_quantize)   

    model = model_fused

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
#        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
 #           if args.no_prefetcher:
 #               target = target.cuda()
 #               input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
    #        with amp_autocast():
    #            output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Beispiel #4
0
def validate(args):
    _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n')
    _logger.info("Argument parser collected the following arguments:")
    for arg in vars(args):
        _logger.info(f"    {arg}:{getattr(args, arg)}")
    _logger.info("\n")

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
        else:
            _logger.warning("Neither APEX or Native Torch AMP is available.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        _logger.info('Validating in mixed precision with native PyTorch AMP.')
    elif args.apex_amp:
        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
    else:
        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])        
    _logger.info(
        f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)'
    )

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(
        root=args.data_dir, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_fgm_ae = AverageMeter()
    top5_fgm_ae = AverageMeter()
    top1_pgd_ae = AverageMeter()
    top5_pgd_ae = AverageMeter()

    model.eval()
    #with torch.no_grad():# TODO Requires grad
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
    if args.channels_last:
        input = input.contiguous(memory_format=torch.channels_last)
    model(input)
    end = time.time()
    for batch_idx, (input, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # compute output
        with amp_autocast():
            output = model(input)

        if valid_labels is not None:
            output = output[:, valid_labels]
        loss = criterion(output, target)

        if real_labels is not None:
            real_labels.add_result(output)

        # TODO <---------------------
        # Generate adversarial examples for current inputs
        input_fgm_ae = fast_gradient_method(
            model_fn=model,
            x=input,
            eps=args.eps,
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        input_pgd_ae = projected_gradient_descent(
            model_fn=model,
            x=input, 
            eps=args.eps, 
            eps_iter=0.01, 
            nb_iter=40, 
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        # Predict with Adversarial Examples
        with torch.no_grad():
            with amp_autocast():
                output_fgm_ae = model(input_fgm_ae)
                output_pgd_ae = model(input_pgd_ae)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))

        acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5))
        acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5))
        top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0))
        top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0))
        top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0))
        top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.log_freq == 0:
            _logger.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    batch_idx, len(loader), batch_time=batch_time,
                    rate_avg=input.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        raise NotImplementedError # TODO NOt modified for the adversarial examples mode 
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
        top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg
        top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        top1_fgm_ae=round(top1a_fgm_ae, 4),
        top5_fgm_ae=round(top5a_fgm_ae, 4),
        top1_pgd_ae=round(top1a_pgd_ae, 4),
        top5_pgd_ae=round(top5a_pgd_ae, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_fgm_ae'], results['top5_fgm_ae']))
    _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_pgd_ae'], results['top5_pgd_ae']))

    return results
Beispiel #5
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          train_mode='val',
                          fold_num=args.fold_num,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # top5 = AverageMeter()
    f1_m = AverageMeter()

    model.eval()
    last_idx = len(loader) - 1
    cuda = torch.device('cuda')
    temperature = nn.Parameter(torch.ones(1) *
                               1.5).to(cuda).detach().requires_grad_(True)

    m = nn.Sigmoid()
    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = _ECELoss().cuda()

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()

        logits_list = []
        target_list = []

        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, _ = accuracy(output.detach(), target, topk=(1, 1))

            logits_list.append(output)
            target_list.append(target)

            best_f1 = 0.0
            best_th = 1.0

            if last_batch:
                logits = torch.cat(logits_list).cuda()  ###
                targets = torch.cat(target_list).cuda()  ###

                targets_cpu = targets.cpu().numpy()
                sigmoided = m(logits)[:, 1].cpu().numpy()

                for i in range(1000, 0, -1):
                    th = i * 0.001
                    real_pred = (sigmoided >= th) * 1.0
                    f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

                    if f1 > best_f1:
                        best_f1 = f1
                        best_th = th

            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'thresh: {thresh:>7.4f}  '
                    'f1: {f1:>7.4f}'.format(batch_idx,
                                            len(loader),
                                            batch_time=batch_time,
                                            rate_avg=input.size(0) /
                                            batch_time.avg,
                                            loss=losses,
                                            top1=top1,
                                            thresh=best_th,
                                            f1=best_f1))

    print(best_th, best_f1)

    #for temp_scalilng
    if args.temp_scaling:

        #         before_temperature_ece = ece_criterion(logits, targets).item()
        #         before_temperature_nll = nll_criterion(logits, targets).item()
        #         print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        #         optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50)

        #         def eval():
        #             unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        #             loss = nll_criterion(logits/unsqueezed_temperature, targets)
        #             loss.backward()
        #             return loss
        #         optimizer.step(eval)

        #         unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))

        #         logits = logits/unsqueezed_temperature
        #         after_temperature_nll = nll_criterion(logits, targets).item()
        #         after_temperature_ece = ece_criterion(logits, targets).item()
        #         print('Optimal temperature: %.3f' % temperature.item())
        #         print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

        sigmoided = m(logits)[:, 1].detach().cpu().numpy()
        temperature = nn.Parameter(torch.ones(1) *
                                   11).to(cuda).detach().requires_grad_(False)

        logits = logits / temperature.unsqueeze(1).expand(
            logits.size(0), logits.size(1))
        targets_cpu = targets.cpu().numpy()
        sigmoided = m(logits)[:, 1].detach().cpu().numpy()

        best_f1 = 0.0
        best_th = 1.0
        for i in range(1000, 0, -1):
            th = i * 0.001
            real_pred = (sigmoided >= th) * 1.0
            f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

            if f1 > best_f1:
                best_f1 = f1
                best_th = th

        print(best_th, best_f1)

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, f1a = top1.avg, best_f1
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          f1=f1a,
                          f1_err=round(100 - f1a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['f1'],
        results['f1_err']))

    return results