def data_creator(config): # torch.manual_seed(args.seed + torch.distributed.get_rank()) args = config["args"] train_dir = join(args.data, "train") val_dir = join(args.data, "val") if args.mock_data: util.mock_data(train_dir, val_dir) # todo: verbose should depend on rank data_config = resolve_data_config(vars(args), verbose=True) dataset_train = Dataset(join(args.data, "train")) dataset_eval = Dataset(join(args.data, "val")) collate_fn = None if args.prefetcher and args.mixup > 0: # collate conflict (need to support deinterleaving in collate mixup) assert args.num_aug_splits == 0 collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) common_params = dict( input_size=data_config["input_size"], use_prefetcher=args.prefetcher, mean=data_config["mean"], std=data_config["std"], num_workers=1, distributed=args.distributed, pin_memory=args.pin_mem) train_loader = create_loader( dataset_train, is_training=True, batch_size=config[BATCH_SIZE], re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, collate_fn=collate_fn, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.train_interpolation, num_aug_splits=args.num_aug_splits, # always 0 right now **common_params) eval_loader = create_loader( dataset_eval, is_training=False, batch_size=args.validation_batch_size_multiplier * config[BATCH_SIZE], interpolation=data_config["interpolation"], crop_pct=data_config["crop_pct"], **common_params) return train_loader, eval_loader
def _evaluate(self, val_data, metric_name=None): if self._problem_type == REGRESSION: validate_loss_fn = nn.MSELoss() else: validate_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = validate_loss_fn.to(self.ctx[0]) val_data = val_data.to_torch() val_loader = create_loader( val_data, input_size=self._data_cfg.input_size, batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size, is_training=False, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, crop_pct=self._data_cfg.crop_pct, pin_memory=self._misc_cfg.pin_mem, ) return self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast, metric_name=metric_name)
def _init_dataloader(self): """Init dataloader from timm.""" if self.horovod and hvd.local_rank() == 0: FileOps.copy_folder(self.cfg.dataset.remote_data_dir, self.cfg.dataset.data_dir) if self.horovod: hvd.join() args = self.cfg.dataset train_dir = os.path.join(self.cfg.dataset.data_dir, 'train') dataset_train = Dataset(train_dir) world_size, rank = None, None if self.horovod: world_size, rank = hvd.size(), hvd.rank() self.train_loader = create_loader(dataset_train, input_size=tuple(args.input_size), batch_size=args.batch_size, is_training=True, use_prefetcher=self.cfg.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, rand_erase_count=args.recount, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='random', mean=tuple(args.mean), std=tuple(args.std), num_workers=args.workers, distributed=self.horovod, world_size=world_size, rank=rank) valid_dir = os.path.join(self.cfg.dataset.data_dir, 'val') dataset_eval = Dataset(valid_dir) self.valid_loader = create_loader(dataset_eval, input_size=tuple(args.input_size), batch_size=4 * args.batch_size, is_training=False, use_prefetcher=self.cfg.prefetcher, interpolation=args.interpolation, mean=tuple(args.mean), std=tuple(args.std), num_workers=args.workers, distributed=self.horovod, world_size=world_size, rank=rank)
def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data) else: dataset = Dataset(args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy().transpose(0, 2, 3, 1) labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def validate(args): model = create_model(args.model, pretrained=True) print(f'Created {args.model} model. Validating...') eval_step = objax.Jit( lambda images, labels: eval_forward(model, images, labels), model.vars()) dataset = create_dataset('imagenet', args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy() labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += int(top1_count) correct_top5 += int(top5_count) total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def _predict_feature(self, x, **kwargs): if isinstance(x, str): return self._predict_feature((x,)) elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict_feature(tuple(x['image'])) df = df.set_index(x.index) df['image'] = x['image'] return df elif isinstance(x, (list, tuple)): assert isinstance(x[0], str), "expect image paths in list/tuple input" loader = create_loader( ImageListDataset(x), input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, crop_pct=self._data_cfg.crop_pct ) self.net.eval() results = [] with torch.no_grad(): for input, _ in loader: input = input.to(self.ctx[0]) try: features = self.net.forward_features(input) except AttributeError: features = self.net.module.forward_features(input) for f in features: f = f.cpu().numpy().flatten() results.append({'image_feature': f}) df = pd.DataFrame(results) df['image'] = x return df elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) with torch.no_grad(): input = x.to(self.ctx[0]) feature = self.net.forward_features(input) result = [{'image_feature': feature}] df = pd.DataFrame(result) return df
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32, resize_scale=0.08, distort_color=None, image_size=224, tf_preprocessing=False, num_replicas=None, rank=None, use_prefetcher=False, pin_memory=False, fp16=False): warnings.filterwarnings('ignore') dataset = Dataset(os.path.join(save_path, "val"), load_bytes=tf_preprocessing) dummy_model = create_model('efficientnet_b0') data_config = resolve_data_config({}, model=dummy_model) test_loader = create_loader(dataset, input_size=image_size, batch_size=test_batch_size, use_prefetcher=use_prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=n_worker, crop_pct=data_config['crop_pct'], pin_memory=pin_memory, fp16=fp16, tf_preprocessing=None) self.test = test_loader
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) #################################################################################### # Start - SparseML optional load weights from SparseZoo #################################################################################### if args.initial_checkpoint == "zoo": # Load checkpoint from base weights associated with given SparseZoo recipe if args.sparseml_recipe.startswith("zoo:"): args.initial_checkpoint = Zoo.download_recipe_base_framework_files( args.sparseml_recipe, extensions=[".pth.tar", ".pth"] )[0] else: raise ValueError( "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When initial-checkpoint is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " f"stub. sparseml-recipe was set to {args.sparseml_recipe}" ) elif args.initial_checkpoint.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint) args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"]) #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) #################################################################################### # Start SparseML Integration #################################################################################### sparseml_loggers = ( [PythonLogger(), TensorBoardLogger(log_path=output_dir)] if output_dir else None ) manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe) optimizer = ScheduledOptimizer( optimizer, model, manager, steps_per_epoch=len(loader_train), loggers=sparseml_loggers ) # override lr scheduler if recipe makes any LR updates if any("LearningRate" in str(modifier) for modifier in manager.modifiers): _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe") lr_scheduler = None if manager.max_epochs: _logger.info( f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe" ) num_epochs = manager.max_epochs or num_epochs #################################################################################### # End SparseML Integration #################################################################################### if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) ################################################################################# # Start SparseML ONNX Export ################################################################################# if output_dir: _logger.info( f"training complete, exporting ONNX to {output_dir}/model.onnx" ) exporter = ModuleExporter(model, output_dir) exporter.export_onnx(torch.randn((1, *data_config["input_size"]))) ################################################################################# # End SparseML ONNX Export ################################################################################# except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def main(): setup_default_logging() args = parser.parse_args() # might as well try to do something useful... args.pretrained = args.pretrained or not args.checkpoint # create model model = create_model(args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained, checkpoint_path=args.checkpoint) _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool(model, config) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range( args.num_gpu))).cuda() else: model = model.cuda() loader = create_loader( ImageDataset(args.data), input_size=config['input_size'], batch_size=args.batch_size, use_prefetcher=True, interpolation=config['interpolation'], mean=config['mean'], std=config['std'], num_workers=args.workers, crop_pct=1.0 if test_time_pool else config['crop_pct']) model.eval() k = min(args.topk, args.num_classes) batch_time = AverageMeter() end = time.time() topk_ids = [] with torch.no_grad(): for batch_idx, (input, _) in enumerate(loader): input = input.cuda() labels = model(input) topk = labels.topk(k)[1] topk_ids.append(topk.cpu().numpy()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})' .format(batch_idx, len(loader), batch_time=batch_time)) topk_ids = np.concatenate(topk_ids, axis=0).squeeze() with open(os.path.join(args.output_dir, 'topk_ids.csv'), 'w') as out_file: filenames = loader.dataset.filenames(basename=True) for filename, label in zip(filenames, topk_ids): out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( filename, label[0], label[1], label[2], label[3], label[4]))
def main(): args, cfg = parse_config_args('child net testing') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, 'test.log')) writer = SummaryWriter(os.path.join(output_dir, 'runs')) else: writer, logger = None, None # retrain model selection if cfg.NET.SELECTION == 470: arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1], [3, 3, 3, 3], [3, 3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 elif cfg.NET.SELECTION == 42: arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 96 elif cfg.NET.SELECTION == 14: arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] cfg.DATASET.IMAGE_SIZE = 64 elif cfg.NET.SELECTION == 112: arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 160 elif cfg.NET.SELECTION == 285: arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 elif cfg.NET.SELECTION == 600: arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 else: raise ValueError("Model Test Selection is not Supported!") # define childnet architecture from arch_list stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] choice_block_pool = [ 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s2_e6_c192_se0.25' ] arch_def = [[stem[0]]] + [[ choice_block_pool[idx] for repeat_times in range(len(arch_list[idx + 1])) ] for idx in range(len(choice_block_pool))] + [[stem[1]]] # generate childnet model = gen_childnet(arch_list, arch_def, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP) if args.local_rank == 0: macs, params = get_model_flops_params( model, input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) logger.info('[Model-{}] Flops: {} Params: {}'.format( cfg.NET.SELECTION, macs, params)) # initialize distributed parameters torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info("Training on Process {} with {} GPUs.".format( args.local_rank, cfg.NUM_GPU)) # resume model from checkpoint assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH) _, __ = resume_checkpoint(model, cfg.RESUME_PATH) model = model.cuda() model_ema = None if cfg.NET.EMA.USE: # Important to create EMA model after cuda(), DP wrapper, and AMP but # before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=cfg.NET.EMA.DECAY, device='cpu' if cfg.NET.EMA.FORCE_CPU else '', resume=cfg.RESUME_PATH) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.exists(eval_dir) and args.local_rank == 0: logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, is_training=False, num_workers=cfg.WORKERS, distributed=True, pin_memory=cfg.DATASET.PIN_MEM, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) # only test accuracy of model-EMA validate_loss_fn = nn.CrossEntropyLoss().cuda() validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg, log_suffix='_EMA', logger=logger, writer=writer, local_rank=args.local_rank)
def main(): args, cfg = parse_config_args('child net training') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, 'retrain.log')) writer = SummaryWriter(os.path.join(output_dir, 'runs')) else: writer, logger = None, None # retrain model selection if cfg.NET.SELECTION == 481: arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1, 1], [3, 3, 3, 3], [3, 3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 elif cfg.NET.SELECTION == 43: arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 96 elif cfg.NET.SELECTION == 14: arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] cfg.DATASET.IMAGE_SIZE = 64 elif cfg.NET.SELECTION == 114: arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 160 elif cfg.NET.SELECTION == 287: arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 elif cfg.NET.SELECTION == 604: arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 else: raise ValueError("Model Retrain Selection is not Supported!") # define childnet architecture from arch_list stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] choice_block_pool = [ 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25' ] arch_def = [[stem[0]]] + [[ choice_block_pool[idx] for repeat_times in range(len(arch_list[idx + 1])) ] for idx in range(len(choice_block_pool))] + [[stem[1]]] # generate childnet model = gen_childnet(arch_list, arch_def, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP) # initialize training parameters eval_metric = cfg.EVAL_METRICS best_metric, best_epoch, saver = None, None, None # initialize distributed parameters distributed = cfg.NUM_GPU > 1 torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info('Training on Process {} with {} GPUs.'.format( args.local_rank, cfg.NUM_GPU)) # fix random seeds torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) np.random.seed(cfg.SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # get parameters and FLOPs of model if args.local_rank == 0: macs, params = get_model_flops_params( model, input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) logger.info('[Model-{}] Flops: {} Params: {}'.format( cfg.NET.SELECTION, macs, params)) # create optimizer model = model.cuda() optimizer = create_optimizer(cfg, model) # optionally resume from a checkpoint resume_state, resume_epoch = {}, None if cfg.AUTO_RESUME: resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH) optimizer.load_state_dict(resume_state['optimizer']) del resume_state model_ema = None if cfg.NET.EMA.USE: model_ema = ModelEma( model, decay=cfg.NET.EMA.DECAY, device='cpu' if cfg.NET.EMA.FORCE_CPU else '', resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None) if distributed: if cfg.BATCHNORM.SYNC_BN: try: if HAS_APEX: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logger.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: if args.local_rank == 0: logger.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}' .format(e)) if HAS_APEX: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # imagenet train dataset train_dir = os.path.join(cfg.DATA_DIR, 'train') if not os.path.exists(train_dir) and args.local_rank == 0: logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) loader_train = create_loader(dataset_train, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.BATCH_SIZE, is_training=True, color_jitter=cfg.AUGMENTATION.COLOR_JITTER, auto_augment=cfg.AUGMENTATION.AA, num_aug_splits=0, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=cfg.WORKERS, distributed=distributed, collate_fn=None, pin_memory=cfg.DATASET.PIN_MEM, interpolation='random', re_mode=cfg.AUGMENTATION.RE_MODE, re_prob=cfg.AUGMENTATION.RE_PROB) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.exists(eval_dir) and args.local_rank == 0: logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, is_training=False, interpolation='bicubic', crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=cfg.WORKERS, distributed=distributed, pin_memory=cfg.DATASET.PIN_MEM) # whether to use label smoothing if cfg.AUGMENTATION.SMOOTHING > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn # create learning rate scheduler lr_scheduler, num_epochs = create_scheduler(cfg, optimizer) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: {}'.format(num_epochs)) try: best_record, best_ep = 0, 0 for epoch in range(start_epoch, num_epochs): if distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, cfg, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, model_ema=model_ema, logger=logger, writer=writer, local_rank=args.local_rank) eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, cfg, logger=logger, writer=writer, local_rank=args.local_rank) if model_ema is not None and not cfg.NET.EMA.FORCE_CPU: ema_eval_metrics = validate(epoch, model_ema.ema, loader_eval, validate_loss_fn, cfg, log_suffix='_EMA', logger=logger, writer=writer, local_rank=args.local_rank) eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, cfg, epoch=epoch, model_ema=model_ema, metric=save_metric) if best_record < eval_metrics[eval_metric]: best_record = eval_metrics[eval_metric] best_ep = epoch if args.local_rank == 0: logger.info('*** Best metric: {0} (epoch {1})'.format( best_record, best_ep)) except KeyboardInterrupt: pass if best_metric is not None: logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): get_logger("./") args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) logging.info("Exponential : {}".format(args.model_ema_decay)) logging.info("Color Jitter : {}".format(args.color_jitter)) logging.info("Model EMA Decay : {}".format(args.model_ema_decay)) torch.manual_seed(args.seed + args.rank) model = eval(args.model)(config_path=args.config_path, target_flops=args.target_flops, num_classes=args.num_classes, bn_momentum=args.bn_momentum, activation=args.activation, se=args.se) if os.path.exists(args.initial_checkpoint): load_checkpoint(model, args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if args.resume: optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() logging.info(args.weight_decay) optimizer = create_optimizer(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state["optimizer"]) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but # before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) if args.lmdb: train_dir = os.path.join(args.data, 'train_lmdb', 'train.lmdb') dataset_train = ImageFolderLMDB(train_dir, None, None) else: train_dir = os.path.join(args.data, 'train') dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, ) if args.lmdb: eval_dir = os.path.join(args.data, 'test_lmdb', 'test.lmdb') dataset_eval = ImageFolderLMDB(eval_dir, None, None) else: eval_dir = os.path.join(args.data, 'val') dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def train_imagenet_dq(): setup_default_logging() args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') args.num_gpu = 1 args.device = xm.xla_device() args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) device = xm.xla_device() model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, drop_connect_rate=0.2, checkpoint_path=args.initial_checkpoint, args = args).to(device) flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=args.display_info) print('Flops: ' + flops) print('Params: ' + params) if args.KD_train: teacher_model = create_model( "efficientnet_b7_dq", pretrained=True, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, drop_connect_rate=0.2, checkpoint_path=args.initial_checkpoint, args = args) flops_teacher, params_teacher = get_model_complexity_info(teacher_model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) print("Using KD training...") print("FLOPs of teacher model: ", flops_teacher) print("Params of teacher model: ", params_teacher) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) # optionally resume from a checkpoint start_epoch = 0 optimizer_state = None if args.resume: optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch) # import pdb;pdb.set_trace() torch.manual_seed(42) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') args.amp = False # device = xm.xla_device() # devices = ( # xm.get_xla_supported_devices( # max_devices=num_cores) if num_cores != 0 else []) # model = nn.DataParallel(model, device_ids=devices).cuda() # model = model.to(device) if args.KD_train: teacher_model = nn.DataParallel(teacher_model, device_ids=list(range(args.num_gpu))).cuda() else: # device = xm.xla_device() # model = model.to(device) if args.KD_train: teacher_model.cuda() optimizer = create_optimizer(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # import pdb; pdb.set_trace() model_e = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, drop_connect_rate=0.2, checkpoint_path=args.initial_checkpoint, args = args).to(device) model_ema = ModelEma( model_e, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logging.info('Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if args.auto_augment: print('using auto data augumentation...') loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, interpolation='bicubic', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, use_auto_aug=args.auto_augment, use_mixcut=args.mixcut, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): logging.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size = args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy() validate_loss_fn = nn.CrossEntropyLoss() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) validate_loss_fn = nn.CrossEntropyLoss() else: train_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = train_loss_fn if args.KD_train: train_loss_fn = nn.KLDivLoss(reduction='batchmean') eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) def train_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, teacher_model = None, loader_len=0): if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: loader.mixup_enabled = False batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() if args.KD_train: teacher_model.eval() end = time.time() last_idx = loader_len - 1 num_updates = epoch * loader_len for batch_idx, (input, target) in loader: last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: # input = input.cuda() # target = target.cuda() if args.mixup > 0.: lam = 1. if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: lam = np.random.beta(args.mixup, args.mixup) input.mul_(lam).add_(1 - lam, input.flip(0)) target = mixup_target(target, args.num_classes, lam, args.smoothing) r = np.random.rand(1) if args.beta > 0 and r < args.cutmix_prob: # generate mixed sample lam = np.random.beta(args.beta, args.beta) rand_index = torch.randperm(input.size()[0]) target_a = target target_b = target[rand_index] bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2] # adjust lambda to exactly match pixel ratio lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) # compute output input_var = torch.autograd.Variable(input, requires_grad=True) target_a_var = torch.autograd.Variable(target_a) target_b_var = torch.autograd.Variable(target_b) output = model(input_var) loss = loss_fn(output, target_a_var) * lam + loss_fn(output, target_b_var) * (1. - lam) else: # NOTE KD Train is exclusive with mixcut, FIX it later output = model(input) if args.KD_train: # teacher_model.cuda() teacher_outputs_tmp = [] assert(input.shape[0]%args.teacher_step == 0) step_size = int(input.shape[0]//args.teacher_step) with torch.no_grad(): for k in range(0,int(input.shape[0]),step_size): input_tmp = input[k:k+step_size,:,:,:] teacher_outputs_tmp.append(teacher_model(input_tmp)) # torch.cuda.empty_cache() # import pdb; pdb.set_trace() teacher_outputs = torch.cat(teacher_outputs_tmp) alpha = args.KD_alpha T = args.KD_temperature loss = loss_fn(F.log_softmax(output/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ F.cross_entropy(output, target) * (1. - alpha) else: loss = loss_fn(output, target) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if use_amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() #optimizer.step() xm.optimizer_step(optimizer) # torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: logging.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, loader_len, 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery( model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() return OrderedDict([('loss', losses_m.avg)]) def validate(model, loader, loss_fn, args, log_suffix='',loader_len=0): batch_time_m = AverageMeter() losses_m = AverageMeter() prec1_m = AverageMeter() prec5_m = AverageMeter() model.eval() end = time.time() last_idx = loader_len - 1 with torch.no_grad(): for batch_idx, (input, target) in loader: last_batch = batch_idx == last_idx # if not args.prefetcher: # input = input.cuda() # target = target.cuda() output = model(input) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) prec1 = reduce_tensor(prec1, args.world_size) prec5 = reduce_tensor(prec5, args.world_size) else: reduced_loss = loss.data # torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) prec1_m.update(prec1.item(), output.size(0)) prec5_m.update(prec5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix logging.info( '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=prec1_m, top5=prec5_m)) metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) return metrics try: # import pdb;pdb.set_trace() for epoch in range(start_epoch, num_epochs): loader_len=len(loader_train) if args.distributed: loader_train.sampler.set_epoch(epoch) # import pdb; pdb.set_trace() if args.KD_train: train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, teacher_model = teacher_model) else: para_loader = dp.ParallelLoader(loader_train, [device]) train_metrics = train_epoch( epoch, model, para_loader.per_device_loader(device), optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, loader_len=loader_len) # def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, # overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, # per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None, # scale_approx_mult_bits=None): # import distiller # import pdb; pdb.set_trace() # quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args) # quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape)) # quantizer = distiller.quantization.PostTrainLinearQuantizer(model, bits_activations=8, bits_parameters=8) # quantizer.prepare_model() # distiller.utils.assign_layer_fq_names(model) # # msglogger.info("Generating quantization calibration stats based on {0} users".format(args.qe_calibration)) # collector = distiller.data_loggers.QuantCalibrationStatsCollector(model) # with collector_context(collector): # eval_metrics = validate(model, loader_eval, validate_loss_fn, args) # # Here call your model evaluation function, making sure to execute only # # the portion of the dataset specified by the qe_calibration argument # yaml_path = './dir/quantization_stats.yaml' # collector.save(yaml_path) loader_len_val = len(loader_eval) para_loader = dp.ParallelLoader(loader_eval, [device]) eval_metrics = validate(model, para_loader.per_device_loader(device), validate_loss_fn, args, loader_len=loader_len_val) if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch + 1, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # amp_autocast = suppress # do nothing # if args.amp: # if has_native_amp: # args.native_amp = True # elif has_apex: # args.apex_amp = True # else: # _logger.warning("Neither APEX or Native Torch AMP is available.") # assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." # if args.native_amp: # amp_autocast = torch.cuda.amp.autocast # _logger.info('Validating in mixed precision with native PyTorch AMP.') # elif args.apex_amp: # _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') # else: # _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) # model = model.cuda() # if args.apex_amp: # model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) # if args.num_gpu > 1: # model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) # criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss() dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) # added for post quantization calibration calib_dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) #Also create loader for calibration dataset calib_loader = create_loader( calib_dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() print('Start calibration of quantization observers before post-quantization') model_to_quantize = copy.deepcopy(model) model_to_quantize.eval() #post training static quantization if args.quant_option == 'static': qconfig_dict = {"": torch.quantization.default_static_qconfig} model_to_quantize = copy.deepcopy(model_fp) qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')} model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # calibrate with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) #post training dynamic/weight only quantization elif args.quant_option == 'dynamic': qconfig_dict = {"": torch.quantization.default_dynamic_qconfig} # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # no calibration needed when we only have dynamici/weight_only quantization # quantize model_quantized = quantize_fx.convert_fx(model_prepared) else: _logger.warning("Invalid quantization option. Set option to default(static)") # # fusion # model_to_quantize = copy.deepcopy(model_fp) model_fused = quantize_fx.fuse_fx(model_to_quantize) model = model_fused with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non # input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # if args.no_prefetcher: # target = target.cuda() # input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output # with amp_autocast(): # output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset(root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
filenames = [os.path.splitext(f)[0] for f in dataset.filenames()] # get appropriate transform for model's default pretrained config data_config = resolve_data_config(m['args'], model=model, verbose=True) test_time_pool = False if m['ttp']: model, test_time_pool = apply_test_time_pool(model, data_config) data_config['crop_pct'] = 1.0 batch_size = m['batch_size'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=batch_size, use_prefetcher=True, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=6, crop_pct=data_config['crop_pct'], pin_memory=True) evaluator = ImageNetEvaluator( root=DATA_ROOT, model_name=m['paper_model_name'], paper_arxiv_id=m['paper_arxiv_id'], model_description=m.get('model_description', None), ) model.cuda() model.eval() with torch.no_grad():
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, use_cos_reg=args.cos_reg_component > 0, checkpoint_path=args.initial_checkpoint) with torch.cuda.device(0): input = torch.randn(1, 3, 224, 224) size_for_madd = 224 if args.img_size is None else args.img_size # flops, params = get_model_complexity_info(model, (3, size_for_madd, size_for_madd), as_strings=True, print_per_layer_stat=True) # print("=>Flops: " + flops) # print("=>Params: " + params) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) if args.use_lmdb: dataset_train = ImageFolderLMDB('../dataset_lmdb/train') else: dataset_train = Dataset(train_dir) # dataset_train = Dataset(train_dir) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) if args.use_lmdb: dataset_eval = ImageFolderLMDB('../dataset_lmdb/val') else: dataset_eval = Dataset(eval_dir) # dataset_eval = Dataset(eval_dir) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, repeated_aug=args.use_repeated_aug, world_size=args.world_size, rank=args.rank) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) loader_cali = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.cali_batch_size, is_training=False, use_prefetcher=args.prefetcher, no_aug=True, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=None, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, repeated_aug=args.use_repeated_aug, world_size=args.world_size, rank=args.rank) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform if args.cos_reg_component > 0: args.use_cos_reg_component = True train_loss_fn = SoftTargetCrossEntropyCosReg( n_comn=args.cos_reg_component).cuda() else: train_loss_fn = SoftTargetCrossEntropy().cuda() args.use_cos_reg_component = False elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) code_dir = get_outdir(output_dir, 'code') copy_tree(os.getcwd(), code_dir) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) if not args.eval_only: train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') if args.max_iter > 0: _ = validate(model, loader_cali, validate_loss_fn, args, amp_autocast=amp_autocast, use_bn_calibration=True) eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if not args.eval_only: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) if args.eval_only: break except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, drop_connect_rate=0.2, checkpoint_path=args.initial_checkpoint, args=args) flops, params = get_model_complexity_info( model, (3, 224, 224), as_strings=True, print_per_layer_stat=args.display_info) print('Flops: ' + flops) print('Params: ' + params) if args.KD_train: teacher_model = create_model("efficientnet_b7_dq", pretrained=True, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, drop_connect_rate=0.2, checkpoint_path=args.initial_checkpoint, args=args) flops_teacher, params_teacher = get_model_complexity_info( teacher_model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) print("Using KD training...") print("FLOPs of teacher model: ", flops_teacher) print("Params of teacher model: ", params_teacher) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) # optionally resume from a checkpoint start_epoch = 0 optimizer_state = None if args.resume: optimizer_state, start_epoch = resume_checkpoint( model, args.resume, args.start_epoch) # import pdb;pdb.set_trace() if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() if args.KD_train: teacher_model = nn.DataParallel(teacher_model, device_ids=list(range( args.num_gpu))).cuda() else: model.cuda() if args.KD_train: teacher_model.cuda() optimizer = create_optimizer(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # import pdb; pdb.set_trace() model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if args.auto_augment: print('using auto data augumentation...') loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, interpolation= 'bicubic', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, use_auto_aug=args.auto_augment, use_mixcut=args.mixcut, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn if args.KD_train: train_loss_fn = nn.KLDivLoss(reduction='batchmean').cuda() eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) try: # import pdb;pdb.set_trace() for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) # import pdb; pdb.set_trace() if args.KD_train: train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, teacher_model=teacher_model) else: train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) # def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, # overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, # per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None, # scale_approx_mult_bits=None): # import distiller # import pdb; pdb.set_trace() # quantizer = quantization.PostTrainLinearQuantizer.from_args(model, args) # quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape)) # quantizer = distiller.quantization.PostTrainLinearQuantizer(model, bits_activations=8, bits_parameters=8) # quantizer.prepare_model() # distiller.utils.assign_layer_fq_names(model) # # msglogger.info("Generating quantization calibration stats based on {0} users".format(args.qe_calibration)) # collector = distiller.data_loggers.QuantCalibrationStatsCollector(model) # with collector_context(collector): # eval_metrics = validate(model, loader_eval, validate_loss_fn, args) # # Here call your model evaluation function, making sure to execute only # # the portion of the dataset specified by the qe_calibration argument # yaml_path = './dir/quantization_stats.yaml' # collector.save(yaml_path) eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch + 1, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(model, args) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() loader = create_loader( Dataset(args.data, load_bytes=args.tf_preprocessing), input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=True, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=1.0 if test_time_pool else data_config['crop_pct'], tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): target = target.cuda() input = input.cuda() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), param_count=round(param_count / 1e6, 2)) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def main(): import os args, args_text = _parse_args() eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = 'train' if args.gate_train: exp_name += '-dynamic' if args.slim_train: exp_name += '-slimmable' exp_name += '-{}'.format(args.model) exp_info = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model]) output_dir = get_outdir(output_base, exp_name, exp_info) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) setup_default_logging(outdir=output_dir, local_rank=args.local_rank) torch.backends.cudnn.benchmark = True args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group(backend='nccl', # init_method='tcp://127.0.0.1:23334', # rank=args.local_rank, # world_size=int(os.environ['WORLD_SIZE'])) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) # --------- random seed ----------- random.seed(args.seed) # TODO: do we need same seed on all GPU? np.random.seed(args.seed) torch.manual_seed(args.seed) # torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() if args.train_mode == 'se': optimizer = create_optimizer(args, model.get_se()) elif args.train_mode == 'bn': optimizer = create_optimizer(args, model.get_bn()) elif args.train_mode == 'all': optimizer = create_optimizer(args, model) elif args.train_mode == 'gate': optimizer = create_optimizer(args, model.get_gate()) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) if resume_state and not args.no_resume_opt: # ----------- Load Optimizer --------- if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True ) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) # ------------- data -------------- train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) loader_bn = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # ------------- loss_fn -------------- if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn if args.ieb: distill_loss_fn = SoftTargetCrossEntropy().cuda() else: distill_loss_fn = None if args.local_rank == 0: model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=True) else: model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=False) if not args.test_mode: # start training for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = OrderedDict([('loss', 0.)]) # train if args.gate_train: train_metrics = train_epoch_slim_gate( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, optimizer_step=args.optimizer_step) else: train_metrics = train_epoch_slim( epoch, model, loader_train, optimizer, loss_fn=train_loss_fn, distill_loss_fn=distill_loss_fn, args=args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, optimizer_step=args.optimizer_step, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') # eval if args.gate_train: eval_sample_list = ['dynamic'] else: if epoch % 10 == 0 and epoch != 0: eval_sample_list = ['smallest', 'largest', 'uniform'] else: eval_sample_list = ['smallest', 'largest'] eval_metrics = [ validate_slim(model, loader_eval, validate_loss_fn, args, model_mode=model_mode) for model_mode in eval_sample_list ] if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = [ validate_slim(model_ema.ema, loader_eval, validate_loss_fn, args, model_mode=model_mode) for model_mode in eval_sample_list ] eval_metrics = ema_eval_metrics if isinstance(eval_metrics, list): eval_metrics = eval_metrics[0] if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) # save update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) # end training if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch)) # test eval_metrics = [] for choice in range(args.num_choice): # reset bn if not smallest or largest if choice != 0 and choice != args.num_choice - 1: for layer in model.modules(): if isinstance(layer, nn.BatchNorm2d) or \ isinstance(layer, nn.SyncBatchNorm) or \ (has_apex and isinstance(layer, apex.parallel.SyncBatchNorm)): layer.reset_running_stats() model.train() with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader_bn): if args.slim_train: if hasattr(model, 'module'): model.module.set_mode('uniform', choice=choice) else: model.set_mode('uniform', choice=choice) model(input) if batch_idx % 1000 == 0 and batch_idx != 0: print('Subnet {} : reset bn for {} steps'.format( choice, batch_idx)) break if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics.append( validate_slim(model, loader_eval, validate_loss_fn, args, model_mode=choice)) if args.local_rank == 0: print('Test results of the last epoch:\n', eval_metrics)
def validate(args): _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n') _logger.info("Argument parser collected the following arguments:") for arg in vars(args): _logger.info(f" {arg}:{getattr(args, arg)}") _logger.info("\n") # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True else: _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info( f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)' ) data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset( root=args.data_dir, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() top1_fgm_ae = AverageMeter() top5_fgm_ae = AverageMeter() top1_pgd_ae = AverageMeter() top5_pgd_ae = AverageMeter() model.eval() #with torch.no_grad():# TODO Requires grad # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # TODO <--------------------- # Generate adversarial examples for current inputs input_fgm_ae = fast_gradient_method( model_fn=model, x=input, eps=args.eps, norm=np.inf, clip_min=None, clip_max=None, ) input_pgd_ae = projected_gradient_descent( model_fn=model, x=input, eps=args.eps, eps_iter=0.01, nb_iter=40, norm=np.inf, clip_min=None, clip_max=None, ) # Predict with Adversarial Examples with torch.no_grad(): with amp_autocast(): output_fgm_ae = model(input_fgm_ae) output_pgd_ae = model(input_pgd_ae) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5)) acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5)) top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0)) top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0)) top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0)) top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: raise NotImplementedError # TODO NOt modified for the adversarial examples mode # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), top1_fgm_ae=round(top1a_fgm_ae, 4), top5_fgm_ae=round(top5a_fgm_ae, 4), top1_pgd_ae=round(top1a_pgd_ae, 4), top5_pgd_ae=round(top5a_pgd_ae, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_fgm_ae'], results['top5_fgm_ae'])) _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_pgd_ae'], results['top5_pgd_ae'])) return results
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model_KD = None if args.kd_model_path is not None: model_KD = build_kd_model(args) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy( target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, model_KD=model_KD) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher if args.legacy_jit: set_jit_legacy() # create model if 'inception' in args.model: model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, aux_logits=True, # ! add aux loss in_chans=3, scriptable=args.torchscript) else: model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, scriptable=args.torchscript) # ! add more layer to classifier layer if args.create_classifier_layerfc: model.global_pool, model.classifier = create_classifier_layerfc( model.num_features, model.num_classes) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) if args.has_eval_label: criterion = nn.CrossEntropyLoss().cuda() # ! don't have gold label if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map, args=args) if args.valid_labels: with open(args.valid_labels, 'r') as f: # @valid_labels is index numbering valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config[ 'interpolation'], # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43 mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing, auto_augment=args.aa, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, args=args) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() topk = AverageMeter() prediction = None # ! need to save output true_label = None model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # ! not have real label if args.has_eval_label: # ! just save true labels anyway... why not if true_label is None: true_label = target.cpu().data.numpy() else: true_label = np.concatenate( (true_label, target.cpu().data.numpy()), axis=0) if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) if isinstance(output, (tuple, list)): output = output[0] # ! some model returns both loss + aux loss if valid_labels is not None: output = output[:, valid_labels] # ! keep only valid labels ? good to eval by class. # ! save prediction, don't append too slow ... whatever ? # ! are names of files also sorted ? if prediction is None: prediction = output.cpu().data.numpy() # batchsize x label else: # stack prediction = np.concatenate( (prediction, output.cpu().data.numpy()), axis=0) if real_labels is not None: real_labels.add_result(output) if args.has_eval_label: # measure accuracy and record loss loss = criterion( output, target) # ! don't have gold standard on testset acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) topk.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.has_eval_label and (batch_idx % args.log_freq == 0): _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, topk=topk)) if not args.has_eval_label: top1a, topka = 0, 0 # just dummy, because we don't know ground labels else: if real_labels is not None: # real labels mode replaces topk values at the end top1a, topka = real_labels.get_accuracy( k=1), real_labels.get_accuracy(k=args.topk) else: top1a, topka = top1.avg, topk.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), topk=round(topka, 4), topk_err=round(100 - topka, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['topk'], results['topk_err'])) return results, prediction, true_label
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() #from torchvision.datasets import ImageNet #dataset = ImageNet(args.data, split='val') if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output.data, target, topk=(1, 2)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def main(): args, cfg = parse_config_args('super net training') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, "train.log")) else: logger = None # initialize distributed parameters torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info('Training on Process %d with %d GPUs.', args.local_rank, cfg.NUM_GPU) # fix random seeds torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) np.random.seed(cfg.SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # generate supernet model, sta_num, resolution = gen_supernet( flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM, flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP, resunit=cfg.SUPERNET.RESUNIT, dil_conv=cfg.SUPERNET.DIL_CONV, slice=cfg.SUPERNET.SLICE, verbose=cfg.VERBOSE, logger=logger) # initialize meta matching networks MetaMN = MetaMatchingNetwork(cfg) # number of choice blocks in supernet choice_num = len(model.blocks[1][0]) if args.local_rank == 0: logger.info('Supernet created, param count: %d', (sum([m.numel() for m in model.parameters()]))) logger.info('resolution: %d', (resolution)) logger.info('choice number: %d', (choice_num)) #initialize prioritized board prioritized_board = PrioritizedBoard(cfg, CHOICE_NUM=choice_num, sta_num=sta_num) # initialize flops look-up table model_est = FlopsEst(model) # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if cfg.AUTO_RESUME: optimizer_state, resume_epoch = resume_checkpoint( model, cfg.RESUME_PATH) # create optimizer and resume from checkpoint optimizer = create_optimizer_supernet(cfg, model, USE_APEX) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state['optimizer']) model = model.cuda() # convert model to distributed mode if cfg.BATCHNORM.SYNC_BN: try: if USE_APEX: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logger.info('Converted model to use Synchronized BatchNorm.') except Exception as exception: logger.info( 'Failed to enable Synchronized BatchNorm. ' 'Install Apex or Torch >= 1.1 with Exception %s', exception) if USE_APEX: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # create learning rate scheduler lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: %d', num_epochs) # imagenet train dataset train_dir = os.path.join(cfg.DATA_DIR, 'train') if not os.path.exists(train_dir): logger.info('Training folder does not exist at: %s', train_dir) sys.exit() dataset_train = Dataset(train_dir) loader_train = create_loader(dataset_train, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.BATCH_SIZE, is_training=True, use_prefetcher=True, re_prob=cfg.AUGMENTATION.RE_PROB, re_mode=cfg.AUGMENTATION.RE_MODE, color_jitter=cfg.AUGMENTATION.COLOR_JITTER, interpolation='random', num_workers=cfg.WORKERS, distributed=True, collate_fn=None, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.isdir(eval_dir): logger.info('Validation folder does not exist at: %s', eval_dir) sys.exit() dataset_eval = Dataset(eval_dir) loader_eval = create_loader(dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=4 * cfg.DATASET.BATCH_SIZE, is_training=False, use_prefetcher=True, num_workers=cfg.WORKERS, distributed=True, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, interpolation=cfg.DATASET.INTERPOLATION) # whether to use label smoothing if cfg.AUGMENTATION.SMOOTHING > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn # initialize training parameters eval_metric = cfg.EVAL_METRICS best_metric, best_epoch, saver, best_children_pool = None, None, None, [] if args.local_rank == 0: decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) # training scheme try: for epoch in range(start_epoch, num_epochs): loader_train.sampler.set_epoch(epoch) # train one epoch train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, prioritized_board, MetaMN, cfg, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, logger=logger, est=model_est, local_rank=args.local_rank) # evaluate one epoch eval_metrics = validate(model, loader_eval, validate_loss_fn, prioritized_board, MetaMN, cfg, local_rank=args.local_rank, logger=logger) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, cfg, epoch=epoch, metric=save_metric) except KeyboardInterrupt: pass
def train_imagenet(): torch.manual_seed(42) device = xm.xla_device() # model = get_model_property('model_fn')().to(device) model = create_model( FLAGS.model, pretrained=FLAGS.pretrained, num_classes=FLAGS.num_classes, drop_rate=FLAGS.drop, global_pool=FLAGS.gp, bn_tf=FLAGS.bn_tf, bn_momentum=FLAGS.bn_momentum, bn_eps=FLAGS.bn_eps, drop_connect_rate=0.2, checkpoint_path=FLAGS.initial_checkpoint, args = FLAGS).to(device) model_ema=None if FLAGS.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # import pdb; pdb.set_trace() model_e = create_model( FLAGS.model, pretrained=FLAGS.pretrained, num_classes=FLAGS.num_classes, drop_rate=FLAGS.drop, global_pool=FLAGS.gp, bn_tf=FLAGS.bn_tf, bn_momentum=FLAGS.bn_momentum, bn_eps=FLAGS.bn_eps, drop_connect_rate=0.2, checkpoint_path=FLAGS.initial_checkpoint, args = FLAGS).to(device) model_ema = ModelEma( model_e, decay=FLAGS.model_ema_decay, device='cpu' if FLAGS.model_ema_force_cpu else '', resume=FLAGS.resume) print('==> Preparing data..') img_dim = 224 if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) # else: # normalize = transforms.Normalize( # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # train_dataset = torchvision.datasets.ImageFolder( # os.path.join(FLAGS.data, 'train'), # transforms.Compose([ # transforms.RandomResizedCrop(img_dim), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ])) # train_dataset_len = len(train_dataset.imgs) # resize_dim = max(img_dim, 256) # test_dataset = torchvision.datasets.ImageFolder( # os.path.join(FLAGS.data, 'val'), # # Matches Torchvision's eval transforms except Torchvision uses size # # 256 resize for all models both here and in the train loader. Their # # version crashes during training on 299x299 images, e.g. inception. # transforms.Compose([ # transforms.Resize(resize_dim), # transforms.CenterCrop(img_dim), # transforms.ToTensor(), # normalize, # ])) # train_sampler = None # if xm.xrt_world_size() > 1: # train_sampler = torch.utils.data.distributed.DistributedSampler( # train_dataset, # num_replicas=xm.xrt_world_size(), # rank=xm.get_ordinal(), # shuffle=True) # train_loader = torch.utils.data.DataLoader( # train_dataset, # batch_size=FLAGS.batch_size, # sampler=train_sampler, # shuffle=False if train_sampler else True, # num_workers=FLAGS.workers) # test_loader = torch.utils.data.DataLoader( # test_dataset, # batch_size=FLAGS.batch_size, # shuffle=False, # num_workers=FLAGS.workers) else: train_dir = os.path.join(FLAGS.data, 'train') data_config = resolve_data_config(model, FLAGS, verbose=FLAGS.local_rank == 0) dataset_train = Dataset(train_dir) collate_fn = None if not FLAGS.no_prefetcher and FLAGS.mixup > 0: collate_fn = FastCollateMixup(FLAGS.mixup, FLAGS.smoothing, FLAGS.num_classes) train_loader = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=FLAGS.batch_size, is_training=True, use_prefetcher=not FLAGS.no_prefetcher, rand_erase_prob=FLAGS.reprob, rand_erase_mode=FLAGS.remode, interpolation='bicubic', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=FLAGS.workers, distributed=FLAGS.distributed, collate_fn=collate_fn, use_auto_aug=FLAGS.auto_augment, use_mixcut=FLAGS.mixcut, ) eval_dir = os.path.join(FLAGS.data, 'val') train_dataset_len = len(train_loader) if not os.path.isdir(eval_dir): logging.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) test_loader = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size = FLAGS.batch_size, is_training=False, use_prefetcher=FLAGS.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=FLAGS.workers, distributed=FLAGS.distributed, ) writer = None start_epoch = 0 if FLAGS.output and xm.is_master_ordinal(): writer = SummaryWriter(log_dir=FLAGS.output) optimizer = create_optimizer(flags, model) lr_scheduler, num_epochs = create_scheduler(flags, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) # optimizer = optim.SGD( # model.parameters(), # lr=FLAGS.lr, # momentum=FLAGS.momentum, # weight_decay=5e-4) num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) train_loss_fn = LabelSmoothingCrossEntropy(smoothing=flags.smoothing) validate_loss_fn = nn.CrossEntropyLoss() # loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = train_loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if model_ema is not None: model_ema.update(model) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy def test_loop_fn_ema(loader): total_samples = 0 correct = 0 model_ema.eval() for x, (data, target) in loader: output = model_ema(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.epochs + 1): para_loader = dp.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = dp.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) if model_ema is not None: accuracy = test_loop_fn_ema(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: _logger.warning( 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") if args.num_gpu > 1: if use_amp == 'apex': _logger.warning( 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') use_amp = None model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() assert not args.channels_last, "Channels last not supported with DP, use DDP." else: model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) optimizer = create_optimizer(args, model) amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') except Exception as e: _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) # if saver.cmp(best_metric, save_metric): # _logger.info(f"Metric is no longer improving [BEST: {best_metric}, CURRENT: {save_metric}]" # f"\nFinishing training process") # if epoch > 15: # break except KeyboardInterrupt: pass if best_metric is not None: message = '*** Best metric: <{0:.2f}>, epoch: <{1}>, path: <{2}> ***'\ .format(best_metric, best_epoch, output_dir) _logger.info(message) print(message)
def main(): args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: print( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: print( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: print('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) print('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) # optionally resume from a checkpoint start_epoch = 0 optimizer_state = None if args.resume: optimizer_state, start_epoch = resume_checkpoint( model, args.resume, args.start_epoch) if args.num_gpu > 1: if args.amp: print( 'Warning: AMP does not work well with nn.DataParallel, disabling. ' 'Use distributed mode for multi-GPU AMP.') args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: if args.distributed and args.sync_bn and has_apex: model = convert_syncbn_model(model) model.cuda() optimizer = create_optimizer(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True print('AMP enabled') else: use_amp = False print('AMP disabled') model_ema = None if args.model_ema: model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: model = DDP(model, delay_allreduce=True) if model_ema is not None and not args.model_ema_force_cpu: # must also distribute EMA model to allow validation model_ema.ema = DDP(model_ema.ema, delay_allreduce=True) model_ema.ema_has_module = True lr_scheduler, num_epochs = create_scheduler(args, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: print('Scheduled epochs: ', num_epochs) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): print('Error: training folder does not exist at: %s' % train_dir) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, interpolation= 'random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, ) eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): print('Error: validation folder does not exist at: %s' % eval_dir) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch + 1, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: print('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): start_endpoint = "http://localhost:3000/start" stop_endpoint = "http://localhost:3000/stop" setup_default_logging() args = parser.parse_args() # might as well try to do something useful... args.pretrained = args.pretrained or not args.checkpoint output_dir = args.checkpoint.split('/') output_dir.pop(-1) output_dir = ('/').join(output_dir) # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained, checkpoint_path=args.checkpoint) logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) # config = resolve_data_config(vars(args), model=model) # model, test_time_pool = apply_test_time_pool(model, config, args) if args.num_gpu > 1: model = torch.nn.DataParallel( model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() dataset_eval = torchvision.datasets.CIFAR100( root='./data', train=False, download=True) data_config = resolve_data_config(vars(args), model=model) # #CIFAR_100_MEAN = (0.5071, 0.4865, 0.4409) # #CIFAR_100_STD = (0.2673, 0.2564, 0.2762) data_config['mean'] = (0.5071, 0.4865, 0.4409) data_config['std'] = (0.2673, 0.2564, 0.2762) loader = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=data_config['crop_pct'] ) model.eval() batch_time = AverageMeter() with torch.no_grad(): idle_power = requests.post(url=start_endpoint) idle_json = idle_power.json() for batch_idx, (input, _) in enumerate(loader): input = input.cuda() tstart = time.time() output = model(input) tend = time.time() if batch_idx != 0: batch_time.update(tend - tstart) if batch_idx % args.log_freq == 0: print('Predict: [{0}/{1}] Time {batch_time.val:.6f} ({batch_time.avg:.6f})'.format( batch_idx, len(loader), batch_time=batch_time), end='\r') load_power = requests.post(url=stop_endpoint) load_json = load_power.json() fps = 1 / batch_time.avg inference_power = float(load_json['load']) - float(idle_json['idle']) stats = [{'FPS': [float(fps)]}, {'Total_Power': [float(inference_power)]}] with open(os.path.join(output_dir, '{}_fps_cifar.yaml'.format(args.model)), 'w') as f: yaml.safe_dump(stats, f)