def _init_loss_scaler(self): # setup automatic mixed-precision (AMP) loss scaling and op casting self._amp_autocast = suppress # do nothing self._loss_scaler = None if self.use_amp == 'apex': self.net, self._optimizer = amp.initialize(self.net, self._optimizer, opt_level='O1') self._loss_scaler = ApexScaler() self._logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif self.use_amp == 'native': self._amp_autocast = torch.cuda.amp.autocast self._loss_scaler = NativeScaler() self._logger.info('Using native Torch AMP. Training in mixed precision.') else: self._logger.info('AMP not enabled. Training in float32.')
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) #################################################################################### # Start - SparseML optional load weights from SparseZoo #################################################################################### if args.initial_checkpoint == "zoo": # Load checkpoint from base weights associated with given SparseZoo recipe if args.sparseml_recipe.startswith("zoo:"): args.initial_checkpoint = Zoo.download_recipe_base_framework_files( args.sparseml_recipe, extensions=[".pth.tar", ".pth"] )[0] else: raise ValueError( "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When initial-checkpoint is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " f"stub. sparseml-recipe was set to {args.sparseml_recipe}" ) elif args.initial_checkpoint.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint) args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"]) #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) #################################################################################### # Start SparseML Integration #################################################################################### sparseml_loggers = ( [PythonLogger(), TensorBoardLogger(log_path=output_dir)] if output_dir else None ) manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe) optimizer = ScheduledOptimizer( optimizer, model, manager, steps_per_epoch=len(loader_train), loggers=sparseml_loggers ) # override lr scheduler if recipe makes any LR updates if any("LearningRate" in str(modifier) for modifier in manager.modifiers): _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe") lr_scheduler = None if manager.max_epochs: _logger.info( f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe" ) num_epochs = manager.max_epochs or num_epochs #################################################################################### # End SparseML Integration #################################################################################### if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) ################################################################################# # Start SparseML ONNX Export ################################################################################# if output_dir: _logger.info( f"training complete, exporting ONNX to {output_dir}/model.onnx" ) exporter = ModuleExporter(model, output_dir) exporter.export_onnx(torch.randn((1, *data_config["input_size"]))) ################################################################################# # End SparseML ONNX Export ################################################################################# except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=int( 1.5 * args.batch_size), shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, ) # TODO: finetuning model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print("Start training") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) # disable any harsh augmentation in case of Self-supervise training if args.training_mode == 'SSL': print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training") args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0 args.aa = '' if args.SiT_LinearEvaluation == 1: print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0") args.SiT_LinearEvaluation = 0 utils.print_args(args) device = torch.device(args.device) seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True print("Loading dataset ....") dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size, drop_block_rate=None, training_mode=args.training_mode) if args.finetune: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) new_size = int(num_patches ** 0.5) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) # Freeze the backbone in case of linear evaluation if args.SiT_LinearEvaluation == 1: requires_grad(model, False) model.rot_head.weight.requires_grad = True model.rot_head.bias.requires_grad = True model.contrastive_head.weight.requires_grad = True model.contrastive_head.bias.requires_grad = True if args.representation_size is not None: model.pre_logits_rot.fc.weight.requires_grad = True model.pre_logits_rot.fc.bias.requires_grad = True model.pre_logits_contrastive.fc.weight.requires_grad = True model.pre_logits_contrastive.fc.bias.requires_grad = True model_ema = None if args.model_ema: model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) if args.training_mode == 'SSL': criterion = MTL_loss(args.device, args.batch_size) elif args.training_mode == 'finetune' and args.mixup > 0.: criterion = SoftTargetCrossEntropy() else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate_SSL(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) if args.training_mode == 'SSL': train_stats = train_SSL( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) else: train_stats = train_finetune( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) lr_scheduler.step(epoch) if epoch%args.validate_every == 0: if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) if args.training_mode == 'SSL': test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir) else: test_stats = evaluate_finetune(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, use_cos_reg=args.cos_reg_component > 0, checkpoint_path=args.initial_checkpoint) with torch.cuda.device(0): input = torch.randn(1, 3, 224, 224) size_for_madd = 224 if args.img_size is None else args.img_size # flops, params = get_model_complexity_info(model, (3, size_for_madd, size_for_madd), as_strings=True, print_per_layer_stat=True) # print("=>Flops: " + flops) # print("=>Params: " + params) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) if args.use_lmdb: dataset_train = ImageFolderLMDB('../dataset_lmdb/train') else: dataset_train = Dataset(train_dir) # dataset_train = Dataset(train_dir) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) if args.use_lmdb: dataset_eval = ImageFolderLMDB('../dataset_lmdb/val') else: dataset_eval = Dataset(eval_dir) # dataset_eval = Dataset(eval_dir) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, repeated_aug=args.use_repeated_aug, world_size=args.world_size, rank=args.rank) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) loader_cali = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.cali_batch_size, is_training=False, use_prefetcher=args.prefetcher, no_aug=True, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=None, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, repeated_aug=args.use_repeated_aug, world_size=args.world_size, rank=args.rank) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform if args.cos_reg_component > 0: args.use_cos_reg_component = True train_loss_fn = SoftTargetCrossEntropyCosReg( n_comn=args.cos_reg_component).cuda() else: train_loss_fn = SoftTargetCrossEntropy().cuda() args.use_cos_reg_component = False elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) code_dir = get_outdir(output_dir, 'code') copy_tree(os.getcwd(), code_dir) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) if not args.eval_only: train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') if args.max_iter > 0: _ = validate(model, loader_cali, validate_loss_fn, args, amp_autocast=amp_autocast, use_bn_calibration=True) eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if not args.eval_only: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) if args.eval_only: break except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 if args.control_amp == 'amp': args.amp = True elif args.control_amp == 'apex': args.apex_amp = True elif args.control_amp == 'native': args.native_amp = True # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") _logger.info( '====================\n\n' 'Actfun: {}\n' 'LR: {}\n' 'Epochs: {}\n' 'p: {}\n' 'k: {}\n' 'g: {}\n' 'Extra channel multiplier: {}\n' 'AMP: {}\n' 'Weight Init: {}\n' '\n===================='.format(args.actfun, args.lr, args.epochs, args.p, args.k, args.g, args.extra_channel_mult, use_amp, args.weight_init)) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, actfun=args.actfun, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, p=args.p, k=args.k, g=args.g, extra_channel_mult=args.extra_channel_mult, weight_init_name=args.weight_init, partial_ho_actfun=args.partial_ho_actfun ) if args.tl: if args.data == 'caltech101' and not os.path.exists('caltech101'): dir_root = r'101_ObjectCategories' dir_new = r'caltech101' dir_new_train = os.path.join(dir_new, 'train') dir_new_val = os.path.join(dir_new, 'val') dir_new_test = os.path.join(dir_new, 'test') if not os.path.exists(dir_new): os.mkdir(dir_new) os.mkdir(dir_new_train) os.mkdir(dir_new_val) os.mkdir(dir_new_test) for dir2 in os.listdir(dir_root): if dir2 != 'BACKGROUND_Google': curr_path = os.path.join(dir_root, dir2) new_path_train = os.path.join(dir_new_train, dir2) new_path_val = os.path.join(dir_new_val, dir2) new_path_test = os.path.join(dir_new_test, dir2) if not os.path.exists(new_path_train): os.mkdir(new_path_train) if not os.path.exists(new_path_val): os.mkdir(new_path_val) if not os.path.exists(new_path_test): os.mkdir(new_path_test) train_upper = int(0.8 * len(os.listdir(curr_path))) val_upper = int(0.9 * len(os.listdir(curr_path))) curr_files_all = os.listdir(curr_path) curr_files_train = curr_files_all[:train_upper] curr_files_val = curr_files_all[train_upper:val_upper] curr_files_test = curr_files_all[val_upper:] for file in curr_files_train: copyfile(os.path.join(curr_path, file), os.path.join(new_path_train, file)) for file in curr_files_val: copyfile(os.path.join(curr_path, file), os.path.join(new_path_val, file)) for file in curr_files_test: copyfile(os.path.join(curr_path, file), os.path.join(new_path_test, file)) time.sleep(5) if args.tl: pre_model = create_model( args.model, pretrained=True, actfun='swish', num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, p=args.p, k=args.k, g=args.g, extra_channel_mult=args.extra_channel_mult, weight_init_name=args.weight_init, partial_ho_actfun=args.partial_ho_actfun ) model = MLP.MLP(actfun=args.actfun, input_dim=1280, output_dim=args.num_classes, k=args.k, p=args.p, g=args.g, num_params=400_000, permute_type='shuffle') pre_model_layers = list(pre_model.children()) pre_model = torch.nn.Sequential(*pre_model_layers[:-1]) else: pre_model = None if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.tl: pre_model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) if args.tl: optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5) else: optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') if args.local_rank == 0: _logger.info('\n--------------------\nModel:\n' + repr(model) + '--------------------') # optionally resume from a checkpoint resume_epoch = None resume_path = os.path.join(args.resume, 'recover.pth.tar') if args.resume and os.path.exists(resume_path): resume_epoch = resume_checkpoint( model, resume_path, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) cp_loaded = None resume_epoch = None checkname = 'recover' if args.actfun != 'swish': checkname = '{}_'.format(args.actfun) + checkname check_path = os.path.join(args.check_path, checkname) + '.pth' loader = None if os.path.isfile(check_path): loader = check_path elif args.load_path != '' and os.path.isfile(args.load_path): loader = args.load_path if loader is not None: cp_loaded = torch.load(loader) model.load_state_dict(cp_loaded['model']) optimizer.load_state_dict(cp_loaded['optimizer']) resume_epoch = cp_loaded['epoch'] model.cuda() loss_scaler.load_state_dict(cp_loaded['amp']) if args.channels_last: model = model.to(memory_format=torch.channels_last) _logger.info('============ LOADED CHECKPOINT: Epoch {}'.format(resume_epoch)) model_raw = model # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume and os.path.exists(resume_path): load_checkpoint(model_ema.module, args.resume, use_ema=True) if cp_loaded is not None: model_ema.load_state_dict(cp_loaded['model_ema']) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # create the train and eval datasets train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeline train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer, dataset_train) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if cp_loaded is not None: lr_scheduler.load_state_dict(cp_loaded['scheduler']) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=args.resume, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) fieldnames = ['seed', 'weight_init', 'actfun', 'epoch', 'max_lr', 'lr', 'train_loss', 'eval_loss', 'eval_acc1', 'eval_acc5', 'ema'] filename = 'output' if args.actfun != 'swish': filename = '{}_'.format(args.actfun) + filename outfile_path = os.path.join(args.output, filename) + '.csv' if not os.path.exists(outfile_path): with open(outfile_path, mode='w') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writeheader() try: for epoch in range(start_epoch, num_epochs): if os.path.exists(args.check_path): amp_loss = None if use_amp == 'native': amp_loss = loss_scaler.state_dict() elif use_amp == 'apex': amp_loss = amp.state_dict() if model_ema is not None: ema_save = model_ema.state_dict() else: ema_save = None torch.save({'model': model_raw.state_dict(), 'model_ema': ema_save, 'optimizer': optimizer.state_dict(), 'scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'amp': amp_loss }, check_path) _logger.info('============ SAVED CHECKPOINT: Epoch {}'.format(epoch)) if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, pre_model=pre_model) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, pre_model=pre_model) with open(outfile_path, mode='a') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writerow({'seed': args.seed, 'actfun': args.actfun, 'epoch': epoch, 'lr': train_metrics['lr'], 'train_loss': train_metrics['loss'], 'eval_loss': eval_metrics['loss'], 'eval_acc1': eval_metrics['top1'], 'eval_acc5': eval_metrics['top5'], 'ema': False }) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)', pre_model=pre_model) eval_metrics = ema_eval_metrics with open(outfile_path, mode='a') as out_file: writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n') writer.writerow({'seed': args.seed, 'weight_init': args.weight_init, 'actfun': args.actfun, 'epoch': epoch, 'max_lr': args.lr, 'lr': train_metrics['lr'], 'train_loss': train_metrics['loss'], 'eval_loss': eval_metrics['loss'], 'eval_acc1': eval_metrics['top1'], 'eval_acc5': eval_metrics['top5'], 'ema': True }) if lr_scheduler is not None and args.sched != 'onecycle': # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( args.seed, epoch, args.lr, args.epochs, args.batch_size, args.actfun, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def main(args): utils.my_init_distributed_mode(args) print(args) # if args.distillation_type != 'none' and args.finetune and not args.eval: # raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) # if True: # args.distributed: # num_tasks = utils.get_world_size() # global_rank = utils.get_rank() # if args.repeated_aug: # sampler_train = RASampler( # dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True # ) # else: # sampler_train = torch.utils.data.DistributedSampler( # dataset_train, # # num_replicas=num_tasks, # num_replicas=0, # rank=global_rank, shuffle=True # ) # if args.dist_eval: # if len(dataset_val) % num_tasks != 0: # print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' # 'This will slightly alter validation results as extra duplicate entries are added to achieve ' # 'equal num of samples per-process.') # sampler_val = torch.utils.data.DistributedSampler( # dataset_val, # # num_replicas=num_tasks, # num_replicas=0, # rank=global_rank, shuffle=False) # else: # sampler_val = torch.utils.data.SequentialSampler(dataset_val) # else: # sampler_train = torch.utils.data.RandomSampler(dataset_train) # sampler_val = torch.utils.data.SequentialSampler(dataset_val) # # data_loader_train = torch.utils.data.DataLoader( # dataset_train, sampler=sampler_train, # batch_size=args.batch_size, # num_workers=args.num_workers, # pin_memory=args.pin_mem, # drop_last=True, # ) # # data_loader_val = torch.utils.data.DataLoader( # dataset_val, sampler=sampler_val, # batch_size=int(1.5 * args.batch_size), # num_workers=args.num_workers, # pin_memory=args.pin_mem, # drop_last=False # ) # if args.distributed: if args.cache_mode: sampler_train = samplers.NodeDistributedSampler(dataset_train) sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) else: sampler_train = samplers.DistributedSampler(dataset_train) sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, num_workers=args.num_workers, pin_memory=True) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) model_without_ddp = model # # there are bugs # if args.finetune: # if args.finetune.startswith('https'): # checkpoint = torch.hub.load_state_dict_from_url( # args.finetune, map_location='cpu', check_hash=True) # else: # checkpoint = torch.load(args.finetune, map_location='cpu') # # checkpoint_model = checkpoint['model'] # # state_dict = model.state_dict() # for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: # print(f"Removing key {k} from pretrained checkpoint") # del checkpoint_model[k] # # _ = model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None # if args.model_ema: # # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # model_ema = ModelEma( # model, # decay=args.model_ema_decay, # device='cpu' if args.model_ema_force_cpu else '', # resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() criterion = DistillationLoss(criterion, None, 'none', 0, 0) output_dir = Path(args.output_dir) # for finetune if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: if not os.path.exists(args.finetune): checkpoint = None print('NOTICE:' + args.finetune + ' does not exist!') else: checkpoint = torch.load(args.finetune, map_location='cpu') if checkpoint is not None: if 'model' in checkpoint: check_model = checkpoint['model'] else: check_model = checkpoint missing_keys = model_without_ddp.load_state_dict( check_model, strict=False).missing_keys skip_keys = model_without_ddp.no_weight_decay() # create optimizer manually param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if n in missing_keys and n not in skip_keys ], "lr": args.lr, 'weight_decay': args.weight_decay, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n in missing_keys and n in skip_keys ], "lr": args.lr, 'weight_decay': 0, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n not in missing_keys and n not in skip_keys ], "lr": args.lr * args.fine_factor, 'weight_decay': args.weight_decay, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n not in missing_keys and n in skip_keys ], "lr": args.lr * args.fine_factor, 'weight_decay': 0, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: # optimizer.load_state_dict(checkpoint['optimizer']) # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) # args.start_epoch = checkpoint['epoch'] + 1 # # if args.model_ema: # # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) # if 'scaler' in checkpoint: # loss_scaler.load_state_dict(checkpoint['scaler']) print('finetune from' + args.finetune) # for debug # lr_scheduler.step(10) # lr_scheduler.step(100) # lr_scheduler.step(200) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: if not os.path.exists(args.resume): checkpoint = None print('NOTICE:' + args.resume + ' does not exist!') else: checkpoint = torch.load(args.resume, map_location='cpu') if checkpoint is not None: if 'model' in checkpoint: model_without_ddp.load_state_dict(checkpoint['model']) else: model_without_ddp.load_state_dict(checkpoint) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # if args.model_ema: # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print('resume from' + args.resume) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 max_epoch_dp_warm_up = 100 if 'pvt_tiny' in args.model or 'pvt_small' in args.model: max_epoch_dp_warm_up = 0 if args.start_epoch < max_epoch_dp_warm_up: model_without_ddp.reset_drop_path(0.0) for epoch in range(args.start_epoch, args.epochs): if args.fp32_resume and epoch > args.start_epoch + 1: args.fp32_resume = False loss_scaler._scaler = torch.cuda.amp.GradScaler( enabled=not args.fp32_resume) if epoch == max_epoch_dp_warm_up: model_without_ddp.reset_drop_path(args.drop_path) if epoch < args.warmup_epochs: optimizer.param_groups[2]['lr'] = 0 optimizer.param_groups[3]['lr'] = 0 if args.distributed: # data_loader_train.sampler.set_epoch(epoch) sampler_train.set_epoch(epoch) train_stats = my_train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, # set_training_mode=args.finetune == '', # keep in eval mode during finetuning fp32=args.fp32_resume) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, # 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError( "Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in [ 'head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias' ]: if k in checkpoint_model and checkpoint_model[ k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int( (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) # height (== width) for the new position embedding new_size = int(num_patches**0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss(criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, set_training_mode=args.finetune == '' # keep in eval mode during finetuning ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) if not args.train_without_eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } else: log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
class TorchImageClassificationEstimator(BaseEstimator): """Torch Estimator implementation for Image Classification. Parameters ---------- config : dict Config in nested dict. logger : logging.Logger Optional logger for this estimator, can be `None` when default setting is used. reporter : callable The reporter for metric checkpointing. net : torch.nn.Module The custom network. If defined, the model name in config will be ignored so your custom network will be used for training rather than pulling it from model zoo. """ def __init__(self, config, logger=None, reporter=None, net=None, optimizer=None, problem_type=None): super().__init__(config, logger=logger, reporter=reporter, name=None) if problem_type is None: problem_type = MULTICLASS self._problem_type = problem_type self._feature_net = None self._custom_net = False self._img_cls_cfg = self._cfg.img_cls self._data_cfg = self._cfg.data self._optimizer_cfg = self._cfg.optimizer self._train_cfg = self._cfg.train self._augmentation_cfg = self._cfg.augmentation self._model_ema_cfg = self._cfg.model_ema self._misc_cfg = self._cfg.misc # resolve AMP arguments based on PyTorch / Apex availability self.use_amp = None if self._misc_cfg.amp: # `amp` chooses native amp before apex (APEX ver not actively maintained) if self._misc_cfg.native_amp and has_native_amp: self.use_amp = 'native' elif self._misc_cfg.apex_amp and has_apex: self.use_amp = 'apex' elif self._misc_cfg.apex_amp or self._misc_cfg.native_amp: self._logger.warning(f'Neither APEX or native Torch AMP is available, using float32. \ Install NVIDA apex or upgrade to PyTorch 1.6') # FIXME: will provided model conflict with config provided? if net is not None: assert isinstance(net, nn.Module), f"given custom network {type(net)}, `torch.nn` expected" try: net.to('cpu') self._custom_net = True except ValueError: pass self.net = net if optimizer is not None: self._logger.warning('Custom optimizer object not supported. Will follow the config instead.') self._optimizer = None def _fit(self, train_data, val_data, time_limit=math.inf): tic = time.time() self._cp_name = '' self._best_acc = -float('inf') self.epochs = self._train_cfg.epochs self.epoch = 0 self.start_epoch = self._train_cfg.start_epoch self._time_elapsed = 0 if max(self.start_epoch, self.epoch) >= self.epochs: return {'time', self._time_elapsed} self._init_trainer() self._init_model_ema() self._time_elapsed += time.time() - tic return self._resume_fit(train_data, val_data, time_limit=time_limit) def _resume_fit(self, train_data, val_data, time_limit=math.inf): tic = time.time() # TODO: regression not implemented if self._problem_type != REGRESSION and (not self.classes or not self.num_class): raise ValueError('This is a classification problem and we are not able to determine classes of dataset') if max(self.start_epoch, self.epoch) >= self.epochs: return {'time': self._time_elapsed} # wrap DP if possible if self.found_gpu: self.net = torch.nn.DataParallel(self.net, device_ids=[int(i) for i in self.valid_gpus]) self.net = self.net.to(self.ctx[0]) # prepare dataset train_dataset = train_data.to_torch() val_dataset = val_data.to_torch() # setup mixup / cutmix self._collate_fn = None self._mixup_fn = None self.mixup_active = self._augmentation_cfg.mixup > 0 or self._augmentation_cfg.cutmix > 0. or self._augmentation_cfg.cutmix_minmax is not None if self.mixup_active: mixup_args = dict( mixup_alpha=self._augmentation_cfg.mixup, cutmix_alpha=self._augmentation_cfg.cutmix, cutmix_minmax=self._augmentation_cfg.cutmix_minmax, prob=self._augmentation_cfg.mixup_prob, switch_prob=self._augmentation_cfg.mixup_switch_prob, mode=self._augmentation_cfg.mixup_mode, label_smoothing=self._augmentation_cfg.smoothing, num_classes=self.num_class) if self._misc_cfg.prefetcher: self._collate_fn = FastCollateMixup(**mixup_args) else: self._mixup_fn = Mixup(**mixup_args) # create data loaders w/ augmentation pipeiine train_interpolation = self._augmentation_cfg.train_interpolation if self._augmentation_cfg.no_aug or not train_interpolation: train_interpolation = self._data_cfg.interpolation train_loader = create_loader( train_dataset, input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, is_training=True, use_prefetcher=self._misc_cfg.prefetcher, no_aug=self._augmentation_cfg.no_aug, scale=self._augmentation_cfg.scale, ratio=self._augmentation_cfg.ratio, hflip=self._augmentation_cfg.hflip, vflip=self._augmentation_cfg.vflip, color_jitter=self._augmentation_cfg.color_jitter, auto_augment=self._augmentation_cfg.auto_augment, interpolation=train_interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, collate_fn=self._collate_fn, pin_memory=self._misc_cfg.pin_mem, use_multi_epochs_loader=self._misc_cfg.use_multi_epochs_loader ) val_loader = create_loader( val_dataset, input_size=self._data_cfg.input_size, batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size, is_training=False, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, crop_pct=self._data_cfg.crop_pct, pin_memory=self._misc_cfg.pin_mem, ) self._time_elapsed += time.time() - tic return self._train_loop(train_loader, val_loader, time_limit=time_limit) def _train_loop(self, train_loader, val_loader, time_limit=math.inf): start_tic = time.time() # setup loss function if self.mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy() elif self._augmentation_cfg.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=self._augmentation_cfg.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.to(self.ctx[0]) validate_loss_fn = validate_loss_fn.to(self.ctx[0]) eval_metric = self._misc_cfg.eval_metric if self._problem_type == REGRESSION: train_loss_fn = nn.MSELoss() validate_loss_fn = nn.MSELoss() eval_metric = 'rmse' early_stopper = EarlyStopperOnPlateau( patience=self._train_cfg.early_stop_patience, min_delta=self._train_cfg.early_stop_min_delta, baseline_value=self._train_cfg.early_stop_baseline, max_value=self._train_cfg.early_stop_max_value) self._logger.info('Start training from [Epoch %d]', max(self._train_cfg.start_epoch, self.epoch)) self._time_elapsed += time.time() - start_tic for self.epoch in range(max(self.start_epoch, self.epoch), self.epochs): epoch = self.epoch if self._best_acc >= 1.0: self._logger.info('[Epoch {}] Early stopping as acc is reaching 1.0'.format(epoch)) break should_stop, stop_message = early_stopper.get_early_stop_advice() if should_stop: self._logger.info('[Epoch {}] '.format(epoch) + stop_message) break train_metrics = self.train_one_epoch( epoch, self.net, train_loader, self._optimizer, train_loss_fn, lr_scheduler=self._lr_scheduler, output_dir=self._logdir, amp_autocast=self._amp_autocast, loss_scaler=self._loss_scaler, model_ema=self._model_ema, mixup_fn=self._mixup_fn, time_limit=time_limit) # reaching time limit, exit early if train_metrics['time_limit']: self._logger.warning(f'`time_limit={time_limit}` reached, exit early...') return {'train_acc': train_metrics['train_acc'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} post_tic = time.time() eval_metrics = self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) if self._model_ema is not None and not self._model_ema_cfg.model_ema_force_cpu: ema_eval_metrics = self.validate( self._model_ema.module, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) eval_metrics = ema_eval_metrics if self._problem_type == REGRESSION: val_acc = eval_metrics['rmse'] if self._reporter: self._reporter(epoch=epoch, acc_reward=-val_acc) early_stopper.update(-val_acc) if -val_acc > self._best_acc: self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) self._logger.info('[Epoch %d] Current best rmse: %f vs previous %f, saved to %s', self.epoch, val_acc, -self._best_acc, self._cp_name) self.save(self._cp_name) self._best_acc = -val_acc else: val_acc = eval_metrics['top1'] if self._reporter: self._reporter(epoch=epoch, acc_reward=val_acc) early_stopper.update(val_acc) if val_acc > self._best_acc: self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) self._logger.info('[Epoch %d] Current best top-1: %f vs previous %f, saved to %s', self.epoch, val_acc, self._best_acc, self._cp_name) self.save(self._cp_name) self._best_acc = val_acc if self._lr_scheduler is not None: # step LR for next epoch self._lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) self._time_elapsed += time.time() - post_tic if 'accuracy' in train_metrics: return {'train_acc': train_metrics['accuracy'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} # rmse else: if self._problem_type == REGRESSION: return {'train_score': train_metrics['rmse'], 'valid_score': -self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} # mixup else: return {'train_score': train_metrics['rmse'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name} def train_one_epoch( self, epoch, net, loader, optimizer, loss_fn, lr_scheduler=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf): start_tic = time.time() if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch: if self._misc_cfg.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order losses_m = AverageMeter() train_metric_score_m = AverageMeter() net.train() num_updates = epoch * len(loader) self._time_elapsed += time.time() - start_tic tic = time.time() last_tic = time.time() train_metric_name = 'accuracy' batch_idx = 0 for batch_idx, (input, target) in enumerate(loader): b_tic = time.time() if self._time_elapsed > time_limit: return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True} if self._problem_type == REGRESSION: target = target.to(torch.float32) if not self._misc_cfg.prefetcher: # prefetcher would move data to cuda by default input, target = input.to(self.ctx[0]), target.to(self.ctx[0]) if mixup_fn is not None: input, target = mixup_fn(input, target) with amp_autocast(): output = net(input) if self._problem_type == REGRESSION: output = output.flatten() loss = loss_fn(output, target) if self._problem_type == REGRESSION: train_metric_name = 'rmse' train_metric_score = rmse(output, target) else: if output.shape == target.shape: train_metric_name = 'rmse' train_metric_score = rmse(output, target) else: train_metric_score = accuracy(output, target)[0] / 100 losses_m.update(loss.item(), input.size(0)) train_metric_score_m.update(train_metric_score.item(), output.size(0)) optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode, parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), create_graph=second_order) else: loss.backward(create_graph=second_order) if self._optimizer_cfg.clip_grad is not None: dispatch_clip_grad( model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode) optimizer.step() if model_ema is not None: model_ema.update(net) if self.found_gpu: torch.cuda.synchronize() num_updates += 1 if (batch_idx+1) % self._misc_cfg.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f', epoch, batch_idx, self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic), train_metric_name, train_metric_score_m.avg, lr) last_tic = time.time() if self._misc_cfg.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) self._time_elapsed += time.time() - b_tic throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic)) self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg) self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic) end_time = time.time() if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() self._time_elapsed += time.time() - end_time return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False} def validate(self, net, loader, loss_fn, amp_autocast=suppress, metric_name=None): losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() rmse_m = AverageMeter() net.eval() with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): if not self._misc_cfg.prefetcher: input = input.to(self.ctx[0]) target = target.to(self.ctx[0]) with amp_autocast(): output = net(input) if self._problem_type == REGRESSION: output = output.flatten() if isinstance(output, (tuple, list)): output = output[0] if self._problem_type == REGRESSION: if metric_name: assert metric_name == 'rmse', f'{metric_name} metric not supported for regression.' val_metric_score = rmse(output, target) else: val_metric_score = accuracy(output, target, topk=(1, min(5, self.num_class))) # augmentation reduction reduce_factor = self._misc_cfg.tta if self._problem_type != REGRESSION and reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) reduced_loss = loss.data if self.found_gpu: torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) if self._problem_type == REGRESSION: rmse_score = val_metric_score rmse_m.update(rmse_score.item(), output.size(0)) else: acc1, acc5 = val_metric_score acc1 /= 100 acc5 /= 100 top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) if self._problem_type == REGRESSION: self._logger.info('[Epoch %d] validation: rmse=%f', self.epoch, rmse_m.avg) return {'loss': losses_m.avg, 'rmse': rmse_m.avg} else: self._logger.info('[Epoch %d] validation: top1=%f top5=%f', self.epoch, top1_m.avg, top5_m.avg) return {'loss': losses_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg} def _init_network(self, **kwargs): load_only = kwargs.get('load_only', False) if not self.num_class and self._problem_type != REGRESSION: raise ValueError('This is a classification problem and we are not able to create network when `num_class` is unknown. \ It should be inferred from dataset or resumed from saved states.') assert len(self.classes) == self.num_class # Disable syncBatchNorm as it's only supported on DDP if self._train_cfg.sync_bn: self._logger.info( 'Disable Sync batch norm as it is not supported for now.') update_cfg(self._cfg, {'train': {'sync_bn': False}}) # ctx self.found_gpu = False valid_gpus = [] if self._cfg.gpus: valid_gpus = self._torch_validate_gpus(self._cfg.gpus) self.found_gpu = True if not valid_gpus: self.found_gpu = False self._logger.warning( 'No gpu detected, fallback to cpu. You can ignore this warning if this is intended.') elif len(valid_gpus) != len(self._cfg.gpus): self._logger.warning( f'Loaded on gpu({valid_gpus}), different from gpu({self._cfg.gpus}).') self.ctx = [torch.device(f'cuda:{gid}') for gid in valid_gpus] if self.found_gpu else [torch.device('cpu')] self.valid_gpus = valid_gpus if not self.found_gpu and self.use_amp: self.use_amp = None self._logger.warning('Training on cpu. AMP disabled.') update_cfg(self._cfg, {'misc': {'amp': False, 'apex_amp': False, 'native_amp': False}}) if not self.found_gpu and self._misc_cfg.prefetcher: self._logger.warning( 'Training on cpu. Prefetcher disabled.') update_cfg(self._cfg, {'misc': {'prefetcher': False}}) self._logger.warning( 'Training on cpu. SyncBatchNorm disabled.') update_cfg(self._cfg, {'train': {'sync_bn': False}}) random_seed(self._misc_cfg.seed) if not self.net: self.net = create_model( self._img_cls_cfg.model, pretrained=self._img_cls_cfg.pretrained and not load_only, num_classes=max(self.num_class, 1), global_pool=self._img_cls_cfg.global_pool_type, drop_rate=self._augmentation_cfg.drop, drop_path_rate=self._augmentation_cfg.drop_path, drop_block_rate=self._augmentation_cfg.drop_block, bn_momentum=self._train_cfg.bn_momentum, bn_eps=self._train_cfg.bn_eps, scriptable=self._misc_cfg.torchscript ) self._logger.info(f'Model {safe_model_name(self._img_cls_cfg.model)} created, param count: \ {sum([m.numel() for m in self.net.parameters()])}') else: self._logger.info(f'Use user provided model. Neglect model in config.') out_features = list(self.net.children())[-1].out_features if self._problem_type != REGRESSION: assert out_features == self.num_class, f'Custom model out_feature {out_features} != num_class {self.num_class}.' else: assert out_features == 1, f'Regression problem expects num_out_feature == 1, got {out_features} instead.' resolve_data_config(self._cfg, model=self.net) self.net = self.net.to(self.ctx[0]) # setup synchronized BatchNorm if self._train_cfg.sync_bn: if has_apex and self.use_amp != 'native': # Apex SyncBN preferred unless native amp is activated self.net = convert_syncbn_model(self.net) else: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self._logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if self._misc_cfg.torchscript: assert not self.use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not self._train_cfg.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' self.net = torch.jit.script(self.net) def _init_trainer(self): if self._optimizer is None: if self._img_cls_cfg.pretrained and not self._custom_net \ and (self._train_cfg.transfer_lr_mult != 1 or self._train_cfg.output_lr_mult != 1): # adjust feature/last_fc learning rate multiplier in optimizer self._logger.debug(f'Reduce network lr multiplier to {self._train_cfg.transfer_lr_mult}, while keep ' + f'last FC layer lr_mult to {self._train_cfg.output_lr_mult}') optim_kwargs = optimizer_kwargs(cfg=self._cfg) optim_kwargs['feature_lr_mult'] = self._cfg.train.transfer_lr_mult optim_kwargs['fc_lr_mult'] = self._cfg.train.output_lr_mult self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) else: self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) self._init_loss_scaler() self._lr_scheduler, self.epochs = create_scheduler(self._cfg, self._optimizer) self._lr_scheduler.step(self.start_epoch, self.epoch) def _init_loss_scaler(self): # setup automatic mixed-precision (AMP) loss scaling and op casting self._amp_autocast = suppress # do nothing self._loss_scaler = None if self.use_amp == 'apex': self.net, self._optimizer = amp.initialize(self.net, self._optimizer, opt_level='O1') self._loss_scaler = ApexScaler() self._logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif self.use_amp == 'native': self._amp_autocast = torch.cuda.amp.autocast self._loss_scaler = NativeScaler() self._logger.info('Using native Torch AMP. Training in mixed precision.') else: self._logger.info('AMP not enabled. Training in float32.') def _init_model_ema(self): # Disable for now if self._model_ema_cfg.model_ema: self._logger.info('Disable EMA as it is not supported for now.') update_cfg(self._cfg, {'model_ema': {'model_ema': False}}) # setup exponential moving average of model weights, SWA could be used here too self._model_ema = None if self._model_ema_cfg.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper self._model_ema = ModelEmaV2( self.net, decay=self._model_ema_cfg.model_ema_decay, device='cpu' if self._model_ema_cfg.model_ema_force_cpu else None) def evaluate(self, val_data, metric_name=None): return self._evaluate(val_data, metric_name) def _evaluate(self, val_data, metric_name=None): if self._problem_type == REGRESSION: validate_loss_fn = nn.MSELoss() else: validate_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = validate_loss_fn.to(self.ctx[0]) val_data = val_data.to_torch() val_loader = create_loader( val_data, input_size=self._data_cfg.input_size, batch_size=self._data_cfg.validation_batch_size_multiplier * self._train_cfg.batch_size, is_training=False, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, distributed=False, crop_pct=self._data_cfg.crop_pct, pin_memory=self._misc_cfg.pin_mem, ) return self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast, metric_name=metric_name) def _predict(self, x, **kwargs): with_proba = kwargs.get('with_proba', False) if with_proba and self._problem_type not in [MULTICLASS, BINARY]: raise AssertionError('with_proba is only supported for classification problems. Please use predict instead.') if isinstance(x, str): return self._predict((x,), **kwargs).drop(columns=['image'], errors='ignore') elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict(tuple(x['image']), **kwargs) return df.reset_index(drop=True) elif isinstance(x, (list, tuple)): loader = create_loader( ImageListDataset(x), input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, crop_pct=self._data_cfg.crop_pct ) self.net.eval() topk = min(5, self.num_class) results = [] idx = 0 with torch.no_grad(): for input, _ in loader: input = input.to(self.ctx[0]) labels = self.net(input) for l in labels: if self._problem_type in [MULTICLASS, BINARY]: probs = nn.functional.softmax(l, dim=0).cpu().numpy().flatten() if with_proba: results.append({'image_proba': probs.tolist(), 'image': x[idx]}) else: topk_inds = l.topk(topk)[1].cpu().numpy().flatten() results.extend([{'class': self.classes[topk_inds[k]], 'score': probs[topk_inds[k]], 'id': topk_inds[k], 'image': x[idx]} for k in range(topk)]) else: results.append({'prediction': l.cpu().numpy().flatten(), 'image': x[idx]}) idx += 1 return pd.DataFrame(results) elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) assert len(x.shape) == 4 and x.shape[1] == 3, f"Expect input to be (n, 3, h, w), given {x.shape}" with torch.no_grad(): input = x.to(self.ctx[0]) label = self.net(input) if self._problem_type in [MULTICLASS, BINARY]: topk = min(5, self.num_class) probs = nn.functional.softmax(label, dim=1).cpu().numpy().flatten() topk_inds = label.topk(topk)[1].cpu().numpy().flatten() if with_proba: df = pd.DataFrame([{'image_proba': probs.tolist()}]) else: df = pd.DataFrame([{'class': self.classes[topk_inds[k]], 'score': probs[topk_inds[k]], 'id': topk_inds[k]} for k in range(topk)]) else: df = pd.DataFrame([{'prediction': label.cpu().numpy().flatten()}]) return df def _predict_feature(self, x, **kwargs): if isinstance(x, str): return self._predict_feature((x,)) elif isinstance(x, pd.DataFrame): assert 'image' in x.columns, "Expect column `image` for input images" df = self._predict_feature(tuple(x['image'])) df = df.set_index(x.index) df['image'] = x['image'] return df elif isinstance(x, (list, tuple)): assert isinstance(x[0], str), "expect image paths in list/tuple input" loader = create_loader( ImageListDataset(x), input_size=self._data_cfg.input_size, batch_size=self._train_cfg.batch_size, use_prefetcher=self._misc_cfg.prefetcher, interpolation=self._data_cfg.interpolation, mean=self._data_cfg.mean, std=self._data_cfg.std, num_workers=self._misc_cfg.num_workers, crop_pct=self._data_cfg.crop_pct ) self.net.eval() results = [] with torch.no_grad(): for input, _ in loader: input = input.to(self.ctx[0]) try: features = self.net.forward_features(input) except AttributeError: features = self.net.module.forward_features(input) for f in features: f = f.cpu().numpy().flatten() results.append({'image_feature': f}) df = pd.DataFrame(results) df['image'] = x return df elif not isinstance(x, torch.Tensor): raise ValueError('Input is not supported: {}'.format(type(x))) with torch.no_grad(): input = x.to(self.ctx[0]) feature = self.net.forward_features(input) result = [{'image_feature': feature}] df = pd.DataFrame(result) return df def _reconstruct_state_dict(self, state_dict): new_state_dict = {} for k, v in state_dict.items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v return new_state_dict def save(self, filename): d = dict() current_states = self.__dict__.copy() if self.net: if not self._custom_net: if isinstance(self.net, torch.nn.DataParallel): d['model_state_dict'] = get_state_dict(self.net.module, unwrap_model) else: d['model_state_dict'] = get_state_dict(self.net, unwrap_model) else: net_pickle = pickle.dumps(self.net) d['net_pickle'] = net_pickle self.net = None if self._optimizer: d['optimizer_state_dict'] = self._optimizer.state_dict() self._optimizer = None if hasattr(self, '_loss_scaler') and self._loss_scaler: d[self._loss_scaler.state_dict_key] = self._loss_scaler.state_dict() d['_loss_scaler_state_dict_key'] = self._loss_scaler.state_dict_key if self._model_ema: d['ema_state_dict'] = get_state_dict(self._model_ema, unwrap_model) self._model_ema = None self._logger = None self._reporter = None d['estimator'] = self torch.save(d, filename) self.__dict__.update(current_states) @classmethod def load(cls, filename, ctx='auto'): d = torch.load(filename, map_location=torch.device('cpu')) est = d.pop('estimator') # logger est._logger = logging.getLogger(cls.__name__) est._logger.setLevel(logging.ERROR) try: fh = logging.FileHandler(est._log_file) est._logger.addHandler(fh) #pylint: disable=bare-except except: pass model_state_dict = d.get('model_state_dict', None) net_pickle = d.get('net_pickle', None) if model_state_dict: est._init_network(load_only=True) net_state_dict = est._reconstruct_state_dict(model_state_dict) if isinstance(est.net, torch.nn.DataParallel): est.net.module.load_state_dict(net_state_dict) else: est.net.load_state_dict(net_state_dict) elif net_pickle: est.net = pickle.loads(net_pickle) optimizer_state_dict = d.get('optimizer_state_dict', None) if optimizer_state_dict: est._init_trainer() est._optimizer.load_state_dict(optimizer_state_dict) if hasattr(est, '_loss_scaler') and est._loss_scaler: loss_scaler_state_dict_key = d.get('loss_scaler_state_dict') loss_scaler_dict = d.get(loss_scaler_state_dict_key, None) if loss_scaler_dict: est._loss_scaler.load_state_dict(loss_scaler_dict) ema_state_dict = d.get('ema_state_dict', None) est._init_model_ema() if ema_state_dict: ema_state_dict = est._reconstruct_state_dict(ema_state_dict) if isinstance(est.net, torch.nn.DataParallel): est._model_ema.module.module.load_state_dict(ema_state_dict) else: est._model_ema.module.load_state_dict(ema_state_dict) new_ctx = _suggest_load_context(est.net, ctx, est.ctx) est.reset_ctx(new_ctx) est._logger.setLevel(logging.INFO) return est # pylint: disable=redefined-outer-name, reimported def __getstate__(self): d = self.__dict__.copy() try: import torch net = d.pop('net', None) model_ema = d.pop('_model_ema', None) optimizer = d.pop('_optimizer', None) loss_scaler = d.pop('_loss_scaler', None) save_state = {} if net is not None: if not self._custom_net: if isinstance(net, torch.nn.DataParallel): save_state['state_dict'] = get_state_dict(net.module, unwrap_model) else: save_state['state_dict'] = get_state_dict(net, unwrap_model) else: net_pickle = pickle.dumps(net) save_state['net_pickle'] = net_pickle if optimizer is not None: save_state['optimizer'] = optimizer.state_dict() if loss_scaler is not None: save_state[loss_scaler.state_dict_key] = loss_scaler.state_dict() if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model) except ImportError: pass d['save_state'] = save_state d['_logger'] = None d['_reporter'] = None return d def __setstate__(self, state): save_state = state.pop('save_state', None) self.__dict__.update(state) # logger self._logger = logging.getLogger(state.get('_name', self.__class__.__name__)) self._logger.setLevel(logging.ERROR) try: fh = logging.FileHandler(self._log_file) self._logger.addHandler(fh) #pylint: disable=bare-except except: pass if not save_state: self.net = None self._optimizer = None self._logger.setLevel(logging.INFO) return try: import torch self.net = None self._optimizer = None if self._custom_net: if save_state.get('net_pickle', None): self.net = pickle.loads(save_state['net_pickle']) else: if save_state.get('state_dict', None): self._init_network(load_only=True) net_state_dict = self._reconstruct_state_dict(save_state['state_dict']) if isinstance(self.net, torch.nn.DataParallel): self.net.module.load_state_dict(net_state_dict) else: self.net.load_state_dict(net_state_dict) if save_state.get('optimizer', None): self._init_trainer() self._optimizer.load_state_dict(save_state['optimizer']) if hasattr(self, '_loss_scaler') and self._loss_scaler and self._loss_scaler.state_dict_key in save_state: loss_scaler_dict = save_state[self._loss_scaler.state_dict_key] self._loss_scaler.load_state_dict(loss_scaler_dict) if save_state.get('state_dict_ema', None): self._init_model_ema() model_ema_dict = save_state.get('state_dict_ema') model_ema_dict = self._reconstruct_state_dict(model_ema_dict) if isinstance(self.net, torch.nn.DataParallel): self._model_ema.module.module.load_state_dict(model_ema_dict) else: self._model_ema.module.load_state_dict(model_ema_dict) except ImportError: pass self._logger.setLevel(logging.INFO)
def main(fold_i=0, data_=None, train_index=None, val_index=None): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank best_score = 0.0 args.output = args.output + 'fold_' + str(fold_i) if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) if fold_i == 0: torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model = nn.DataParallel(model) model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) #optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-6) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(20)) ##create DataLoader train_trans = get_riadd_train_transforms(args) valid_trans = get_riadd_valid_transforms(args) train_data = data_.iloc[train_index, :].reset_index(drop=True) dataset_train = RiaddDataSet(image_ids=train_data, baseImgPath=args.data) val_data = data_.iloc[val_index, :].reset_index(drop=True) dataset_eval = RiaddDataSet(image_ids=val_data, baseImgPath=args.data) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] train_trans = get_riadd_train_transforms(args) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, transform=train_trans) valid_trans = get_riadd_valid_transforms(args) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, transform=valid_trans) # # setup loss function # if args.jsd: # assert num_aug_splits > 1 # JSD only valid with aug splits set # train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() # elif mixup_active: # # smoothing is handled with mixup target transform # train_loss_fn = SoftTargetCrossEntropy().cuda() # elif args.smoothing: # train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() # else: # train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.BCEWithLogitsLoss().cuda() train_loss_fn = nn.BCEWithLogitsLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None vis = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) vis = Visualizer(env=args.output) try: for epoch in range(0, args.epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) score, scores = get_score(eval_metrics['valid_label'], eval_metrics['predictions']) ##visdom if vis is not None: vis.plot_curves({'None': epoch}, iters=epoch, title='None', xlabel='iters', ylabel='None') vis.plot_curves( {'learing rate': optimizer.param_groups[0]['lr']}, iters=epoch, title='lr', xlabel='iters', ylabel='learing rate') vis.plot_curves({'train loss': float(train_metrics['loss'])}, iters=epoch, title='train loss', xlabel='iters', ylabel='train loss') vis.plot_curves({'val loss': float(eval_metrics['loss'])}, iters=epoch, title='val loss', xlabel='iters', ylabel='val loss') vis.plot_curves({'val score': float(score)}, iters=epoch, title='val score', xlabel='iters', ylabel='val score') if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch # lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) lr_scheduler.step(epoch + 1, score) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None and score > best_score: # save proper checkpoint with eval metric best_score = score save_metric = best_score best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) del model del optimizer torch.cuda.empty_cache() except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=int( 2 * args.batch_size), sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating S3-Transformer") model = SSSTransformer(img_size=args.input_size, patch_size=args.patch_size, num_classes=args.nb_classes, embed_dim=cfg.EMBED_DIM, depths=cfg.DEPTHS, num_heads=cfg.NUM_HEADS, window_size=cfg.WINDOW_SIZE, mlp_ratio=cfg.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, patch_norm=True) model.to(device) model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(output_dir / "config.yaml", 'w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.0 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") # model = create_model( # args.model, # pretrained=False, # num_classes=args.nb_classes, # drop_rate=args.drop, # drop_path_rate=args.drop_path, # drop_block_rate=None, # ) model = getattr(SwinTransformer, args.model)(num_classes=args.nb_classes, drop_path_rate=args.drop_path) model.to(device) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr linear_scaled_warmup_lr = args.warmup_lr * args.batch_size * utils.get_world_size( ) / 512.0 args.warmup_lr = linear_scaled_warmup_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) lr_scheduler.step(epoch + 1) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, mixup_fn, set_training_mode=True # keep in eval mode during finetuning ) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=args.batch_size // 2, sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print("Creating SuperVisionTransformer") print(cfg) model = Vision_TransformerSuper(img_size=args.input_size, patch_size=args.patch_size, embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, num_heads=cfg.SUPERNET.NUM_HEADS,mlp_ratio=cfg.SUPERNET.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, gp=args.gp, num_classes=args.nb_classes, max_relative_position=args.max_relative_position, relative_position=args.relative_position, change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM , 'depth': cfg.SEARCH_SPACE.DEPTH} model.to(device) if args.teacher_model: teacher_model = create_model( args.teacher_model, pretrained=True, num_classes=args.nb_classes, ) teacher_model.to(device) teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: teacher_model = None teacher_loss = None model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(file=output_dir / "config.yaml", mode='w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) retrain_config = None if args.mode == 'retrain' and "RETRAIN" in cfg: retrain_config = {'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM]*cfg.RETRAIN.DEPTH, 'num_heads': cfg.RETRAIN.NUM_HEADS,'mlp_ratio': cfg.RETRAIN.MLP_RATIO} trainer = AFSupernetTrainer( model, criterion, data_loader_train, data_loader_val, optimizer, device, args.epochs, loss_scaler, args.clip_grad, model_ema, mixup_fn, args.amp, teacher_model, teacher_loss,choices, args.mode, retrain_config, 0., output_dir, lr_scheduler, ) if args.eval: trainer._validate_one_epoch(-1) return trainer.fit()
def main(args): print(args) device = torch.device(args.device) # fix seed for reproducability print("Setting random seed") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.benchmark = True cudnn.deterministic = True data_directory = args.data_path print("Loading data") train_dataset = ColonCancerDataset(data_directory, train=True, seed=args.seed) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True) val_dataset = ColonCancerDataset(data_directory, train=False, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=int(1.5 * args.batch_size), shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) print(f"Creating model: {args.model}") efficient_transformer = Performer(dim=384, depth=12, heads=6, causal=True) model = ViT(image_size=500, patch_size=25, num_classes=2, dim=384, transformer=efficient_transformer) # TODO fix create model function and files # model = create_model( # args.model, # pretrained=False, # num_classes=2, # drop_rate=args.drop, # drop_path_rate=args.drop_path, # drop_block_rate=None, # ) model.to(device) model_without_ddp = model n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of params: {n_parameters}") linear_scaled_lr = args.lr * args.batch_size / 512.0 args.lr = linear_scaled_lr optimiser = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimiser) criterion = LabelSmoothingCrossEntropy() output_dir = Path(args.output_dir) wandb.watch(model, criterion, log='all', log_freq=10) print(f"Starting training for {args.epochs} epochs") start_time = time.time() for epoch in tqdm(range(args.start_epoch, args.epochs + 1)): train_loss, train_metrics = train_one_epoch(model, criterion, train_loader, optimiser, device) lr_scheduler.step(epoch) # TODO add in resuming training val_loss, val_metrics = evaluate(val_loader, model, device) if args.output_dir: checkpoint_paths = [output_dir / "checkpoint.pth"] for checkpoint_path in checkpoint_paths: save( { "model": model_without_ddp.state_dict(), "optimiser": optimiser.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "scaler": loss_scaler.state_dict(), "args": args, }, checkpoint_path) wandb.log({ "epoch": epoch, "train loss": train_loss, "val loss": val_loss, "train acc": train_metrics["accuracy"], "train f1": train_metrics["f1 score"], "train prec": train_metrics["precision"], "train recall": train_metrics["recall"], "val acc": val_metrics["accuracy"], "val f1": val_metrics["f1 score"], "val prec": val_metrics["precision"], "val recall": val_metrics["recall"] }) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time {}".format(total_time_str))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=int( 2 * args.batch_size), sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating SuperVisionTransformer") print(cfg) model = Vision_TransformerSuper( img_size=args.input_size, patch_size=args.patch_size, embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, num_heads=cfg.SUPERNET.NUM_HEADS, mlp_ratio=cfg.SUPERNET.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, gp=args.gp, num_classes=args.nb_classes, max_relative_position=args.max_relative_position, relative_position=args.relative_position, change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) choices = { 'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM, 'depth': cfg.SEARCH_SPACE.DEPTH } model.to(device) if args.teacher_model: teacher_model = create_model( args.teacher_model, pretrained=True, num_classes=args.nb_classes, ) teacher_model.to(device) teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: teacher_model = None teacher_loss = None model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(output_dir / "config.yaml", 'w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) retrain_config = None if args.mode == 'retrain' and "RETRAIN" in cfg: retrain_config = { 'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM] * cfg.RETRAIN.DEPTH, 'num_heads': cfg.RETRAIN.NUM_HEADS, 'mlp_ratio': cfg.RETRAIN.MLP_RATIO } if args.eval: print(retrain_config) test_stats = evaluate(data_loader_val, model, device, mode=args.mode, retrain_config=retrain_config) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print("Start training") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, amp=args.amp, teacher_model=teacher_model, teach_loss=teacher_loss, choices=choices, mode=args.mode, retrain_config=retrain_config, ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, # 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device, amp=args.amp, choices=choices, mode=args.mode, retrain_config=retrain_config) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model_KD = None if args.kd_model_path is not None: model_KD = build_kd_model(args) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy( target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, model_KD=model_KD) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main( log_dir, dataset, im_size, crop_size, window_size, window_stride, backbone, decoder, optimizer, scheduler, weight_decay, dropout, drop_path, batch_size, epochs, learning_rate, normalization, eval_freq, amp, resume, ): # start distributed mode ptu.set_gpu_mode(True) distributed.init_process() # set up configuration cfg = config.load_config() model_cfg = cfg["model"][backbone] dataset_cfg = cfg["dataset"][dataset] if "mask_transformer" in decoder: decoder_cfg = cfg["decoder"]["mask_transformer"] else: decoder_cfg = cfg["decoder"][decoder] # model config if not im_size: im_size = dataset_cfg["im_size"] if not crop_size: crop_size = dataset_cfg.get("crop_size", im_size) if not window_size: window_size = dataset_cfg.get("window_size", im_size) if not window_stride: window_stride = dataset_cfg.get("window_stride", im_size) model_cfg["image_size"] = (crop_size, crop_size) model_cfg["backbone"] = backbone model_cfg["dropout"] = dropout model_cfg["drop_path_rate"] = drop_path decoder_cfg["name"] = decoder model_cfg["decoder"] = decoder_cfg # dataset config world_batch_size = dataset_cfg["batch_size"] num_epochs = dataset_cfg["epochs"] lr = dataset_cfg["learning_rate"] if batch_size: world_batch_size = batch_size if epochs: num_epochs = epochs if learning_rate: lr = learning_rate if eval_freq is None: eval_freq = dataset_cfg.get("eval_freq", 1) if normalization: model_cfg["normalization"] = normalization # experiment config batch_size = world_batch_size // ptu.world_size variant = dict( world_batch_size=world_batch_size, version="normal", resume=resume, dataset_kwargs=dict( dataset=dataset, image_size=im_size, crop_size=crop_size, batch_size=batch_size, normalization=model_cfg["normalization"], split="train", num_workers=10, ), algorithm_kwargs=dict( batch_size=batch_size, start_epoch=0, num_epochs=num_epochs, eval_freq=eval_freq, ), optimizer_kwargs=dict( opt=optimizer, lr=lr, weight_decay=weight_decay, momentum=0.9, clip_grad=None, sched=scheduler, epochs=num_epochs, min_lr=1e-5, poly_power=0.9, poly_step_size=1, ), net_kwargs=model_cfg, amp=amp, log_dir=log_dir, inference_kwargs=dict( im_size=im_size, window_size=window_size, window_stride=window_stride, ), ) log_dir = Path(log_dir) log_dir.mkdir(parents=True, exist_ok=True) checkpoint_path = log_dir / "checkpoint.pth" # dataset dataset_kwargs = variant["dataset_kwargs"] train_loader = create_dataset(dataset_kwargs) val_kwargs = dataset_kwargs.copy() val_kwargs["split"] = "val" val_kwargs["batch_size"] = 1 val_kwargs["crop"] = False val_loader = create_dataset(val_kwargs) n_cls = train_loader.unwrapped.n_cls # model net_kwargs = variant["net_kwargs"] net_kwargs["n_cls"] = n_cls model = create_segmenter(net_kwargs) model.to(ptu.device) # optimizer optimizer_kwargs = variant["optimizer_kwargs"] optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"] optimizer_kwargs["iter_warmup"] = 0.0 opt_args = argparse.Namespace() opt_vars = vars(opt_args) for k, v in optimizer_kwargs.items(): opt_vars[k] = v optimizer = create_optimizer(opt_args, model) lr_scheduler = create_scheduler(opt_args, optimizer) num_iterations = 0 amp_autocast = suppress loss_scaler = None if amp: amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() # resume if resume and checkpoint_path.exists(): print(f"Resuming training from checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) if loss_scaler and "loss_scaler" in checkpoint: loss_scaler.load_state_dict(checkpoint["loss_scaler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] + 1 else: sync_model(log_dir, model) if ptu.distributed: model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True) # save config variant_str = yaml.dump(variant) print(f"Configuration:\n{variant_str}") variant["net_kwargs"] = net_kwargs variant["dataset_kwargs"] = dataset_kwargs log_dir.mkdir(parents=True, exist_ok=True) with open(log_dir / "variant.yml", "w") as f: f.write(variant_str) # train start_epoch = variant["algorithm_kwargs"]["start_epoch"] num_epochs = variant["algorithm_kwargs"]["num_epochs"] eval_freq = variant["algorithm_kwargs"]["eval_freq"] model_without_ddp = model if hasattr(model, "module"): model_without_ddp = model.module val_seg_gt = val_loader.dataset.get_gt_seg_maps() print(f"Train dataset length: {len(train_loader.dataset)}") print(f"Val dataset length: {len(val_loader.dataset)}") print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}") print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}") for epoch in range(start_epoch, num_epochs): # train for one epoch train_logger = train_one_epoch( model, train_loader, optimizer, lr_scheduler, epoch, amp_autocast, loss_scaler, ) # save checkpoint if ptu.dist_rank == 0: snapshot = dict( model=model_without_ddp.state_dict(), optimizer=optimizer.state_dict(), n_cls=model_without_ddp.n_cls, lr_scheduler=lr_scheduler.state_dict(), ) if loss_scaler is not None: snapshot["loss_scaler"] = loss_scaler.state_dict() snapshot["epoch"] = epoch torch.save(snapshot, checkpoint_path) # evaluate eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1 if eval_epoch: eval_logger = evaluate( model, val_loader, val_seg_gt, window_size, window_stride, amp_autocast, ) print(f"Stats [{epoch}]:", eval_logger, flush=True) print("") # log stats if ptu.dist_rank == 0: train_stats = { k: meter.global_avg for k, meter in train_logger.meters.items() } val_stats = {} if eval_epoch: val_stats = { k: meter.global_avg for k, meter in eval_logger.meters.items() } log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"val_{k}": v for k, v in val_stats.items()}, "epoch": epoch, "num_updates": (epoch + 1) * len(train_loader), } with open(log_dir / "log.txt", "a") as f: f.write(json.dumps(log_stats) + "\n") distributed.barrier() distributed.destroy_process() sys.exit(1)
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: _logger.warning( 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") if args.num_gpu > 1: if use_amp == 'apex': _logger.warning( 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') use_amp = None model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() assert not args.channels_last, "Channels last not supported with DP, use DDP." else: model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) optimizer = create_optimizer(args, model) amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') except Exception as e: _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error('Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) # if saver.cmp(best_metric, save_metric): # _logger.info(f"Metric is no longer improving [BEST: {best_metric}, CURRENT: {save_metric}]" # f"\nFinishing training process") # if epoch > 15: # break except KeyboardInterrupt: pass if best_metric is not None: message = '*** Best metric: <{0:.2f}>, epoch: <{1}>, path: <{2}> ***'\ .format(best_metric, best_epoch, output_dir) _logger.info(message) print(message)
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) if args.fuser: set_jit_fuser(args.fuser) data_splits = get_data_splits_by_name( dataset_name=args.dataset_name, data_root=args.data_dir, batch_size=args.batch_size, ) loader_train, loader_eval = data_splits['train'], data_splits['test'] model_wrapper_fn = MODEL_WRAPPER_REGISTRY.get( model_name=args.model.lower(), dataset_name=args.pretraining_original_dataset) model = model_wrapper_fn(pretrained=args.pretrained, progress=True, num_classes=len(loader_train.dataset.classes)) if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))