コード例 #1
0
ファイル: build.py プロジェクト: microsoft/AutoML
def build_dataset(is_train, config):
    transform = build_transform(is_train, config)
    if config.DATA.DATASET == 'imagenet':
        prefix = 'train' if is_train else 'val'
        if config.DATA.LOAD_TAR:
            data_dir = os.path.join(config.DATA.DATA_PATH, f'{prefix}.tar')
            dataset = DatasetTar(data_dir, transform=transform)
        else:
            if config.DATA.ZIP_MODE:
                ann_file = prefix + "_map.txt"
                prefix = prefix + ".zip@/"
                dataset = CachedImageFolder(
                    config.DATA.DATA_PATH,
                    ann_file,
                    prefix,
                    transform,
                    cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
            else:
                root = os.path.join(config.DATA.DATA_PATH, prefix)
                dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 1000
    else:
        raise NotImplementedError("We only support ImageNet Now.")

    return dataset, nb_classes
コード例 #2
0
def validate(args):
    rng = jax.random.PRNGKey(0)
    model, variables = create_model(args.model, pretrained=True, rng=rng)
    print(f'Created {args.model} model. Validating...')

    if args.no_jit:
        eval_step = lambda images, labels: eval_forward(
            model, variables, images, labels)
    else:
        eval_step = jax.jit(lambda images, labels: eval_forward(
            model, variables, images, labels))

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data)
    else:
        dataset = Dataset(args.data)

    data_config = resolve_data_config(vars(args), model=model)
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=False,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=8,
                           crop_pct=data_config['crop_pct'])

    batch_time = AverageMeter()
    correct_top1, correct_top5 = 0, 0
    total_examples = 0
    start_time = prev_time = time.time()
    for batch_index, (images, labels) in enumerate(loader):
        images = images.numpy().transpose(0, 2, 3, 1)
        labels = labels.numpy()

        top1_count, top5_count = eval_step(images, labels)
        correct_top1 += top1_count
        correct_top5 += top5_count
        total_examples += images.shape[0]

        batch_time.update(time.time() - prev_time)
        if batch_index % 20 == 0 and batch_index > 0:
            print(
                f'Test: [{batch_index:>4d}/{len(loader)}]  '
                f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) '
                f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} '
                f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}')
        prev_time = time.time()

    acc_1 = 100 * correct_top1 / total_examples
    acc_5 = 100 * correct_top5 / total_examples
    print(
        f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. '
        f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}')
    return dict(top1=float(acc_1), top5=float(acc_5))
コード例 #3
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    #from torchvision.datasets import ImageNet
    #dataset = ImageNet(args.data, split='val')
    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
    else:
        dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 2))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
コード例 #4
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    if args.legacy_jit:
        set_jit_legacy()

    # create model
    if 'inception' in args.model:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            aux_logits=True,  # ! add aux loss
            in_chans=3,
            scriptable=args.torchscript)
    else:
        model = create_model(args.model,
                             pretrained=args.pretrained,
                             num_classes=args.num_classes,
                             in_chans=3,
                             scriptable=args.torchscript)

    # ! add more layer to classifier layer
    if args.create_classifier_layerfc:
        model.global_pool, model.classifier = create_classifier_layerfc(
            model.num_features, model.num_classes)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    if args.has_eval_label:
        criterion = nn.CrossEntropyLoss().cuda()  # ! don't have gold label

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map,
                          args=args)

    if args.valid_labels:
        with open(args.valid_labels,
                  'r') as f:  # @valid_labels is index numbering
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']

    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config[
            'interpolation'],  # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        auto_augment=args.aa,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        args=args)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    topk = AverageMeter()

    prediction = None  # ! need to save output
    true_label = None

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for batch_idx, (input,
                        target) in enumerate(loader):  # ! not have real label

            if args.has_eval_label:  # ! just save true labels anyway... why not
                if true_label is None: true_label = target.cpu().data.numpy()
                else:
                    true_label = np.concatenate(
                        (true_label, target.cpu().data.numpy()), axis=0)

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]  # ! some model returns both loss + aux loss

            if valid_labels is not None:
                output = output[:,
                                valid_labels]  # ! keep only valid labels ? good to eval by class.

            # ! save prediction, don't append too slow ... whatever ?
            # ! are names of files also sorted ?
            if prediction is None:
                prediction = output.cpu().data.numpy()  # batchsize x label
            else:  # stack
                prediction = np.concatenate(
                    (prediction, output.cpu().data.numpy()), axis=0)

            if real_labels is not None:
                real_labels.add_result(output)

            if args.has_eval_label:
                # measure accuracy and record loss
                loss = criterion(
                    output, target)  # ! don't have gold standard on testset
                acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                topk.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if args.has_eval_label and (batch_idx % args.log_freq == 0):
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        topk=topk))

    if not args.has_eval_label:
        top1a, topka = 0, 0  # just dummy, because we don't know ground labels
    else:
        if real_labels is not None:
            # real labels mode replaces topk values at the end
            top1a, topka = real_labels.get_accuracy(
                k=1), real_labels.get_accuracy(k=args.topk)
        else:
            top1a, topka = top1.avg, topk.avg

    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          topk=round(topka, 4),
                          topk_err=round(100 - topka, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['topk'],
        results['topk_err']))

    return results, prediction, true_label
コード例 #5
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    if args.neuron:
        model = torch.jit.load(args.checkpoint)
    else:
        model = create_model(args.model,
                             pretrained=args.pretrained,
                             num_classes=args.num_classes,
                             in_chans=3,
                             global_pool=args.gp,
                             scriptable=args.torchscript)

        if args.checkpoint:
            load_checkpoint(model, args.checkpoint, args.use_ema)

    if not args.neuron:
        param_count = sum([m.numel() for m in model.parameters()])
        _logger.info('Model %s created, param count: %d' %
                     (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if not args.neuron:
        model = model.cuda()
        if args.apex_amp:
            model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1 and not args.neuron:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss()
    if not args.neuron:
        criterion = criterion.cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) + data_config['input_size'])
        if not args.neuron:
            input = input.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                if not args.neuron:
                    target = target.cuda()
                    input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            if not args.neuron:
                with amp_autocast():
                    output = model(input)
            else:
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          top5=round(top5a, 4),
                          top5_err=round(100 - top5a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
コード例 #6
0
if is_server():
    DATA_ROOT = './.data/vision/imagenet'
else:
    # local settings
    DATA_ROOT = './'
DATA_FILENAME = 'ILSVRC2012_img_val.tar'
TAR_PATH = os.path.join(DATA_ROOT, DATA_FILENAME)

for m in model_list:
    model_name = m['model']
    # create model from name
    model = create_model(model_name, pretrained=True)
    param_count = sum([m.numel() for m in model.parameters()])
    print('Model %s, %s created. Param count: %d' % (model_name, m['paper_model_name'], param_count))

    dataset = DatasetTar(TAR_PATH)
    filenames = [os.path.splitext(f)[0] for f in dataset.filenames()]

    # get appropriate transform for model's default pretrained config
    data_config = resolve_data_config(m['args'], model=model, verbose=True)
    test_time_pool = False
    if m['ttp']:
        model, test_time_pool = apply_test_time_pool(model, data_config)
        data_config['crop_pct'] = 1.0

    batch_size = m['batch_size']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=batch_size,
        use_prefetcher=True,
コード例 #7
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          train_mode='val',
                          fold_num=args.fold_num,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # top5 = AverageMeter()
    f1_m = AverageMeter()

    model.eval()
    last_idx = len(loader) - 1
    cuda = torch.device('cuda')
    temperature = nn.Parameter(torch.ones(1) *
                               1.5).to(cuda).detach().requires_grad_(True)

    m = nn.Sigmoid()
    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = _ECELoss().cuda()

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()

        logits_list = []
        target_list = []

        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, _ = accuracy(output.detach(), target, topk=(1, 1))

            logits_list.append(output)
            target_list.append(target)

            best_f1 = 0.0
            best_th = 1.0

            if last_batch:
                logits = torch.cat(logits_list).cuda()  ###
                targets = torch.cat(target_list).cuda()  ###

                targets_cpu = targets.cpu().numpy()
                sigmoided = m(logits)[:, 1].cpu().numpy()

                for i in range(1000, 0, -1):
                    th = i * 0.001
                    real_pred = (sigmoided >= th) * 1.0
                    f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

                    if f1 > best_f1:
                        best_f1 = f1
                        best_th = th

            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'thresh: {thresh:>7.4f}  '
                    'f1: {f1:>7.4f}'.format(batch_idx,
                                            len(loader),
                                            batch_time=batch_time,
                                            rate_avg=input.size(0) /
                                            batch_time.avg,
                                            loss=losses,
                                            top1=top1,
                                            thresh=best_th,
                                            f1=best_f1))

    print(best_th, best_f1)

    #for temp_scalilng
    if args.temp_scaling:

        #         before_temperature_ece = ece_criterion(logits, targets).item()
        #         before_temperature_nll = nll_criterion(logits, targets).item()
        #         print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        #         optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50)

        #         def eval():
        #             unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        #             loss = nll_criterion(logits/unsqueezed_temperature, targets)
        #             loss.backward()
        #             return loss
        #         optimizer.step(eval)

        #         unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))

        #         logits = logits/unsqueezed_temperature
        #         after_temperature_nll = nll_criterion(logits, targets).item()
        #         after_temperature_ece = ece_criterion(logits, targets).item()
        #         print('Optimal temperature: %.3f' % temperature.item())
        #         print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

        sigmoided = m(logits)[:, 1].detach().cpu().numpy()
        temperature = nn.Parameter(torch.ones(1) *
                                   11).to(cuda).detach().requires_grad_(False)

        logits = logits / temperature.unsqueeze(1).expand(
            logits.size(0), logits.size(1))
        targets_cpu = targets.cpu().numpy()
        sigmoided = m(logits)[:, 1].detach().cpu().numpy()

        best_f1 = 0.0
        best_th = 1.0
        for i in range(1000, 0, -1):
            th = i * 0.001
            real_pred = (sigmoided >= th) * 1.0
            f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

            if f1 > best_f1:
                best_f1 = f1
                best_th = th

        print(best_th, best_f1)

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, f1a = top1.avg, best_f1
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          f1=f1a,
                          f1_err=round(100 - f1a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['f1'],
        results['f1_err']))

    return results
コード例 #8
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    if args.fp16:
        model = model.half()

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
    else:
        dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           fp16=args.fp16,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    c_matrix = np.zeros((40, 40), dtype=int)
    labels = np.arange(0, 40, 1)

    model.eval()
    end = time.time()
    with torch.no_grad():
        cf = open('results.csv', 'w')
        cv = open('results-parent.csv', 'w')
        writer = csv.writer(cf)
        writer_2 = csv.writer(cv)
        for i, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))
            c_matrix += cal_confusions(output, target, labels=labels)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            writer.writerow([i, round(top1.avg, 4)])
            # 计算大类分类准确率
            if args.hier_classify:
                a = [i for i in range(0, 6)]
                b = [i for i in range(6, 14)]
                c = [i for i in range(14, 37)]
                d = [i for i in range(37, 40)]
                corrects = 0.
                corrects += c_matrix[a][:, a].sum()
                corrects += c_matrix[b][:, b].sum()
                corrects += c_matrix[c][:, c].sum()
                corrects += c_matrix[d][:, d].sum()

                writer_2.writerow([i, round(corrects / c_matrix.sum(), 4)])
                logging.info('parent precision: {}'.format(corrects /
                                                           c_matrix.sum()))

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))
        cf.close()
        cv.close()

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    logging.info('confusion_matrix: \n {}'.format(c_matrix))
    logging.info('precision by confusion matrix: \n {}'.format(
        truediv(np.sum(np.diag(c_matrix)), np.sum(np.sum(c_matrix, axis=1)))))
    # with open('confusion_matrix.csv', 'w') as cf:
    #     writer = csv.writer(cf)
    #     for row in c_matrix:
    #         writer.writerow(row)
    #
    #     diag = np.diag(c_matrix)
    #     each_acc = truediv(diag, np.sum(c_matrix, axis=1))
    #     writer.writerow(each_acc)

    return results
コード例 #9
0
def validate(args):
    # might as well try to validate something
    args.pretrained = False
    args.prefetcher = True

    # create model
    model = eval(args.model)()

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, False)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=False)
    else:
        dataset = Dataset(args.data, load_bytes=False)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
コード例 #10
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    if args.distillation_type != 'none' and args.finetune and not args.eval:
        raise NotImplementedError("Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    if args.load_tar:
        train_dir = os.path.join(args.data_path, 'train.tar')
        train_transform = build_transform(True, args)
        dataset_train = DatasetTar(train_dir, transform=train_transform)
        args.nb_classes = 1000
        val_transform = build_transform(False, args)
        eval_dir = os.path.join(args.data_path, 'val.tar')
        dataset_val = DatasetTar(eval_dir, transform=val_transform)
    else:
        dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
        dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size),
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.finetune, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(
            model,
            decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '',
            resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    teacher_model = None
    if args.distillation_type != 'none':
        print(f"Creating teacher model: {args.teacher_model}")
        # teacher_pretrained is True when args.teacher_path is empty
        teacher_pretrained = not bool(args.teacher_path)
        teacher_model = create_model(
            args.teacher_model,
            pretrained=teacher_pretrained,
            num_classes=args.nb_classes,
            global_pool='avg',
        )
        if not teacher_pretrained:
            if args.teacher_path.startswith('https'):
                checkpoint = torch.hub.load_state_dict_from_url(
                    args.teacher_path, map_location='cpu', check_hash=True)
            else:
                checkpoint = torch.load(args.teacher_path, map_location='cpu')
            teacher_model.load_state_dict(checkpoint['model'])
        teacher_model.to(device)
        teacher_model.eval()

    # wrap the criterion in our custom DistillationLoss, which
    # just dispatches to the original criterion if args.distillation_type is 'none'
    criterion = DistillationLoss(
        criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
    )

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, model_ema, mixup_fn,
            set_training_mode=args.finetune == ''  # keep in eval mode during finetuning
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'model_ema': get_state_dict(model_ema),
                    'scaler': loss_scaler.state_dict(),
                    'args': args,
                }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
コード例 #11
0
ファイル: validate.py プロジェクト: Tiantian-Han/HardCoReNAS
def validate(args):
    args.pretrained = args.pretrained or (not args.checkpoint)
    args.prefetcher = not args.no_prefetcher
    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)
    logging.info(f'Validation data has {len(dataset)} images')
    args.num_classes = len(dataset.class_to_idx)
    logging.info(f'setting num classes to {args.num_classes}')

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained,
                         scriptable=args.torchscript,
                         resnet_structure=args.resnet_structure,
                         resnet_block=args.resnet_block,
                         heaviest_network=args.heaviest_network,
                         use_kernel_3=args.use_kernel_3,
                         exp_r=args.exp_r,
                         depth=args.depth,
                         reduced_exp_ratio=args.reduced_exp_ratio,
                         use_dedicated_pwl_se=args.use_dedicated_pwl_se,
                         multipath_sampling=args.multipath_sampling,
                         force_sync_gpu=args.force_sync_gpu,
                         mobilenet_string=args.mobilenet_string
                         if not args.transform_model_to_mobilenet else '',
                         no_swish=args.no_swish,
                         use_swish=args.use_swish)
    data_config = resolve_data_config(vars(args), model=model)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, True, strict=True)

    if 'mobilenasnet' in args.model and args.transform_model_to_mobilenet:
        model.eval()
        expected_latency = model.extract_expected_latency(
            file_name=args.lut_filename,
            batch_size=args.lut_measure_batch_size,
            iterations=args.repeat_measure,
            target=args.target_device)
        model.eval()
        model2, string_model = transform_model_to_mobilenet(
            model, mobilenet_string=args.mobilenet_string)
        del model
        model = model2
        model.eval()
        print('Model converted. Expected latency: {:0.2f}[ms]'.format(
            expected_latency * 1e3))

    elif args.normalize_weights:
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
        std = torch.tensor(IMAGENET_DEFAULT_STD).unsqueeze(0).unsqueeze(
            -1).unsqueeze(-1)
        mean = torch.tensor(IMAGENET_DEFAULT_MEAN).unsqueeze(0).unsqueeze(
            -1).unsqueeze(-1)
        W = model.conv_stem.weight.data
        bnw = model.bn1.weight.data
        bnb = model.bn1.bias.data
        model.conv_stem.weight.data = W / std
        bias = -bnw.data * (W.sum(dim=[-1, -2]) @ (mean / std).squeeze()) / (
            torch.sqrt(model.bn1.running_var + model.bn1.eps))
        model.bn1.bias.data = bnb + bias

    if args.fuse_bn:
        model = fuse_bn(model)

    if args.target_device == 'gpu':
        measure_time(model, batch_size=64, target='gpu')
        t = measure_time(model, batch_size=64, target='gpu')

    elif args.target_device == 'onnx':
        t = measure_time_onnx(model)

    else:
        measure_time(model)
        t = measure_time(model)

    param_count = sum([m.numel() for m in model.parameters()])
    flops = compute_flops(model, data_config['input_size'])
    logging.info(
        'Model {} created, param count: {}, flops: {}, Measured latency ({}): {:0.2f}[ms]'
        .format(args.model, param_count, flops / 1e9, args.target_device,
                t * 1e3))

    data_config = resolve_data_config(vars(args), model=model, verbose=False)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')

    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        squish=args.squish,
    )

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.cuda()
    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for i, (input, target) in enumerate(loader):
            if i == 0:
                end = time.time()

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()

            if args.amp:
                input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            k = min(5, args.num_classes)
            acc1, acc5 = accuracy(output.data, target, topk=(1, k))

            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results