def main(): args, cfg = parse_config_args('super net training') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, "train.log")) else: logger = None # initialize distributed parameters torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info('Training on Process %d with %d GPUs.', args.local_rank, cfg.NUM_GPU) # fix random seeds torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) np.random.seed(cfg.SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # generate supernet model, sta_num, resolution = gen_supernet( flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM, flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP, resunit=cfg.SUPERNET.RESUNIT, dil_conv=cfg.SUPERNET.DIL_CONV, slice=cfg.SUPERNET.SLICE, verbose=cfg.VERBOSE, logger=logger) # initialize meta matching networks MetaMN = MetaMatchingNetwork(cfg) # number of choice blocks in supernet choice_num = len(model.blocks[1][0]) if args.local_rank == 0: logger.info('Supernet created, param count: %d', (sum([m.numel() for m in model.parameters()]))) logger.info('resolution: %d', (resolution)) logger.info('choice number: %d', (choice_num)) #initialize prioritized board prioritized_board = PrioritizedBoard(cfg, CHOICE_NUM=choice_num, sta_num=sta_num) # initialize flops look-up table model_est = FlopsEst(model) # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if cfg.AUTO_RESUME: optimizer_state, resume_epoch = resume_checkpoint( model, cfg.RESUME_PATH) # create optimizer and resume from checkpoint optimizer = create_optimizer_supernet(cfg, model, USE_APEX) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state['optimizer']) model = model.cuda() # convert model to distributed mode if cfg.BATCHNORM.SYNC_BN: try: if USE_APEX: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logger.info('Converted model to use Synchronized BatchNorm.') except Exception as exception: logger.info( 'Failed to enable Synchronized BatchNorm. ' 'Install Apex or Torch >= 1.1 with Exception %s', exception) if USE_APEX: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # create learning rate scheduler lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: %d', num_epochs) # imagenet train dataset train_dir = os.path.join(cfg.DATA_DIR, 'train') if not os.path.exists(train_dir): logger.info('Training folder does not exist at: %s', train_dir) sys.exit() dataset_train = Dataset(train_dir) loader_train = create_loader(dataset_train, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.BATCH_SIZE, is_training=True, use_prefetcher=True, re_prob=cfg.AUGMENTATION.RE_PROB, re_mode=cfg.AUGMENTATION.RE_MODE, color_jitter=cfg.AUGMENTATION.COLOR_JITTER, interpolation='random', num_workers=cfg.WORKERS, distributed=True, collate_fn=None, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.isdir(eval_dir): logger.info('Validation folder does not exist at: %s', eval_dir) sys.exit() dataset_eval = Dataset(eval_dir) loader_eval = create_loader(dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=4 * cfg.DATASET.BATCH_SIZE, is_training=False, use_prefetcher=True, num_workers=cfg.WORKERS, distributed=True, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, interpolation=cfg.DATASET.INTERPOLATION) # whether to use label smoothing if cfg.AUGMENTATION.SMOOTHING > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn # initialize training parameters eval_metric = cfg.EVAL_METRICS best_metric, best_epoch, saver, best_children_pool = None, None, None, [] if args.local_rank == 0: decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) # training scheme try: for epoch in range(start_epoch, num_epochs): loader_train.sampler.set_epoch(epoch) # train one epoch train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, prioritized_board, MetaMN, cfg, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, logger=logger, est=model_est, local_rank=args.local_rank) # evaluate one epoch eval_metrics = validate(model, loader_eval, validate_loss_fn, prioritized_board, MetaMN, cfg, local_rank=args.local_rank, logger=logger) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, cfg, epoch=epoch, metric=save_metric) except KeyboardInterrupt: pass
def main(): import os args, args_text = _parse_args() eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = 'train' if args.gate_train: exp_name += '-dynamic' if args.slim_train: exp_name += '-slimmable' exp_name += '-{}'.format(args.model) exp_info = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model]) output_dir = get_outdir(output_base, exp_name, exp_info) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) setup_default_logging(outdir=output_dir, local_rank=args.local_rank) torch.backends.cudnn.benchmark = True args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group(backend='nccl', # init_method='tcp://127.0.0.1:23334', # rank=args.local_rank, # world_size=int(os.environ['WORLD_SIZE'])) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) # --------- random seed ----------- random.seed(args.seed) # TODO: do we need same seed on all GPU? np.random.seed(args.seed) torch.manual_seed(args.seed) # torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() if args.train_mode == 'se': optimizer = create_optimizer(args, model.get_se()) elif args.train_mode == 'bn': optimizer = create_optimizer(args, model.get_bn()) elif args.train_mode == 'all': optimizer = create_optimizer(args, model) elif args.train_mode == 'gate': optimizer = create_optimizer(args, model.get_gate()) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) if resume_state and not args.no_resume_opt: # ----------- Load Optimizer --------- if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True ) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) # ------------- data -------------- train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) loader_bn = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # ------------- loss_fn -------------- if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn if args.ieb: distill_loss_fn = SoftTargetCrossEntropy().cuda() else: distill_loss_fn = None if args.local_rank == 0: model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=True) else: model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=False) if not args.test_mode: # start training for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = OrderedDict([('loss', 0.)]) # train if args.gate_train: train_metrics = train_epoch_slim_gate( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, optimizer_step=args.optimizer_step) else: train_metrics = train_epoch_slim( epoch, model, loader_train, optimizer, loss_fn=train_loss_fn, distill_loss_fn=distill_loss_fn, args=args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, optimizer_step=args.optimizer_step, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') # eval if args.gate_train: eval_sample_list = ['dynamic'] else: if epoch % 10 == 0 and epoch != 0: eval_sample_list = ['smallest', 'largest', 'uniform'] else: eval_sample_list = ['smallest', 'largest'] eval_metrics = [ validate_slim(model, loader_eval, validate_loss_fn, args, model_mode=model_mode) for model_mode in eval_sample_list ] if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = [ validate_slim(model_ema.ema, loader_eval, validate_loss_fn, args, model_mode=model_mode) for model_mode in eval_sample_list ] eval_metrics = ema_eval_metrics if isinstance(eval_metrics, list): eval_metrics = eval_metrics[0] if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) # save update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) # end training if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch)) # test eval_metrics = [] for choice in range(args.num_choice): # reset bn if not smallest or largest if choice != 0 and choice != args.num_choice - 1: for layer in model.modules(): if isinstance(layer, nn.BatchNorm2d) or \ isinstance(layer, nn.SyncBatchNorm) or \ (has_apex and isinstance(layer, apex.parallel.SyncBatchNorm)): layer.reset_running_stats() model.train() with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader_bn): if args.slim_train: if hasattr(model, 'module'): model.module.set_mode('uniform', choice=choice) else: model.set_mode('uniform', choice=choice) model(input) if batch_idx % 1000 == 0 and batch_idx != 0: print('Subnet {} : reset bn for {} steps'.format( choice, batch_idx)) break if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics.append( validate_slim(model, loader_eval, validate_loss_fn, args, model_mode=choice)) if args.local_rank == 0: print('Test results of the last epoch:\n', eval_metrics)
def main(): args, cfg = parse_config_args('child net training') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, 'retrain.log')) writer = SummaryWriter(os.path.join(output_dir, 'runs')) else: writer, logger = None, None # retrain model selection if cfg.NET.SELECTION == 481: arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1, 1], [3, 3, 3, 3], [3, 3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 elif cfg.NET.SELECTION == 43: arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 96 elif cfg.NET.SELECTION == 14: arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] cfg.DATASET.IMAGE_SIZE = 64 elif cfg.NET.SELECTION == 114: arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 160 elif cfg.NET.SELECTION == 287: arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 elif cfg.NET.SELECTION == 604: arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] cfg.DATASET.IMAGE_SIZE = 224 else: raise ValueError("Model Retrain Selection is not Supported!") # define childnet architecture from arch_list stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] choice_block_pool = [ 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25' ] arch_def = [[stem[0]]] + [[ choice_block_pool[idx] for repeat_times in range(len(arch_list[idx + 1])) ] for idx in range(len(choice_block_pool))] + [[stem[1]]] # generate childnet model = gen_childnet(arch_list, arch_def, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP) # initialize training parameters eval_metric = cfg.EVAL_METRICS best_metric, best_epoch, saver = None, None, None # initialize distributed parameters distributed = cfg.NUM_GPU > 1 torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info('Training on Process {} with {} GPUs.'.format( args.local_rank, cfg.NUM_GPU)) # fix random seeds torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) np.random.seed(cfg.SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # get parameters and FLOPs of model if args.local_rank == 0: macs, params = get_model_flops_params( model, input_size=(1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) logger.info('[Model-{}] Flops: {} Params: {}'.format( cfg.NET.SELECTION, macs, params)) # create optimizer model = model.cuda() optimizer = create_optimizer(cfg, model) # optionally resume from a checkpoint resume_state, resume_epoch = {}, None if cfg.AUTO_RESUME: resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH) optimizer.load_state_dict(resume_state['optimizer']) del resume_state model_ema = None if cfg.NET.EMA.USE: model_ema = ModelEma( model, decay=cfg.NET.EMA.DECAY, device='cpu' if cfg.NET.EMA.FORCE_CPU else '', resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None) if distributed: if cfg.BATCHNORM.SYNC_BN: try: if HAS_APEX: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logger.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: if args.local_rank == 0: logger.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}' .format(e)) if HAS_APEX: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # imagenet train dataset train_dir = os.path.join(cfg.DATA_DIR, 'train') if not os.path.exists(train_dir) and args.local_rank == 0: logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) loader_train = create_loader(dataset_train, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.BATCH_SIZE, is_training=True, color_jitter=cfg.AUGMENTATION.COLOR_JITTER, auto_augment=cfg.AUGMENTATION.AA, num_aug_splits=0, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=cfg.WORKERS, distributed=distributed, collate_fn=None, pin_memory=cfg.DATASET.PIN_MEM, interpolation='random', re_mode=cfg.AUGMENTATION.RE_MODE, re_prob=cfg.AUGMENTATION.RE_PROB) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.exists(eval_dir) and args.local_rank == 0: logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, is_training=False, interpolation='bicubic', crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=cfg.WORKERS, distributed=distributed, pin_memory=cfg.DATASET.PIN_MEM) # whether to use label smoothing if cfg.AUGMENTATION.SMOOTHING > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn # create learning rate scheduler lr_scheduler, num_epochs = create_scheduler(cfg, optimizer) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: {}'.format(num_epochs)) try: best_record, best_ep = 0, 0 for epoch in range(start_epoch, num_epochs): if distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, cfg, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, model_ema=model_ema, logger=logger, writer=writer, local_rank=args.local_rank) eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, cfg, logger=logger, writer=writer, local_rank=args.local_rank) if model_ema is not None and not cfg.NET.EMA.FORCE_CPU: ema_eval_metrics = validate(epoch, model_ema.ema, loader_eval, validate_loss_fn, cfg, log_suffix='_EMA', logger=logger, writer=writer, local_rank=args.local_rank) eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, cfg, epoch=epoch, model_ema=model_ema, metric=save_metric) if best_record < eval_metrics[eval_metric]: best_record = eval_metrics[eval_metric] best_ep = epoch if args.local_rank == 0: logger.info('*** Best metric: {0} (epoch {1})'.format( best_record, best_ep)) except KeyboardInterrupt: pass if best_metric is not None: logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) if args.fuser: set_jit_fuser(args.fuser) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy( target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))