def main(**config):
    config['path_logs'] = os.path.join(config['path_experiment_root'],
                                       f"logs_{config['dataset']}_describe")

    os.makedirs(config['path_logs'], exist_ok=True)

    logging_fh = logging.FileHandler(
        os.path.join(config['path_logs'], 'main.log'))
    logging_fh.setLevel(logging.DEBUG)
    logger.addHandler(logging_fh)

    # Collect the available and specified sources
    sources = sources_from_path(
        path_data_root=config['path_data_root'],
        selection=config['dataset'],
        with_folds=False,
        seed_trainval_test=config['seed_trainval_test'])

    if config['dataset'] == 'oai_imo':
        from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d
    elif config['dataset'] == 'okoa':
        from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d
    elif config['dataset'] == 'maknee':
        from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d
    else:
        raise ValueError('Unknown dataset')

    for subset in ('trainval', 'test'):
        name = subset
        df = sources[config['dataset']][f"{subset}_df"]

        dataset = DatasetSagittal2d(df_meta=df,
                                    mask_mode=config['mask_mode'],
                                    name=name,
                                    sample_mode=config['sample_mode'],
                                    transforms=[
                                        PercentileClippingAndToFloat(
                                            cut_min=10, cut_max=99),
                                        ToTensor()
                                    ])

        loader = DataLoader(dataset,
                            batch_size=config['batch_size'],
                            shuffle=False,
                            num_workers=config['num_workers'],
                            pin_memory=True,
                            drop_last=False)
        describer = Describer(config=config)

        describer.run(loader)
        loader.dataset.describe()
def main(**config):
    config['path_data_root'] = os.path.abspath(config['path_data_root'])
    config['path_experiment_root'] = os.path.abspath(
        config['path_experiment_root'])

    config['path_weights'] = os.path.join(config['path_experiment_root'],
                                          'weights')
    config['path_logs'] = os.path.join(config['path_experiment_root'],
                                       'logs_train')
    os.makedirs(config['path_weights'], exist_ok=True)
    os.makedirs(config['path_logs'], exist_ok=True)

    logging_fh = logging.FileHandler(
        os.path.join(config['path_logs'],
                     'main_{}.log'.format(config['fold_idx'])))
    logging_fh.setLevel(logging.DEBUG)
    logger.addHandler(logging_fh)

    # Collect the available and specified sources
    sources = sources_from_path(
        path_data_root=config['path_data_root'],
        selection=('oai_imo', 'okoa', 'maknee'),
        with_folds=True,
        fold_num=config['fold_num'],
        seed_trainval_test=config['seed_trainval_test'])

    # Build a list of folds to run on
    if config['fold_idx'] == -1:
        fold_idcs = list(range(config['fold_num']))
    else:
        fold_idcs = [
            config['fold_idx'],
        ]
    for g in config['fold_idx_ignore']:
        fold_idcs = [i for i in fold_idcs if i != g]

    # Train each fold separately
    fold_scores = dict()

    # Use straightforward fold allocation strategy
    folds = list(
        zip(sources['oai_imo']['trainval_folds'],
            sources['okoa']['trainval_folds'],
            sources['maknee']['trainval_folds']))

    for fold_idx, idcs_subsets in enumerate(folds):
        if fold_idx not in fold_idcs:
            continue
        logger.info(f'Training fold {fold_idx}')

        (sources['oai_imo']['train_idcs'],
         sources['oai_imo']['val_idcs']) = idcs_subsets[0]
        (sources['okoa']['train_idcs'],
         sources['okoa']['val_idcs']) = idcs_subsets[1]
        (sources['maknee']['train_idcs'],
         sources['maknee']['val_idcs']) = idcs_subsets[2]

        sources['oai_imo']['train_df'] = sources['oai_imo'][
            'trainval_df'].iloc[sources['oai_imo']['train_idcs']]
        sources['oai_imo']['val_df'] = sources['oai_imo']['trainval_df'].iloc[
            sources['oai_imo']['val_idcs']]
        sources['okoa']['train_df'] = sources['okoa']['trainval_df'].iloc[
            sources['okoa']['train_idcs']]
        sources['okoa']['val_df'] = sources['okoa']['trainval_df'].iloc[
            sources['okoa']['val_idcs']]
        sources['maknee']['train_df'] = sources['maknee']['trainval_df'].iloc[
            sources['maknee']['train_idcs']]
        sources['maknee']['val_df'] = sources['maknee']['trainval_df'].iloc[
            sources['maknee']['val_idcs']]

        for n, s in sources.items():
            logger.info(
                'Made {} train-val split, number of samples: {}, {}'.format(
                    n, len(s['train_df']), len(s['val_df'])))

        datasets = defaultdict(dict)

        datasets['oai_imo']['train'] = DatasetOAIiMoSagittal2d(
            df_meta=sources['oai_imo']['train_df'],
            mask_mode=config['mask_mode'],
            sample_mode=config['sample_mode'],
            transforms=[
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
                CenterCrop(height=300, width=300),
                HorizontalFlip(prob=.5),
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
                OneOf([
                    DualCompose([
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
                    ]),
                    NoTransform()
                ]),
                Crop(output_size=(300, 300)),
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
                Normalize(mean=0.252699, std=0.251142),
                ToTensor(),
            ])
        datasets['okoa']['train'] = DatasetOKOASagittal2d(
            df_meta=sources['okoa']['train_df'],
            mask_mode='background_femoral_unitibial',
            sample_mode=config['sample_mode'],
            transforms=[
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
                CenterCrop(height=300, width=300),
                HorizontalFlip(prob=.5),
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
                OneOf([
                    DualCompose([
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
                    ]),
                    NoTransform()
                ]),
                Crop(output_size=(300, 300)),
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
                Normalize(mean=0.252699, std=0.251142),
                ToTensor(),
            ])
        datasets['maknee']['train'] = DatasetMAKNEESagittal2d(
            df_meta=sources['maknee']['train_df'],
            mask_mode='',
            sample_mode=config['sample_mode'],
            transforms=[
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
                CenterCrop(height=300, width=300),
                HorizontalFlip(prob=.5),
                GammaCorrection(gamma_range=(0.5, 1.5), prob=.5),
                OneOf([
                    DualCompose([
                        Scale(ratio_range=(0.7, 0.8), prob=1.),
                        Scale(ratio_range=(1.5, 1.6), prob=1.),
                    ]),
                    NoTransform()
                ]),
                Crop(output_size=(300, 300)),
                BilateralFilter(d=5, sigma_color=50, sigma_space=50, prob=.3),
                Normalize(mean=0.252699, std=0.251142),
                ToTensor(),
            ])
        datasets['oai_imo']['val'] = DatasetOAIiMoSagittal2d(
            df_meta=sources['oai_imo']['val_df'],
            mask_mode=config['mask_mode'],
            sample_mode=config['sample_mode'],
            transforms=[
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
                CenterCrop(height=300, width=300),
                Normalize(mean=0.252699, std=0.251142),
                ToTensor()
            ])
        datasets['okoa']['val'] = DatasetOKOASagittal2d(
            df_meta=sources['okoa']['val_df'],
            mask_mode='background_femoral_unitibial',
            sample_mode=config['sample_mode'],
            transforms=[
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
                CenterCrop(height=300, width=300),
                Normalize(mean=0.252699, std=0.251142),
                ToTensor()
            ])
        datasets['maknee']['val'] = DatasetMAKNEESagittal2d(
            df_meta=sources['maknee']['val_df'],
            mask_mode='',
            sample_mode=config['sample_mode'],
            transforms=[
                PercentileClippingAndToFloat(cut_min=10, cut_max=99),
                CenterCrop(height=300, width=300),
                Normalize(mean=0.252699, std=0.251142),
                ToTensor()
            ])

        loaders = defaultdict(dict)

        loaders['oai_imo']['train'] = DataLoader(
            datasets['oai_imo']['train'],
            batch_size=int(config['batch_size'] / 2),
            shuffle=True,
            num_workers=config['num_workers'],
            drop_last=True)
        loaders['oai_imo']['val'] = DataLoader(
            datasets['oai_imo']['val'],
            batch_size=int(config['batch_size'] / 2),
            shuffle=False,
            num_workers=config['num_workers'],
            drop_last=True)
        loaders['okoa']['train'] = DataLoader(
            datasets['okoa']['train'],
            batch_size=int(config['batch_size'] / 2),
            shuffle=True,
            num_workers=config['num_workers'],
            drop_last=True)
        loaders['okoa']['val'] = DataLoader(datasets['okoa']['val'],
                                            batch_size=int(
                                                config['batch_size'] / 2),
                                            shuffle=False,
                                            num_workers=config['num_workers'],
                                            drop_last=True)
        loaders['maknee']['train'] = DataLoader(
            datasets['maknee']['train'],
            batch_size=int(config['batch_size'] / 2),
            shuffle=True,
            num_workers=config['num_workers'],
            drop_last=True)
        loaders['maknee']['val'] = DataLoader(
            datasets['maknee']['val'],
            batch_size=int(config['batch_size'] / 2),
            shuffle=False,
            num_workers=config['num_workers'],
            drop_last=True)

        trainer = ModelTrainer(config=config, fold_idx=fold_idx)

        tmp = trainer.fit(loaders=loaders)
        metrics_train, fnames_train, metrics_val, fnames_val = tmp

        fold_scores[fold_idx] = (metrics_val['datasetw']['dice_score_oai'],
                                 metrics_val['datasetw']['dice_score_okoa'])
        trainer.tensorboard.close()
    logger.info(f'Fold scores:\n{repr(fold_scores)}')
def main(**config):
    config['path_data_root'] = os.path.abspath(config['path_data_root'])
    config['path_experiment_root'] = os.path.abspath(config['path_experiment_root'])

    config['path_weights'] = os.path.join(config['path_experiment_root'], 'weights')
    if not os.path.exists(config['path_weights']):
        raise ValueError('{} does not exist'.format(config['path_weights']))

    config['path_predicts'] = os.path.join(
        config['path_experiment_root'], f"predicts_{config['dataset']}_test")
    config['path_logs'] = os.path.join(
        config['path_experiment_root'], f"logs_{config['dataset']}_test")

    os.makedirs(config['path_predicts'], exist_ok=True)
    os.makedirs(config['path_logs'], exist_ok=True)

    logging_fh = logging.FileHandler(
        os.path.join(config['path_logs'], 'main.log'))
    logging_fh.setLevel(logging.DEBUG)
    logger.addHandler(logging_fh)

    # Collect the available and specified sources
    sources = sources_from_path(path_data_root=config['path_data_root'],
                                selection=config['dataset'],
                                with_folds=True,
                                seed_trainval_test=config['seed_trainval_test'])

    # Select the subset for evaluation
    if config['subset'] == 'test':
        logging.warning('Using the regular trainval-test split')
    elif config['subset'] == 'all':
        logging.warning('Using data selection: full dataset')
        for s in sources:
            sources[s]['test_df'] = sources[s]['sel_df']
            logger.info(f"Selected number of samples: {len(sources[s]['test_df'])}")
    else:
        raise ValueError(f"Unknown dataset: {config['subset']}")

    if config['dataset'] == 'oai_imo':
        from rocaseg.datasets import DatasetOAIiMoSagittal2d as DatasetSagittal2d
    elif config['dataset'] == 'okoa':
        from rocaseg.datasets import DatasetOKOASagittal2d as DatasetSagittal2d
    elif config['dataset'] == 'maknee':
        from rocaseg.datasets import DatasetMAKNEESagittal2d as DatasetSagittal2d
    else:
        raise ValueError(f"Unknown dataset: {config['dataset']}")

    # Configure dataset-dependent transforms
    fn_crop = CenterCrop(height=300, width=300)
    if config['dataset'] == 'oai_imo':
        fn_norm = Normalize(mean=0.252699, std=0.251142)
        fn_unnorm = UnNormalize(mean=0.252699, std=0.251142)
    elif config['dataset'] == 'okoa':
        fn_norm = Normalize(mean=0.232454, std=0.236259)
        fn_unnorm = UnNormalize(mean=0.232454, std=0.236259)
    else:
        msg = f"No transforms defined for dataset: {config['dataset']}"
        raise NotImplementedError(msg)
    dict_fns = {'crop': fn_crop, 'norm': fn_norm, 'unnorm': fn_unnorm}

    dataset_test = DatasetSagittal2d(
        df_meta=sources[config['dataset']]['test_df'], mask_mode=config['mask_mode'],
        name=config['dataset'], sample_mode=config['sample_mode'],
        transforms=[
            PercentileClippingAndToFloat(cut_min=10, cut_max=99),
            fn_crop,
            fn_norm,
            ToTensor()
        ])
    loader_test = DataLoader(dataset_test,
                             batch_size=config['batch_size'],
                             shuffle=False,
                             num_workers=config['num_workers'],
                             drop_last=False)

    # Build a list of folds to run on
    if config['fold_idx'] == -1:
        fold_idcs = list(range(config['fold_num']))
    else:
        fold_idcs = [config['fold_idx'], ]
    for g in config['fold_idx_ignore']:
        fold_idcs = [i for i in fold_idcs if i != g]

    # Execute
    with torch.no_grad():
        if config['predict_folds']:
            predict_folds(config=config, loader=loader_test, fold_idcs=fold_idcs)

        if config['merge_predictions']:
            merge_predictions(config=config, source=sources[config['dataset']],
                              loader=loader_test, dict_fns=dict_fns,
                              save_plots=config['save_plots'], remove_foldw=False,
                              convert_to_nifti=True)