logger.setLevel(logging.DEBUG) if __name__ == '__main__': parser = ConfigArgumentParser(conflict_handler='resolve') parser.add_argument('--test-batch', type=int, default=32) parser.add_argument('--tta', type=str, default='center') parser.add_argument('--deform', type=str, default='') parser.add_argument('--corrupt', type=str, default='') args = parser.parse_args() assert args.dataset == 'imagenet' model_target = get_model(args.target_network, gpus=[0], num_classes=args.num_classes, train_aug=args.target_aug).eval() profiler = Profiler(model_target) print('target network, FLOPs=', profiler.flops(torch.zeros((1, 3, C.get()['target_size'], C.get()['target_size'])).cuda(), )) scaled_size = int(math.floor(args.target_size / 0.875)) if args.deform != '': deform_type, deform_level = args.deform.split(' ') if deform_type in ['rotate', 'rotation']: t = torchvision.transforms.Lambda(lambda img_orig: torchvision.transforms.functional.rotate(img_orig, int(deform_level), resample=PIL.Image.BICUBIC)) elif deform_type == 'bright': t = torchvision.transforms.Lambda(lambda img_orig: torchvision.transforms.functional.adjust_brightness(img_orig, float(deform_level))) elif deform_type == 'zoom': resize = int(scaled_size * float(deform_level)) t = torchvision.transforms.Lambda(lambda img_orig: torchvision.transforms.functional.resize(img_orig, resize, interpolation=PIL.Image.BICUBIC)) elif deform_type: raise ValueError('Invalid Deformation=%s' % deform_type) else: