def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError( "Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in [ 'head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias' ]: if k in checkpoint_model and checkpoint_model[ k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int( (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) # height (== width) for the new position embedding new_size = int(num_patches**0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() teacher_model = None if args.distillation_type != 'none': assert args.teacher_path, 'need to specify teacher-path when using distillation' print(f"Creating teacher model: {args.teacher_model}") teacher_model = create_model( args.teacher_model, pretrained=False, num_classes=args.nb_classes, global_pool='avg', ) if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss(criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, set_training_mode=args.finetune == '' # keep in eval mode during finetuning ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning("You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def _train_loop(self, train_loader, val_loader, time_limit=math.inf): start_tic = time.time() # setup loss function if self.mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy() elif self._augmentation_cfg.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=self._augmentation_cfg.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() validate_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.to(self.ctx[0]) validate_loss_fn = validate_loss_fn.to(self.ctx[0]) eval_metric = self._misc_cfg.eval_metric if self._problem_type == REGRESSION: train_loss_fn = nn.MSELoss() validate_loss_fn = nn.MSELoss() eval_metric = 'rmse' early_stopper = EarlyStopperOnPlateau( patience=self._train_cfg.early_stop_patience, min_delta=self._train_cfg.early_stop_min_delta, baseline_value=self._train_cfg.early_stop_baseline, max_value=self._train_cfg.early_stop_max_value) self._logger.info('Start training from [Epoch %d]', max(self._train_cfg.start_epoch, self.epoch)) self._time_elapsed += time.time() - start_tic for self.epoch in range(max(self.start_epoch, self.epoch), self.epochs): epoch = self.epoch if self._best_acc >= 1.0: self._logger.info( '[Epoch {}] Early stopping as acc is reaching 1.0'.format( epoch)) break should_stop, stop_message = early_stopper.get_early_stop_advice() if should_stop: self._logger.info('[Epoch {}] '.format(epoch) + stop_message) break train_metrics = self.train_one_epoch( epoch, self.net, train_loader, self._optimizer, train_loss_fn, lr_scheduler=self._lr_scheduler, output_dir=self._logdir, amp_autocast=self._amp_autocast, loss_scaler=self._loss_scaler, model_ema=self._model_ema, mixup_fn=self._mixup_fn, time_limit=time_limit) # reaching time limit, exit early if train_metrics['time_limit']: self._logger.warning( f'`time_limit={time_limit}` reached, exit early...') return { 'train_acc': train_metrics['train_acc'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name } post_tic = time.time() eval_metrics = self.validate(self.net, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) if self._model_ema is not None and not self._model_ema_cfg.model_ema_force_cpu: ema_eval_metrics = self.validate( self._model_ema.module, val_loader, validate_loss_fn, amp_autocast=self._amp_autocast) eval_metrics = ema_eval_metrics if self._problem_type == REGRESSION: val_acc = eval_metrics['rmse'] if self._reporter: self._reporter(epoch=epoch, acc_reward=-val_acc) early_stopper.update(-val_acc) if -val_acc > self._best_acc: self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) self._logger.info( '[Epoch %d] Current best rmse: %f vs previous %f, saved to %s', self.epoch, val_acc, -self._best_acc, self._cp_name) self.save(self._cp_name) self._best_acc = -val_acc else: val_acc = eval_metrics['top1'] if self._reporter: self._reporter(epoch=epoch, acc_reward=val_acc) early_stopper.update(val_acc) if val_acc > self._best_acc: self._cp_name = os.path.join(self._logdir, _BEST_CHECKPOINT_FILE) self._logger.info( '[Epoch %d] Current best top-1: %f vs previous %f, saved to %s', self.epoch, val_acc, self._best_acc, self._cp_name) self.save(self._cp_name) self._best_acc = val_acc if self._lr_scheduler is not None: # step LR for next epoch self._lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) self._time_elapsed += time.time() - post_tic if 'accuracy' in train_metrics: return { 'train_acc': train_metrics['accuracy'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name } # rmse else: if self._problem_type == REGRESSION: return { 'train_score': train_metrics['rmse'], 'valid_score': -self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name } # mixup else: return { 'train_score': train_metrics['rmse'], 'valid_acc': self._best_acc, 'time': self._time_elapsed, 'checkpoint': self._cp_name }
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: _logger.warning( 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") if args.num_gpu > 1: if use_amp == 'apex': _logger.warning( 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.' ) use_amp = None model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() assert not args.channels_last, "Channels last not supported with DP, use DDP." else: model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) optimizer = create_optimizer(args, model) amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: _logger.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[ args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): _logger.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): _logger.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() eval_metric = args.eval_metric best_metric = None best_epoch = None if args.eval_checkpoint: # evaluate the model load_checkpoint(model, args.eval_checkpoint, args.model_ema) val_metrics = validate(model, loader_eval, validate_loss_fn, args) print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%") return saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: # train the model for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") if config.EVAL_MODE: return if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=int( 1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) # TODO: finetuning model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print("Start training") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) net_config = json.load(open(args.model_config)) model = NSGANetV2.build_from_config(net_config, drop_connect_rate=args.drop_path) init = torch.load(args.initial_checkpoint, map_location='cpu')['state_dict'] model.load_state_dict(init) if args.reset_classifier: NSGANetV2.reset_classifier(model, last_channel=model.classifier.in_features, n_classes=args.num_classes, dropout_rate=args.drop) # create a dummy model to get the model cfg just for running this timm training script dummy_model = create_model('efficientnet_b0') # add a teacher model if args.teacher: # using the supernet at full scale to supervise the training of the subnets # this is taken from https://github.com/mit-han-lab/once-for-all/blob/ # 4decacf9d85dbc948a902c6dc34ea032d711e0a9/train_ofa_net.py#L125 from evaluator import OFAEvaluator supernet = OFAEvaluator(n_classes=1000, model_path=args.teacher) teacher, _ = supernet.sample({ 'ks': [7] * 20, 'e': [6] * 20, 'd': [4] * 5 }) # as alternative, you could simply use a pretrained model from timm, e.g. "tf_efficientnet_b1", # "mobilenetv3_large_100", etc. See https://rwightman.github.io/pytorch-image-models/results/ for full list. # import timm # teacher = timm.create_model(args.teacher, pretrained=True) args.kd_ratio = 1.0 print("teacher model loaded") else: args.kd_ratio = 0.0 if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=dummy_model, verbose=args.local_rank == 0) if args.img_size is not None: data_config['input_size'] = (3, args.img_size, args.img_size ) # override the input image resolution num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() if args.teacher: teacher = nn.DataParallel(teacher, device_ids=list(range( args.num_gpu))).cuda() else: model.cuda() if args.teacher: teacher.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) del resume_state model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) if args.teacher: teacher = DDP(teacher, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 if teacher: teacher = DDP(teacher, device_ids=[args.local_rank]) # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema, teacher=teacher) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args = parser.parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # optionally resume from a checkpoint optimizer_state = None resume_epoch = None if args.resume: optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) use_amp = False if has_apex and args.amp: model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma(model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume=args.resume) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) train_dir = os.path.join(args.data, 'train') if not os.path.exists(train_dir): logging.error( 'Training folder does not exist at: {}'.format(train_dir)) exit(1) dataset_train = Dataset(train_dir) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, color_jitter=args.color_jitter, interpolation= 'random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, ) eval_dir = os.path.join(args.data, 'validation') if not os.path.isdir(eval_dir): logging.error( 'Validation folder does not exist at: {}'.format(eval_dir)) exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) try: for epoch in range(start_epoch, num_epochs): if args.distributed: loader_train.sampler.set_epoch(epoch) train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: ema_eval_metrics = validate(model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(args): utils.init_distributed_mode(args) update_config_from_file(args.cfg) print(args) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_train = torch.utils.data.RandomSampler(dataset_train) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=args.batch_size // 2, sampler=sampler_val, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print("Creating SuperVisionTransformer") print(cfg) model = Vision_TransformerSuper(img_size=args.input_size, patch_size=args.patch_size, embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, num_heads=cfg.SUPERNET.NUM_HEADS,mlp_ratio=cfg.SUPERNET.MLP_RATIO, qkv_bias=True, drop_rate=args.drop, drop_path_rate=args.drop_path, gp=args.gp, num_classes=args.nb_classes, max_relative_position=args.max_relative_position, relative_position=args.relative_position, change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM , 'depth': cfg.SEARCH_SPACE.DEPTH} model.to(device) if args.teacher_model: teacher_model = create_model( args.teacher_model, pretrained=True, num_classes=args.nb_classes, ) teacher_model.to(device) teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: teacher_model = None teacher_loss = None model_ema = None model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() output_dir = Path(args.output_dir) if not output_dir.exists(): output_dir.mkdir(parents=True) # save config for later experiments with open(file=output_dir / "config.yaml", mode='w') as f: f.write(args_text) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) retrain_config = None if args.mode == 'retrain' and "RETRAIN" in cfg: retrain_config = {'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM]*cfg.RETRAIN.DEPTH, 'num_heads': cfg.RETRAIN.NUM_HEADS,'mlp_ratio': cfg.RETRAIN.MLP_RATIO} trainer = AFSupernetTrainer( model, criterion, data_loader_train, data_loader_val, optimizer, device, args.epochs, loss_scaler, args.clip_grad, model_ema, mixup_fn, args.amp, teacher_model, teacher_loss,choices, args.mode, retrain_config, 0., output_dir, lr_scheduler, ) if args.eval: trainer._validate_one_epoch(-1) return trainer.fit()
def main(): global args setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.binarizable: Model_binary_patch(model) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: print('Using amp with --opt-level {}.'.format(args.opt_level)) model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) use_amp = True else: print('Do NOT use amp.') if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) resume_state = None if args.freeze_binary: Model_freeze_binary(model) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) print(num_epochs) # start_epoch = 0 # if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if args.reset_lr_scheduler is not None: lr_scheduler.base_values = len( lr_scheduler.base_values) * [args.reset_lr_scheduler] lr_scheduler.step(start_epoch) if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) # Using pruner to get sparse weights if args.prune: if args.pruner == 'V1': pruner = Pruner(model, 0, 100) elif args.pruner == 'V2': pruner = Pruner_mixed(model, 0, 100) else: pruner = None dataset_train = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=True, download=True) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader_CIFAR100( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, rand_erase_count=args.recount, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='random', mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, is_clean_data=args.clean_train, ) dataset_eval = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=False, download=True) loader_eval = create_loader_CIFAR100( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy( multiplier=args.softmax_multiplier).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None saver_last_10_epochs = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False os.makedirs(output_dir + '/Top') os.makedirs(output_dir + '/Last') saver = CheckpointSaver( checkpoint_dir=output_dir + '/Top', decreasing=decreasing, max_history=10) # Save the results of the top 10 epochs saver_last_10_epochs = CheckpointSaver( checkpoint_dir=output_dir + '/Last', decreasing=decreasing, max_history=10) # Save the results of the last 10 epochs with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) f.write('==============================\n') f.write(model.__str__()) if pruner: f.write('\n Sparsity \n') #f.write(pruner.threshold_dict.__str__()) f.write('\n pruner.start_epoch={}, pruner.end_epoch={}'.format( pruner.start_epoch, pruner.end_epoch)) tensorboard_writer = SummaryWriter(output_dir) try: for epoch in range(start_epoch, num_epochs): global alpha alpha = get_alpha(epoch, args) if args.distributed: loader_train.sampler.set_epoch(epoch) if pruner: pruner.on_epoch_begin(epoch) # pruning train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, tensorboard_writer=tensorboard_writer, pruner=pruner) if pruner: pruner.print_statistics() eval_metrics = validate(model, loader_eval, validate_loss_fn, args, tensorboard_writer=tensorboard_writer, epoch=epoch) if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, metric=save_metric, use_amp=use_amp) if saver_last_10_epochs is not None: # save the checkpoint in the last 5 epochs _, _ = saver_last_10_epochs.save_checkpoint(model, optimizer, args, epoch=epoch, metric=epoch, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(args): utils.my_init_distributed_mode(args) print(args) # if args.distillation_type != 'none' and args.finetune and not args.eval: # raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) # if True: # args.distributed: # num_tasks = utils.get_world_size() # global_rank = utils.get_rank() # if args.repeated_aug: # sampler_train = RASampler( # dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True # ) # else: # sampler_train = torch.utils.data.DistributedSampler( # dataset_train, # # num_replicas=num_tasks, # num_replicas=0, # rank=global_rank, shuffle=True # ) # if args.dist_eval: # if len(dataset_val) % num_tasks != 0: # print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' # 'This will slightly alter validation results as extra duplicate entries are added to achieve ' # 'equal num of samples per-process.') # sampler_val = torch.utils.data.DistributedSampler( # dataset_val, # # num_replicas=num_tasks, # num_replicas=0, # rank=global_rank, shuffle=False) # else: # sampler_val = torch.utils.data.SequentialSampler(dataset_val) # else: # sampler_train = torch.utils.data.RandomSampler(dataset_train) # sampler_val = torch.utils.data.SequentialSampler(dataset_val) # # data_loader_train = torch.utils.data.DataLoader( # dataset_train, sampler=sampler_train, # batch_size=args.batch_size, # num_workers=args.num_workers, # pin_memory=args.pin_mem, # drop_last=True, # ) # # data_loader_val = torch.utils.data.DataLoader( # dataset_val, sampler=sampler_val, # batch_size=int(1.5 * args.batch_size), # num_workers=args.num_workers, # pin_memory=args.pin_mem, # drop_last=False # ) # if args.distributed: if args.cache_mode: sampler_train = samplers.NodeDistributedSampler(dataset_train) sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) else: sampler_train = samplers.DistributedSampler(dataset_train) sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True) data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, num_workers=args.num_workers, pin_memory=True) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) model_without_ddp = model # # there are bugs # if args.finetune: # if args.finetune.startswith('https'): # checkpoint = torch.hub.load_state_dict_from_url( # args.finetune, map_location='cpu', check_hash=True) # else: # checkpoint = torch.load(args.finetune, map_location='cpu') # # checkpoint_model = checkpoint['model'] # # state_dict = model.state_dict() # for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: # print(f"Removing key {k} from pretrained checkpoint") # del checkpoint_model[k] # # _ = model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None # if args.model_ema: # # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # model_ema = ModelEma( # model, # decay=args.model_ema_decay, # device='cpu' if args.model_ema_force_cpu else '', # resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size( ) / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() criterion = DistillationLoss(criterion, None, 'none', 0, 0) output_dir = Path(args.output_dir) # for finetune if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.finetune, map_location='cpu', check_hash=True) else: if not os.path.exists(args.finetune): checkpoint = None print('NOTICE:' + args.finetune + ' does not exist!') else: checkpoint = torch.load(args.finetune, map_location='cpu') if checkpoint is not None: if 'model' in checkpoint: check_model = checkpoint['model'] else: check_model = checkpoint missing_keys = model_without_ddp.load_state_dict( check_model, strict=False).missing_keys skip_keys = model_without_ddp.no_weight_decay() # create optimizer manually param_dicts = [ { "params": [ p for n, p in model_without_ddp.named_parameters() if n in missing_keys and n not in skip_keys ], "lr": args.lr, 'weight_decay': args.weight_decay, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n in missing_keys and n in skip_keys ], "lr": args.lr, 'weight_decay': 0, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n not in missing_keys and n not in skip_keys ], "lr": args.lr * args.fine_factor, 'weight_decay': args.weight_decay, }, { "params": [ p for n, p in model_without_ddp.named_parameters() if n not in missing_keys and n in skip_keys ], "lr": args.lr * args.fine_factor, 'weight_decay': 0, }, ] optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) # if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: # optimizer.load_state_dict(checkpoint['optimizer']) # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) # args.start_epoch = checkpoint['epoch'] + 1 # # if args.model_ema: # # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) # if 'scaler' in checkpoint: # loss_scaler.load_state_dict(checkpoint['scaler']) print('finetune from' + args.finetune) # for debug # lr_scheduler.step(10) # lr_scheduler.step(100) # lr_scheduler.step(200) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) else: if not os.path.exists(args.resume): checkpoint = None print('NOTICE:' + args.resume + ' does not exist!') else: checkpoint = torch.load(args.resume, map_location='cpu') if checkpoint is not None: if 'model' in checkpoint: model_without_ddp.load_state_dict(checkpoint['model']) else: model_without_ddp.load_state_dict(checkpoint) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # if args.model_ema: # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print('resume from' + args.resume) if args.eval: test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 max_epoch_dp_warm_up = 100 if 'pvt_tiny' in args.model or 'pvt_small' in args.model: max_epoch_dp_warm_up = 0 if args.start_epoch < max_epoch_dp_warm_up: model_without_ddp.reset_drop_path(0.0) for epoch in range(args.start_epoch, args.epochs): if args.fp32_resume and epoch > args.start_epoch + 1: args.fp32_resume = False loss_scaler._scaler = torch.cuda.amp.GradScaler( enabled=not args.fp32_resume) if epoch == max_epoch_dp_warm_up: model_without_ddp.reset_drop_path(args.drop_path) if epoch < args.warmup_epochs: optimizer.param_groups[2]['lr'] = 0 optimizer.param_groups[3]['lr'] = 0 if args.distributed: # data_loader_train.sampler.set_epoch(epoch) sampler_train.set_epoch(epoch) train_stats = my_train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, # set_training_mode=args.finetune == '', # keep in eval mode during finetuning fp32=args.fp32_resume) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master( { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, # 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print( f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" ) max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))