def __init__(self, project, model, early_stop: bool = True, runner_name=None, train_test_val_indices=None): self.task_flow = project.get_full_flow() self.default_criterion = self.task_flow.get_loss() self.default_callbacks = self.default_criterion.catalyst_callbacks() self.default_optimizer = partial(optim.AdamW, lr=1e-4) self.default_scheduler = ReduceLROnPlateau self.project_dir: Path = project.project_dir self.project_dir.mkdir(exist_ok=True) runner_name = f'{self.task_flow.get_name()}_{time()}' if runner_name is None else runner_name self.default_logdir = f'./logdir_{runner_name}' if early_stop: self.default_callbacks.append(EarlyStoppingCallback(patience=5)) if train_test_val_indices is None: (self.project_dir / self.default_logdir).mkdir(exist_ok=True) train_test_val_indices = project_split( project.df, self.project_dir / self.default_logdir) else: save_split(self.project_dir / self.default_logdir, train_test_val_indices) self.train_test_val_indices = train_test_val_indices self.tensor_loggers = project.converters.tensorboard_converters converters_file = self.project_dir / self.default_logdir / 'converters.pkl' if converters_file.exists(): project.converters.load_state_dict(torch.load(converters_file)) else: torch.save(project.converters.state_dict(), converters_file) super().__init__(model=model)
def __init__(self, project, model, early_stop: bool = True, balance_dataparallel_memory: bool = False, runner_name=None, train_test_val_indices=None): self.task_flow = project.get_full_flow() self.default_criterion = self.task_flow.get_loss() self.balance_dataparallel_memory = balance_dataparallel_memory self.default_callbacks = [] if self.balance_dataparallel_memory: self.default_callbacks.append(ReplaceGatherCallback( self.task_flow)) self.default_callbacks.extend( self.default_criterion.catalyst_callbacks()) self.default_optimizer = partial(optim.AdamW, lr=1e-4) self.default_scheduler = ReduceLROnPlateau self.project_dir: Path = project.project_dir self.project_dir.mkdir(exist_ok=True) runner_name = f'{self.task_flow.get_name()}_{time()}' if runner_name is None else runner_name self.default_logdir = f'./logdir_{runner_name}' if early_stop: self.default_callbacks.append(EarlyStoppingCallback(patience=5)) (self.project_dir / self.default_logdir).mkdir(exist_ok=True) if train_test_val_indices is None: train_test_val_indices = project_split( project.df, self.project_dir / self.default_logdir) else: save_split(self.project_dir / self.default_logdir, train_test_val_indices) self.train_test_val_indices = train_test_val_indices self.tensor_loggers = project.converters.tensorboard_converters super().__init__(model=model)
def torch_train(self, loaders, model, optimizer, loss_func, scheduler, config): self.config = config self.model = model self.optimizer = optimizer self.loss_func = loss_func self.scheduler = scheduler self.loader_key = list(loaders)[0] self.metric_key = 'loss' self.import_from_config() if 'cuda' in str(self.device): self.optimizer_to(optimizer, self.device) #checks if logdir exists - deletes it if yes self.check_logdir() if self.loader_key != 'train': warnings.warn( "WARNING: loader to be used for early-stop callback is '%s'. You can define it manually in /lib/estimator/pytorch_estimator.torch_train" % (self.loader_key)) model = self.model torch.cuda.empty_cache() if self.ddp: self.engine = None else: self.engine = DeviceEngine(self.device) self.print_info() self.runner.train( model=model, criterion=self.loss_func, optimizer=self.optimizer, scheduler=self.scheduler, loaders=loaders, logdir=self.config.logdir, num_epochs=self.config.n_epochs, callbacks=[ EarlyStoppingCallback(patience=self.config.patience, min_delta=self.config.min_delta, loader_key=self.loader_key, metric_key=self.metric_key, minimize=True), SchedulerCallback( loader_key=self.loader_key, metric_key=self.metric_key, ), SkipCheckpointCallback(logdir=self.config.logdir), ], verbose=False, check=False, engine=self.engine, ddp=self.ddp, ) self.config.parameters['model - device'] = str(self.runner.device) self.model_metrics['final epoch'] = self.runner.stage_epoch_step for key, value in self.runner.epoch_metrics.items(): self.model_metrics[key] = value with open('model_details.txt', 'w') as file: file.write('%s\n\n%s\n\n%s' % (str(self.runner.model), str( self.runner.optimizer), str(self.runner.scheduler))) return model
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--fast', action='store_true') parser.add_argument('--mixup', action='store_true') parser.add_argument('--balance', action='store_true') parser.add_argument('--balance-datasets', action='store_true') parser.add_argument('--swa', action='store_true') parser.add_argument('--show', action='store_true') parser.add_argument('--use-idrid', action='store_true') parser.add_argument('--use-messidor', action='store_true') parser.add_argument('--use-aptos2015', action='store_true') parser.add_argument('--use-aptos2019', action='store_true') parser.add_argument('-v', '--verbose', action='store_true') parser.add_argument('--coarse', action='store_true') parser.add_argument('-acc', '--accumulation-steps', type=int, default=1, help='Number of batches to process') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory') parser.add_argument('-m', '--model', type=str, default='resnet18_gap', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') parser.add_argument('-f', '--fold', action='append', type=int, default=None) parser.add_argument('-ft', '--fine-tune', default=0, type=int) parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--criterion-reg', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-ord', type=str, default=None, nargs='+', help='Criterion') parser.add_argument('--criterion-cls', type=str, default=['ce'], nargs='+', help='Criterion') parser.add_argument('-l1', type=float, default=0, help='L1 regularization loss') parser.add_argument('-l2', type=float, default=0, help='L2 regularization loss') parser.add_argument('-o', '--optimizer', default='Adam', help='Name of the optimizer') parser.add_argument('-p', '--preprocessing', default=None, help='Preprocessing method') parser.add_argument( '-c', '--checkpoint', type=str, default=None, help='Checkpoint filename to use as initial model weights') parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(), type=int, help='Num workers') parser.add_argument('-a', '--augmentations', default='medium', type=str, help='') parser.add_argument('-tta', '--tta', default=None, type=str, help='Type of TTA to use [fliplr, d4]') parser.add_argument('-t', '--transfer', default=None, type=str, help='') parser.add_argument('--fp16', action='store_true') parser.add_argument('-s', '--scheduler', default='multistep', type=str, help='') parser.add_argument('--size', default=512, type=int, help='Image size for training & inference') parser.add_argument('-wd', '--weight-decay', default=0, type=float, help='L2 weight decay') parser.add_argument('-wds', '--weight-decay-step', default=None, type=float, help='L2 weight decay step to add after each epoch') parser.add_argument('-d', '--dropout', default=0.0, type=float, help='Dropout before head layer') parser.add_argument( '--warmup', default=0, type=int, help= 'Number of warmup epochs with 0.1 of the initial LR and frozed encoder' ) parser.add_argument('-x', '--experiment', default=None, type=str, help='Dropout before head layer') args = parser.parse_args() data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate l1 = args.l1 l2 = args.l2 early_stopping = args.early_stopping model_name = args.model optimizer_name = args.optimizer image_size = (args.size, args.size) fast = args.fast augmentations = args.augmentations fp16 = args.fp16 fine_tune = args.fine_tune criterion_reg_name = args.criterion_reg criterion_cls_name = args.criterion_cls criterion_ord_name = args.criterion_ord folds = args.fold mixup = args.mixup balance = args.balance balance_datasets = args.balance_datasets use_swa = args.swa show_batches = args.show scheduler_name = args.scheduler verbose = args.verbose weight_decay = args.weight_decay use_idrid = args.use_idrid use_messidor = args.use_messidor use_aptos2015 = args.use_aptos2015 use_aptos2019 = args.use_aptos2019 warmup = args.warmup dropout = args.dropout use_unsupervised = False experiment = args.experiment preprocessing = args.preprocessing weight_decay_step = args.weight_decay_step coarse_grading = args.coarse class_names = get_class_names(coarse_grading) assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor current_time = datetime.now().strftime('%b%d_%H_%M') random_name = get_random_name() if folds is None or len(folds) == 0: folds = [None] for fold in folds: torch.cuda.empty_cache() checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}' if preprocessing is not None: checkpoint_prefix += f'_{preprocessing}' if use_aptos2019: checkpoint_prefix += '_aptos2019' if use_aptos2015: checkpoint_prefix += '_aptos2015' if use_messidor: checkpoint_prefix += '_messidor' if use_idrid: checkpoint_prefix += '_idrid' if coarse_grading: checkpoint_prefix += '_coarse' if fold is not None: checkpoint_prefix += f'_fold{fold}' checkpoint_prefix += f'_{random_name}' if experiment is not None: checkpoint_prefix = experiment directory_prefix = f'{current_time}/{checkpoint_prefix}' log_dir = os.path.join('runs', directory_prefix) os.makedirs(log_dir, exist_ok=False) config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json') with open(config_fname, 'w') as f: train_session_args = vars(args) f.write(json.dumps(train_session_args, indent=2)) set_manual_seed(args.seed) num_classes = len(class_names) model = get_model(model_name, num_classes=num_classes, dropout=dropout).cuda() if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint['model_state_dict'] for name, value in pretrained_dict.items(): try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) report_checkpoint(checkpoint) if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) train_ds, valid_ds, train_sizes = get_datasets( data_dir=data_dir, use_aptos2019=use_aptos2019, use_aptos2015=use_aptos2015, use_idrid=use_idrid, use_messidor=use_messidor, use_unsupervised=False, coarse_grading=coarse_grading, image_size=image_size, augmentation=augmentations, preprocessing=preprocessing, target_dtype=int, fold=fold, folds=4) train_loader, valid_loader = get_dataloaders( train_ds, valid_ds, batch_size=batch_size, num_workers=num_workers, train_sizes=train_sizes, balance=balance, balance_datasets=balance_datasets, balance_unlabeled=False) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader print('Datasets :', data_dir) print(' Train size :', len(train_loader), len(train_loader.dataset)) print(' Valid size :', len(valid_loader), len(valid_loader.dataset)) print(' Aptos 2019 :', use_aptos2019) print(' Aptos 2015 :', use_aptos2015) print(' IDRID :', use_idrid) print(' Messidor :', use_messidor) print('Train session :', directory_prefix) print(' FP16 mode :', fp16) print(' Fast mode :', fast) print(' Mixup :', mixup) print(' Balance cls. :', balance) print(' Balance ds. :', balance_datasets) print(' Warmup epoch :', warmup) print(' Train epochs :', num_epochs) print(' Fine-tune ephs :', fine_tune) print(' Workers :', num_workers) print(' Fold :', fold) print(' Log dir :', log_dir) print(' Augmentations :', augmentations) print('Model :', model_name) print(' Parameters :', count_parameters(model)) print(' Image size :', image_size) print(' Dropout :', dropout) print(' Classes :', class_names, num_classes) print('Optimizer :', optimizer_name) print(' Learning rate :', learning_rate) print(' Batch size :', batch_size) print(' Criterion (cls):', criterion_cls_name) print(' Criterion (reg):', criterion_reg_name) print(' Criterion (ord):', criterion_ord_name) print(' Scheduler :', scheduler_name) print(' Weight decay :', weight_decay, weight_decay_step) print(' L1 reg. :', l1) print(' L2 reg. :', l2) print(' Early stopping :', early_stopping) # model training callbacks = [] criterions = {} main_metric = 'cls/kappa' if criterion_reg_name is not None: cb, crits = get_reg_callbacks(criterion_reg_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_ord_name is not None: cb, crits = get_ord_callbacks(criterion_ord_name, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if criterion_cls_name is not None: cb, crits = get_cls_callbacks(criterion_cls_name, num_classes=num_classes, num_epochs=num_epochs, class_names=class_names, show=show_batches) callbacks += cb criterions.update(crits) if l1 > 0: callbacks += [ LPRegularizationCallback(start_wd=l1, end_wd=l1, schedule=None, prefix='l1', p=1) ] if l2 > 0: callbacks += [ LPRegularizationCallback(start_wd=l2, end_wd=l2, schedule=None, prefix='l2', p=2) ] callbacks += [CustomOptimizerCallback()] runner = SupervisedRunner(input_key='image') # Pretrain/warmup if warmup: set_trainable(model.encoder, False, False) optimizer = get_optimizer('Adam', get_optimizable_parameters(model), learning_rate=learning_rate * 0.1) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=None, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'warmup'), num_epochs=warmup, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer # Main train if num_epochs: set_trainable(model.encoder, True, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay) if use_swa: from torchcontrib.optim import SWA optimizer = SWA(optimizer, swa_start=len(train_loader), swa_freq=512) scheduler = get_scheduler(scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(train_loader)) # Additional callbacks that specific to main stage only added here to copy of callbacks main_stage_callbacks = callbacks if early_stopping: es_callback = EarlyStoppingCallback(early_stopping, min_delta=1e-4, metric=main_metric, minimize=False) main_stage_callbacks = callbacks + [es_callback] runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=main_stage_callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'main'), num_epochs=num_epochs, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) del optimizer, scheduler best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint) # Restoring best model from checkpoint checkpoint = load_checkpoint(best_checkpoint) unpack_checkpoint(checkpoint, model=model) report_checkpoint(checkpoint) # Stage 3 - Fine tuning if fine_tune: set_trainable(model.encoder, False, False) optimizer = get_optimizer(optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate) scheduler = get_scheduler('multistep', optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(train_loader)) runner.train(fp16=fp16, model=model, criterion=criterions, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=os.path.join(log_dir, 'finetune'), num_epochs=fine_tune, verbose=verbose, main_metric=main_metric, minimize_metric=False, checkpoint_data={"cmd_args": vars(args)}) best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', 'best.pth') model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints', f'{checkpoint_prefix}.pth') clean_checkpoint(best_checkpoint, model_checkpoint)
loaders = OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader num_epochs = 50 logdir = "/var/data/deepfake/" + experiment_name runner = SupervisedRunner() runner.train(fp16=False, model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, logdir=logdir, scheduler=scheduler, num_epochs=num_epochs, callbacks=[ MultiMetricCallback(metric_fn=catalyst_roc_auc, prefix='rocauc', input_key="targets", output_key="logits", list_args=['_']), MultiMetricCallback(metric_fn=catalyst_logloss, prefix='logloss', input_key="targets", output_key="logits", list_args=['_']), EarlyStoppingCallback(patience=10, min_delta=0.01) ], verbose=True)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--fast', action='store_true') parser.add_argument('-dd', '--data-dir', type=str, default='data', help='Data directory for INRIA sattelite dataset') parser.add_argument('-m', '--model', type=str, default='cls_resnet18', help='') parser.add_argument('-b', '--batch-size', type=int, default=8, help='Batch Size during training, e.g. -b 64') parser.add_argument('-e', '--epochs', type=int, default=100, help='Epoch to run') parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') parser.add_argument('-fe', '--freeze-encoder', action='store_true') parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('-l', '--criterion', type=str, default='bce', help='Criterion') parser.add_argument('-o', '--optimizer', default='Adam', help='Name of the optimizer') parser.add_argument( '-c', '--checkpoint', type=str, default=None, help='Checkpoint filename to use as initial model weights') parser.add_argument('-w', '--workers', default=multiprocessing.cpu_count(), type=int, help='Num workers') parser.add_argument('-a', '--augmentations', default='hard', type=str, help='') parser.add_argument('-tta', '--tta', default=None, type=str, help='Type of TTA to use [fliplr, d4]') parser.add_argument('-tm', '--train-mode', default='random', type=str, help='') parser.add_argument('-rm', '--run-mode', default='fit_predict', type=str, help='') parser.add_argument('--transfer', default=None, type=str, help='') parser.add_argument('--fp16', action='store_true') args = parser.parse_args() set_manual_seed(args.seed) data_dir = args.data_dir num_workers = args.workers num_epochs = args.epochs batch_size = args.batch_size learning_rate = args.learning_rate early_stopping = args.early_stopping model_name = args.model optimizer_name = args.optimizer image_size = (512, 512) fast = args.fast augmentations = args.augmentations train_mode = args.train_mode run_mode = args.run_mode log_dir = None fp16 = args.fp16 freeze_encoder = args.freeze_encoder run_train = run_mode == 'fit_predict' or run_mode == 'fit' run_predict = run_mode == 'fit_predict' or run_mode == 'predict' model = maybe_cuda(get_model(model_name, num_classes=1)) if args.transfer: transfer_checkpoint = fs.auto_file(args.transfer) print("Transfering weights from model checkpoint", transfer_checkpoint) checkpoint = load_checkpoint(transfer_checkpoint) pretrained_dict = checkpoint['model_state_dict'] for name, value in pretrained_dict.items(): try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) checkpoint = None if args.checkpoint: checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) unpack_checkpoint(checkpoint, model=model) checkpoint_epoch = checkpoint['epoch'] print('Loaded model weights from:', args.checkpoint) print('Epoch :', checkpoint_epoch) print('Metrics (Train):', 'f1 :', checkpoint['epoch_metrics']['train']['f1_score'], 'loss:', checkpoint['epoch_metrics']['train']['loss']) print('Metrics (Valid):', 'f1 :', checkpoint['epoch_metrics']['valid']['f1_score'], 'loss:', checkpoint['epoch_metrics']['valid']['loss']) log_dir = os.path.dirname( os.path.dirname(fs.auto_file(args.checkpoint))) if run_train: if freeze_encoder: set_trainable(model.encoder, trainable=False, freeze_bn=True) criterion = get_loss(args.criterion) parameters = get_optimizable_parameters(model) optimizer = get_optimizer(optimizer_name, parameters, learning_rate) if checkpoint is not None: try: unpack_checkpoint(checkpoint, optimizer=optimizer) print('Restored optimizer state from checkpoint') except Exception as e: print('Failed to restore optimizer state from checkpoint', e) train_loader, valid_loader = get_dataloaders( data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, image_size=image_size, augmentation=augmentations, fast=fast) loaders = collections.OrderedDict() loaders["train"] = train_loader loaders["valid"] = valid_loader current_time = datetime.now().strftime('%b%d_%H_%M') prefix = f'adversarial/{args.model}/{current_time}_{args.criterion}' if fp16: prefix += '_fp16' if fast: prefix += '_fast' log_dir = os.path.join('runs', prefix) os.makedirs(log_dir, exist_ok=False) scheduler = MultiStepLR(optimizer, milestones=[10, 30, 50, 70, 90], gamma=0.5) print('Train session :', prefix) print('\tFP16 mode :', fp16) print('\tFast mode :', args.fast) print('\tTrain mode :', train_mode) print('\tEpochs :', num_epochs) print('\tEarly stopping :', early_stopping) print('\tWorkers :', num_workers) print('\tData dir :', data_dir) print('\tLog dir :', log_dir) print('\tAugmentations :', augmentations) print('\tTrain size :', len(train_loader), len(train_loader.dataset)) print('\tValid size :', len(valid_loader), len(valid_loader.dataset)) print('Model :', model_name) print('\tParameters :', count_parameters(model)) print('\tImage size :', image_size) print('\tFreeze encoder :', freeze_encoder) print('Optimizer :', optimizer_name) print('\tLearning rate :', learning_rate) print('\tBatch size :', batch_size) print('\tCriterion :', args.criterion) # model training visualization_fn = partial(draw_classification_predictions, class_names=['Train', 'Test']) callbacks = [ F1ScoreCallback(), AUCCallback(), ShowPolarBatchesCallback(visualization_fn, metric='f1_score', minimize=False), ] if early_stopping: callbacks += [ EarlyStoppingCallback(early_stopping, metric='auc', minimize=False) ] runner = SupervisedRunner(input_key='image') runner.train(fp16=fp16, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, logdir=log_dir, num_epochs=num_epochs, verbose=True, main_metric='auc', minimize_metric=False, state_kwargs={"cmd_args": vars(args)}) if run_predict and not fast: # Training is finished. Let's run predictions using best checkpoint weights best_checkpoint = load_checkpoint( fs.auto_file('best.pth', where=log_dir)) unpack_checkpoint(best_checkpoint, model=model) model.eval() torch.no_grad() train_csv = pd.read_csv(os.path.join(data_dir, 'train.csv')) train_csv['id_code'] = train_csv['id_code'].apply( lambda x: os.path.join(data_dir, 'train_images', f'{x}.png')) test_ds = RetinopathyDataset(train_csv['id_code'], None, get_test_aug(image_size), target_as_array=True) test_dl = DataLoader(test_ds, batch_size, pin_memory=True, num_workers=num_workers) test_ids = [] test_preds = [] for batch in tqdm(test_dl, desc='Inference'): input = batch['image'].cuda() outputs = model(input) predictions = to_numpy(outputs['logits'].sigmoid().squeeze(1)) test_ids.extend(batch['image_id']) test_preds.extend(predictions) df = pd.DataFrame.from_dict({ 'id_code': test_ids, 'is_test': test_preds }) df.to_csv(os.path.join(log_dir, 'test_in_train.csv'), index=None)
optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 8], gamma=0.3) # model runner runner = SupervisedRunner() # model training runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=[EarlyStoppingCallback(patience=2, min_delta=0.01)], logdir=logdir, num_epochs=num_epochs, check=True, ) # In[ ]: # utils.plot_metrics(logdir=logdir, metrics=["loss", "_base/lr"]) # # Setup 4 - training with additional metrics # In[ ]: from catalyst.runners import SupervisedRunner from catalyst.dl import EarlyStoppingCallback, AccuracyCallback
def train(self): # model = {"model": self.model} # criterion = {"criterion": nn.CrossEntropyLoss()} # optimizer = {"optimizer": self.optimizer} callbacks = [ # dl.CriterionCallback( # input_key="logits", # target_key="targets", # metric_key="loss", # criterion_key="criterion", # ), # dl.OptimizerCallback( # model_key="model", # optimizer_key="optimizer", # metric_key="loss" # ), EarlyStoppingCallback(patience=15, metric_key="loss", loader_key="valid", minimize=True, min_delta=0), AccuracyCallback(num_classes=2, input_key="logits", target_key="targets"), AUCCallback(input_key="logits", target_key="targets"), # CheckpointCallback( # "./logs", loader_key="valid", metric_key="loss", minimize=True, save_n_best=3, # # load_on_stage_start={"model": "best"}, # load_on_stage_end={"model": "best"} # ), ] scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode="min") train_dataset = TensorDataset(self.tr_eps, self.tr_labels) val_dataset = TensorDataset(self.val_eps, self.val_labels) test_dataset = TensorDataset(self.tst_eps, self.test_labels) runner = CustomRunner("./logs") v_bs = self.val_eps.shape[0] t_bs = self.tst_eps.shape[0] loaders = { "train": DataLoader( train_dataset, batch_size=self.batch_size, num_workers=0, shuffle=True, ), "valid": DataLoader( val_dataset, batch_size=v_bs, num_workers=0, shuffle=True, ), } if self.complete_arc == True: if self.PT in ["milc", "two-loss-milc"]: if self.exp in ["UFPT", "FPT"]: model_dict = torch.load( os.path.join(self.oldpath, "best_full" + ".pth"), map_location=self.device, ) model_dict = model_dict["model_state_dict"] print("Complete Arch Loaded") self.model.load_state_dict(model_dict) # num_features=2 # model training # train_loader_param = {"batch_size": 64, # "shuffle":True, # } # val_loader_param = {"batch_size": 32, # "shuffle": True, # } # loaders_params = {"train" : train_loader_param, # "valid": val_loader_param} # datasets = { # "batch_size": 64, # "num_workers": 1, # "loaders_params": loaders_params, # "get_datasets_fn": self.datasets_fn, # "num_features": num_features, # }, runner.train( model=self.model, optimizer=self.optimizer, # criterion=criterion, scheduler=scheduler, loaders=loaders, valid_loader='valid', callbacks=callbacks, logdir="./logs", num_epochs=self.epochs, verbose=True, load_best_on_end=True, valid_metric="loss", minimize_valid_metric=True, ) loader = (DataLoader(test_dataset, batch_size=t_bs, num_workers=1, shuffle=True), ) ( self.test_accuracy, self.test_auc, self.test_loss, ) = runner.predict_batch(next(iter(loader)))
def main(cfg: DictConfig): cwd = Path(get_original_cwd()) # overwrite config if continue training from checkpoint resume_cfg = None if "resume" in cfg: cfg_path = cwd / cfg.resume / ".hydra/config.yaml" print(f"Continue from: {cfg.resume}") # Overwrite everything except device # TODO config merger (perhaps continue training with the same optimizer but other lrs?) resume_cfg = OmegaConf.load(cfg_path) cfg.model = resume_cfg.model if cfg.train.num_epochs == 0: cfg.data.scale_factor = resume_cfg.data.scale_factor OmegaConf.save(cfg, ".hydra/config.yaml") print(OmegaConf.to_yaml(cfg)) device = set_device_id(cfg.device) set_seed(cfg.seed, device=device) # Augmentations if cfg.data.aug == "auto": transforms = albu.load(cwd / "autoalbument/autoconfig.json") else: transforms = D.get_training_augmentations() if OmegaConf.is_missing(cfg.model, "convert_bottleneck"): cfg.model.convert_bottleneck = (0, 0, 0) # Model print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} " f"convert_bn={cfg.model.convert_bn} " f"convert_bottleneck={cfg.model.convert_bottleneck} ") model = get_segmentation_model( arch=cfg.model.arch, encoder_name=cfg.model.encoder_name, encoder_weights=cfg.model.encoder_weights, classes=1, convert_bn=cfg.model.convert_bn, convert_bottleneck=cfg.model.convert_bottleneck, # decoder_attention_type="scse", # TODO to config ) model = model.to(device) model.train() print(model) # Optimization # Reduce LR for pretrained encoder layerwise_params = { "encoder*": dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder) } model_params = cutils.process_model_params( model, layerwise_params=layerwise_params) # Select optimizer optimizer = get_optimizer( name=cfg.optim.name, model_params=model_params, lr=cfg.optim.lr, wd=cfg.optim.wd, lookahead=cfg.optim.lookahead, ) criterion = { "dice": DiceLoss(), # "dice": SoftDiceLoss(mode="binary", smooth=1e-7), "iou": IoULoss(), "bce": nn.BCEWithLogitsLoss(), "lovasz": LovaszLossBinary(), "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75), } # Load states if resuming training if "resume" in cfg: checkpoint_path = (cwd / cfg.resume / cfg.train.logdir / "checkpoints/best_full.pth") if checkpoint_path.exists(): print(f"\nLoading checkpoint {str(checkpoint_path)}") checkpoint = cutils.load_checkpoint(checkpoint_path) cutils.unpack_checkpoint( checkpoint=checkpoint, model=model, optimizer=optimizer if resume_cfg.optim.name == cfg.optim.name else None, criterion=criterion, ) else: raise ValueError("Nothing to resume, checkpoint missing") # We could only want to validate resume, in this case skip training routine best_th = 0.5 stats = None if cfg.data.stats: print(f"Use statistics from file: {cfg.data.stats}") stats = cwd / cfg.data.stats if cfg.train.num_epochs is not None: callbacks = [ # Each criterion is calculated separately. CriterionCallback(input_key="mask", prefix="loss_dice", criterion_key="dice"), CriterionCallback(input_key="mask", prefix="loss_iou", criterion_key="iou"), CriterionCallback(input_key="mask", prefix="loss_bce", criterion_key="bce"), CriterionCallback(input_key="mask", prefix="loss_lovasz", criterion_key="lovasz"), CriterionCallback( input_key="mask", prefix="loss_focal_tversky", criterion_key="focal_tversky", ), # And only then we aggregate everything into one loss. MetricAggregationCallback( prefix="loss", mode="weighted_sum", # can be "sum", "weighted_sum" or "mean" # because we want weighted sum, we need to add scale for each loss metrics={ "loss_dice": cfg.loss.dice, "loss_iou": cfg.loss.iou, "loss_bce": cfg.loss.bce, "loss_lovasz": cfg.loss.lovasz, "loss_focal_tversky": cfg.loss.focal_tversky, }, ), # metrics DiceCallback(input_key="mask"), IouCallback(input_key="mask"), # gradient accumulation OptimizerCallback(accumulation_steps=cfg.optim.accumulate), # early stopping SchedulerCallback(reduced_metric="loss_dice", mode=cfg.scheduler.mode), EarlyStoppingCallback(**cfg.scheduler.early_stopping, minimize=False), # TODO WandbLogger works poorly with multistage right now WandbLogger(project=cfg.project, config=dict(cfg)), # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best), ] # Training runner = SupervisedRunner(device=device, input_key="image", input_target_key="mask") # TODO Scheduler does not work now, every stage restarts from base lr scheduler_warm_restart = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[1, 2], gamma=10, ) for i, (size, num_epochs) in enumerate( zip(cfg.data.sizes, cfg.train.num_epochs)): scale = size / 1024 print( f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}" ) # Datasets ( train_ds, valid_ds, train_images, val_images, ) = D.get_train_valid_datasets_from_path( # path=(cwd / cfg.data.path), path=(cwd / f"data/hubmap-{size}x{size}/"), train_ids=cfg.data.train_ids, valid_ids=cfg.data.valid_ids, seed=cfg.seed, valid_split=cfg.data.valid_split, mean=cfg.data.mean, std=cfg.data.std, transforms=transforms, stats=stats, ) train_bs = int(cfg.loader.train_bs / (scale**2)) valid_bs = int(cfg.loader.valid_bs / (scale**2)) print( f"train: {len(train_ds)}; bs {train_bs}", f"valid: {len(valid_ds)}, bs {valid_bs}", ) # Data loaders data_loaders = D.get_data_loaders( train_ds=train_ds, valid_ds=valid_ds, train_bs=train_bs, valid_bs=valid_bs, num_workers=cfg.loader.num_workers, ) # Select scheduler scheduler = get_scheduler( name=cfg.scheduler.type, optimizer=optimizer, num_epochs=num_epochs * (len(data_loaders["train"]) if cfg.scheduler.mode == "batch" else 1), eta_min=scheduler_warm_restart.get_last_lr()[0] / cfg.scheduler.eta_min_factor, plateau=cfg.scheduler.plateau, ) runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, logdir=cfg.train.logdir, loaders=data_loaders, num_epochs=num_epochs, verbose=True, main_metric=cfg.train.main_metric, load_best_on_end=True, minimize_metric=False, check=cfg.check, fp16=dict(amp=cfg.amp), ) # Set new initial LR for optimizer after restart scheduler_warm_restart.step() print( f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}" ) # Find optimal threshold for dice score model.eval() best_th, dices = find_dice_threshold(model, data_loaders["valid"]) print("Best dice threshold", best_th, np.max(dices[1])) np.save(f"dices_{size}.npy", dices) else: print("Validation only") # Datasets size = cfg.data.sizes[-1] train_ds, valid_ds = D.get_train_valid_datasets_from_path( # path=(cwd / cfg.data.path), path=(cwd / f"data/hubmap-{size}x{size}/"), train_ids=cfg.data.train_ids, valid_ids=cfg.data.valid_ids, seed=cfg.seed, valid_split=cfg.data.valid_split, mean=cfg.data.mean, std=cfg.data.std, transforms=transforms, stats=stats, ) train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2)) valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2)) print( f"train: {len(train_ds)}; bs {train_bs}", f"valid: {len(valid_ds)}, bs {valid_bs}", ) # Data loaders data_loaders = D.get_data_loaders( train_ds=train_ds, valid_ds=valid_ds, train_bs=train_bs, valid_bs=valid_bs, num_workers=cfg.loader.num_workers, ) # Find optimal threshold for dice score model.eval() best_th, dices = find_dice_threshold(model, data_loaders["valid"]) print("Best dice threshold", best_th, np.max(dices[1])) np.save(f"dices_val.npy", dices) # # # Load best checkpoint # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth" # if checkpoint_path.exists(): # print(f"\nLoading checkpoint {str(checkpoint_path)}") # state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ # "model_state_dict" # ] # model.load_state_dict(state_dict) # del state_dict # model = model.to(device) # Load config for updating with threshold and metric # (otherwise loading do not work) cfg = OmegaConf.load(".hydra/config.yaml") cfg.threshold = float(best_th) # Evaluate on full-size image if valid_ids is non-empty df_train = pd.read_csv(cwd / "data/train.csv") df_train = { r["id"]: r["encoding"] for r in df_train.to_dict(orient="record") } dices = [] unique_ids = sorted( set( str(p).split("/")[-1].split("_")[0] for p in (cwd / cfg.data.path / "train").iterdir())) size = cfg.data.sizes[-1] scale = size / 1024 for image_id in cfg.data.valid_ids: image_name = unique_ids[image_id] print(f"\nValidate for {image_name}") rle_pred, shape = inference_one( image_path=(cwd / f"data/train/{image_name}.tiff"), target_path=Path("."), cfg=cfg, model=model, scale_factor=scale, tile_size=cfg.data.tile_size, tile_step=cfg.data.tile_step, threshold=best_th, save_raw=False, tta_mode=None, weight="pyramid", device=device, filter_crops="tissue", stats=stats, ) print("Predict", shape) pred = rle_decode(rle_pred["predicted"], shape) mask = rle_decode(df_train[image_name], shape) assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}" assert pred.shape == shape, f"pred {pred.shape}, expected {shape}" dices.append( dice( torch.from_numpy(pred).type(torch.uint8), torch.from_numpy(mask).type(torch.uint8), threshold=None, activation="none", )) print("Full image dice:", np.mean(dices)) OmegaConf.save(cfg, ".hydra/config.yaml") return