def __init__(self, args, training=True): super(RunnerWrapper, self).__init__() # Store general options self.args = args self.training = training # Read names lists from the args self.load_names(args) # Initialize classes for the networks nets_names = self.nets_names_test if self.training: nets_names += self.nets_names_train nets_names = list(set(nets_names)) self.nets = nn.ModuleDict() for net_name in sorted(nets_names): self.nets[net_name] = importlib.import_module(f'networks.{net_name}').NetworkWrapper(args) if args.num_gpus > 1: # Apex is only needed for multi-gpu training from apex import parallel self.nets[net_name] = parallel.convert_syncbn_model(self.nets[net_name]) # Set nets that are not training into eval mode for net_name in self.nets.keys(): if net_name not in self.nets_names_to_train: self.nets[net_name].eval() # Initialize classes for the losses if self.training: losses_names = list(set(self.losses_names_train + self.losses_names_test)) self.losses = nn.ModuleDict() for loss_name in sorted(losses_names): self.losses[loss_name] = importlib.import_module(f'losses.{loss_name}').LossWrapper(args) # Spectral norm if args.spn_layers: spn_layers = utils.parse_str_to_list(args.spn_layers, sep=',') spn_nets_names = utils.parse_str_to_list(args.spn_networks, sep=',') for net_name in spn_nets_names: self.nets[net_name].apply(lambda module: utils.spectral_norm(module, apply_to=spn_layers, eps=args.eps)) # Remove spectral norm in modules in exceptions spn_exceptions = utils.parse_str_to_list(args.spn_exceptions, sep=',') for full_module_name in spn_exceptions: if not full_module_name: continue parts = full_module_name.split('.') # Get the module that needs to be changed module = self.nets[parts[0]] for part in parts[1:]: module = getattr(module, part) module.apply(utils.remove_spectral_norm) # Weight averaging if args.wgv_mode != 'none': # Apply weight averaging only for networks that are being trained for net_name, _ in self.nets_names_to_train: self.nets[net_name].apply(lambda module: utils.weight_averaging(module, mode=args.wgv_mode, momentum=args.wgv_momentum)) # Check which networks are being trained and put the rest into the eval mode for net_name in self.nets.keys(): if net_name not in self.nets_names_to_train: self.nets[net_name].eval() # Set the same batchnorm momentum accross all modules if self.training: self.apply(lambda module: utils.set_batchnorm_momentum(module, args.bn_momentum)) # Store a history of losses and images for visualization self.losses_history = { True: {}, # self.training = True False: {}}
def main(): parser = argparse.ArgumentParser("PyTorch Xview Pipeline") arg = parser.add_argument arg('--config', metavar='CONFIG_FILE', help='path to configuration file') arg('--workers', type=int, default=8, help='number of cpu threads to use') arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') arg('--output-dir', type=str, default='weights/') arg('--resume', type=str, default='') arg('--fold', type=int, default=0) arg('--prefix', type=str, default='localization_') arg('--data-dir', type=str, default="/home/selim/datasets/xview/train") arg('--folds-csv', type=str, default='folds.csv') arg('--logdir', type=str, default='logs') arg('--zero-score', action='store_true', default=False) arg('--from-zero', action='store_true', default=False) arg('--distributed', action='store_true', default=False) arg('--freeze-epochs', type=int, default=1) arg("--local_rank", default=0, type=int) arg("--opt-level", default='O0', type=str) arg("--predictions", default="../oof_preds", type=str) arg("--test_every", type=int, default=1) args = parser.parse_args() if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') else: os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.benchmark = True conf = load_config(args.config) model = models.__dict__[conf['network']](seg_classes=conf['num_classes'], backbone_arch=conf['encoder']) model = model.cuda() if args.distributed: model = convert_syncbn_model(model) mask_loss_function = losses.__dict__[conf["mask_loss"]["type"]]( **conf["mask_loss"]["params"]).cuda() loss_functions = {"mask_loss": mask_loss_function} optimizer, scheduler = create_optimizer(conf['optimizer'], model) dice_best = 0 start_epoch = 0 batch_size = conf['optimizer']['batch_size'] data_train = XviewSingleDataset( mode="train", fold=args.fold, data_path=args.data_dir, folds_csv=args.folds_csv, transforms=create_train_transforms(conf['input']), multiplier=conf["data_multiplier"], normalize=conf["input"].get("normalize", None)) data_val = XviewSingleDataset( mode="val", fold=args.fold, data_path=args.data_dir, folds_csv=args.folds_csv, transforms=create_val_transforms(conf['input']), normalize=conf["input"].get("normalize", None)) train_sampler = None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( data_train) train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, drop_last=True) val_batch_size = 1 val_data_loader = DataLoader(data_val, batch_size=val_batch_size, num_workers=args.workers, shuffle=False, pin_memory=False) os.makedirs(args.logdir, exist_ok=True) summary_writer = SummaryWriter(args.logdir + '/' + args.prefix + conf['encoder']) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') state_dict = checkpoint['state_dict'] if conf['optimizer'].get('zero_decoder', False): for key in state_dict.copy().keys(): if key.startswith("module.final"): del state_dict[key] state_dict = {k[7:]: w for k, w in state_dict.items()} model.load_state_dict(state_dict, strict=False) if not args.from_zero: start_epoch = checkpoint['epoch'] if not args.zero_score: dice_best = checkpoint.get('dice_best', 0) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.from_zero: start_epoch = 0 current_epoch = start_epoch if conf['fp16']: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale='dynamic') snapshot_name = "{}{}_{}_{}".format(args.prefix, conf['network'], conf['encoder'], args.fold) if args.distributed: model = DistributedDataParallel(model, delay_allreduce=True) else: model = DataParallel(model).cuda() for epoch in range(start_epoch, conf['optimizer']['schedule']['epochs']): if train_sampler: train_sampler.set_epoch(epoch) if epoch < args.freeze_epochs: print("Freezing encoder!!!") model.module.encoder_stages.eval() for p in model.module.encoder_stages.parameters(): p.requires_grad = False else: print("Unfreezing encoder!!!") model.module.encoder_stages.train() for p in model.module.encoder_stages.parameters(): p.requires_grad = True train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf, args.local_rank) model = model.eval() if args.local_rank == 0: torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'dice_best': dice_best, }, args.output_dir + '/' + snapshot_name + "_last") if epoch % args.test_every == 0: preds_dir = os.path.join(args.predictions, snapshot_name) dice_best = evaluate_val(args, val_data_loader, dice_best, model, snapshot_name=snapshot_name, current_epoch=current_epoch, optimizer=optimizer, summary_writer=summary_writer, predictions_dir=preds_dir) current_epoch += 1
def main(): parser = argparse.ArgumentParser("PyTorch Xview Pipeline") arg = parser.add_argument arg('--config', metavar='CONFIG_FILE', help='path to configuration file') arg('--workers', type=int, default=6, help='number of cpu threads to use') arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') arg('--output-dir', type=str, default='weights/') arg('--resume', type=str, default='') arg('--fold', type=int, default=0) arg('--prefix', type=str, default='classifier_') arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake") arg('--folds-csv', type=str, default='folds.csv') arg('--crops-dir', type=str, default='crops') arg('--label-smoothing', type=float, default=0.01) arg('--logdir', type=str, default='logs') arg('--zero-score', action='store_true', default=False) arg('--from-zero', default=True) arg('--distributed', action='store_true', default=False) arg('--freeze-epochs', type=int, default=0) arg("--local_rank", default=0, type=int) arg("--seed", default=777, type=int) arg("--padding-part", default=3, type=int) arg("--opt-level", default='O1', type=str) arg("--test_every", type=int, default=1) arg("--no-oversample", action="store_true") arg("--no-hardcore", action="store_true") arg("--only-changed-frames", action="store_true") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') else: os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.benchmark = True conf = load_config(args.config) model = classifiers.__dict__[conf['network']](encoder=conf['encoder']) model = model.cuda() teacher_model = classifiers.__dict__[conf['network']]( encoder=conf['encoder']) teacher_model = teacher_model.cuda() if args.distributed: model = convert_syncbn_model(model) ohem = conf.get("ohem_samples", None) reduction = "mean" if ohem: reduction = "none" loss_fn = [] weights = [] for loss_name, weight in conf["losses"].items(): loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda()) weights.append(weight) loss = WeightedLosses(loss_fn, weights) loss_functions = {"classifier_loss": loss} optimizer, scheduler = create_optimizer(conf['optimizer'], model) teacher_optimizer, _ = create_optimizer(conf['optimizer'], model) bce_best = 100 start_epoch = 0 batch_size = conf['optimizer']['batch_size'] data_train = DeepFakeClassifierDataset_kn( mode="train", oversample_real=not args.no_oversample, fold=args.fold, padding_part=args.padding_part, hardcore=not args.no_hardcore, crops_dir=args.crops_dir, data_path=args.data_dir, label_smoothing=args.label_smoothing, folds_csv=args.folds_csv, # transforms=create_train_transforms(conf["size"]), normalize=conf.get("normalize", None)) # data_val = DeepFakeClassifierDataset_kn(mode="val", # fold=args.fold, # padding_part=args.padding_part, # crops_dir=args.crops_dir, # data_path=args.data_dir, # folds_csv=args.folds_csv, # # transforms=create_val_transforms(conf["size"]), # normalize=conf.get("normalize", None)) data_train_student = DeepFakeClassifierDataset_kn( mode="train", oversample_real=not args.no_oversample, fold=args.fold, padding_part=args.padding_part, hardcore=not args.no_hardcore, crops_dir=args.crops_dir, data_path=args.data_dir, label_smoothing=args.label_smoothing, folds_csv=args.folds_csv, transforms=create_train_transforms(conf["size"]), normalize=conf.get("normalize", None)) data_val_student = DeepFakeClassifierDataset_kn( mode="val", fold=args.fold, padding_part=args.padding_part, crops_dir=args.crops_dir, data_path=args.data_dir, folds_csv=args.folds_csv, transforms=create_val_transforms(conf["size"]), normalize=conf.get("normalize", None)) # val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, # pin_memory=False) val_data_loader_student = DataLoader(data_val_student, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False, pin_memory=False) os.makedirs(args.logdir, exist_ok=True) summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold)) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location='cpu') state_dict = checkpoint['state_dict'] state_dict = {k[7:]: w for k, w in state_dict.items()} teacher_model.load_state_dict(state_dict, strict=False) if not args.from_zero: start_epoch = checkpoint['epoch'] if not args.zero_score: bce_best = checkpoint.get('bce_best', 0) print("=> loaded checkpoint '{}' (epoch {}, bce_best {})".format( args.resume, checkpoint['epoch'], checkpoint['bce_best'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.from_zero: start_epoch = 0 current_epoch = start_epoch if conf['fp16']: model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale='dynamic') teacher_model, _ = amp.initialize(teacher_model, teacher_optimizer, opt_level=args.opt_level, loss_scale='dynamic') snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold) if args.distributed: model = DistributedDataParallel(model, delay_allreduce=True) else: model = DataParallel(model).cuda() teacher_model = DataParallel(teacher_model).cuda() teacher_model.eval() # register each block, in order to extract the blocks' feature maps for name, block in model.encoder.blocks.named_children(): block.register_forward_hook(hook_function) for name, block in teacher_model.encoder.blocks.named_children(): block.register_forward_hook(teacher_hook_function) data_val_student.reset(1, args.seed) max_epochs = conf['optimizer']['schedule']['epochs'] for epoch in range(start_epoch, max_epochs): data_train.reset(epoch, args.seed) train_sampler = None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( data_train) train_sampler.set_epoch(epoch) if epoch < args.freeze_epochs: print("Freezing encoder!!!") model.module.encoder.eval() for p in model.module.encoder.parameters(): p.requires_grad = False else: model.module.encoder.train() for p in model.module.encoder.parameters(): p.requires_grad = True train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers, shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, drop_last=True) train_data_loader_student = DataLoader(data_train_student, batch_size=batch_size, num_workers=args.workers, shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False, drop_last=True) train_epoch(current_epoch, loss_functions, model, teacher_model, optimizer, scheduler, train_data_loader, train_data_loader_student, summary_writer, conf, args.local_rank, args.only_changed_frames) model = model.eval() if args.local_rank == 0: torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + '/' + snapshot_name + "_last") torch.save( { 'epoch': current_epoch + 1, 'state_dict': model.state_dict(), 'bce_best': bce_best, }, args.output_dir + snapshot_name + "_{}".format(current_epoch)) if (epoch + 1) % args.test_every == 0: bce_best = evaluate_val(args, val_data_loader_student, bce_best, model, snapshot_name=snapshot_name, current_epoch=current_epoch, summary_writer=summary_writer) current_epoch += 1
loss_fn = OHEMLoss(ignore_index=255, numel_frac=0.05) loss_fn = loss_fn.cuda() scheduler = CosineAnnealingScheduler( optimizer, 'lr', args.learning_rate, 1e-6, cycle_size=args.epochs * len(train_loader), ) scheduler = create_lr_scheduler_with_warmup( scheduler, 0, args.learning_rate, 1000) model, optimizer = amp.initialize(model, optimizer, opt_level="O2") if args.distributed: model = convert_syncbn_model(model) model = DistributedDataParallel(model) trainer = create_segmentation_trainer( model, optimizer, loss_fn, device=device, use_f16=True, ) trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) evaluator = create_segmentation_evaluator( model, device=device, num_classes=19,
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler) else: # load the dataset using structure DataLoader (part of torch.utils.data) dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # instantiate Generator(nn.Module) and load in cpu/gpu generator = Generator().to(device) ## DEFINE CHECKPOINT # checkpoints are used during training to save a model (model parameters I suppose) # here we are only testing the pre-trained model, thus we load (torch.load) the model if args.distributed: g_checkpoint = torch.load(args.checkpoint_path, map_location = lambda storage, loc: storage.cuda(args.local_rank)) generator = parallel.DistributedDataParallel(generator) generator = parallel.convert_syncbn_model(generator) else: g_checkpoint = torch.load(args.checkpoint_path) generator.load_state_dict(g_checkpoint['model_state_dict'], strict=False) step = g_checkpoint['step'] alpha = g_checkpoint['alpha'] iteration = g_checkpoint['iteration'] print('pre-trained model is loaded step:%d, alpha:%d iteration:%d'%(step, alpha, iteration)) MSE_Loss = nn.MSELoss() # notify all layers that you are in eval mode instead of training mode generator.eval() test(dataloader, generator, MSE_Loss, step, alpha)
def main(opt): """ Trains SRVP and saved the resulting model. Parameters ---------- opt : helper.DotDict Contains the training configuration. """ ################################################################################################################## # Setup ################################################################################################################## # Device handling (CPU, GPU, multi GPU) if opt.device is None: device = torch.device('cpu') opt.n_gpu = 0 else: opt.n_gpu = len(opt.device) os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.device[opt.local_rank]) device = torch.device('cuda:0') torch.cuda.set_device(0) # In the case of multi GPU: sets up distributed training if opt.n_gpu > 1 or opt.local_rank > 0: torch.distributed.init_process_group(backend='nccl') # Since we are in distributed mode, divide batch size by the number of GPUs assert opt.batch_size % opt.n_gpu == 0 opt.batch_size = opt.batch_size // opt.n_gpu # Seed if opt.seed is None: opt.seed = random.randint(1, 10000) else: assert isinstance(opt.seed, int) and opt.seed > 0 print(f'Learning on {opt.n_gpu} GPU(s) (seed: {opt.seed})') random.seed(opt.seed) np.random.seed(opt.seed + opt.local_rank) torch.manual_seed(opt.seed) # cuDNN if opt.n_gpu > 1 or opt.local_rank > 0: assert torch.backends.cudnn.enabled cudnn.deterministic = True # Mixed-precision training if opt.torch_amp and not torch_amp_imported: raise ImportError( 'Mixed-precision not supported by this PyTorch version, upgrade PyTorch or use Apex' ) if opt.apex_amp and not apex_amp_imported: raise ImportError( 'Apex not installed (https://github.com/NVIDIA/apex)') ################################################################################################################## # Data ################################################################################################################## print('Loading data...') # Load data dataset = data.load_dataset(opt, True) trainset = dataset.get_fold('train') valset = dataset.get_fold('val') # Change validation sequence length, if specified if opt.seq_len_test is not None: valset.change_seq_len(opt.seq_len_test) # Handle random seed for dataloader workers def worker_init_fn(worker_id): np.random.seed( (opt.seed + itr + opt.local_rank * opt.n_workers + worker_id) % (2**32 - 1)) # Dataloader sampler = None shuffle = True if opt.n_gpu > 1: # Let the distributed sampler shuffle for the distributed case sampler = torch.utils.data.distributed.DistributedSampler(trainset) shuffle = False train_loader = DataLoader(trainset, batch_size=opt.batch_size, collate_fn=data.collate_fn, sampler=sampler, num_workers=opt.n_workers, shuffle=shuffle, drop_last=True, pin_memory=True, worker_init_fn=worker_init_fn) val_loader = DataLoader( valset, batch_size=opt.batch_size_test, collate_fn=data.collate_fn, num_workers=opt.n_workers, shuffle=True, drop_last=True, pin_memory=True, worker_init_fn=worker_init_fn) if opt.local_rank == 0 else None ################################################################################################################## # Model ################################################################################################################## # Buid model print('Building model...') model = srvp.StochasticLatentResidualVideoPredictor( opt.nx, opt.nc, opt.nf, opt.nhx, opt.ny, opt.nz, opt.skipco, opt.nt_inf, opt.nh_inf, opt.nlayers_inf, opt.nh_res, opt.nlayers_res, opt.archi) model.init(res_gain=opt.res_gain) # Make the batch norms in the model synchronized in the distributed case if opt.n_gpu > 1: if opt.apex_amp: from apex.parallel import convert_syncbn_model model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.to(device) ################################################################################################################## # Optimizer ################################################################################################################## optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) opt.n_iter = opt.lr_scheduling_burnin + opt.lr_scheduling_n_iter lr_sch_n_iter = opt.lr_scheduling_n_iter lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda i: max(0, (lr_sch_n_iter - i) / lr_sch_n_iter)) ################################################################################################################## # Automatic Mixed Precision ################################################################################################################## scaler = None if opt.torch_amp: scaler = torch_amp.GradScaler() if opt.apex_amp: model, optimizer = apex_amp.initialize( model, optimizer, opt_level=opt.amp_opt_lvl, keep_batchnorm_fp32=opt.keep_batchnorm_fp32, verbosity=opt.apex_verbose) ################################################################################################################## # Multi GPU ################################################################################################################## if opt.n_gpu > 1: if opt.apex_amp: from apex.parallel import DistributedDataParallel forward_fn = DistributedDataParallel(model) else: forward_fn = torch.nn.parallel.DistributedDataParallel(model) else: forward_fn = model ################################################################################################################## # Training ################################################################################################################## cudnn.benchmark = True # Activate benchmarks to select the fastest algorithms assert opt.n_iter > 0 itr = 0 finished = False # Progress bar if opt.local_rank == 0: pb = tqdm(total=opt.n_iter, ncols=0) # Current and best model evaluation metric (lower is better) val_metric = None best_val_metric = None try: while not finished: if sampler is not None: sampler.set_epoch(opt.seed + itr) # -------- TRAIN -------- for batch in train_loader: # Stop when the given number of optimization steps have been done if itr >= opt.n_iter: finished = True status_code = 0 break itr += 1 model.train() # Optimization step on batch # Allow PyTorch's mixed-precision computations if required while ensuring retrocompatibilty with (torch_amp.autocast() if opt.torch_amp else nullcontext()): loss, nll, kl_y_0, kl_z = train(forward_fn, optimizer, scaler, batch, device, opt) # Learning rate scheduling if itr >= opt.lr_scheduling_burnin: lr_scheduler.step() # Evaluation and model saving are performed on the process with local rank zero if opt.local_rank == 0: # Evaluation if itr % opt.val_interval == 0: model.eval() val_metric = evaluate(forward_fn, val_loader, device, opt) if best_val_metric is None or best_val_metric > val_metric: best_val_metric = val_metric torch.save( model.state_dict(), os.path.join(opt.save_path, 'model_best.pt')) # Checkpointing if opt.chkpt_interval is not None and itr % opt.chkpt_interval == 0: torch.save( model.state_dict(), os.path.join(opt.save_path, f'model_{itr}.pt')) # Progress bar if opt.local_rank == 0: pb.set_postfix( { 'loss': loss, 'nll': nll, 'kl_y_0': kl_y_0, 'kl_z': kl_z, 'val_metric': val_metric, 'best_val_metric': best_val_metric }, refresh=False) pb.update() except KeyboardInterrupt: status_code = 130 if opt.local_rank == 0: pb.close() # Save model print('Saving...') if opt.local_rank == 0: torch.save(model.state_dict(), os.path.join(opt.save_path, 'model.pt')) print('Done') return status_code
def main(): setup_default_logging() args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) # optionally resume from a checkpoint start_epoch = 0 optimizer_state = None if args.resume: optimizer_state, start_epoch = resume_checkpoint( model, args.resume, args.start_epoch) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, interpolation= 'random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, ) eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch + 1, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def train(): if args.local_rank == 0: logger.info('Initializing....') cudnn.enable = True cudnn.benchmark = True # torch.manual_seed(1) # torch.cuda.manual_seed(1) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() if args.local_rank == 0: write_config_into_log(cfg) if args.local_rank == 0: logger.info('Building model......') if cfg.pretrained: model = make_model(cfg) model.load_param(cfg) if args.local_rank == 0: logger.info('Loaded pretrained model from {0}'.format( cfg.pretrained)) else: model = make_model(cfg) if args.sync_bn: if args.local_rank == 0: logging.info("using apex synced BN") model = convert_syncbn_model(model) model.cuda() if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # delay_allreduce delays all communication to the end of the backward pass. model = DistributedDataParallel(model, delay_allreduce=True) else: model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam( [{ 'params': model.module.base.parameters(), 'lr': cfg.get_lr(0)[0] }, { 'params': model.module.classifiers.parameters(), 'lr': cfg.get_lr(0)[1] }], weight_decay=cfg.weight_decay) celoss = nn.CrossEntropyLoss().cuda() softloss = SoftLoss() sp_kd_loss = SP_KD_Loss() criterions = [celoss, softloss, sp_kd_loss] cfg.batch_size = cfg.batch_size // args.world_size cfg.num_workers = cfg.num_workers // args.world_size train_loader, val_loader = make_dataloader(cfg) if args.local_rank == 0: logger.info('Begin training......') for epoch in range(cfg.start_epoch, cfg.max_epoch): train_one_epoch(train_loader, val_loader, model, criterions, optimizer, epoch, cfg) total_acc = test(cfg, val_loader, model, epoch) if args.local_rank == 0: with open(cfg.test_log, 'a+') as f: f.write('Epoch {0}: Acc is {1:.4f}\n'.format(epoch, total_acc)) torch.save(obj=model.state_dict(), f=os.path.join( cfg.snapshot_dir, 'ep{}_acc{:.4f}.pth'.format(epoch, total_acc))) logger.info('Model saved')
def main(): global n_eval_epoch ## dataloader dataset_train = ImageNet(datapth, mode='train', cropsize=cropsize) sampler_train = torch.utils.data.distributed.DistributedSampler( dataset_train, shuffle=True) batch_sampler_train = torch.utils.data.sampler.BatchSampler( sampler_train, batchsize, drop_last=True ) dl_train = DataLoader( dataset_train, batch_sampler=batch_sampler_train, num_workers=num_workers, pin_memory=True ) dataset_eval = ImageNet(datapth, mode='val', cropsize=cropsize) sampler_val = torch.utils.data.distributed.DistributedSampler( dataset_eval, shuffle=False) batch_sampler_val = torch.utils.data.sampler.BatchSampler( sampler_val, batchsize * 2, drop_last=False ) dl_eval = DataLoader( dataset_eval, batch_sampler=batch_sampler_val, num_workers=4, pin_memory=True ) n_iters_per_epoch = len(dataset_train) // n_gpus // batchsize n_iters = n_epoches * n_iters_per_epoch ## model # model = EfficientNet(model_type, n_classes) model = build_model(**model_args) ## sync bn # if use_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) init_model_weights(model) model.cuda() if use_sync_bn: model = parallel.convert_syncbn_model(model) crit = nn.CrossEntropyLoss() # crit = LabelSmoothSoftmaxCEV3(lb_smooth) # crit = SoftmaxCrossEntropyV2() ## optimizer optim = set_optimizer(model, lr, opt_wd, momentum, nesterov=nesterov) ## apex model, optim = amp.initialize(model, optim, opt_level=fp16_level) ## ema ema = EMA(model, ema_alpha) ## ddp training model = parallel.DistributedDataParallel(model, delay_allreduce=True) # local_rank = dist.get_rank() # model = nn.parallel.DistributedDataParallel( # model, device_ids=[local_rank, ], output_device=local_rank # ) ## log meters time_meter = TimeMeter(n_iters) loss_meter = AvgMeter() logger = logging.getLogger() # for mixup label_encoder = OnehotEncoder(n_classes=model_args['n_classes'], lb_smooth=lb_smooth) mixuper = MixUper(mixup_alpha, mixup=mixup) ## train loop for e in range(n_epoches): sampler_train.set_epoch(e) model.train() for idx, (im, lb) in enumerate(dl_train): im, lb= im.cuda(), lb.cuda() # lb = label_encoder(lb) # im, lb = mixuper(im, lb) optim.zero_grad() logits = model(im) loss = crit(logits, lb) #+ cal_l2_loss(model, weight_decay) # loss.backward() with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() optim.step() torch.cuda.synchronize() ema.update_params() time_meter.update() loss_meter.update(loss.item()) if (idx + 1) % 200 == 0: t_intv, eta = time_meter.get() lr_log = scheduler.get_lr_ratio() * lr msg = 'epoch: {}, iter: {}, lr: {:.4f}, loss: {:.4f}, time: {:.2f}, eta: {}'.format( e + 1, idx + 1, lr_log, loss_meter.get()[0], t_intv, eta) logger.info(msg) scheduler.step() torch.cuda.empty_cache() if (e + 1) % n_eval_epoch == 0: if e > 50: n_eval_epoch = 5 logger.info('evaluating...') acc_1, acc_5, acc_1_ema, acc_5_ema = evaluate(ema, dl_eval) msg = 'epoch: {}, naive_acc1: {:.4}, naive_acc5: {:.4}, ema_acc1: {:.4}, ema_acc5: {:.4}'.format(e + 1, acc_1, acc_5, acc_1_ema, acc_5_ema) logger.info(msg) if dist.is_initialized() and dist.get_rank() == 0: torch.save(model.module.state_dict(), './res/model_final.pth') torch.save(ema.ema_model.state_dict(), './res/model_final_ema.pth')
def _init_network(self, **kwargs): load_only = kwargs.get('load_only', False) if not self.num_class and self._problem_type != REGRESSION: raise ValueError( 'This is a classification problem and we are not able to create network when `num_class` is unknown. \ It should be inferred from dataset or resumed from saved states.' ) assert len(self.classes) == self.num_class # Disable syncBatchNorm as it's only supported on DDP if self._train_cfg.sync_bn: self._logger.info( 'Disable Sync batch norm as it is not supported for now.') update_cfg(self._cfg, {'train': {'sync_bn': False}}) # ctx self.found_gpu = False valid_gpus = [] if self._cfg.gpus: valid_gpus = self._torch_validate_gpus(self._cfg.gpus) self.found_gpu = True if not valid_gpus: self.found_gpu = False self._logger.warning( 'No gpu detected, fallback to cpu. You can ignore this warning if this is intended.' ) elif len(valid_gpus) != len(self._cfg.gpus): self._logger.warning( f'Loaded on gpu({valid_gpus}), different from gpu({self._cfg.gpus}).' ) self.ctx = [torch.device(f'cuda:{gid}') for gid in valid_gpus ] if self.found_gpu else [torch.device('cpu')] self.valid_gpus = valid_gpus if not self.found_gpu and self.use_amp: self.use_amp = None self._logger.warning('Training on cpu. AMP disabled.') update_cfg(self._cfg, { 'misc': { 'amp': False, 'apex_amp': False, 'native_amp': False } }) if not self.found_gpu and self._misc_cfg.prefetcher: self._logger.warning('Training on cpu. Prefetcher disabled.') update_cfg(self._cfg, {'misc': {'prefetcher': False}}) self._logger.warning('Training on cpu. SyncBatchNorm disabled.') update_cfg(self._cfg, {'train': {'sync_bn': False}}) random_seed(self._misc_cfg.seed) if not self.net: self.net = create_model( self._img_cls_cfg.model, pretrained=self._img_cls_cfg.pretrained and not load_only, num_classes=max(self.num_class, 1), global_pool=self._img_cls_cfg.global_pool_type, drop_rate=self._augmentation_cfg.drop, drop_path_rate=self._augmentation_cfg.drop_path, drop_block_rate=self._augmentation_cfg.drop_block, bn_momentum=self._train_cfg.bn_momentum, bn_eps=self._train_cfg.bn_eps, scriptable=self._misc_cfg.torchscript) self._logger.info( f'Model {safe_model_name(self._img_cls_cfg.model)} created, param count: \ {sum([m.numel() for m in self.net.parameters()])}' ) else: self._logger.info( f'Use user provided model. Neglect model in config.') out_features = list(self.net.children())[-1].out_features if self._problem_type != REGRESSION: assert out_features == self.num_class, f'Custom model out_feature {out_features} != num_class {self.num_class}.' else: assert out_features == 1, f'Regression problem expects num_out_feature == 1, got {out_features} instead.' resolve_data_config(self._cfg, model=self.net) self.net = self.net.to(self.ctx[0]) # setup synchronized BatchNorm if self._train_cfg.sync_bn: if has_apex and self.use_amp != 'native': # Apex SyncBN preferred unless native amp is activated self.net = convert_syncbn_model(self.net) else: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.net) self._logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if self._misc_cfg.torchscript: assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not self._train_cfg.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' self.net = torch.jit.script(self.net)
def main(): args, cfg = parse_config_args('nni.cream.supernet') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if not os.path.exists(output_dir): os.mkdir(output_dir) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, "train.log")) else: logger = None # initialize distributed parameters torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info('Training on Process %d with %d GPUs.', args.local_rank, cfg.NUM_GPU) # fix random seeds torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) np.random.seed(cfg.SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # generate supernet model, sta_num, resolution = gen_supernet( flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM, flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP, resunit=cfg.SUPERNET.RESUNIT, dil_conv=cfg.SUPERNET.DIL_CONV, slice=cfg.SUPERNET.SLICE, verbose=cfg.VERBOSE, logger=logger) # number of choice blocks in supernet choice_num = len(model.blocks[7]) if args.local_rank == 0: logger.info('Supernet created, param count: %d', (sum([m.numel() for m in model.parameters()]))) logger.info('resolution: %d', (resolution)) logger.info('choice number: %d', (choice_num)) # initialize flops look-up table model_est = FlopsEst(model) flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if cfg.AUTO_RESUME: optimizer_state, resume_epoch = resume_checkpoint( model, cfg.RESUME_PATH) # create optimizer and resume from checkpoint optimizer = create_optimizer_supernet(cfg, model, USE_APEX) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state['optimizer']) model = model.cuda() # convert model to distributed mode if cfg.BATCHNORM.SYNC_BN: try: if USE_APEX: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logger.info('Converted model to use Synchronized BatchNorm.') except Exception as exception: logger.info( 'Failed to enable Synchronized BatchNorm. ' 'Install Apex or Torch >= 1.1 with Exception %s', exception) if USE_APEX: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # create learning rate scheduler lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: %d', num_epochs) # imagenet train dataset train_dir = os.path.join(cfg.DATA_DIR, 'train') if not os.path.exists(train_dir): logger.info('Training folder does not exist at: %s', train_dir) sys.exit() dataset_train = Dataset(train_dir) loader_train = create_loader(dataset_train, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.BATCH_SIZE, is_training=True, use_prefetcher=True, re_prob=cfg.AUGMENTATION.RE_PROB, re_mode=cfg.AUGMENTATION.RE_MODE, color_jitter=cfg.AUGMENTATION.COLOR_JITTER, interpolation='random', num_workers=cfg.WORKERS, distributed=True, collate_fn=None, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.isdir(eval_dir): logger.info('Validation folder does not exist at: %s', eval_dir) sys.exit() dataset_eval = Dataset(eval_dir) loader_eval = create_loader(dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=4 * cfg.DATASET.BATCH_SIZE, is_training=False, use_prefetcher=True, num_workers=cfg.WORKERS, distributed=True, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, interpolation=cfg.DATASET.INTERPOLATION) # whether to use label smoothing if cfg.AUGMENTATION.SMOOTHING > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn mutator = RandomMutator(model) trainer = CreamSupernetTrainer(model, train_loss_fn, validate_loss_fn, optimizer, num_epochs, loader_train, loader_eval, mutator=mutator, batch_size=cfg.DATASET.BATCH_SIZE, log_frequency=cfg.LOG_INTERVAL, meta_sta_epoch=cfg.SUPERNET.META_STA_EPOCH, update_iter=cfg.SUPERNET.UPDATE_ITER, slices=cfg.SUPERNET.SLICE, pool_size=cfg.SUPERNET.POOL_SIZE, pick_method=cfg.SUPERNET.PICK_METHOD, choice_num=choice_num, sta_num=sta_num, acc_gap=cfg.ACC_GAP, flops_dict=flops_dict, flops_fixed=flops_fixed, local_rank=args.local_rank, callbacks=[ LRSchedulerCallback(lr_scheduler), ModelCheckpoint(output_dir) ]) trainer.train()
def main(): logger = None output_dir = '' setup_default_logging() args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) import random port = random.randint(0, 50000) torch.distributed.init_process_group( backend='nccl', init_method='env://' ) # tcp://127.0.0.1:{}'.format(port), rank=args.local_rank, world_size=8) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) seed = args.seed torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False model, sta_num, size_factor = _gen_supernet_large( num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, resunit=args.resunit, dil_conv=args.dil_conv, slice=args.slice) if args.local_rank == 0: print("Model Searched Using FLOPs {}".format(size_factor * 32)) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) if args.local_rank == 0: '''output dir''' output_base = args.output if args.output else './experiments' exp_name = args.model output_dir = get_outdir(output_base, 'search', exp_name) log_file = os.path.join(output_dir, "search.log") logger = get_logger(log_file) if args.local_rank == 0: logger.info(args) choice_num = 6 if args.resunit: choice_num += 1 if args.dil_conv: choice_num += 2 if args.local_rank == 0: logger.info("Choice_num: {}".format(choice_num)) model_est = LatencyEst(model) if os.path.exists(args.initial_checkpoint): load_checkpoint(model, args.initial_checkpoint) if args.local_rank == 0: logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) # data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if args.resume: optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer_supernet(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state['optimizer']) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logger.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logger.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True ) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: {}'.format(num_epochs)) if args.tiny: from lib.dataset.tiny_imagenet import get_newimagenet [loader_train, loader_eval], [train_sampler, test_sampler ] = get_newimagenet(args.data, args.batch_size) else: train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logger.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, color_jitter=args.color_jitter, interpolation= 'random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None best_children_pool = [] if args.local_rank == 0: decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) try: for epoch in range(start_epoch, num_epochs): if args.distributed: if args.tiny: train_sampler.set_epoch(epoch) else: loader_train.sampler.set_epoch(epoch) '''2020.10.19 large_model=True !''' train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, CHOICE_NUM=choice_num, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, logger=logger, val_loader=loader_eval, use_amp=use_amp, model_ema=model_ema, est=model_est, sta_num=sta_num, large_model=True) # eval_metrics = OrderedDict([('loss', 0.0), ('prec1', 0.0), ('prec5', 0.0)]) eval_metrics = validate(model, loader_eval, validate_loss_fn, args, CHOICE_NUM=choice_num, sta_num=sta_num) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
_, loader = get_train_loader(args.data_root, args.num_src, total_steps, args.batch_size, { 'interval_scale': args.interval_scale, 'resize_width': resize_width, 'resize_height': resize_height, 'crop_width': crop_width, 'crop_height': crop_height }, num_workers=args.num_workers) model = Model() model.cuda() model = apex_parallel.convert_syncbn_model(model) print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters() if p.requires_grad]))) compute_loss = Loss() model = nn.DataParallel(model) if args.load_path is None: for m in model.modules(): if any([ isinstance(m, T) for T in [nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d] ]): if m.weight.requires_grad: nn.init.xavier_uniform_(m.weight)
def main(fold_i=0, data_=None, train_index=None, val_index=None): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank best_score = 0.0 args.output = args.output + 'fold_' + str(fold_i) if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) if fold_i == 0: torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) #optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-6) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(20)) ##create DataLoader train_data = data_.iloc[train_index, :].reset_index(drop=True) dataset_train = RiaddDataSet(image_ids=train_data, baseImgPath=args.data) val_data = data_.iloc[val_index, :].reset_index(drop=True) dataset_eval = RiaddDataSet(image_ids=val_data, baseImgPath=args.data) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] train_trans = get_riadd_train_transforms(args) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, transform=train_trans) valid_trans = get_riadd_valid_transforms(args) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, transform=valid_trans) # # setup loss function # if args.jsd: # assert num_aug_splits > 1 # JSD only valid with aug splits set # train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() # elif mixup_active: # # smoothing is handled with mixup target transform # train_loss_fn = SoftTargetCrossEntropy().cuda() # elif args.smoothing: # train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() # else: # train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.BCEWithLogitsLoss().cuda() train_loss_fn = nn.BCEWithLogitsLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None vis = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) vis = Visualizer(env=args.output) try: for epoch in range(0, args.epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) # predictions = gather_tensor(eval_metrics) predictions, valid_label = gather_predict_label(eval_metrics, args) valid_label = valid_label.detach().cpu().numpy() predictions = predictions.detach().cpu().numpy() score, scores = get_score(valid_label, predictions) ##visdom if vis is not None: vis.plot_curves({'None': epoch}, iters=epoch, title='None', xlabel='iters', ylabel='None') vis.plot_curves( {'learing rate': optimizer.param_groups[0]['lr']}, iters=epoch, title='lr', xlabel='iters', ylabel='learing rate') vis.plot_curves({'train loss': float(train_metrics['loss'])}, iters=epoch, title='train loss', xlabel='iters', ylabel='train loss') vis.plot_curves({'val loss': float(eval_metrics['loss'])}, iters=epoch, title='val loss', xlabel='iters', ylabel='val loss') vis.plot_curves({'val score': float(score)}, iters=epoch, title='val score', xlabel='iters', ylabel='val score') if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch # lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) lr_scheduler.step(epoch + 1, score) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None and score > best_score: # save proper checkpoint with eval metric best_score = score save_metric = best_score best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) del model del optimizer torch.cuda.empty_cache() except KeyboardInterrupt: pass
def main(): args, args_text = _parse_args() if args.local_rank == 0: args.local_rank = local_rank print("rank:{0},word_size:{1},dist_url:{2}".format( local_rank, word_size, dist_url)) if args.model_selection == 470: arch_list = [[0], [3, 4, 3, 1], [3, 2, 3, 0], [3, 3, 3, 1], [3, 3, 3, 3], [3, 3, 3, 3], [0]] arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_se0.25'], # stage 1, 112x112 in [ 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25' ], # stage 2, 56x56 in [ 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25' ], # stage 3, 28x28 in [ 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r2_k3_s1_e4_c80_se0.25' ], # stage 4, 14x14in [ 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25' ], # stage 5, 14x14in [ 'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25' ], # stage 6, 7x7 in ['cn_r1_k1_s1_c320_se0.25'], ] args.img_size = 224 elif args.model_selection == 42: arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_se0.25'], # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_se0.25'], # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25'], # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25'], # stage 4, 14x14in [ 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25' ], # stage 5, 14x14in ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25'], # stage 6, 7x7 in ['cn_r1_k1_s1_c320_se0.25'], ] args.img_size = 96 elif args.model_selection == 14: arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_se0.25'], # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_se0.25'], # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e4_c40_se0.25'], # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e4_c80_se0.25'], # stage 4, 14x14in ['ir_r1_k3_s1_e6_c96_se0.25'], # stage 5, 14x14in ['ir_r1_k5_s2_e6_c192_se0.25'], # stage 6, 7x7 in ['cn_r1_k1_s1_c320_se0.25'], ] args.img_size = 64 elif args.model_selection == 112: arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_se0.25'], # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_se0.25'], # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k3_s2_e4_c40_se0.25'], # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25'], # stage 4, 14x14in [ 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25' ], # stage 5, 14x14in ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25'], # stage 6, 7x7 in ['cn_r1_k1_s1_c320_se0.25'], ] args.img_size = 160 elif args.model_selection == 285: arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_se0.25'], # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_se0.25'], # stage 2, 56x56 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25'], # stage 3, 28x28 in [ 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s2_e6_c80_se0.25' ], # stage 4, 14x14in [ 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25' ], # stage 5, 14x14in [ 'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25' ], # stage 6, 7x7 in ['cn_r1_k1_s1_c320_se0.25'], ] args.img_size = 224 elif args.model_selection == 600: arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_se0.25'], # stage 1, 112x112 in [ 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s2_e4_c24_se0.25' ], # stage 2, 56x56 in [ 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25' ], # stage 3, 28x28 in [ 'ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25' ], # stage 4, 14x14in [ 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25' ], # stage 5, 14x14in [ 'ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25' ], # stage 6, 7x7 in ['cn_r1_k1_s1_c320_se0.25'], ] args.img_size = 224 else: raise ValueError("Not Supported!") model = _gen_childnet(arch_list, arch_def, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, pool_bn=args.pool_bn, zero_gamma=args.zero_gamma) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './experiments' exp_name = '-'.join([ args.name, datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'retrain', exp_name) logger = get_logger(os.path.join(output_dir, 'retrain.log')) writer = SummaryWriter(os.path.join(output_dir, 'runs')) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'config.yaml'), 'w') as f: f.write(args_text) else: writer = None logger = None args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1 and args.local_rank == 0: logger.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank args.distributed = True if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.local_rank == 0: if args.distributed: logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logger.info('Training with a single process on %d GPUs.' % args.num_gpu) seed = args.seed torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if args.local_rank == 0: scope(model, input_size=(3, 224, 224)) if os.path.exists(args.initial_checkpoint): load_checkpoint(model, args.initial_checkpoint) if args.local_rank == 0: logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) if args.num_gpu > 1: if args.amp: if args.local_rank == 0: logger.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logger.info('NVIDIA APEX {}. ' ' {}.'.format('installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logger.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: if args.local_rank == 0: logger.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir) and args.local_rank == 0: logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) eval_dir = os.path.join(args.data, 'val') if not os.path.exists(eval_dir) and args.local_rank == 0: logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=0, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=None, pin_memory=args.pin_mem, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: {}'.format(num_epochs)) try: best_record = 0 best_ep = 0 for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, logger=logger, writer=writer) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(epoch, model, loader_eval, validate_loss_fn, args, logger=logger, writer=writer) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(epoch, model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)', logger=logger, writer=writer) eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) if best_record < eval_metrics[eval_metric]: best_record = eval_metrics[eval_metric] best_ep = epoch if args.local_rank == 0: logger.info('*** Best metric: {0} (epoch {1})'.format( best_record, best_ep)) except KeyboardInterrupt: pass if best_metric is not None: logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.binarizable: Model_binary_patch(model) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.') args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: print('Using amp.') model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True else: print('Do NOT use amp.') if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) resume_state = None if args.freeze_binary: Model_freeze_binary(model) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logging.info('Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) # start_epoch = 0 # if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if args.reset_lr_scheduler is not None: lr_scheduler.base_values = len(lr_scheduler.base_values)*[args.reset_lr_scheduler] lr_scheduler.step(start_epoch) if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) # Using pruner to get sparse weights if args.prune: pruner = Pruner(model, 0, 100, 0.75) else: pruner = None dataset_train = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=True, download=True) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader_CIFAR100( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, rand_erase_count=args.recount, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='random', mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, is_clean_data=args.clean_train, ) dataset_eval = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=False, download=True) loader_eval = create_loader_CIFAR100( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy(multiplier=args.softmax_multiplier).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None saver_last_10_epochs = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False os.makedirs(output_dir+'/Top') os.makedirs(output_dir+'/Last') saver = CheckpointSaver(checkpoint_dir=output_dir + '/Top', decreasing=decreasing, max_history=10) # Save the results of the top 10 epochs saver_last_10_epochs = CheckpointSaver(checkpoint_dir=output_dir + '/Last', decreasing=decreasing, max_history=10) # Save the results of the last 10 epochs with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) f.write('==============================') f.write(model.__str__()) tensorboard_writer = SummaryWriter(output_dir) try: for epoch in range(start_epoch, num_epochs): global alpha alpha = get_alpha(epoch, args) if args.distributed: loader_train.sampler.set_epoch(epoch) if pruner: pruner.on_epoch_begin(epoch) # pruning train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, tensorboard_writer=tensorboard_writer, pruner = pruner) if pruner: pruner.print_statistics() eval_metrics = validate(model, loader_eval, validate_loss_fn, args, tensorboard_writer=tensorboard_writer, epoch=epoch) if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, metric=save_metric, use_amp=use_amp) if saver_last_10_epochs is not None: # save the checkpoint in the last 5 epochs _, _ = saver_last_10_epochs.save_checkpoint( model, optimizer, args, epoch=epoch, metric=epoch, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) logging.info('The checkpoint of the last epoch is: \n') logging.info(saver_last_10_epochs.checkpoint_files[0][0])
def train(): print("local_rank:", args.local_rank) cudnn.benchmark = True if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.local_rank) torch.set_printoptions(precision=10) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend='nccl', init_method='env://', ) torch.manual_seed(0) if not args.eval_net: train_ds = dataset_desc.Dataset('train') train_sampler = torch.utils.data.distributed.DistributedSampler( train_ds) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=config.mini_batch_size, shuffle=False, drop_last=True, num_workers=4, sampler=train_sampler, pin_memory=True) val_ds = dataset_desc.Dataset('test') val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=config.val_mini_batch_size, shuffle=False, drop_last=False, num_workers=4, sampler=val_sampler) else: test_ds = dataset_desc.Dataset('test') test_loader = torch.utils.data.DataLoader( test_ds, batch_size=config.test_mini_batch_size, shuffle=False, num_workers=20) rndla_cfg = ConfigRandLA model = FFB6D(n_classes=config.n_objects, n_pts=config.n_sample_points, rndla_cfg=rndla_cfg, n_kps=config.n_keypoints) model = convert_syncbn_model(model) device = torch.device('cuda:{}'.format(args.local_rank)) print('local_rank:', args.local_rank) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) opt_level = args.opt_level model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, ) model = nn.DataParallel(model) # default value it = -1 # for the initialize value of `LambdaLR` and `BNMomentumScheduler` best_loss = 1e10 start_epoch = 1 # load status from checkpoint if args.checkpoint is not None: checkpoint_status = load_checkpoint(model, optimizer, filename=args.checkpoint[:-8]) if checkpoint_status is not None: it, start_epoch, best_loss = checkpoint_status if args.eval_net: assert checkpoint_status is not None, "Failed loadding model." if not args.eval_net: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) clr_div = 6 lr_scheduler = CyclicLR( optimizer, base_lr=1e-5, max_lr=1e-3, cycle_momentum=False, step_size_up=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus, step_size_down=config.n_total_epoch * train_ds.minibatch_per_epoch // clr_div // args.gpus, mode='triangular') else: lr_scheduler = None bnm_lmbd = lambda it: max( args.bn_momentum * args.bn_decay** (int(it * config.mini_batch_size / args.decay_step)), bnm_clip, ) bnm_scheduler = pt_utils.BNMomentumScheduler(model, bn_lambda=bnm_lmbd, last_epoch=it) it = max(it, 0) # for the initialize value of `trainer.train` if args.eval_net: model_fn = model_fn_decorator( FocalLoss(gamma=2), OFLoss(), args.test, ) else: model_fn = model_fn_decorator( FocalLoss(gamma=2).to(device), OFLoss().to(device), args.test, ) checkpoint_fd = config.log_model_dir trainer = Trainer( model, model_fn, optimizer, checkpoint_name=os.path.join(checkpoint_fd, "FFB6D"), best_name=os.path.join(checkpoint_fd, "FFB6D_best"), lr_scheduler=lr_scheduler, bnm_scheduler=bnm_scheduler, ) if args.eval_net: start = time.time() val_loss, res = trainer.eval_epoch(test_loader, is_test=True, test_pose=args.test_pose) end = time.time() print("\nUse time: ", end - start, 's') else: trainer.train(it, start_epoch, config.n_total_epoch, train_loader, None, val_loader, best_loss=best_loss, tot_iter=config.n_total_epoch * train_ds.minibatch_per_epoch // args.gpus, clr_div=clr_div) if start_epoch == config.n_total_epoch: _ = trainer.eval_epoch(val_loader)
def main(): args.can_print = (args.distributed and args.local_rank == 0) or (not args.distributed) log_out_dir = f'{RESULT_DIR}/logs/{args.out_dir}' os.makedirs(log_out_dir, exist_ok=True) if args.can_print: log = Logger() log.open(f'{log_out_dir}/log.train.txt', mode='a') else: log = None model_out_dir = f'{RESULT_DIR}/models/{args.out_dir}' if args.can_print: log.write( f'>> Creating directory if it does not exist:\n>> {model_out_dir}\n' ) os.makedirs(model_out_dir, exist_ok=True) # set cuda visible device os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id if args.distributed: torch.cuda.set_device(args.local_rank) # set random seeds torch.manual_seed(0) torch.cuda.manual_seed_all(0) np.random.seed(0) model_params = {} model_params['architecture'] = args.arch model_params['num_classes'] = args.num_classes model_params['in_channels'] = args.in_channels model_params['can_print'] = args.can_print model = init_network(model_params) # move network to gpu if args.distributed: dist.init_process_group(backend='nccl', init_method='env://') model = convert_syncbn_model(model) model.cuda() if args.distributed: model = DistributedDataParallel(model, delay_allreduce=True) else: model = DataParallel(model) # define loss function (criterion) try: criterion = eval(args.loss)().cuda() except: raise RuntimeError(f'Loss {args.loss} not available!') start_epoch = 0 best_score = 0 best_epoch = 0 # define scheduler try: scheduler = eval(args.scheduler)(model) except: raise RuntimeError(f'Scheduler {args.scheduler} not available!') # optionally resume from a checkpoint reset_epoch = True pretrained_file = None if args.model_file: reset_epoch = True pretrained_file = args.model_file if args.resume: reset_epoch = False pretrained_file = f'{model_out_dir}/{args.resume}' if pretrained_file and os.path.isfile(pretrained_file): # load checkpoint weights and update model and optimizer if args.can_print: log.write(f'>> Loading checkpoint:\n>> {pretrained_file}\n') checkpoint = torch.load(pretrained_file) if not reset_epoch: start_epoch = checkpoint['epoch'] best_epoch = checkpoint['best_epoch'] best_score = checkpoint['best_score'] model.module.load_state_dict(checkpoint['state_dict']) if args.can_print: if reset_epoch: log.write(f'>>>> loaded checkpoint:\n>>>> {pretrained_file}\n') else: log.write( f'>>>> loaded checkpoint:\n>>>> {pretrained_file} (epoch {checkpoint["epoch"]:.2f})\n' ) else: if args.can_print: log.write(f'>> No checkpoint found at {pretrained_file}\n') # Data loading code train_transform = eval(f'train_multi_augment{args.aug_version}') train_split_file = f'{DATA_DIR}/split/{args.split_type}/random_train_cv0.csv' valid_split_file = f'{DATA_DIR}/split/{args.split_type}/random_valid_cv0.csv' train_dataset = RetrievalDataset( args, train_split_file, transform=train_transform, data_type='train', ) valid_dataset = RetrievalDataset( args, valid_split_file, transform=None, data_type='valid', ) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset) else: train_sampler = RandomSampler(train_dataset) valid_sampler = SequentialSampler(valid_dataset) train_loader = DataLoader( train_dataset, sampler=train_sampler, batch_size=args.batch_size, drop_last=True, num_workers=args.workers, pin_memory=True, collate_fn=image_collate, ) valid_loader = DataLoader( valid_dataset, sampler=valid_sampler, batch_size=args.batch_size, drop_last=False, num_workers=args.workers, pin_memory=True, collate_fn=image_collate, ) train(args, train_loader, valid_loader, model, criterion, scheduler, log, best_epoch, best_score, start_epoch, model_out_dir)
def main(): args, cfg = parse_config_args('super net training') # resolve logging output_dir = os.path.join( cfg.SAVE_PATH, "{}-{}".format(datetime.date.today().strftime('%m%d'), cfg.MODEL)) if args.local_rank == 0: logger = get_logger(os.path.join(output_dir, "train.log")) else: logger = None # initialize distributed parameters torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info('Training on Process %d with %d GPUs.', args.local_rank, cfg.NUM_GPU) # fix random seeds torch.manual_seed(cfg.SEED) torch.cuda.manual_seed_all(cfg.SEED) np.random.seed(cfg.SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # generate supernet model, sta_num, resolution = gen_supernet( flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM, flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM, num_classes=cfg.DATASET.NUM_CLASSES, drop_rate=cfg.NET.DROPOUT_RATE, global_pool=cfg.NET.GP, resunit=cfg.SUPERNET.RESUNIT, dil_conv=cfg.SUPERNET.DIL_CONV, slice=cfg.SUPERNET.SLICE, verbose=cfg.VERBOSE, logger=logger) # initialize meta matching networks MetaMN = MetaMatchingNetwork(cfg) # number of choice blocks in supernet choice_num = len(model.blocks[1][0]) if args.local_rank == 0: logger.info('Supernet created, param count: %d', (sum([m.numel() for m in model.parameters()]))) logger.info('resolution: %d', (resolution)) logger.info('choice number: %d', (choice_num)) #initialize prioritized board prioritized_board = PrioritizedBoard(cfg, CHOICE_NUM=choice_num, sta_num=sta_num) # initialize flops look-up table model_est = FlopsEst(model) # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if cfg.AUTO_RESUME: optimizer_state, resume_epoch = resume_checkpoint( model, cfg.RESUME_PATH) # create optimizer and resume from checkpoint optimizer = create_optimizer_supernet(cfg, model, USE_APEX) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state['optimizer']) model = model.cuda() # convert model to distributed mode if cfg.BATCHNORM.SYNC_BN: try: if USE_APEX: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logger.info('Converted model to use Synchronized BatchNorm.') except Exception as exception: logger.info( 'Failed to enable Synchronized BatchNorm. ' 'Install Apex or Torch >= 1.1 with Exception %s', exception) if USE_APEX: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logger.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) # can use device str in Torch >= 1.1 model = DDP(model, device_ids=[args.local_rank]) # create learning rate scheduler lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: %d', num_epochs) # imagenet train dataset train_dir = os.path.join(cfg.DATA_DIR, 'train') if not os.path.exists(train_dir): logger.info('Training folder does not exist at: %s', train_dir) sys.exit() dataset_train = Dataset(train_dir) loader_train = create_loader(dataset_train, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=cfg.DATASET.BATCH_SIZE, is_training=True, use_prefetcher=True, re_prob=cfg.AUGMENTATION.RE_PROB, re_mode=cfg.AUGMENTATION.RE_MODE, color_jitter=cfg.AUGMENTATION.COLOR_JITTER, interpolation='random', num_workers=cfg.WORKERS, distributed=True, collate_fn=None, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) # imagenet validation dataset eval_dir = os.path.join(cfg.DATA_DIR, 'val') if not os.path.isdir(eval_dir): logger.info('Validation folder does not exist at: %s', eval_dir) sys.exit() dataset_eval = Dataset(eval_dir) loader_eval = create_loader(dataset_eval, input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), batch_size=4 * cfg.DATASET.BATCH_SIZE, is_training=False, use_prefetcher=True, num_workers=cfg.WORKERS, distributed=True, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, interpolation=cfg.DATASET.INTERPOLATION) # whether to use label smoothing if cfg.AUGMENTATION.SMOOTHING > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn # initialize training parameters eval_metric = cfg.EVAL_METRICS best_metric, best_epoch, saver, best_children_pool = None, None, None, [] if args.local_rank == 0: decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) # training scheme try: for epoch in range(start_epoch, num_epochs): loader_train.sampler.set_epoch(epoch) # train one epoch train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, prioritized_board, MetaMN, cfg, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, logger=logger, est=model_est, local_rank=args.local_rank) # evaluate one epoch eval_metrics = validate(model, loader_eval, validate_loss_fn, prioritized_board, MetaMN, cfg, local_rank=args.local_rank, logger=logger) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, cfg, epoch=epoch, metric=save_metric) except KeyboardInterrupt: pass
scheduler = CosineAnnealingScheduler( optimizer, 'lr', args.learning_rate, args.learning_rate / 1000, args.epochs * len(train_loader), ) scheduler = create_lr_scheduler_with_warmup(scheduler, 0, args.learning_rate, 1000) (teacher, student), optimizer = amp.initialize([teacher, student], optimizer, opt_level="O2") if args.distributed: student = convert_syncbn_model(student) teacher = DistributedDataParallel(teacher) student = DistributedDataParallel(student) def create_segmentation_distillation_trainer(student, teacher, optimizer, supervised_loss_fn, distillation_loss_fn, device, use_f16=True, non_blocking=True): from ignite.engine import Engine, Events, _prepare_batch from ignite.metrics import RunningAverage, Loss
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def __init__(self, serialization_dir, params, model, loss, alphabet, local_rank=0, world_size=1, sync_bn=False, opt_level='O0', keep_batchnorm_fp32=None, loss_scale=1.0): self.alphabet = alphabet self.clip_grad_norm = params.get('clip_grad_norm', None) self.clip_grad_value = params.get('clip_grad_value', None) self.warmup_epochs = params.get('warmup_epochs', 0) self.label_smoothing = params.get('label_smoothing', 0.0) self.local_rank = local_rank self.loss = loss.cuda() self.model = model self.best_monitor = float('inf') self.monitor = params.get('monitor', 'loss') self.serialization_dir = serialization_dir self.distributed = world_size > 1 self.world_size = world_size self.epoch = 0 self.num_epochs = 0 self.start_epoch = 0 self.start_iteration = 0 self.start_time = 0 self.iterations_per_epoch = None self.time_since_last = time.time() self.save_every = params.get('save_every', 60 * 10) # 10 minutes if sync_bn: logger.info('Using Apex `sync_bn`') self.model = convert_syncbn_model(self.model) self.model = model.cuda() # Setup optimizer parameters = [(n, p) for n, p in self.model.named_parameters() if p.requires_grad] self.optimizer = optimizers.from_params(params.pop("optimizer"), parameters, world_size=self.world_size) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=opt_level, keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale=loss_scale) # Setup lr scheduler scheduler_params = params.pop('lr_scheduler', None) self.lr_scheduler = None if scheduler_params: self.lr_scheduler = lr_schedulers.from_params( scheduler_params, self.optimizer) self.base_lrs = list( map(lambda group: group['initial_lr'], self.optimizer.param_groups)) # Setup metrics metrics_params = params.pop('metrics', []) if 'loss' not in metrics_params: metrics_params = ['loss'] + metrics_params # Initializing history self.metrics = {} for phase in ['train', 'val']: self.metrics[phase] = metrics.from_params(metrics_params, alphabet=alphabet) if self.distributed: self.model = DistributedDataParallel(self.model, delay_allreduce=True)
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on 1 GPU.') torch.manual_seed(args.seed + args.rank) # create model config = get_efficientdet_config(args.model) config.redundant_bias = args.redundant_bias # redundant conv + BN bias layers (True to match official models) model = EfficientDet(config) model = DetBenchTrain(model, config) # FIXME create model factory, pretrained zoo # model = create_model( # args.model, # pretrained=args.pretrained, # num_classes=args.num_classes, # drop_rate=args.drop, # drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path # drop_path_rate=args.drop_path, # drop_block_rate=args.drop_block, # global_pool=args.gp, # bn_tf=args.bn_tf, # bn_momentum=args.bn_momentum, # bn_eps=args.bn_eps, # checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(_unwrap_bench(model), args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '') #resume=args.resume) # FIXME bit of a mess with bench if args.resume: load_checkpoint(_unwrap_bench(model_ema), args.resume, use_ema=True) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_anno_set = 'train2017' train_annotation_path = os.path.join(args.data, 'annotations', f'instances_{train_anno_set}.json') train_image_dir = train_anno_set dataset_train = CocoDetection(os.path.join(args.data, train_image_dir), train_annotation_path) # FIXME cutmix/mixup worth investigating? # collate_fn = None # if args.prefetcher and args.mixup > 0: # collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader( dataset_train, input_size=config.image_size, batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, #re_prob=args.reprob, # FIXME add back various augmentations #re_mode=args.remode, #re_count=args.recount, #re_split=args.resplit, #color_jitter=args.color_jitter, #auto_augment=args.aa, interpolation=args.train_interpolation, #mean=data_config['mean'], #std=data_config['std'], num_workers=args.workers, distributed=args.distributed, #collate_fn=collate_fn, pin_mem=args.pin_mem, ) train_anno_set = 'val2017' train_annotation_path = os.path.join(args.data, 'annotations', f'instances_{train_anno_set}.json') train_image_dir = train_anno_set dataset_eval = CocoDetection(os.path.join(args.data, train_image_dir), train_annotation_path) loader_eval = create_loader( dataset_eval, input_size=config.image_size, batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=args.interpolation, #mean=data_config['mean'], #std=data_config['std'], num_workers=args.workers, #distributed=args.distributed, pin_mem=args.pin_mem, ) eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( _unwrap_bench(model), optimizer, args, epoch=epoch, model_ema=_unwrap_bench(model_ema), metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: _logger.warning( 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) # ! build model if 'inception' in args.model: model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, aux_logits=True, # ! add aux loss global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) else: model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) # ! add more layer to classifier layer if args.create_classifier_layerfc: model.global_pool, model.classifier = create_classifier_layerfc( model.num_features, model.num_classes) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") if args.num_gpu > 1: if use_amp == 'apex': _logger.warning( 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.' ) use_amp = None model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() assert not args.channels_last, "Channels last not supported with DP, use DDP." else: model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) optimizer = create_optimizer(args, model) # ! optimizer if args.classification_layer_name is not None: # ! add our own classification layer... doesn't really improve very much though... args.classification_layer_name = args.classification_layer_name.strip( ).split() print('classification_layer_name {}'.format( args.classification_layer_name)) optimizer = create_optimizer( args, model, filter_bias_and_bn=args.filter_bias_and_bn, classification_layer_name=args.classification_layer_name) amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') if args.num_gpu > 1: # ! these lines used to be above @optimizer, but we move it here, so that we override apex warning. if args.amp: _logger.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) _logger.warning('... we will ignore this ... .') # args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: _logger.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir, args=args) collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, # ! see file auto_augment.py num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, args=args) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir, args=args) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, # ! so we can eval faster is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, args=args) # add weighted loss for each label class if args.weighted_cross_entropy is None: weighted_cross_entropy = torch.ones(args.num_classes) else: weighted_cross_entropy = args.weighted_cross_entropy.strip().split() weighted_cross_entropy = torch.FloatTensor( [float(w) for w in weighted_cross_entropy]) # string if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss( weight=weighted_cross_entropy).cuda() if args.weighted_cross_entropy_eval: validate_loss_fn = nn.CrossEntropyLoss( weight=weighted_cross_entropy).cuda() else: validate_loss_fn = nn.CrossEntropyLoss().cuda( ) # ! eval as usual ? weight=weighted_cross_entropy eval_metric = args.eval_metric best_metric = None best_epoch = 0 saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) # early stop if epoch - best_epoch > args.early_stop_counter: _logger.info( '*** Best metric: {0} (epoch {1}) (current epoch {2}'. format(best_metric, best_epoch, epoch)) break # ! exit except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def set_syncbn(net): if has_apex: net = parallel.convert_syncbn_model(net) else: net = nn.SyncBatchNorm.convert_sync_batchnorm(net) return net
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 DistributedManager.set_args(args) if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.initial_checkpoint_pruned: try: data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) model2 = load_module_from_ckpt( model, args.initial_checkpoint_pruned, input_size=data_config['input_size'][1]) logging.info("New pruned model adapted from the checkpoint") except Exception as e: raise RuntimeError(e) else: model2 = model if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model2.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: model2 = nn.DataParallel(model2, device_ids=list(range(args.num_gpu))).cuda() else: model2.cuda() use_amp = False if args.distributed: model2 = nn.parallel.distributed.DistributedDataParallel( model2, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'test') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) test_dir = os.path.join(args.data, 'test') if not os.path.isdir(test_dir): test_dir = os.path.join(args.data, 'validation') if not os.path.isdir(test_dir): test_dir = os.path.join(args.data, 'val') if not os.path.isdir(test_dir): logging.error( 'Test folder does not exist at: {}'.format(test_dir)) exit(1) dataset_eval = Dataset(eval_dir) if args.prune_test: dataset_test = Dataset(test_dir) else: dataset_test = Dataset(train_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) len_loader = int( len(loader_eval) * (4 * args.batch_size) / args.batch_size_prune) if args.prune_test: len_loader = None if args.prune: loader_p = create_loader( dataset_test, input_size=data_config['input_size'], batch_size=args.batch_size_prune, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if 'resnet' in model2.__class__.__name__.lower() or ( hasattr(model2, 'module') and 'resnet' in model2.module.__class__.__name__.lower()): list_channel_to_prune = compute_num_channels_per_layer_taylor( model2, data_config['input_size'], loader_p, pruning_ratio=args.pruning_ratio, taylor_file=args.taylor_file, local_rank=args.local_rank, len_data_loader=len_loader, prune_skip=args.prune_skip, taylor_abs=args.taylor_abs, prune_conv1=args.prune_conv1, use_time=args.use_time, distributed=args.distributed) new_net = redesign_module_resnet( model2, list_channel_to_prune, use_amp=use_amp, distributed=args.distributed, local_rank=args.local_rank, input_size=data_config['input_size'][1]) else: list_channel_to_prune = compute_num_channels_per_layer_taylor( model2, data_config['input_size'], loader_p, pruning_ratio=args.pruning_ratio, taylor_file=args.taylor_file, local_rank=args.local_rank, len_data_loader=len_loader, prune_pwl=not args.no_pwl, taylor_abs=args.taylor_abs, use_se=not args.use_eca, use_time=args.use_time, distributed=args.distributed) new_net = redesign_module_efnet( model2, list_channel_to_prune, use_amp=use_amp, distributed=args.distributed, local_rank=args.local_rank, input_size=data_config['input_size'][1], use_se=not args.use_eca) new_net.train() model.train() if isinstance(model, nn.DataParallel) or isinstance(model, DDP): model = model.module else: model = model.cuda() co_mod = build_co_train_model( model, new_net.module.cpu() if hasattr(new_net, 'module') else new_net, gamma=args.gamma_knowledge, only_last=args.only_last, progressive_IKD_factor=args.progressive_IKD_factor) optimizer = create_optimizer(args, co_mod) del model del new_net gc.collect() torch.cuda.empty_cache() if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False co_mod = nn.DataParallel(co_mod, device_ids=list(range( args.num_gpu))).cuda() else: co_mod = co_mod.cuda() use_amp = False if has_apex and args.amp: co_mod, optimizer = amp.initialize(co_mod, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) if args.distributed: if args.sync_bn: try: if has_apex and use_amp: co_mod = convert_syncbn_model(co_mod) else: co_mod = torch.nn.SyncBatchNorm.convert_sync_batchnorm( co_mod) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex and use_amp: co_mod = DDP(co_mod, delay_allreduce=False) else: if args.local_rank == 0 and use_amp: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) co_mod = nn.parallel.distributed.DistributedDataParallel( co_mod, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP co_mod.train() lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: if args.local_rank == 0: logging.info(f'First validation') co_mod.eval() eval_metrics = validate(co_mod, loader_eval, validate_loss_fn, args) if args.local_rank == 0: logging.info(f'Prec@top1 : {eval_metrics["prec1"]}') co_mod.train() for epoch in range(start_epoch, num_epochs): torch.cuda.empty_cache() if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, co_mod, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=None) torch.cuda.empty_cache() eval_metrics = validate(co_mod, loader_eval, validate_loss_fn, args) if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( co_mod, optimizer, args, epoch=epoch, model_ema=None, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.pretrained_backbone = not args.no_pretrained_backbone args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:' + str(args.cuda_device) args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on 1 GPU.') use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: logging.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6.") if args.apex_amp: if has_apex: use_amp = 'apex' else: logging.warning( "APEX AMP not available, using float32. Install NVIDA apex") elif args.native_amp: if has_native_amp: use_amp = 'native' else: logging.warning( "Native AMP not available, using float32. Upgrade to PyTorch 1.6." ) torch.manual_seed(args.seed + args.rank) with set_layer_config(scriptable=args.torchscript): model = create_model( args.model, bench_task='train', num_classes=args.num_classes, pretrained=args.pretrained, pretrained_backbone=args.pretrained_backbone, redundant_bias=args.redundant_bias, label_smoothing=args.smoothing, legacy_focal=args.legacy_focal, jit_loss=args.jit_loss, soft_nms=args.soft_nms, bench_labeler=args.bench_labeler, checkpoint_path=args.initial_checkpoint, ) model_config = model.config # grab before we obscure with DP/DDP wrappers if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.distributed and args.sync_bn: if has_apex and use_amp != 'native': model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model, force native amp with `--native-amp` flag' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model. Use `--dist-bn reduce` instead of `--sync-bn`' model = torch.jit.script(model) optimizer = create_optimizer(args, model) amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: logging.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: logging.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: logging.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( unwrap_bench(model), args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2(model, decay=args.model_ema_decay) if args.resume: load_checkpoint(unwrap_bench(model_ema), args.resume, use_ema=True) if args.distributed: if has_apex and use_amp != 'native': if args.local_rank == 0: logging.info("Using apex DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info("Using torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.device]) # NOTE: EMA model does not need to be wrapped by DDP... if model_ema is not None and not args.resume: # ...but it is a good idea to sync EMA copy of weights # NOTE: ModelEma init could be moved after DDP wrapper if using PyTorch DDP, not Apex. model_ema.set(model) lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) loader_train, loader_eval, evaluator = create_datasets_and_loaders( args, model_config) if model_config.num_classes < loader_train.dataset.parser.max_label: logging.error( f'Model {model_config.num_classes} has fewer classes than dataset {loader_train.dataset.parser.max_label}.' ) exit(1) if model_config.num_classes > loader_train.dataset.parser.max_label: logging.warning( f'Model {model_config.num_classes} has more classes than dataset {loader_train.dataset.parser.max_label}.' ) eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model, optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, decreasing=decreasing, unwrap_fn=unwrap_bench) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') # the overhead of evaluating with coco style datasets is fairly high, so just ema or non, not both if model_ema is not None: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model_ema.module, loader_eval, args, evaluator, log_suffix=' (EMA)') else: eval_metrics = validate(model, loader_eval, args, evaluator) if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if saver is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) # save proper checkpoint with eval metric best_metric, best_epoch = saver.save_checkpoint( epoch=epoch, metric=eval_metrics[eval_metric]) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): import os args, args_text = _parse_args() eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = 'train' if args.gate_train: exp_name += '-dynamic' if args.slim_train: exp_name += '-slimmable' exp_name += '-{}'.format(args.model) exp_info = '-'.join( [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model]) output_dir = get_outdir(output_base, exp_name, exp_info) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) setup_default_logging(outdir=output_dir, local_rank=args.local_rank) torch.backends.cudnn.benchmark = True args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group(backend='nccl', # init_method='tcp://127.0.0.1:23334', # rank=args.local_rank, # world_size=int(os.environ['WORLD_SIZE'])) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) # --------- random seed ----------- random.seed(args.seed) # TODO: do we need same seed on all GPU? np.random.seed(args.seed) torch.manual_seed(args.seed) # torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() if args.train_mode == 'gate': optimizer = create_optimizer(args, model.get_gate()) else: optimizer = create_optimizer(args, model) # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, checkpoint_path=args.resume, optimizer=optimizer if not args.no_resume_opt else None, log_info=args.local_rank == 0, strict=False) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume, log_info=args.local_rank == 0, resume_strict=False) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True ) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) # ------------- data -------------- train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) loader_bn = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # ------------- loss_fn -------------- if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn if args.ieb: distill_loss_fn = SoftTargetCrossEntropy().cuda() else: distill_loss_fn = None if args.local_rank == 0: model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=True) else: model_profiling(model, 224, 224, 1, 3, use_cuda=True, verbose=False) if not args.test_mode: # start training for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = OrderedDict([('loss', 0.)]) # train if args.gate_train: train_metrics = train_epoch_slim_gate( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, optimizer_step=args.optimizer_step) else: train_metrics = train_epoch_slim( epoch, model, loader_train, optimizer, loss_fn=train_loss_fn, distill_loss_fn=distill_loss_fn, args=args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, optimizer_step=args.optimizer_step, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): torch.cuda.synchronize() if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') # eval if args.gate_train: eval_metrics = [ validate_gate(model, loader_eval, validate_loss_fn, args) ] if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = [ validate_gate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix='(EMA)') ] eval_metrics = ema_eval_metrics else: if epoch % 10 == 0 and epoch != 0: eval_sample_list = ['smallest', 'largest', 'uniform'] else: eval_sample_list = ['smallest', 'largest'] eval_metrics = [ validate_slim(model, loader_eval, validate_loss_fn, args, model_mode=model_mode) for model_mode in eval_sample_list ] if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = [ validate_slim(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix='(EMA)', model_mode=model_mode) for model_mode in eval_sample_list ] eval_metrics = ema_eval_metrics if isinstance(eval_metrics, list): eval_metrics = eval_metrics[0] if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) # save update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) # end training if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch)) # test eval_metrics = [] # reset bn if args.reset_bn: if args.local_rank == 0: logging.info("Recalibrating BatchNorm statistics...") if model_ema is not None and not args.model_ema_force_cpu: model_list = [model, model_ema.ema] else: model_list = [model] for model_ in model_list: for layer in model_.modules(): if isinstance(layer, nn.BatchNorm2d) or \ isinstance(layer, nn.SyncBatchNorm) or \ (has_apex and isinstance(layer, apex.parallel.SyncBatchNorm)): layer.reset_running_stats() model_.train() with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader_bn): for choice in range(args.num_choice): if args.slim_train: if hasattr(model_, 'module'): model_.module.set_mode('uniform', choice=choice) else: model_.set_mode('uniform', choice=choice) model_(input) if batch_idx % 1000 == 0 and batch_idx != 0: break if args.local_rank == 0: logging.info("Finish recalibrating BatchNorm statistics.") if args.distributed and args.dist_bn in ('broadcast', 'reduce'): torch.cuda.synchronize() if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model_, args.world_size, args.dist_bn == 'reduce') # dynamic if args.gate_train: eval_metrics = [ validate_gate(model, loader_eval, validate_loss_fn, args) ] if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = [ validate_gate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix='(EMA)') ] eval_metrics = ema_eval_metrics # supernet for choice in range(args.num_choice): eval_metrics.append( validate_slim(model, loader_eval, validate_loss_fn, args, model_mode=choice)) if model_ema is not None and not args.model_ema_force_cpu: for choice in range(args.num_choice): eval_metrics.append( validate_slim(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix='(EMA)', model_mode=choice)) if args.local_rank == 0: best_metric, best_epoch = saver.save_checkpoint(model, optimizer, args, epoch=0, model_ema=model_ema, metric=eval_metrics[0], use_amp=use_amp) if args.local_rank == 0: print('Test results of the last epoch:\n', eval_metrics)
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) net_config = json.load(open(args.model_config)) model = NSGANetV2.build_from_config(net_config, drop_connect_rate=args.drop_path) init = torch.load(args.initial_checkpoint, map_location='cpu')['state_dict'] model.load_state_dict(init) if args.reset_classifier: NSGANetV2.reset_classifier(model, last_channel=model.classifier.in_features, n_classes=args.num_classes, dropout_rate=args.drop) # create a dummy model to get the model cfg just for running this timm training script dummy_model = create_model('efficientnet_b0') # add a teacher model if args.teacher: # using the supernet at full scale to supervise the training of the subnets # this is taken from https://github.com/mit-han-lab/once-for-all/blob/ # 4decacf9d85dbc948a902c6dc34ea032d711e0a9/train_ofa_net.py#L125 from evaluator import OFAEvaluator supernet = OFAEvaluator(n_classes=1000, model_path=args.teacher) teacher, _ = supernet.sample({ 'ks': [7] * 20, 'e': [6] * 20, 'd': [4] * 5 }) # as alternative, you could simply use a pretrained model from timm, e.g. "tf_efficientnet_b1", # "mobilenetv3_large_100", etc. See https://rwightman.github.io/pytorch-image-models/results/ for full list. # import timm # teacher = timm.create_model(args.teacher, pretrained=True) args.kd_ratio = 1.0 print("teacher model loaded") else: args.kd_ratio = 0.0 if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=dummy_model, verbose=args.local_rank == 0) if args.img_size is not None: data_config['input_size'] = (3, args.img_size, args.img_size ) # override the input image resolution num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() if args.teacher: teacher = nn.DataParallel(teacher, device_ids=list(range( args.num_gpu))).cuda() else: model.cuda() if args.teacher: teacher.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) if args.teacher: teacher = DDP(teacher, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 if teacher: teacher = DDP(teacher, device_ids=[args.local_rank]) # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, teacher=teacher) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))