def test_rectified(optimizer): _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict(weight, bias, lr=3e-3), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=3e-3), optimizer, lr=1e-3)) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)) _test_model(optimizer, dict(lr=1e-3))
def test_lookahead_radam(optimizer): _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict(weight, bias, lr=3e-3), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=3e-3), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=3e-3), optimizer)) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-4))
def _init_trainer(self): if self._optimizer is None: if self._img_cls_cfg.pretrained and not self._custom_net \ and (self._train_cfg.transfer_lr_mult != 1 or self._train_cfg.output_lr_mult != 1): # adjust feature/last_fc learning rate multiplier in optimizer self._logger.debug(f'Reduce network lr multiplier to {self._train_cfg.transfer_lr_mult}, while keep ' + f'last FC layer lr_mult to {self._train_cfg.output_lr_mult}') optim_kwargs = optimizer_kwargs(cfg=self._cfg) optim_kwargs['feature_lr_mult'] = self._cfg.train.transfer_lr_mult optim_kwargs['fc_lr_mult'] = self._cfg.train.output_lr_mult self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) else: self._optimizer = create_optimizer_v2(self.net, **optimizer_kwargs(cfg=self._cfg)) self._init_loss_scaler() self._lr_scheduler, self.epochs = create_scheduler(self._cfg, self._optimizer) self._lr_scheduler.step(self.start_epoch, self.epoch)
def _init_trainer(self): if self._optimizer is None: self._optimizer = create_optimizer_v2( self.net, **optimizer_kwargs(cfg=self._cfg)) self._init_loss_scaler() self._lr_scheduler, self.epochs = create_scheduler( self._cfg, self._optimizer) self._lr_scheduler.step(self.start_epoch, self.epoch)
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs): super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) self.model.train() self.loss = nn.CrossEntropyLoss().to(self.device) self.target_shape = tuple() self.optimizer = create_optimizer_v2(self.model, opt=kwargs.pop('opt', 'sgd'), lr=kwargs.pop('lr', 1e-4))
def test_adafactor(optimizer): _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict(weight, bias, lr=3e-3), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=3e-3), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias), optimizer)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=1e-3, weight_decay=1)) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)) _test_model(optimizer, dict(lr=5e-2))
def test_sgd(optimizer): _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict(weight, bias, lr=1e-2), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=1e-2), optimizer, lr=1e-3)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( _build_params_dict_single(weight, bias, lr=1e-2), optimizer)) # _test_basic_cases( # lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3), # [lambda opt: StepLR(opt, gamma=0.9, step_size=10)] # ) # _test_basic_cases( # lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3), # [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")] # ) # _test_basic_cases( # lambda weight, bias: optimizer([weight, bias], lr=1e-3), # [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] # ) # _test_basic_cases( # lambda weight, bias: optimizer([weight, bias], lr=1e-3), # [lambda opt: StepLR(opt, gamma=0.9, step_size=10), # lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)] # ) # _test_basic_cases( # lambda weight, bias: optimizer([weight, bias], lr=1e-3), # [lambda opt: StepLR(opt, gamma=0.9, step_size=10), # lambda opt: ReduceLROnPlateau(opt)] # ) # _test_basic_cases( # lambda weight, bias: optimizer([weight, bias], lr=1e-3), # [lambda opt: StepLR(opt, gamma=0.99, step_size=10), # lambda opt: ExponentialLR(opt, gamma=0.99), # lambda opt: ReduceLROnPlateau(opt)] # ) _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=3e-3, momentum=1)) _test_basic_cases(lambda weight, bias: create_optimizer_v2( [weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)) _test_model(optimizer, dict(lr=1e-3))
def _test_model(optimizer, params, device=torch.device('cpu')): weight = torch.tensor( [[-0.2109, -0.4976], [-0.1413, -0.3420], [-0.2524, 0.6976]], device=device, requires_grad=True) bias = torch.tensor([-0.1085, -0.2979, 0.6892], device=device, requires_grad=True) weight2 = torch.tensor([[-0.0508, -0.3941, -0.2843]], device=device, requires_grad=True) bias2 = torch.tensor([-0.0711], device=device, requires_grad=True) input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], device=device).reshape(3, 2) model = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Sigmoid(), torch.nn.Linear(3, 1), torch.nn.Sigmoid()) model.to(device) pretrained_dict = model.state_dict() pretrained_dict['0.weight'] = weight pretrained_dict['0.bias'] = bias pretrained_dict['2.weight'] = weight2 pretrained_dict['2.bias'] = bias2 model.load_state_dict(pretrained_dict) optimizer = create_optimizer_v2(model, opt=optimizer, **params) prev_loss = float('inf') for i in range(20): optimizer.zero_grad() output = model(input) loss = output.sum() loss.backward() loss = loss.item() assert loss < prev_loss prev_loss = loss optimizer.step()
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model_KD = None if args.kd_model_path is not None: model_KD = build_kd_model(args) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset(args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset(args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy( target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, model_KD=model_KD) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): #wandb.init(project=args.experiment, config=args) setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if torch.cuda.is_available(): device = torch.device("cuda") _logger.info(f'GPU: {torch.cuda.get_device_name(0)}') else: device = torch.device("cpu") args.world_size = 1 args.rank = 0 # global rank assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None random_seed(args.seed, args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) model = determine_layer(model, args.finetune) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 model.to(device) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting # amp_autocast = suppress # do nothing loss_scaler = None # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets # dataset_train = create_dataset( # args.dataset, # root=args.data_dir, split=args.train_split, is_training=True, # batch_size=args.batch_size, repeats=args.epoch_repeats) # dataset_eval = create_dataset( # args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # # setup mixup / cutmix # collate_fn = None # mixup_fn = None # create data loaders w/ augmentation pipeiine if 'skin' in args.experiment: _logger.info('Loading Dataset') img_df = pd.read_csv(args.csv_path) # csv directory img_names, labels = list(img_df['image_name']), list( img_df['diagnosis']) img_index = list(range(len(img_names))) train_valid_index, _, train_valid_labels, _ = train_test_split( img_index, labels, test_size=0.2, shuffle=True, stratify=labels, random_state=args.seed) kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=args.seed) for fold_index, (train_index, valid_index) in enumerate( kf.split(train_valid_index, train_valid_labels)): train_index = train_index valid_index = valid_index break _logger.info(f'augmentation : {args.augment}') train_df = img_df[img_df.index.isin(train_index)].reset_index( drop=True) train_dataset = SkinDataset( data_dir=args.data_dir, df=train_df, transform=get_skin_transforms(augment=args.augment, args=args)) # file directory valid_df = img_df[img_df.index.isin(valid_index)].reset_index( drop=True) valid_dataset = SkinDataset( data_dir=args.data_dir, df=valid_df, transform=get_skin_transforms(augment='none', args=args)) # file directory _logger.info('Load Sampler & Loader') print(len(train_index), len(valid_index)) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=False) valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=False) elif 'lung' in args.experiment: _logger.info('Loading Dataset') img_df = pd.read_csv(args.csv_path) # csv directory img_names, labels = list(img_df['image_link']), list(img_df['label']) img_index = list(range(len(img_names))) _logger.info(f'augmentation : {args.augment}') train_df = img_df[img_df['tvt'] == 'train'].reset_index(drop=True) train_dataset = LungDataset( df=train_df, transform=get_lung_transforms(augment=args.augment, args=args)) # file directory train_loader = DataLoader(train_dataset, batch_size=args.batch_size) valid_df = img_df[img_df['tvt'] == 'valid'].reset_index(drop=True) valid_dataset = LungDataset( df=valid_df, transform=get_lung_transforms(augment='none', args=args)) # file directory valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size) _logger.info('Load Sampler & Loader') _logger.info(len(train_dataset), len(valid_dataset)) train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size) valid_loader = DataLoader(valid_dataset, shuffle=True, batch_size=args.batch_size) # setup loss function if args.loss == 'focal': train_loss_fn = FocalLoss().to(device) else: if args.smoothing: _logger.info('default loss is LabelSmoothing') train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).to(device) else: train_loss_fn = nn.CrossEntropyLoss().to(device) validate_loss_fn = nn.CrossEntropyLoss().to(device) # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None early_stopping = EarlyStopping(patience=args.early_patience, delta=args.early_value, verbose=True) if args.rank == 0: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): train_metrics = train_one_epoch(epoch, model, train_loader, optimizer, train_loss_fn, device, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir) eval_metrics = validate(model, valid_loader, validate_loss_fn, device, args) if lr_scheduler is not None: # step LR for next epoch if args.sched == 'step': lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) else: lr_scheduler.step(epoch) early_stopping(eval_metric, eval_metrics[eval_metric], model) if early_stopping.early_stop: _logger.info('Early Stop') break if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))
def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning( "Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) if args.fuser: set_jit_fuser(args.fuser) data_splits = get_data_splits_by_name( dataset_name=args.dataset_name, data_root=args.data_dir, batch_size=args.batch_size, ) loader_train, loader_eval = data_splits['train'], data_splits['test'] model_wrapper_fn = MODEL_WRAPPER_REGISTRY.get( model_name=args.model.lower(), dataset_name=args.pretraining_original_dataset) model = model_wrapper_fn(pretrained=args.pretrained, progress=True, num_classes=len(loader_train.dataset.classes)) if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' ) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' ) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info( 'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy( smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir( args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info( "Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate(model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))