def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) dataset = create_dataset('imagenet', args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy().transpose(0, 2, 3, 1) labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += int(top1_count) correct_top5 += int(top5_count) total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model_KD = None if args.kd_model_path is not None: model_KD = build_kd_model(args) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy( target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, model_KD=model_KD) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset(root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n') _logger.info("Argument parser collected the following arguments:") for arg in vars(args): _logger.info(f" {arg}:{getattr(args, arg)}") _logger.info("\n") # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True else: _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info( f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)' ) data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset( root=args.data_dir, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() top1_fgm_ae = AverageMeter() top5_fgm_ae = AverageMeter() top1_pgd_ae = AverageMeter() top5_pgd_ae = AverageMeter() model.eval() #with torch.no_grad():# TODO Requires grad # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # TODO <--------------------- # Generate adversarial examples for current inputs input_fgm_ae = fast_gradient_method( model_fn=model, x=input, eps=args.eps, norm=np.inf, clip_min=None, clip_max=None, ) input_pgd_ae = projected_gradient_descent( model_fn=model, x=input, eps=args.eps, eps_iter=0.01, nb_iter=40, norm=np.inf, clip_min=None, clip_max=None, ) # Predict with Adversarial Examples with torch.no_grad(): with amp_autocast(): output_fgm_ae = model(input_fgm_ae) output_pgd_ae = model(input_pgd_ae) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5)) acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5)) top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0)) top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0)) top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0)) top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: raise NotImplementedError # TODO NOt modified for the adversarial examples mode # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), top1_fgm_ae=round(top1a_fgm_ae, 4), top5_fgm_ae=round(top5a_fgm_ae, 4), top1_pgd_ae=round(top1a_pgd_ae, 4), top5_pgd_ae=round(top5a_pgd_ae, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_fgm_ae'], results['top5_fgm_ae'])) _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_pgd_ae'], results['top5_pgd_ae'])) return results
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) #################################################################################### # Start - SparseML optional load weights from SparseZoo #################################################################################### if args.initial_checkpoint == "zoo": # Load checkpoint from base weights associated with given SparseZoo recipe if args.sparseml_recipe.startswith("zoo:"): args.initial_checkpoint = Zoo.download_recipe_base_framework_files( args.sparseml_recipe, extensions=[".pth.tar", ".pth"] )[0] else: raise ValueError( "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When initial-checkpoint is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " f"stub. sparseml-recipe was set to {args.sparseml_recipe}" ) elif args.initial_checkpoint.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint) args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"]) #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) #################################################################################### # Start SparseML Integration #################################################################################### sparseml_loggers = ( [PythonLogger(), TensorBoardLogger(log_path=output_dir)] if output_dir else None ) manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe) optimizer = ScheduledOptimizer( optimizer, model, manager, steps_per_epoch=len(loader_train), loggers=sparseml_loggers ) # override lr scheduler if recipe makes any LR updates if any("LearningRate" in str(modifier) for modifier in manager.modifiers): _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe") lr_scheduler = None if manager.max_epochs: _logger.info( f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe" ) num_epochs = manager.max_epochs or num_epochs #################################################################################### # End SparseML Integration #################################################################################### if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) ################################################################################# # Start SparseML ONNX Export ################################################################################# if output_dir: _logger.info( f"training complete, exporting ONNX to {output_dir}/model.onnx" ) exporter = ModuleExporter(model, output_dir) exporter.export_onnx(torch.randn((1, *data_config["input_size"]))) ################################################################################# # End SparseML ONNX Export ################################################################################# except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # amp_autocast = suppress # do nothing # if args.amp: # if has_native_amp: # args.native_amp = True # elif has_apex: # args.apex_amp = True # else: # _logger.warning("Neither APEX or Native Torch AMP is available.") # assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." # if args.native_amp: # amp_autocast = torch.cuda.amp.autocast # _logger.info('Validating in mixed precision with native PyTorch AMP.') # elif args.apex_amp: # _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') # else: # _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) # model = model.cuda() # if args.apex_amp: # model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) # if args.num_gpu > 1: # model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) # criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss() dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) # added for post quantization calibration calib_dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) #Also create loader for calibration dataset calib_loader = create_loader( calib_dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() print('Start calibration of quantization observers before post-quantization') model_to_quantize = copy.deepcopy(model) model_to_quantize.eval() #post training static quantization if args.quant_option == 'static': qconfig_dict = {"": torch.quantization.default_static_qconfig} model_to_quantize = copy.deepcopy(model_fp) qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')} model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # calibrate with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) #post training dynamic/weight only quantization elif args.quant_option == 'dynamic': qconfig_dict = {"": torch.quantization.default_dynamic_qconfig} # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # no calibration needed when we only have dynamici/weight_only quantization # quantize model_quantized = quantize_fx.convert_fx(model_prepared) else: _logger.warning("Invalid quantization option. Set option to default(static)") # # fusion # model_to_quantize = copy.deepcopy(model_fp) model_fused = quantize_fx.fuse_fx(model_to_quantize) model = model_fused with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non # input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # if args.no_prefetcher: # target = target.cuda() # input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output # with amp_autocast(): # output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def data_creator(config): # torch.manual_seed(args.seed + torch.distributed.get_rank()) args = config["args"] train_dir = join(args.data_dir, "train") val_dir = join(args.data_dir, "val") if args.mock_data: util.mock_data(train_dir, val_dir) # todo: verbose should depend on rank data_config = resolve_data_config(vars(args), verbose=True) dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) collate_fn = None if args.prefetcher and args.mixup > 0: # collate conflict (need to support deinterleaving in collate mixup) assert args.num_aug_splits == 0 collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) common_params = dict(input_size=data_config["input_size"], use_prefetcher=args.prefetcher, mean=data_config["mean"], std=data_config["std"], num_workers=1, distributed=args.distributed, pin_memory=args.pin_mem) train_loader = create_loader( dataset_train, is_training=True, batch_size=config[BATCH_SIZE], re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, collate_fn=collate_fn, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.train_interpolation, num_aug_splits=args.num_aug_splits, # always 0 right now **common_params) eval_loader = create_loader( dataset_eval, is_training=False, batch_size=args.validation_batch_size_multiplier * config[BATCH_SIZE], interpolation=data_config["interpolation"], crop_pct=data_config["crop_pct"], **common_params) return train_loader, eval_loader