logger.setLevel(logging.DEBUG)


if __name__ == '__main__':
    parser = ConfigArgumentParser(conflict_handler='resolve')
    parser.add_argument('--test-batch', type=int, default=32)
    parser.add_argument('--tta', type=str, default='center')
    parser.add_argument('--deform', type=str, default='')
    parser.add_argument('--corrupt', type=str, default='')
    args = parser.parse_args()

    assert args.dataset == 'imagenet'

    model_target = get_model(args.target_network, gpus=[0], num_classes=args.num_classes, train_aug=args.target_aug).eval()
    profiler = Profiler(model_target)
    print('target network, FLOPs=', profiler.flops(torch.zeros((1, 3, C.get()['target_size'], C.get()['target_size'])).cuda(), ))

    scaled_size = int(math.floor(args.target_size / 0.875))

    if args.deform != '':
        deform_type, deform_level = args.deform.split(' ')
        if deform_type in ['rotate', 'rotation']:
            t = torchvision.transforms.Lambda(lambda img_orig: torchvision.transforms.functional.rotate(img_orig, int(deform_level), resample=PIL.Image.BICUBIC))
        elif deform_type == 'bright':
            t = torchvision.transforms.Lambda(lambda img_orig: torchvision.transforms.functional.adjust_brightness(img_orig, float(deform_level)))
        elif deform_type == 'zoom':
            resize = int(scaled_size * float(deform_level))
            t = torchvision.transforms.Lambda(lambda img_orig: torchvision.transforms.functional.resize(img_orig, resize, interpolation=PIL.Image.BICUBIC))
        elif deform_type:
            raise ValueError('Invalid Deformation=%s' % deform_type)
    else: