Exemple #1
0
    def __init__(self,
                 project,
                 model,
                 early_stop: bool = True,
                 runner_name=None,
                 train_test_val_indices=None):
        self.task_flow = project.get_full_flow()

        self.default_criterion = self.task_flow.get_loss()
        self.default_callbacks = self.default_criterion.catalyst_callbacks()
        self.default_optimizer = partial(optim.AdamW, lr=1e-4)
        self.default_scheduler = ReduceLROnPlateau
        self.project_dir: Path = project.project_dir
        self.project_dir.mkdir(exist_ok=True)
        runner_name = f'{self.task_flow.get_name()}_{time()}' if runner_name is None else runner_name
        self.default_logdir = f'./logdir_{runner_name}'

        if early_stop:
            self.default_callbacks.append(EarlyStoppingCallback(patience=5))

        if train_test_val_indices is None:
            (self.project_dir / self.default_logdir).mkdir(exist_ok=True)
            train_test_val_indices = project_split(
                project.df, self.project_dir / self.default_logdir)
        else:
            save_split(self.project_dir / self.default_logdir,
                       train_test_val_indices)
        self.train_test_val_indices = train_test_val_indices
        self.tensor_loggers = project.converters.tensorboard_converters
        converters_file = self.project_dir / self.default_logdir / 'converters.pkl'
        if converters_file.exists():
            project.converters.load_state_dict(torch.load(converters_file))
        else:
            torch.save(project.converters.state_dict(), converters_file)
        super().__init__(model=model)
Exemple #2
0
    def __init__(self,
                 project,
                 model,
                 early_stop: bool = True,
                 balance_dataparallel_memory: bool = False,
                 runner_name=None,
                 train_test_val_indices=None):
        self.task_flow = project.get_full_flow()

        self.default_criterion = self.task_flow.get_loss()
        self.balance_dataparallel_memory = balance_dataparallel_memory

        self.default_callbacks = []
        if self.balance_dataparallel_memory:
            self.default_callbacks.append(ReplaceGatherCallback(
                self.task_flow))
        self.default_callbacks.extend(
            self.default_criterion.catalyst_callbacks())
        self.default_optimizer = partial(optim.AdamW, lr=1e-4)
        self.default_scheduler = ReduceLROnPlateau
        self.project_dir: Path = project.project_dir
        self.project_dir.mkdir(exist_ok=True)
        runner_name = f'{self.task_flow.get_name()}_{time()}' if runner_name is None else runner_name
        self.default_logdir = f'./logdir_{runner_name}'

        if early_stop:
            self.default_callbacks.append(EarlyStoppingCallback(patience=5))

        (self.project_dir / self.default_logdir).mkdir(exist_ok=True)
        if train_test_val_indices is None:
            train_test_val_indices = project_split(
                project.df, self.project_dir / self.default_logdir)
        else:
            save_split(self.project_dir / self.default_logdir,
                       train_test_val_indices)
        self.train_test_val_indices = train_test_val_indices
        self.tensor_loggers = project.converters.tensorboard_converters
        super().__init__(model=model)
    def torch_train(self, loaders, model, optimizer, loss_func, scheduler,
                    config):
        self.config = config
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.scheduler = scheduler
        self.loader_key = list(loaders)[0]
        self.metric_key = 'loss'
        self.import_from_config()

        if 'cuda' in str(self.device):
            self.optimizer_to(optimizer, self.device)

        #checks if logdir exists - deletes it if yes
        self.check_logdir()

        if self.loader_key != 'train':
            warnings.warn(
                "WARNING: loader to be used for early-stop callback is '%s'. You can define it manually in /lib/estimator/pytorch_estimator.torch_train"
                % (self.loader_key))

        model = self.model

        torch.cuda.empty_cache()

        if self.ddp: self.engine = None
        else: self.engine = DeviceEngine(self.device)

        self.print_info()

        self.runner.train(
            model=model,
            criterion=self.loss_func,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            loaders=loaders,
            logdir=self.config.logdir,
            num_epochs=self.config.n_epochs,
            callbacks=[
                EarlyStoppingCallback(patience=self.config.patience,
                                      min_delta=self.config.min_delta,
                                      loader_key=self.loader_key,
                                      metric_key=self.metric_key,
                                      minimize=True),
                SchedulerCallback(
                    loader_key=self.loader_key,
                    metric_key=self.metric_key,
                ),
                SkipCheckpointCallback(logdir=self.config.logdir),
            ],
            verbose=False,
            check=False,
            engine=self.engine,
            ddp=self.ddp,
        )

        self.config.parameters['model - device'] = str(self.runner.device)
        self.model_metrics['final epoch'] = self.runner.stage_epoch_step
        for key, value in self.runner.epoch_metrics.items():
            self.model_metrics[key] = value

        with open('model_details.txt', 'w') as file:
            file.write('%s\n\n%s\n\n%s' %
                       (str(self.runner.model), str(
                           self.runner.optimizer), str(self.runner.scheduler)))

        return model
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')
    parser.add_argument('--swa', action='store_true')
    parser.add_argument('--show', action='store_true')
    parser.add_argument('--use-idrid', action='store_true')
    parser.add_argument('--use-messidor', action='store_true')
    parser.add_argument('--use-aptos2015', action='store_true')
    parser.add_argument('--use-aptos2019', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--coarse', action='store_true')
    parser.add_argument('-acc',
                        '--accumulation-steps',
                        type=int,
                        default=1,
                        help='Number of batches to process')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='resnet18_gap',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-f',
                        '--fold',
                        action='append',
                        type=int,
                        default=None)
    parser.add_argument('-ft', '--fine-tune', default=0, type=int)
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('--criterion-reg',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-ord',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-cls',
                        type=str,
                        default=['ce'],
                        nargs='+',
                        help='Criterion')
    parser.add_argument('-l1',
                        type=float,
                        default=0,
                        help='L1 regularization loss')
    parser.add_argument('-l2',
                        type=float,
                        default=0,
                        help='L2 regularization loss')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument('-p',
                        '--preprocessing',
                        default=None,
                        help='Preprocessing method')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-s',
                        '--scheduler',
                        default='multistep',
                        type=str,
                        help='')
    parser.add_argument('--size',
                        default=512,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-wd',
                        '--weight-decay',
                        default=0,
                        type=float,
                        help='L2 weight decay')
    parser.add_argument('-wds',
                        '--weight-decay-step',
                        default=None,
                        type=float,
                        help='L2 weight decay step to add after each epoch')
    parser.add_argument('-d',
                        '--dropout',
                        default=0.0,
                        type=float,
                        help='Dropout before head layer')
    parser.add_argument(
        '--warmup',
        default=0,
        type=int,
        help=
        'Number of warmup epochs with 0.1 of the initial LR and frozed encoder'
    )
    parser.add_argument('-x',
                        '--experiment',
                        default=None,
                        type=str,
                        help='Dropout before head layer')

    args = parser.parse_args()

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    l1 = args.l1
    l2 = args.l2
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (args.size, args.size)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    fine_tune = args.fine_tune
    criterion_reg_name = args.criterion_reg
    criterion_cls_name = args.criterion_cls
    criterion_ord_name = args.criterion_ord
    folds = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    use_swa = args.swa
    show_batches = args.show
    scheduler_name = args.scheduler
    verbose = args.verbose
    weight_decay = args.weight_decay
    use_idrid = args.use_idrid
    use_messidor = args.use_messidor
    use_aptos2015 = args.use_aptos2015
    use_aptos2019 = args.use_aptos2019
    warmup = args.warmup
    dropout = args.dropout
    use_unsupervised = False
    experiment = args.experiment
    preprocessing = args.preprocessing
    weight_decay_step = args.weight_decay_step
    coarse_grading = args.coarse
    class_names = get_class_names(coarse_grading)

    assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    if folds is None or len(folds) == 0:
        folds = [None]

    for fold in folds:
        torch.cuda.empty_cache()
        checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}'

        if preprocessing is not None:
            checkpoint_prefix += f'_{preprocessing}'
        if use_aptos2019:
            checkpoint_prefix += '_aptos2019'
        if use_aptos2015:
            checkpoint_prefix += '_aptos2015'
        if use_messidor:
            checkpoint_prefix += '_messidor'
        if use_idrid:
            checkpoint_prefix += '_idrid'
        if coarse_grading:
            checkpoint_prefix += '_coarse'

        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

        checkpoint_prefix += f'_{random_name}'

        if experiment is not None:
            checkpoint_prefix = experiment

        directory_prefix = f'{current_time}/{checkpoint_prefix}'
        log_dir = os.path.join('runs', directory_prefix)
        os.makedirs(log_dir, exist_ok=False)

        config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json')
        with open(config_fname, 'w') as f:
            train_session_args = vars(args)
            f.write(json.dumps(train_session_args, indent=2))

        set_manual_seed(args.seed)
        num_classes = len(class_names)
        model = get_model(model_name, num_classes=num_classes,
                          dropout=dropout).cuda()

        if args.transfer:
            transfer_checkpoint = fs.auto_file(args.transfer)
            print("Transfering weights from model checkpoint",
                  transfer_checkpoint)
            checkpoint = load_checkpoint(transfer_checkpoint)
            pretrained_dict = checkpoint['model_state_dict']

            for name, value in pretrained_dict.items():
                try:
                    model.load_state_dict(collections.OrderedDict([(name,
                                                                    value)]),
                                          strict=False)
                except Exception as e:
                    print(e)

            report_checkpoint(checkpoint)

        if args.checkpoint:
            checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_aptos2019=use_aptos2019,
            use_aptos2015=use_aptos2015,
            use_idrid=use_idrid,
            use_messidor=use_messidor,
            use_unsupervised=False,
            coarse_grading=coarse_grading,
            image_size=image_size,
            augmentation=augmentations,
            preprocessing=preprocessing,
            target_dtype=int,
            fold=fold,
            folds=4)

        train_loader, valid_loader = get_dataloaders(
            train_ds,
            valid_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            train_sizes=train_sizes,
            balance=balance,
            balance_datasets=balance_datasets,
            balance_unlabeled=False)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        print('Datasets         :', data_dir)
        print('  Train size     :', len(train_loader),
              len(train_loader.dataset))
        print('  Valid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('  Aptos 2019     :', use_aptos2019)
        print('  Aptos 2015     :', use_aptos2015)
        print('  IDRID          :', use_idrid)
        print('  Messidor       :', use_messidor)
        print('Train session    :', directory_prefix)
        print('  FP16 mode      :', fp16)
        print('  Fast mode      :', fast)
        print('  Mixup          :', mixup)
        print('  Balance cls.   :', balance)
        print('  Balance ds.    :', balance_datasets)
        print('  Warmup epoch   :', warmup)
        print('  Train epochs   :', num_epochs)
        print('  Fine-tune ephs :', fine_tune)
        print('  Workers        :', num_workers)
        print('  Fold           :', fold)
        print('  Log dir        :', log_dir)
        print('  Augmentations  :', augmentations)
        print('Model            :', model_name)
        print('  Parameters     :', count_parameters(model))
        print('  Image size     :', image_size)
        print('  Dropout        :', dropout)
        print('  Classes        :', class_names, num_classes)
        print('Optimizer        :', optimizer_name)
        print('  Learning rate  :', learning_rate)
        print('  Batch size     :', batch_size)
        print('  Criterion (cls):', criterion_cls_name)
        print('  Criterion (reg):', criterion_reg_name)
        print('  Criterion (ord):', criterion_ord_name)
        print('  Scheduler      :', scheduler_name)
        print('  Weight decay   :', weight_decay, weight_decay_step)
        print('  L1 reg.        :', l1)
        print('  L2 reg.        :', l2)
        print('  Early stopping :', early_stopping)

        # model training
        callbacks = []
        criterions = {}

        main_metric = 'cls/kappa'
        if criterion_reg_name is not None:
            cb, crits = get_reg_callbacks(criterion_reg_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_ord_name is not None:
            cb, crits = get_ord_callbacks(criterion_ord_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_cls_name is not None:
            cb, crits = get_cls_callbacks(criterion_cls_name,
                                          num_classes=num_classes,
                                          num_epochs=num_epochs,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if l1 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l1,
                                         end_wd=l1,
                                         schedule=None,
                                         prefix='l1',
                                         p=1)
            ]

        if l2 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l2,
                                         end_wd=l2,
                                         schedule=None,
                                         prefix='l2',
                                         p=2)
            ]

        callbacks += [CustomOptimizerCallback()]

        runner = SupervisedRunner(input_key='image')

        # Pretrain/warmup
        if warmup:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer('Adam',
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate * 0.1)

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=None,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'warmup'),
                         num_epochs=warmup,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer

        # Main train
        if num_epochs:
            set_trainable(model.encoder, True, False)

            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate,
                                      weight_decay=weight_decay)

            if use_swa:
                from torchcontrib.optim import SWA
                optimizer = SWA(optimizer,
                                swa_start=len(train_loader),
                                swa_freq=512)

            scheduler = get_scheduler(scheduler_name,
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=num_epochs,
                                      batches_in_epoch=len(train_loader))

            # Additional callbacks that specific to main stage only added here to copy of callbacks
            main_stage_callbacks = callbacks
            if early_stopping:
                es_callback = EarlyStoppingCallback(early_stopping,
                                                    min_delta=1e-4,
                                                    metric=main_metric,
                                                    minimize=False)
                main_stage_callbacks = callbacks + [es_callback]

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=main_stage_callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'main'),
                         num_epochs=num_epochs,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer, scheduler

            best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)

            # Restoring best model from checkpoint
            checkpoint = load_checkpoint(best_checkpoint)
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        # Stage 3 - Fine tuning
        if fine_tune:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate)
            scheduler = get_scheduler('multistep',
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=fine_tune,
                                      batches_in_epoch=len(train_loader))

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'finetune'),
                         num_epochs=fine_tune,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)
Exemple #5
0
loaders = OrderedDict()
loaders["train"] = train_loader
loaders["valid"] = valid_loader

num_epochs = 50
logdir = "/var/data/deepfake/" + experiment_name
runner = SupervisedRunner()

runner.train(fp16=False,
             model=model,
             criterion=criterion,
             optimizer=optimizer,
             loaders=loaders,
             logdir=logdir,
             scheduler=scheduler,
             num_epochs=num_epochs,
             callbacks=[
                 MultiMetricCallback(metric_fn=catalyst_roc_auc,
                                     prefix='rocauc',
                                     input_key="targets",
                                     output_key="logits",
                                     list_args=['_']),
                 MultiMetricCallback(metric_fn=catalyst_logloss,
                                     prefix='logloss',
                                     input_key="targets",
                                     output_key="logits",
                                     list_args=['_']),
                 EarlyStoppingCallback(patience=10, min_delta=0.01)
             ],
             verbose=True)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory for INRIA sattelite dataset')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='cls_resnet18',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-fe', '--freeze-encoder', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('-l',
                        '--criterion',
                        type=str,
                        default='bce',
                        help='Criterion')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='hard',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-tm',
                        '--train-mode',
                        default='random',
                        type=str,
                        help='')
    parser.add_argument('-rm',
                        '--run-mode',
                        default='fit_predict',
                        type=str,
                        help='')
    parser.add_argument('--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    run_mode = args.run_mode
    log_dir = None
    fp16 = args.fp16
    freeze_encoder = args.freeze_encoder

    run_train = run_mode == 'fit_predict' or run_mode == 'fit'
    run_predict = run_mode == 'fit_predict' or run_mode == 'predict'

    model = maybe_cuda(get_model(model_name, num_classes=1))

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint['model_state_dict']

        for name, value in pretrained_dict.items():
            try:
                model.load_state_dict(collections.OrderedDict([(name, value)]),
                                      strict=False)
            except Exception as e:
                print(e)

    checkpoint = None
    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        checkpoint_epoch = checkpoint['epoch']
        print('Loaded model weights from:', args.checkpoint)
        print('Epoch                    :', checkpoint_epoch)
        print('Metrics (Train):', 'f1  :',
              checkpoint['epoch_metrics']['train']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['train']['loss'])
        print('Metrics (Valid):', 'f1  :',
              checkpoint['epoch_metrics']['valid']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['valid']['loss'])

        log_dir = os.path.dirname(
            os.path.dirname(fs.auto_file(args.checkpoint)))

    if run_train:

        if freeze_encoder:
            set_trainable(model.encoder, trainable=False, freeze_bn=True)

        criterion = get_loss(args.criterion)
        parameters = get_optimizable_parameters(model)
        optimizer = get_optimizer(optimizer_name, parameters, learning_rate)

        if checkpoint is not None:
            try:
                unpack_checkpoint(checkpoint, optimizer=optimizer)
                print('Restored optimizer state from checkpoint')
            except Exception as e:
                print('Failed to restore optimizer state from checkpoint', e)

        train_loader, valid_loader = get_dataloaders(
            data_dir=data_dir,
            batch_size=batch_size,
            num_workers=num_workers,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        current_time = datetime.now().strftime('%b%d_%H_%M')
        prefix = f'adversarial/{args.model}/{current_time}_{args.criterion}'

        if fp16:
            prefix += '_fp16'

        if fast:
            prefix += '_fast'

        log_dir = os.path.join('runs', prefix)
        os.makedirs(log_dir, exist_ok=False)

        scheduler = MultiStepLR(optimizer,
                                milestones=[10, 30, 50, 70, 90],
                                gamma=0.5)

        print('Train session    :', prefix)
        print('\tFP16 mode      :', fp16)
        print('\tFast mode      :', args.fast)
        print('\tTrain mode     :', train_mode)
        print('\tEpochs         :', num_epochs)
        print('\tEarly stopping :', early_stopping)
        print('\tWorkers        :', num_workers)
        print('\tData dir       :', data_dir)
        print('\tLog dir        :', log_dir)
        print('\tAugmentations  :', augmentations)
        print('\tTrain size     :', len(train_loader),
              len(train_loader.dataset))
        print('\tValid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('Model            :', model_name)
        print('\tParameters     :', count_parameters(model))
        print('\tImage size     :', image_size)
        print('\tFreeze encoder :', freeze_encoder)
        print('Optimizer        :', optimizer_name)
        print('\tLearning rate  :', learning_rate)
        print('\tBatch size     :', batch_size)
        print('\tCriterion      :', args.criterion)

        # model training
        visualization_fn = partial(draw_classification_predictions,
                                   class_names=['Train', 'Test'])

        callbacks = [
            F1ScoreCallback(),
            AUCCallback(),
            ShowPolarBatchesCallback(visualization_fn,
                                     metric='f1_score',
                                     minimize=False),
        ]

        if early_stopping:
            callbacks += [
                EarlyStoppingCallback(early_stopping,
                                      metric='auc',
                                      minimize=False)
            ]

        runner = SupervisedRunner(input_key='image')
        runner.train(fp16=fp16,
                     model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     callbacks=callbacks,
                     loaders=loaders,
                     logdir=log_dir,
                     num_epochs=num_epochs,
                     verbose=True,
                     main_metric='auc',
                     minimize_metric=False,
                     state_kwargs={"cmd_args": vars(args)})

    if run_predict and not fast:
        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = load_checkpoint(
            fs.auto_file('best.pth', where=log_dir))
        unpack_checkpoint(best_checkpoint, model=model)

        model.eval()
        torch.no_grad()

        train_csv = pd.read_csv(os.path.join(data_dir, 'train.csv'))
        train_csv['id_code'] = train_csv['id_code'].apply(
            lambda x: os.path.join(data_dir, 'train_images', f'{x}.png'))
        test_ds = RetinopathyDataset(train_csv['id_code'],
                                     None,
                                     get_test_aug(image_size),
                                     target_as_array=True)
        test_dl = DataLoader(test_ds,
                             batch_size,
                             pin_memory=True,
                             num_workers=num_workers)

        test_ids = []
        test_preds = []

        for batch in tqdm(test_dl, desc='Inference'):
            input = batch['image'].cuda()
            outputs = model(input)
            predictions = to_numpy(outputs['logits'].sigmoid().squeeze(1))
            test_ids.extend(batch['image_id'])
            test_preds.extend(predictions)

        df = pd.DataFrame.from_dict({
            'id_code': test_ids,
            'is_test': test_preds
        })
        df.to_csv(os.path.join(log_dir, 'test_in_train.csv'), index=None)
Exemple #7
0
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=[3, 8],
                                                 gamma=0.3)

# model runner
runner = SupervisedRunner()

# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=[EarlyStoppingCallback(patience=2, min_delta=0.01)],
    logdir=logdir,
    num_epochs=num_epochs,
    check=True,
)

# In[ ]:

# utils.plot_metrics(logdir=logdir, metrics=["loss", "_base/lr"])

# # Setup 4 - training with additional metrics

# In[ ]:

from catalyst.runners import SupervisedRunner
from catalyst.dl import EarlyStoppingCallback, AccuracyCallback
    def train(self):
        # model = {"model": self.model}
        # criterion = {"criterion": nn.CrossEntropyLoss()}
        # optimizer = {"optimizer": self.optimizer}
        callbacks = [
            # dl.CriterionCallback(
            #     input_key="logits",
            #     target_key="targets",
            #     metric_key="loss",
            #     criterion_key="criterion",
            # ),
            # dl.OptimizerCallback(
            #     model_key="model",
            #     optimizer_key="optimizer",
            #     metric_key="loss"
            # ),
            EarlyStoppingCallback(patience=15,
                                  metric_key="loss",
                                  loader_key="valid",
                                  minimize=True,
                                  min_delta=0),
            AccuracyCallback(num_classes=2,
                             input_key="logits",
                             target_key="targets"),
            AUCCallback(input_key="logits", target_key="targets"),
            #     CheckpointCallback(
            #     "./logs", loader_key="valid", metric_key="loss", minimize=True, save_n_best=3,
            #     # load_on_stage_start={"model": "best"},
            #     load_on_stage_end={"model": "best"}
            # ),
        ]

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                               mode="min")
        train_dataset = TensorDataset(self.tr_eps, self.tr_labels)
        val_dataset = TensorDataset(self.val_eps, self.val_labels)
        test_dataset = TensorDataset(self.tst_eps, self.test_labels)
        runner = CustomRunner("./logs")
        v_bs = self.val_eps.shape[0]
        t_bs = self.tst_eps.shape[0]
        loaders = {
            "train":
            DataLoader(
                train_dataset,
                batch_size=self.batch_size,
                num_workers=0,
                shuffle=True,
            ),
            "valid":
            DataLoader(
                val_dataset,
                batch_size=v_bs,
                num_workers=0,
                shuffle=True,
            ),
        }

        if self.complete_arc == True:
            if self.PT in ["milc", "two-loss-milc"]:
                if self.exp in ["UFPT", "FPT"]:
                    model_dict = torch.load(
                        os.path.join(self.oldpath, "best_full" + ".pth"),
                        map_location=self.device,
                    )
                    model_dict = model_dict["model_state_dict"]
                    print("Complete Arch Loaded")
                    self.model.load_state_dict(model_dict)
        # num_features=2
        # model training
        # train_loader_param = {"batch_size": 64,
        #                       "shuffle":True,
        #                       }
        # val_loader_param = {"batch_size": 32,
        #                       "shuffle": True,
        #                       }

        # loaders_params = {"train" : train_loader_param,
        #                   "valid": val_loader_param}

        # datasets = {
        #               "batch_size": 64,
        #               "num_workers": 1,
        #               "loaders_params": loaders_params,
        #               "get_datasets_fn": self.datasets_fn,
        #               "num_features": num_features,

        #          },

        runner.train(
            model=self.model,
            optimizer=self.optimizer,
            # criterion=criterion,
            scheduler=scheduler,
            loaders=loaders,
            valid_loader='valid',
            callbacks=callbacks,
            logdir="./logs",
            num_epochs=self.epochs,
            verbose=True,
            load_best_on_end=True,
            valid_metric="loss",
            minimize_valid_metric=True,
        )

        loader = (DataLoader(test_dataset,
                             batch_size=t_bs,
                             num_workers=1,
                             shuffle=True), )

        (
            self.test_accuracy,
            self.test_auc,
            self.test_loss,
        ) = runner.predict_batch(next(iter(loader)))
Exemple #9
0
def main(cfg: DictConfig):

    cwd = Path(get_original_cwd())

    # overwrite config if continue training from checkpoint
    resume_cfg = None
    if "resume" in cfg:
        cfg_path = cwd / cfg.resume / ".hydra/config.yaml"
        print(f"Continue from: {cfg.resume}")
        # Overwrite everything except device
        # TODO config merger (perhaps continue training with the same optimizer but other lrs?)
        resume_cfg = OmegaConf.load(cfg_path)
        cfg.model = resume_cfg.model
        if cfg.train.num_epochs == 0:
            cfg.data.scale_factor = resume_cfg.data.scale_factor
        OmegaConf.save(cfg, ".hydra/config.yaml")

    print(OmegaConf.to_yaml(cfg))

    device = set_device_id(cfg.device)
    set_seed(cfg.seed, device=device)

    # Augmentations
    if cfg.data.aug == "auto":
        transforms = albu.load(cwd / "autoalbument/autoconfig.json")
    else:
        transforms = D.get_training_augmentations()

    if OmegaConf.is_missing(cfg.model, "convert_bottleneck"):
        cfg.model.convert_bottleneck = (0, 0, 0)

    # Model
    print(f"Setup model {cfg.model.arch} {cfg.model.encoder_name} "
          f"convert_bn={cfg.model.convert_bn} "
          f"convert_bottleneck={cfg.model.convert_bottleneck} ")
    model = get_segmentation_model(
        arch=cfg.model.arch,
        encoder_name=cfg.model.encoder_name,
        encoder_weights=cfg.model.encoder_weights,
        classes=1,
        convert_bn=cfg.model.convert_bn,
        convert_bottleneck=cfg.model.convert_bottleneck,
        # decoder_attention_type="scse",  # TODO to config
    )
    model = model.to(device)
    model.train()
    print(model)

    # Optimization
    # Reduce LR for pretrained encoder
    layerwise_params = {
        "encoder*":
        dict(lr=cfg.optim.lr_encoder, weight_decay=cfg.optim.wd_encoder)
    }
    model_params = cutils.process_model_params(
        model, layerwise_params=layerwise_params)

    # Select optimizer
    optimizer = get_optimizer(
        name=cfg.optim.name,
        model_params=model_params,
        lr=cfg.optim.lr,
        wd=cfg.optim.wd,
        lookahead=cfg.optim.lookahead,
    )

    criterion = {
        "dice": DiceLoss(),
        # "dice": SoftDiceLoss(mode="binary", smooth=1e-7),
        "iou": IoULoss(),
        "bce": nn.BCEWithLogitsLoss(),
        "lovasz": LovaszLossBinary(),
        "focal_tversky": FocalTverskyLoss(eps=1e-7, alpha=0.7, gamma=0.75),
    }

    # Load states if resuming training
    if "resume" in cfg:
        checkpoint_path = (cwd / cfg.resume / cfg.train.logdir /
                           "checkpoints/best_full.pth")
        if checkpoint_path.exists():
            print(f"\nLoading checkpoint {str(checkpoint_path)}")
            checkpoint = cutils.load_checkpoint(checkpoint_path)
            cutils.unpack_checkpoint(
                checkpoint=checkpoint,
                model=model,
                optimizer=optimizer
                if resume_cfg.optim.name == cfg.optim.name else None,
                criterion=criterion,
            )
        else:
            raise ValueError("Nothing to resume, checkpoint missing")

    # We could only want to validate resume, in this case skip training routine
    best_th = 0.5

    stats = None
    if cfg.data.stats:
        print(f"Use statistics from file: {cfg.data.stats}")
        stats = cwd / cfg.data.stats

    if cfg.train.num_epochs is not None:
        callbacks = [
            # Each criterion is calculated separately.
            CriterionCallback(input_key="mask",
                              prefix="loss_dice",
                              criterion_key="dice"),
            CriterionCallback(input_key="mask",
                              prefix="loss_iou",
                              criterion_key="iou"),
            CriterionCallback(input_key="mask",
                              prefix="loss_bce",
                              criterion_key="bce"),
            CriterionCallback(input_key="mask",
                              prefix="loss_lovasz",
                              criterion_key="lovasz"),
            CriterionCallback(
                input_key="mask",
                prefix="loss_focal_tversky",
                criterion_key="focal_tversky",
            ),
            # And only then we aggregate everything into one loss.
            MetricAggregationCallback(
                prefix="loss",
                mode="weighted_sum",  # can be "sum", "weighted_sum" or "mean"
                # because we want weighted sum, we need to add scale for each loss
                metrics={
                    "loss_dice": cfg.loss.dice,
                    "loss_iou": cfg.loss.iou,
                    "loss_bce": cfg.loss.bce,
                    "loss_lovasz": cfg.loss.lovasz,
                    "loss_focal_tversky": cfg.loss.focal_tversky,
                },
            ),
            # metrics
            DiceCallback(input_key="mask"),
            IouCallback(input_key="mask"),
            # gradient accumulation
            OptimizerCallback(accumulation_steps=cfg.optim.accumulate),
            # early stopping
            SchedulerCallback(reduced_metric="loss_dice",
                              mode=cfg.scheduler.mode),
            EarlyStoppingCallback(**cfg.scheduler.early_stopping,
                                  minimize=False),
            # TODO WandbLogger works poorly with multistage right now
            WandbLogger(project=cfg.project, config=dict(cfg)),
            # CheckpointCallback(save_n_best=cfg.checkpoint.save_n_best),
        ]

        # Training
        runner = SupervisedRunner(device=device,
                                  input_key="image",
                                  input_target_key="mask")

        # TODO Scheduler does not work now, every stage restarts from base lr
        scheduler_warm_restart = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[1, 2],
            gamma=10,
        )

        for i, (size, num_epochs) in enumerate(
                zip(cfg.data.sizes, cfg.train.num_epochs)):
            scale = size / 1024
            print(
                f"Training stage {i}, scale {scale}, size {size}, epochs {num_epochs}"
            )

            # Datasets
            (
                train_ds,
                valid_ds,
                train_images,
                val_images,
            ) = D.get_train_valid_datasets_from_path(
                # path=(cwd / cfg.data.path),
                path=(cwd / f"data/hubmap-{size}x{size}/"),
                train_ids=cfg.data.train_ids,
                valid_ids=cfg.data.valid_ids,
                seed=cfg.seed,
                valid_split=cfg.data.valid_split,
                mean=cfg.data.mean,
                std=cfg.data.std,
                transforms=transforms,
                stats=stats,
            )

            train_bs = int(cfg.loader.train_bs / (scale**2))
            valid_bs = int(cfg.loader.valid_bs / (scale**2))
            print(
                f"train: {len(train_ds)}; bs {train_bs}",
                f"valid: {len(valid_ds)}, bs {valid_bs}",
            )

            # Data loaders
            data_loaders = D.get_data_loaders(
                train_ds=train_ds,
                valid_ds=valid_ds,
                train_bs=train_bs,
                valid_bs=valid_bs,
                num_workers=cfg.loader.num_workers,
            )

            # Select scheduler
            scheduler = get_scheduler(
                name=cfg.scheduler.type,
                optimizer=optimizer,
                num_epochs=num_epochs * (len(data_loaders["train"]) if
                                         cfg.scheduler.mode == "batch" else 1),
                eta_min=scheduler_warm_restart.get_last_lr()[0] /
                cfg.scheduler.eta_min_factor,
                plateau=cfg.scheduler.plateau,
            )

            runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                callbacks=callbacks,
                logdir=cfg.train.logdir,
                loaders=data_loaders,
                num_epochs=num_epochs,
                verbose=True,
                main_metric=cfg.train.main_metric,
                load_best_on_end=True,
                minimize_metric=False,
                check=cfg.check,
                fp16=dict(amp=cfg.amp),
            )

            # Set new initial LR for optimizer after restart
            scheduler_warm_restart.step()
            print(
                f"New LR for warm restart {scheduler_warm_restart.get_last_lr()[0]}"
            )

            # Find optimal threshold for dice score
            model.eval()
            best_th, dices = find_dice_threshold(model, data_loaders["valid"])
            print("Best dice threshold", best_th, np.max(dices[1]))
            np.save(f"dices_{size}.npy", dices)
    else:
        print("Validation only")
        # Datasets
        size = cfg.data.sizes[-1]
        train_ds, valid_ds = D.get_train_valid_datasets_from_path(
            # path=(cwd / cfg.data.path),
            path=(cwd / f"data/hubmap-{size}x{size}/"),
            train_ids=cfg.data.train_ids,
            valid_ids=cfg.data.valid_ids,
            seed=cfg.seed,
            valid_split=cfg.data.valid_split,
            mean=cfg.data.mean,
            std=cfg.data.std,
            transforms=transforms,
            stats=stats,
        )

        train_bs = int(cfg.loader.train_bs / (cfg.data.scale_factor**2))
        valid_bs = int(cfg.loader.valid_bs / (cfg.data.scale_factor**2))
        print(
            f"train: {len(train_ds)}; bs {train_bs}",
            f"valid: {len(valid_ds)}, bs {valid_bs}",
        )

        # Data loaders
        data_loaders = D.get_data_loaders(
            train_ds=train_ds,
            valid_ds=valid_ds,
            train_bs=train_bs,
            valid_bs=valid_bs,
            num_workers=cfg.loader.num_workers,
        )

        # Find optimal threshold for dice score
        model.eval()
        best_th, dices = find_dice_threshold(model, data_loaders["valid"])
        print("Best dice threshold", best_th, np.max(dices[1]))
        np.save(f"dices_val.npy", dices)

    #
    # # Load best checkpoint
    # checkpoint_path = Path(cfg.train.logdir) / "checkpoints/best.pth"
    # if checkpoint_path.exists():
    #     print(f"\nLoading checkpoint {str(checkpoint_path)}")
    #     state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[
    #         "model_state_dict"
    #     ]
    #     model.load_state_dict(state_dict)
    #     del state_dict
    # model = model.to(device)
    # Load config for updating with threshold and metric
    # (otherwise loading do not work)
    cfg = OmegaConf.load(".hydra/config.yaml")
    cfg.threshold = float(best_th)

    # Evaluate on full-size image if valid_ids is non-empty
    df_train = pd.read_csv(cwd / "data/train.csv")
    df_train = {
        r["id"]: r["encoding"]
        for r in df_train.to_dict(orient="record")
    }
    dices = []
    unique_ids = sorted(
        set(
            str(p).split("/")[-1].split("_")[0]
            for p in (cwd / cfg.data.path / "train").iterdir()))
    size = cfg.data.sizes[-1]
    scale = size / 1024
    for image_id in cfg.data.valid_ids:
        image_name = unique_ids[image_id]
        print(f"\nValidate for {image_name}")

        rle_pred, shape = inference_one(
            image_path=(cwd / f"data/train/{image_name}.tiff"),
            target_path=Path("."),
            cfg=cfg,
            model=model,
            scale_factor=scale,
            tile_size=cfg.data.tile_size,
            tile_step=cfg.data.tile_step,
            threshold=best_th,
            save_raw=False,
            tta_mode=None,
            weight="pyramid",
            device=device,
            filter_crops="tissue",
            stats=stats,
        )

        print("Predict", shape)
        pred = rle_decode(rle_pred["predicted"], shape)
        mask = rle_decode(df_train[image_name], shape)
        assert pred.shape == mask.shape, f"pred {pred.shape}, mask {mask.shape}"
        assert pred.shape == shape, f"pred {pred.shape}, expected {shape}"

        dices.append(
            dice(
                torch.from_numpy(pred).type(torch.uint8),
                torch.from_numpy(mask).type(torch.uint8),
                threshold=None,
                activation="none",
            ))
    print("Full image dice:", np.mean(dices))
    OmegaConf.save(cfg, ".hydra/config.yaml")
    return