def test_cls_models(model_name):
    model = get_model(model_name=model_name, num_classes=4).eval()
    print(model_name, count_parameters(model))
    return
    x = torch.rand((1, 3, 224, 224))
    output = model(x)
    assert output['logits'].size(1) == 4
    assert output['features'].size(1) == model.features_size

    print(model_name, count_parameters(model))
Beispiel #2
0
def model_fn(model_dir):
    model_path = path.join(model_dir,
                           checkpoint_fname)  # '/opt/ml/model/model.pth'

    # already available in this method torch.load(model_path, map_location=lambda storage, loc: storage)
    checkpoint = load_checkpoint(model_path)
    params = checkpoint['checkpoint_data']['cmd_args']

    model_name = 'seresnext50d_gap'

    if model_name is None:
        model_name = params['model']

    coarse_grading = params.get('coarse', False)

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    CLASS_NAMES = get_class_names(coarse_grading=coarse_grading)
    num_classes = len(CLASS_NAMES)
    model = get_model(model_name, pretrained=False, num_classes=num_classes)
    unpack_checkpoint(checkpoint, model=model)
    report_checkpoint(checkpoint)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model = model.eval()

    if apply_softmax:
        model = nn.Sequential(model, ApplySoftmaxToLogits())

    if tta == 'flip' or tta == 'fliplr':
        model = FlipLRMultiheadTTA(model)

    if tta == 'flip4':
        model = Flip4MultiheadTTA(model)

    if tta == 'fliplr_ms':
        model = MultiscaleFlipLRMultiheadTTA(model)

    with torch.no_grad():
        if torch.cuda.is_available():
            model = model.cuda()
            if torch.cuda.device_count() > 1:
                model = nn.DataParallel(
                    model,
                    device_ids=[id for id in range(torch.cuda.device_count())])

    return model
def test_inference():
    model_checkpoint = '../pretrained/seresnext50_gap_512_medium_aptos2019_idrid_fold0_hopeful_easley.pth'
    checkpoint = torch.load(model_checkpoint)
    model_name = checkpoint['checkpoint_data']['cmd_args']['model']

    num_classes = len(get_class_names())
    model = get_model(model_name, pretrained=False, num_classes=num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])

    for image_fname in [
            # '4_left.png',
            # '35_left.png',
            '44_right.png',
            '68_right.png',
            # '92_left.png'
    ]:
        transform = get_test_transform(image_size=(512, 512), crop_black=True)

        image = cv2.imread(image_fname)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image_transformed = transform(image=image)['image']
        image_transformed = tensor_from_rgb_image(image_transformed).unsqueeze(
            0)

        with torch.no_grad():
            model = model.eval().cuda()
            predictions = model(image_transformed.cuda())
            print(predictions['logits'].softmax(dim=1))
            print(predictions['regression'])

        add_mild_dr = AddMicroaneurisms(p=1)
        data = add_mild_dr(image=image, diagnosis=0)
        image_transformed = transform(image=data['image'])['image']
        image_transformed = tensor_from_rgb_image(image_transformed).unsqueeze(
            0)

        with torch.no_grad():
            model = model.eval().cuda()
            predictions = model(image_transformed.cuda())
            print(predictions['logits'].softmax(dim=1))
            print(predictions['regression'])
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')
    parser.add_argument('--swa', action='store_true')
    parser.add_argument('--show', action='store_true')
    parser.add_argument('--use-idrid', action='store_true')
    parser.add_argument('--use-messidor', action='store_true')
    parser.add_argument('--use-aptos2015', action='store_true')
    parser.add_argument('--use-aptos2019', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--coarse', action='store_true')
    parser.add_argument('-acc',
                        '--accumulation-steps',
                        type=int,
                        default=1,
                        help='Number of batches to process')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='resnet18_gap',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-f',
                        '--fold',
                        action='append',
                        type=int,
                        default=None)
    parser.add_argument('-ft', '--fine-tune', default=0, type=int)
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('--criterion-reg',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-ord',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-cls',
                        type=str,
                        default=['ce'],
                        nargs='+',
                        help='Criterion')
    parser.add_argument('-l1',
                        type=float,
                        default=0,
                        help='L1 regularization loss')
    parser.add_argument('-l2',
                        type=float,
                        default=0,
                        help='L2 regularization loss')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument('-p',
                        '--preprocessing',
                        default=None,
                        help='Preprocessing method')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-s',
                        '--scheduler',
                        default='multistep',
                        type=str,
                        help='')
    parser.add_argument('--size',
                        default=512,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-wd',
                        '--weight-decay',
                        default=0,
                        type=float,
                        help='L2 weight decay')
    parser.add_argument('-wds',
                        '--weight-decay-step',
                        default=None,
                        type=float,
                        help='L2 weight decay step to add after each epoch')
    parser.add_argument('-d',
                        '--dropout',
                        default=0.0,
                        type=float,
                        help='Dropout before head layer')
    parser.add_argument(
        '--warmup',
        default=0,
        type=int,
        help=
        'Number of warmup epochs with 0.1 of the initial LR and frozed encoder'
    )
    parser.add_argument('-x',
                        '--experiment',
                        default=None,
                        type=str,
                        help='Dropout before head layer')

    args = parser.parse_args()

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    l1 = args.l1
    l2 = args.l2
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (args.size, args.size)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    fine_tune = args.fine_tune
    criterion_reg_name = args.criterion_reg
    criterion_cls_name = args.criterion_cls
    criterion_ord_name = args.criterion_ord
    folds = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    use_swa = args.swa
    show_batches = args.show
    scheduler_name = args.scheduler
    verbose = args.verbose
    weight_decay = args.weight_decay
    use_idrid = args.use_idrid
    use_messidor = args.use_messidor
    use_aptos2015 = args.use_aptos2015
    use_aptos2019 = args.use_aptos2019
    warmup = args.warmup
    dropout = args.dropout
    use_unsupervised = False
    experiment = args.experiment
    preprocessing = args.preprocessing
    weight_decay_step = args.weight_decay_step
    coarse_grading = args.coarse
    class_names = get_class_names(coarse_grading)

    assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    if folds is None or len(folds) == 0:
        folds = [None]

    for fold in folds:
        torch.cuda.empty_cache()
        checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}'

        if preprocessing is not None:
            checkpoint_prefix += f'_{preprocessing}'
        if use_aptos2019:
            checkpoint_prefix += '_aptos2019'
        if use_aptos2015:
            checkpoint_prefix += '_aptos2015'
        if use_messidor:
            checkpoint_prefix += '_messidor'
        if use_idrid:
            checkpoint_prefix += '_idrid'
        if coarse_grading:
            checkpoint_prefix += '_coarse'

        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

        checkpoint_prefix += f'_{random_name}'

        if experiment is not None:
            checkpoint_prefix = experiment

        directory_prefix = f'{current_time}/{checkpoint_prefix}'
        log_dir = os.path.join('runs', directory_prefix)
        os.makedirs(log_dir, exist_ok=False)

        config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json')
        with open(config_fname, 'w') as f:
            train_session_args = vars(args)
            f.write(json.dumps(train_session_args, indent=2))

        set_manual_seed(args.seed)
        num_classes = len(class_names)
        model = get_model(model_name, num_classes=num_classes,
                          dropout=dropout).cuda()

        if args.transfer:
            transfer_checkpoint = fs.auto_file(args.transfer)
            print("Transfering weights from model checkpoint",
                  transfer_checkpoint)
            checkpoint = load_checkpoint(transfer_checkpoint)
            pretrained_dict = checkpoint['model_state_dict']

            for name, value in pretrained_dict.items():
                try:
                    model.load_state_dict(collections.OrderedDict([(name,
                                                                    value)]),
                                          strict=False)
                except Exception as e:
                    print(e)

            report_checkpoint(checkpoint)

        if args.checkpoint:
            checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_aptos2019=use_aptos2019,
            use_aptos2015=use_aptos2015,
            use_idrid=use_idrid,
            use_messidor=use_messidor,
            use_unsupervised=False,
            coarse_grading=coarse_grading,
            image_size=image_size,
            augmentation=augmentations,
            preprocessing=preprocessing,
            target_dtype=int,
            fold=fold,
            folds=4)

        train_loader, valid_loader = get_dataloaders(
            train_ds,
            valid_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            train_sizes=train_sizes,
            balance=balance,
            balance_datasets=balance_datasets,
            balance_unlabeled=False)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        print('Datasets         :', data_dir)
        print('  Train size     :', len(train_loader),
              len(train_loader.dataset))
        print('  Valid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('  Aptos 2019     :', use_aptos2019)
        print('  Aptos 2015     :', use_aptos2015)
        print('  IDRID          :', use_idrid)
        print('  Messidor       :', use_messidor)
        print('Train session    :', directory_prefix)
        print('  FP16 mode      :', fp16)
        print('  Fast mode      :', fast)
        print('  Mixup          :', mixup)
        print('  Balance cls.   :', balance)
        print('  Balance ds.    :', balance_datasets)
        print('  Warmup epoch   :', warmup)
        print('  Train epochs   :', num_epochs)
        print('  Fine-tune ephs :', fine_tune)
        print('  Workers        :', num_workers)
        print('  Fold           :', fold)
        print('  Log dir        :', log_dir)
        print('  Augmentations  :', augmentations)
        print('Model            :', model_name)
        print('  Parameters     :', count_parameters(model))
        print('  Image size     :', image_size)
        print('  Dropout        :', dropout)
        print('  Classes        :', class_names, num_classes)
        print('Optimizer        :', optimizer_name)
        print('  Learning rate  :', learning_rate)
        print('  Batch size     :', batch_size)
        print('  Criterion (cls):', criterion_cls_name)
        print('  Criterion (reg):', criterion_reg_name)
        print('  Criterion (ord):', criterion_ord_name)
        print('  Scheduler      :', scheduler_name)
        print('  Weight decay   :', weight_decay, weight_decay_step)
        print('  L1 reg.        :', l1)
        print('  L2 reg.        :', l2)
        print('  Early stopping :', early_stopping)

        # model training
        callbacks = []
        criterions = {}

        main_metric = 'cls/kappa'
        if criterion_reg_name is not None:
            cb, crits = get_reg_callbacks(criterion_reg_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_ord_name is not None:
            cb, crits = get_ord_callbacks(criterion_ord_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_cls_name is not None:
            cb, crits = get_cls_callbacks(criterion_cls_name,
                                          num_classes=num_classes,
                                          num_epochs=num_epochs,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if l1 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l1,
                                         end_wd=l1,
                                         schedule=None,
                                         prefix='l1',
                                         p=1)
            ]

        if l2 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l2,
                                         end_wd=l2,
                                         schedule=None,
                                         prefix='l2',
                                         p=2)
            ]

        callbacks += [CustomOptimizerCallback()]

        runner = SupervisedRunner(input_key='image')

        # Pretrain/warmup
        if warmup:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer('Adam',
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate * 0.1)

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=None,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'warmup'),
                         num_epochs=warmup,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer

        # Main train
        if num_epochs:
            set_trainable(model.encoder, True, False)

            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate,
                                      weight_decay=weight_decay)

            if use_swa:
                from torchcontrib.optim import SWA
                optimizer = SWA(optimizer,
                                swa_start=len(train_loader),
                                swa_freq=512)

            scheduler = get_scheduler(scheduler_name,
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=num_epochs,
                                      batches_in_epoch=len(train_loader))

            # Additional callbacks that specific to main stage only added here to copy of callbacks
            main_stage_callbacks = callbacks
            if early_stopping:
                es_callback = EarlyStoppingCallback(early_stopping,
                                                    min_delta=1e-4,
                                                    metric=main_metric,
                                                    minimize=False)
                main_stage_callbacks = callbacks + [es_callback]

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=main_stage_callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'main'),
                         num_epochs=num_epochs,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer, scheduler

            best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)

            # Restoring best model from checkpoint
            checkpoint = load_checkpoint(best_checkpoint)
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        # Stage 3 - Fine tuning
        if fine_tune:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate)
            scheduler = get_scheduler('multistep',
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=fine_tune,
                                      batches_in_epoch=len(train_loader))

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'finetune'),
                         num_epochs=fine_tune,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)
Beispiel #5
0
def run_models_inference_via_dataset(model_checkpoints: List[str],
                                     dataset: RetinopathyDataset,
                                     batch_size=1,
                                     coarse_grading=False,
                                     tta=None,
                                     need_features=True,
                                     apply_softmax=True,
                                     workers=None) -> List[pd.DataFrame]:
    if workers is None:
        workers = multiprocessing.cpu_count()

    models = []
    models_predictions = []

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load models
    for model_checkpoint in model_checkpoints:
        checkpoint = torch.load(model_checkpoint)

        model_name = checkpoint['checkpoint_data']['cmd_args']['model']

        print(model_checkpoint, model_name)
        report_checkpoint(checkpoint)

        num_classes = len(get_class_names(coarse_grading=coarse_grading))
        model = get_model(model_name,
                          pretrained=False,
                          num_classes=num_classes)
        model.load_state_dict(checkpoint['model_state_dict'], strict=True)
        del checkpoint

        if apply_softmax:
            model = nn.Sequential(model, ApplySoftmaxToLogits())

        if tta == 'flip' or tta == 'fliplr':
            model = FlipLRMultiheadTTA(model)

        if tta == 'flip4':
            model = Flip4MultiheadTTA(model)

        if tta == 'fliplr_ms':
            model = MultiscaleFlipLRMultiheadTTA(model)

        model = model.cuda()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(
                model,
                device_ids=[id for id in range(torch.cuda.device_count())])
        model = model.eval()

        models.append(model)
        models_predictions.append(defaultdict(list))

    data_loader = DataLoader(dataset,
                             batch_size,
                             pin_memory=True,
                             num_workers=workers)

    for batch in tqdm(data_loader):
        input = batch['image'].cuda(non_blocking=True)

        for model, predictions in zip(models, models_predictions):
            outputs = model(input)

            predictions['image_id'].extend(batch['image_id'])
            if 'targets' in batch:
                predictions['diagnosis'].extend(
                    to_numpy(batch['targets']).tolist())

            predictions['logits'].extend(to_numpy(outputs['logits']).tolist())
            predictions['regression'].extend(
                to_numpy(outputs['regression']).tolist())
            predictions['ordinal'].extend(
                to_numpy(outputs['ordinal']).tolist())
            if need_features:
                predictions['features'].extend(
                    to_numpy(outputs['features']).tolist())

    models_predictions = [
        pd.DataFrame.from_dict(p) for p in models_predictions
    ]

    del data_loader, models
    return models_predictions
Beispiel #6
0
def run_model_inference_via_dataset(dataset: RetinopathyDataset,
                                    checkpoint,
                                    params,
                                    model_name=None,
                                    batch_size=None,
                                    tta=None,
                                    need_features=True,
                                    apply_softmax=True,
                                    workers=None) -> pd.DataFrame:
    if workers is None:
        workers = multiprocessing.cpu_count()

    report_checkpoint(checkpoint)

    if model_name is None:
        model_name = params['model']

    if batch_size is None:
        batch_size = params.get('batch_size', 1)

    coarse_grading = params.get('coarse', False)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    num_classes = len(get_class_names(coarse_grading=coarse_grading))
    model = get_model(model_name, pretrained=False, num_classes=num_classes)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)

    if apply_softmax:
        model = nn.Sequential(model, ApplySoftmaxToLogits())

    if tta == 'flip' or tta == 'fliplr':
        model = FlipLRMultiheadTTA(model)

    if tta == 'flip4':
        model = Flip4MultiheadTTA(model)

    if tta == 'fliplr_ms':
        model = MultiscaleFlipLRMultiheadTTA(model)

    with torch.no_grad():
        model = model.cuda()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(
                model,
                device_ids=[id for id in range(torch.cuda.device_count())])
        model = model.eval()

        data_loader = DataLoader(dataset,
                                 batch_size,
                                 pin_memory=True,
                                 num_workers=workers)

        predictions = defaultdict(list)

        for batch in tqdm(data_loader):
            input = batch['image'].cuda(non_blocking=True)
            outputs = model(input)

            predictions['image_id'].extend(batch['image_id'])
            if 'targets' in batch:
                predictions['diagnosis'].extend(
                    to_numpy(batch['targets']).tolist())

            predictions['logits'].extend(to_numpy(outputs['logits']).tolist())
            predictions['regression'].extend(
                to_numpy(outputs['regression']).tolist())
            predictions['ordinal'].extend(
                to_numpy(outputs['ordinal']).tolist())
            if need_features:
                predictions['features'].extend(
                    to_numpy(outputs['features']).tolist())

        predictions = pd.DataFrame.from_dict(predictions)

    del data_loader, model
    return predictions
Beispiel #7
0
def reg_resnet50_rms(**kwargs):
    return get_model('reg_resnet50_rms', num_classes=5, **kwargs)
Beispiel #8
0
def evaluate_generalization(checkpoints, num_folds=4):
    num_datasets = len(checkpoints)
    # kappa_matrix = np.zeros((num_datasets, num_datasets), dtype=np.float32)
    class_names = list(checkpoints.keys())

    # results = {}

    for dataset_trained_on, checkpoints_per_fold in checkpoints.items():
        # For each dataset trained on

        for fold_trained_on, checkpoint_file in enumerate(
                checkpoints_per_fold):
            # For each checkpoint
            if checkpoint_file is None:
                continue

            # Load model
            checkpoint = torch.load(checkpoint_file)
            model_name = checkpoint['checkpoint_data']['cmd_args']['model']
            batch_size = 16  # checkpoint['checkpoint_data']['cmd_args']['batch_size']
            num_classes = len(get_class_names())
            model = get_model(model_name,
                              pretrained=False,
                              num_classes=num_classes)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = model.eval().cuda()
            if torch.cuda.device_count() > 1:
                model = nn.DataParallel(
                    model,
                    device_ids=[id for id in range(torch.cuda.device_count())])

            for dataset_index, dataset_validate_on in enumerate(class_names):
                # For each available dataset

                for fold_validate_on in range(num_folds):
                    _, valid_ds, _ = get_datasets(
                        use_aptos2015=dataset_validate_on == 'aptos2015',
                        use_aptos2019=dataset_validate_on == 'aptos2019',
                        use_messidor=dataset_validate_on == 'messidor',
                        use_idrid=dataset_validate_on == 'idrid',
                        fold=fold_validate_on,
                        folds=num_folds)

                    data_loader = DataLoader(valid_ds,
                                             batch_size *
                                             torch.cuda.device_count(),
                                             pin_memory=True,
                                             num_workers=8)

                    predictions = defaultdict(list)
                    for batch in tqdm(
                            data_loader,
                            desc=
                            f'Evaluating {dataset_validate_on} fold {fold_validate_on} on {checkpoint_file}'
                    ):
                        input = batch['image'].cuda(non_blocking=True)
                        outputs = model(input)
                        logits = to_numpy(outputs['logits'].softmax(dim=1))
                        regression = to_numpy(outputs['regression'])
                        features = to_numpy(outputs['features'])

                        predictions['image_id'].extend(batch['image_id'])
                        predictions['diagnosis_true'].extend(
                            to_numpy(batch['targets']))
                        predictions['logits'].extend(logits)
                        predictions['regression'].extend(regression)
                        predictions['features'].extend(features)

                    pickle_name = id_from_fname(
                        checkpoint_file
                    ) + f'_on_{dataset_validate_on}_fold{fold_validate_on}.pkl'

                    df = pd.DataFrame.from_dict(predictions)
                    df.to_pickle(pickle_name)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory for INRIA sattelite dataset')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='cls_resnet18',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-fe', '--freeze-encoder', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('-l',
                        '--criterion',
                        type=str,
                        default='bce',
                        help='Criterion')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='hard',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-tm',
                        '--train-mode',
                        default='random',
                        type=str,
                        help='')
    parser.add_argument('-rm',
                        '--run-mode',
                        default='fit_predict',
                        type=str,
                        help='')
    parser.add_argument('--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    run_mode = args.run_mode
    log_dir = None
    fp16 = args.fp16
    freeze_encoder = args.freeze_encoder

    run_train = run_mode == 'fit_predict' or run_mode == 'fit'
    run_predict = run_mode == 'fit_predict' or run_mode == 'predict'

    model = maybe_cuda(get_model(model_name, num_classes=1))

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint['model_state_dict']

        for name, value in pretrained_dict.items():
            try:
                model.load_state_dict(collections.OrderedDict([(name, value)]),
                                      strict=False)
            except Exception as e:
                print(e)

    checkpoint = None
    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        checkpoint_epoch = checkpoint['epoch']
        print('Loaded model weights from:', args.checkpoint)
        print('Epoch                    :', checkpoint_epoch)
        print('Metrics (Train):', 'f1  :',
              checkpoint['epoch_metrics']['train']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['train']['loss'])
        print('Metrics (Valid):', 'f1  :',
              checkpoint['epoch_metrics']['valid']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['valid']['loss'])

        log_dir = os.path.dirname(
            os.path.dirname(fs.auto_file(args.checkpoint)))

    if run_train:

        if freeze_encoder:
            set_trainable(model.encoder, trainable=False, freeze_bn=True)

        criterion = get_loss(args.criterion)
        parameters = get_optimizable_parameters(model)
        optimizer = get_optimizer(optimizer_name, parameters, learning_rate)

        if checkpoint is not None:
            try:
                unpack_checkpoint(checkpoint, optimizer=optimizer)
                print('Restored optimizer state from checkpoint')
            except Exception as e:
                print('Failed to restore optimizer state from checkpoint', e)

        train_loader, valid_loader = get_dataloaders(
            data_dir=data_dir,
            batch_size=batch_size,
            num_workers=num_workers,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        current_time = datetime.now().strftime('%b%d_%H_%M')
        prefix = f'adversarial/{args.model}/{current_time}_{args.criterion}'

        if fp16:
            prefix += '_fp16'

        if fast:
            prefix += '_fast'

        log_dir = os.path.join('runs', prefix)
        os.makedirs(log_dir, exist_ok=False)

        scheduler = MultiStepLR(optimizer,
                                milestones=[10, 30, 50, 70, 90],
                                gamma=0.5)

        print('Train session    :', prefix)
        print('\tFP16 mode      :', fp16)
        print('\tFast mode      :', args.fast)
        print('\tTrain mode     :', train_mode)
        print('\tEpochs         :', num_epochs)
        print('\tEarly stopping :', early_stopping)
        print('\tWorkers        :', num_workers)
        print('\tData dir       :', data_dir)
        print('\tLog dir        :', log_dir)
        print('\tAugmentations  :', augmentations)
        print('\tTrain size     :', len(train_loader),
              len(train_loader.dataset))
        print('\tValid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('Model            :', model_name)
        print('\tParameters     :', count_parameters(model))
        print('\tImage size     :', image_size)
        print('\tFreeze encoder :', freeze_encoder)
        print('Optimizer        :', optimizer_name)
        print('\tLearning rate  :', learning_rate)
        print('\tBatch size     :', batch_size)
        print('\tCriterion      :', args.criterion)

        # model training
        visualization_fn = partial(draw_classification_predictions,
                                   class_names=['Train', 'Test'])

        callbacks = [
            F1ScoreCallback(),
            AUCCallback(),
            ShowPolarBatchesCallback(visualization_fn,
                                     metric='f1_score',
                                     minimize=False),
        ]

        if early_stopping:
            callbacks += [
                EarlyStoppingCallback(early_stopping,
                                      metric='auc',
                                      minimize=False)
            ]

        runner = SupervisedRunner(input_key='image')
        runner.train(fp16=fp16,
                     model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     callbacks=callbacks,
                     loaders=loaders,
                     logdir=log_dir,
                     num_epochs=num_epochs,
                     verbose=True,
                     main_metric='auc',
                     minimize_metric=False,
                     state_kwargs={"cmd_args": vars(args)})

    if run_predict and not fast:
        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = load_checkpoint(
            fs.auto_file('best.pth', where=log_dir))
        unpack_checkpoint(best_checkpoint, model=model)

        model.eval()
        torch.no_grad()

        train_csv = pd.read_csv(os.path.join(data_dir, 'train.csv'))
        train_csv['id_code'] = train_csv['id_code'].apply(
            lambda x: os.path.join(data_dir, 'train_images', f'{x}.png'))
        test_ds = RetinopathyDataset(train_csv['id_code'],
                                     None,
                                     get_test_aug(image_size),
                                     target_as_array=True)
        test_dl = DataLoader(test_ds,
                             batch_size,
                             pin_memory=True,
                             num_workers=num_workers)

        test_ids = []
        test_preds = []

        for batch in tqdm(test_dl, desc='Inference'):
            input = batch['image'].cuda()
            outputs = model(input)
            predictions = to_numpy(outputs['logits'].sigmoid().squeeze(1))
            test_ids.extend(batch['image_id'])
            test_preds.extend(predictions)

        df = pd.DataFrame.from_dict({
            'id_code': test_ids,
            'is_test': test_preds
        })
        df.to_csv(os.path.join(log_dir, 'test_in_train.csv'), index=None)