Beispiel #1
0
def test_experiment(experiment_cls, config, network_cls, len_test, **kwargs):
    assert issubclass(experiment_cls, BaseExperiment)

    exp = experiment_cls(config, network_cls, **kwargs)

    dset_test = DummyDataset(len_test)
    dmgr_test = BaseDataManager(dset_test, 16, 1, None)

    model = network_cls()

    return exp.test(model, dmgr_test, config.nested_get("metrics", {}),
                    kwargs.get("metric_keys", None))
Beispiel #2
0
    def test_experiment_run_torch(self):

        from delira.training import PyTorchExperiment
        from delira.data_loading import BaseDataManager

        for case in self._test_cases_torch:
            with self.subTest(case=case):

                (params, dataset_length_train, dataset_length_test,
                 val_score_key, val_score_mode, network_cls) = case

                exp = PyTorchExperiment(params, network_cls,
                                        key_mapping={"x": "data"},
                                        val_score_key=val_score_key,
                                        val_score_mode=val_score_mode)

                dset_train = DummyDataset(dataset_length_train)
                dset_test = DummyDataset(dataset_length_test)

                dmgr_train = BaseDataManager(dset_train, 16, 2, None)
                dmgr_test = BaseDataManager(dset_test, 16, 1, None)

                exp.run(dmgr_train, dmgr_test)
Beispiel #3
0
def test_experiment(params, dataset_length_train, dataset_length_test):
    class DummyNetwork(ClassificationNetworkBasePyTorch):
        def __init__(self):
            super().__init__(32, 1)

        def forward(self, x):
            return self.module(x)

        @staticmethod
        def _build_model(in_channels, n_outputs):
            return torch.nn.Sequential(torch.nn.Linear(in_channels, 64),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(64, n_outputs))

    class DummyDataset(AbstractDataset):
        def __init__(self, length):
            super().__init__(None, None, None, None)
            self.length = length

        def __getitem__(self, index):
            return {
                "data": np.random.rand(1, 32),
                "label": np.random.rand(1, 1)
            }

        def __len__(self):
            return self.length

    exp = PyTorchExperiment(params, DummyNetwork)
    dset_train = DummyDataset(dataset_length_train)
    dset_test = DummyDataset(dataset_length_test)

    dmgr_train = BaseDataManager(dset_train, 16, 4, None)
    dmgr_test = BaseDataManager(dset_test, 16, 1, None)

    exp.run(dmgr_train, dmgr_test)
Beispiel #4
0
    def test_experiment_test_tf(self):
        from delira.training import TfExperiment
        from delira.data_loading import BaseDataManager

        for case in self._test_cases_tf:
            with self.subTest(case=case):
                (params, dataset_length_train, dataset_length_test,
                 network_cls) = case

                exp = TfExperiment(params, network_cls,
                                   key_mapping={"images": "data"},
                                   )

                model = network_cls()

                dset_test = DummyDataset(dataset_length_test)
                dmgr_test = BaseDataManager(dset_test, 16, 1, None)

                exp.test(model, dmgr_test, params.nested_get("val_metrics"))
Beispiel #5
0
    def t_experiment(self, raise_error=False, expected_msg=None):
        dummy_exp = DummyExperiment()
        dummy_exp = self.messenger_cls(dummy_exp, **self.messenger_kwargs)

        dset_test = DummyDataset(10)
        dmgr_test = BaseDataManager(dset_test, 2, 1, None)

        model = DummyNetwork()

        with self.assertLogs(logger, level='INFO') as cm:
            if raise_error:
                with self.assertRaises(RuntimeError):
                    dummy_exp.test(model, dmgr_test, {}, raise_error=True)
            else:
                dummy_exp.test(model, dmgr_test, {}, raise_error=False)

            if expected_msg is None or not expected_msg:
                logger.info("NoExpectedMessage")

        if expected_msg is None or not expected_msg:
            self.assertEqual(cm.output,
                             ["INFO:UnitTestMessenger:NoExpectedMessage"])
        else:
            self.assertEqual(cm.output, expected_msg)
Beispiel #6
0
    def setUp(self):

        self.dset = DummyDataset(20)
        self.dmgr = BaseDataManager(self.dset, 4, 1, transforms=None)
Beispiel #7
0
def train_shapenet():
    """
    Trains a single shapenet with config file from comandline arguments

    See Also
    --------
    :class:`delira.training.PyTorchNetworkTrainer`
    
    """
    import logging
    import numpy as np
    import torch
    from shapedata.single_shape import SingleShapeDataset
    from delira.training import PyTorchNetworkTrainer
    from ..utils import Config
    from ..layer import HomogeneousShapeLayer
    from ..networks import SingleShapeNetwork
    from delira.logging import TrixiHandler
    from trixi.logger import PytorchVisdomLogger
    from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch
    from delira.data_loading import BaseDataManager, RandomSampler, \
        SequentialSampler
    import os
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        type=str,
                        help="Path to configuration file")
    parser.add_argument("-v", "--verbose", action="store_true")
    args = parser.parse_args()
    config = Config()

    config_dict = config(os.path.abspath(args.config))

    shapes = np.load(os.path.abspath(config_dict["layer"].pop(
        "pca_path")))["shapes"][:config_dict["layer"].pop("num_shape_params") +
                                1]

    # layer_cls = HomogeneousShapeLayer

    net = SingleShapeNetwork(HomogeneousShapeLayer, {
        "shapes": shapes,
        **config_dict["layer"]
    },
                             img_size=config_dict["data"]["img_size"],
                             **config_dict["network"])

    num_params = 0
    for param in net.parameters():
        num_params += param.numel()

    if args.verbose:
        print("Number of Parameters: %d" % num_params)

    criterions = {"L1": torch.nn.L1Loss()}
    metrics = {"MSE": torch.nn.MSELoss()}

    mixed_prec = config_dict["training"].pop("mixed_prec", False)

    config_dict["training"]["save_path"] = os.path.abspath(
        config_dict["training"]["save_path"])

    trainer = PyTorchNetworkTrainer(
        net,
        criterions=criterions,
        metrics=metrics,
        lr_scheduler_cls=ReduceLROnPlateauCallbackPyTorch,
        lr_scheduler_params=config_dict["scheduler"],
        optimizer_cls=torch.optim.Adam,
        optimizer_params=config_dict["optimizer"],
        mixed_precision=mixed_prec,
        **config_dict["training"])

    if args.verbose:
        print(trainer.input_device)

        print("Load Data")
    dset_train = SingleShapeDataset(
        os.path.abspath(config_dict["data"]["train_path"]),
        config_dict["data"]["img_size"],
        config_dict["data"]["crop"],
        config_dict["data"]["landmark_extension_train"],
        cached=config_dict["data"]["cached"],
        rotate=config_dict["data"]["rotate_train"],
        random_offset=config_dict["data"]["offset_train"])

    if config_dict["data"]["test_path"]:
        dset_val = SingleShapeDataset(
            os.path.abspath(config_dict["data"]["test_path"]),
            config_dict["data"]["img_size"],
            config_dict["data"]["crop"],
            config_dict["data"]["landmark_extension_test"],
            cached=config_dict["data"]["cached"],
            rotate=config_dict["data"]["rotate_test"],
            random_offset=config_dict["data"]["offset_test"])

    else:
        dset_val = None

    mgr_train = BaseDataManager(
        dset_train,
        batch_size=config_dict["data"]["batch_size"],
        n_process_augmentation=config_dict["data"]["num_workers"],
        transforms=None,
        sampler_cls=RandomSampler)
    mgr_val = BaseDataManager(
        dset_val,
        batch_size=config_dict["data"]["batch_size"],
        n_process_augmentation=config_dict["data"]["num_workers"],
        transforms=None,
        sampler_cls=SequentialSampler)

    if args.verbose:
        print("Data loaded")
    if config_dict["logging"].pop("enable", False):
        logger_cls = PytorchVisdomLogger

        logging.basicConfig(
            level=logging.INFO,
            handlers=[TrixiHandler(logger_cls, **config_dict["logging"])])

    else:
        logging.basicConfig(level=logging.INFO,
                            handlers=[logging.NullHandler()])

    logger = logging.getLogger("Test Logger")
    logger.info("Start Training")

    if args.verbose:
        print("Start Training")

    trainer.train(config_dict["training"]["num_epochs"],
                  mgr_train,
                  mgr_val,
                  config_dict["training"]["val_score_key"],
                  val_score_mode='lowest')
Beispiel #8
0
    def test_experiment_kfold_torch(self):
        from delira.training import PyTorchExperiment
        from delira.data_loading import BaseDataManager
        from copy import deepcopy

        # all test cases
        for case in self._test_cases_torch:
            with self.subTest(case=case):
                (params, dataset_length_train,
                 dataset_length_test, val_score_key,
                 val_score_mode, network_cls) = case

                # both split_types
                for split_type in ["random", "stratified", "error"]:
                    with self.subTest(split_type=split_type):
                        if split_type == "error":
                            # must raise ValueError
                            with self.assertRaises(ValueError):
                                exp = PyTorchExperiment(
                                    params, network_cls,
                                    key_mapping={"x": "data"},
                                    val_score_key=val_score_key,
                                    val_score_mode=val_score_mode)

                                dset = DummyDataset(
                                    dataset_length_test + dataset_length_train)

                                dmgr = BaseDataManager(dset, 16, 1, None)
                                exp.kfold(
                                    dmgr,
                                    params.nested_get("val_metrics"),
                                    shuffle=True,
                                    split_type=split_type,
                                    num_splits=2)

                            continue

                        # check all types of validation data
                        for val_split in [0.2, None]:
                            with self.subTest(val_split=val_split):

                                # disable lr scheduling if no validation data
                                # is present
                                _params = deepcopy(params)
                                if val_split is None:
                                    _params["fixed"]["training"
                                                     ]["lr_sched_cls"] = None
                                exp = PyTorchExperiment(
                                    _params, network_cls,
                                    key_mapping={"x": "data"},
                                    val_score_key=val_score_key,
                                    val_score_mode=val_score_mode)

                                dset = DummyDataset(
                                    dataset_length_test + dataset_length_train)

                                dmgr = BaseDataManager(dset, 16, 1, None)
                                exp.kfold(
                                    dmgr,
                                    params.nested_get("val_metrics"),
                                    shuffle=True,
                                    split_type=split_type,
                                    val_split=val_split,
                                    num_splits=2)
Beispiel #9
0
from batchgenerators.transforms import RandomCropTransform, Compose
from batchgenerators.transforms.spatial_transforms import ResizeTransform,SpatialTransform


transforms = Compose([
    #SpatialTransform(patch_size=(1024, 1024), do_rotation=True, patch_center_dist_from_border=1024, border_mode_data='reflect',
     #                border_mode_seg='reflect', angle_x=(args.rot_angle, args.rot_angle), angle_y=(0, 0), angle_z=(0, 0),
      #               do_elastic_deform=False, order_data=1, order_seg=1)
    ResizeTransform((int(args.resize_size), int(args.resize_size)), order=1),
    RandomCropTransform((params.nested_get("image_size"), params.nested_get("image_size"))),
    ])


from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler
manager_test = BaseDataManager(dataset_test, params.nested_get("batch_size"),
                              transforms=transforms,
                              sampler_cls=SequentialSampler,
                              n_process_augmentation=1)

import warnings
warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code
warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code




### Setting path for loading best checkpoint

from delira.training.tf_trainer import tf_load_checkpoint

test_path = '/../checkpoint_best'
def run_experiment(cp: str, test=True) -> str:
    """
    Run classification experiment on patches
    Imports moved inside because of logging setups

    Parameters
    ----------
    ch : str
        path to config file
    test : bool
        test best model on test set

    Returns
    -------
    str
        path to experiment folder
    """
    # setup config
    ch = ConfigHandlerPyTorchDelira(cp)
    ch = feature_map_params(ch)

    if 'mixed_precision' not in ch or ch['mixed_precision'] is None:
        ch['mixed_precision'] = True
    if 'debug_delira' in ch and ch['debug_delira'] is not None:
        delira.set_debug_mode(ch['debug_delira'])
        print("Debug mode active: settings n_process_augmentation to 1!")
        ch['augment.n_process'] = 1

    dset_keys = ['train', 'val', 'test']

    losses = {'class_ce': torch.nn.CrossEntropyLoss()}
    train_metrics = {}
    val_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())}
    test_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())}

    #########################
    #   Setup Parameters    #
    #########################
    params_dict = ch.get_params(losses=losses,
                                train_metrics=train_metrics,
                                val_metrics=val_metrics,
                                add_self=ch['add_config_to_params'])
    params = Parameters(**params_dict)

    #################
    #   Setup IO    #
    #################
    # setup io
    load_sample = load_pickle
    load_fn = LoadPatches(load_fn=load_sample,
                          patch_size=ch['patch_size'],
                          **ch['data.load_patch'])

    datasets = {}
    for key in dset_keys:
        p = os.path.join(ch["data.path"], str(key))

        datasets[key] = BaseExtendCacheDataset(p,
                                               load_fn=load_fn,
                                               **ch['data.kwargs'])

    #############################
    #   Setup Transformations   #
    #############################
    base_transforms = []
    base_transforms.append(PopKeys("mapping"))

    train_transforms = []
    if ch['augment.mode']:
        logger.info("Training augmentation enabled.")
        train_transforms.append(
            SpatialTransform(patch_size=ch['patch_size'],
                             **ch['augment.kwargs']))
        train_transforms.append(MirrorTransform(axes=(0, 1)))
    process = ch['augment.n_process'] if 'augment.n_process' in ch else 1

    #########################
    #   Setup Datamanagers  #
    #########################
    datamanagers = {}
    for key in dset_keys:
        if key == 'train':
            trafos = base_transforms + train_transforms
            sampler = WeightedPrevalenceRandomSampler
        else:
            trafos = base_transforms
            sampler = SequentialSampler

        datamanagers[key] = BaseDataManager(
            data=datasets[key],
            batch_size=params.nested_get('batch_size'),
            n_process_augmentation=process,
            transforms=Compose(trafos),
            sampler_cls=sampler,
        )

    #############################
    #   Initialize Experiment   #
    #############################
    experiment = \
        PyTorchExperiment(
            params=params,
            model_cls=ClassNetwork,
            name=ch['exp.name'],
            save_path=ch['exp.dir'],
            optim_builder=create_optims_default_pytorch,
            trainer_cls=PyTorchNetworkTrainer,
            mixed_precision=ch['mixed_precision'],
            mixed_precision_kwargs={'verbose': False},
            key_mapping={"input_batch": "data"},
            **ch['exp.kwargs'],
        )

    # save configurations
    ch.dump(os.path.join(experiment.save_path, 'config.json'))

    #################
    #   Training    #
    #################
    model = experiment.run(datamanagers['train'],
                           datamanagers['val'],
                           save_path_exp=experiment.save_path,
                           ch=ch,
                           metric_keys={'val_CE': ['pred', 'label']},
                           val_freq=1,
                           verbose=True)
    ################
    #   Testing    #
    ################
    if test and datamanagers['test'] is not None:
        # metrics and metric_keys are used differently than in original
        # Delira implementation in order to support Evaluator
        # see mscl.training.predictor
        preds = experiment.test(
            network=model,
            test_data=datamanagers['test'],
            metrics=test_metrics,
            metric_keys={'CE': ['pred', 'label']},
            verbose=True,
        )

        softmax_fn = metric_wrapper_pytorch(
            partial(torch.nn.functional.softmax, dim=1))
        preds = softmax_fn(preds[0]['pred'])
        labels = [d['label'] for d in datasets['test']]
        fpr, tpr, thresholds = roc_curve(labels, preds[:, 1])
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(experiment.save_path, 'test_roc.pdf'))
        plt.close()

        preds = experiment.test(
            network=model,
            test_data=datamanagers['val'],
            metrics=test_metrics,
            metric_keys={'CE': ['pred', 'label']},
            verbose=True,
        )

        preds = softmax_fn(preds[0]['pred'])
        labels = [d['label'] for d in datasets['val']]
        fpr, tpr, thresholds = roc_curve(labels, preds[:, 1])
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(experiment.save_path, 'best_val_roc.pdf'))
        plt.close()

    return experiment.save_path
transforms = Compose([
    #SpatialTransform(patch_size=(1024, 1024), do_rotation=True, patch_center_dist_from_border=1024, border_mode_data='reflect',
    #                border_mode_seg='reflect', angle_x=(args.rot_angle, args.rot_angle), angle_y=(0, 0), angle_z=(0, 0),
    #               do_elastic_deform=False, order_data=1, order_seg=1)
    ResizeTransform((int(args.resize_size), int(args.resize_size)), order=1),
    RandomCropTransform(
        (params.nested_get("image_size"), params.nested_get("image_size"))),
])

###  Data manager

from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler

manager_train = BaseDataManager(dataset_train,
                                params.nested_get("batch_size"),
                                transforms=transforms,
                                sampler_cls=RandomSampler,
                                n_process_augmentation=4)

manager_val = BaseDataManager(dataset_val,
                              params.nested_get("val_batch_size"),
                              transforms=transforms,
                              sampler_cls=SequentialSampler,
                              n_process_augmentation=4)

import warnings
warnings.simplefilter(
    "ignore", UserWarning)  # ignore UserWarnings raised by dependency code
warnings.simplefilter(
    "ignore", FutureWarning)  # ignore FutureWarnings raised by dependency code
Beispiel #12
0
#####################
#   Augmentation    #
#####################
base_transforms = [ZeroMeanUnitVarianceTransform(),
                   ]
train_transforms = [SpatialTransform(patch_size=(200, 200),
                                     random_crop=False,
                                     ),
                    ]

#####################
#   Datamanagers    #
#####################
manager_train = BaseDataManager(dataset_train, params.nested_get("batch_size"),
                                transforms=Compose(
                                    base_transforms + train_transforms),
                                sampler_cls=RandomSampler,
                                n_process_augmentation=n_process_augmentation)

manager_val = BaseDataManager(dataset_val, params.nested_get("batch_size"),
                              transforms=Compose(base_transforms),
                              sampler_cls=SequentialSampler,
                              n_process_augmentation=n_process_augmentation)

manager_test = BaseDataManager(dataset_test, 1,
                               transforms=Compose(base_transforms),
                               sampler_cls=SequentialSampler,
                               n_process_augmentation=n_process_augmentation)

logger.info("Init Experiment")
experiment = PyTorchExperiment(params,
Beispiel #13
0
    def test_experiment(self):

        from delira.training import PyTorchExperiment, Parameters
        from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch
        from delira.models.classification import ClassificationNetworkBasePyTorch
        from delira.data_loading import AbstractDataset, BaseDataManager
        import torch

        test_cases = [(Parameters(
            fixed_params={
                "model": {},
                "training": {
                    "criterions": {
                        "CE": torch.nn.CrossEntropyLoss()
                    },
                    "optimizer_cls": torch.optim.Adam,
                    "optimizer_params": {
                        "lr": 1e-3
                    },
                    "num_epochs": 2,
                    "metrics": {},
                    "lr_sched_cls": ReduceLROnPlateauCallbackPyTorch,
                    "lr_sched_params": {}
                }
            }), 500, 50)]

        class DummyNetwork(ClassificationNetworkBasePyTorch):
            def __init__(self):
                super().__init__(32, 1)

            def forward(self, x):
                return self.module(x)

            @staticmethod
            def _build_model(in_channels, n_outputs):
                return torch.nn.Sequential(torch.nn.Linear(in_channels, 64),
                                           torch.nn.ReLU(),
                                           torch.nn.Linear(64, n_outputs))

            @staticmethod
            def prepare_batch(batch_dict, input_device, output_device):
                return {
                    "data":
                    torch.from_numpy(batch_dict["data"]).to(
                        input_device, torch.float),
                    "label":
                    torch.from_numpy(batch_dict["label"]).to(
                        output_device, torch.long)
                }

        class DummyDataset(AbstractDataset):
            def __init__(self, length):
                super().__init__(None, None, None, None)
                self.length = length

            def __getitem__(self, index):
                return {
                    "data": np.random.rand(32),
                    "label": np.random.randint(0, 1, 1)
                }

            def __len__(self):
                return self.length

            def get_sample_from_index(self, index):
                return self.__getitem__(index)

        for case in test_cases:
            with self.subTest(case=case):

                params, dataset_length_train, dataset_length_test = case

                exp = PyTorchExperiment(params, DummyNetwork)
                dset_train = DummyDataset(dataset_length_train)
                dset_test = DummyDataset(dataset_length_test)

                dmgr_train = BaseDataManager(dset_train, 16, 4, None)
                dmgr_test = BaseDataManager(dset_test, 16, 1, None)

                net = exp.run(dmgr_train, dmgr_test)
                exp.test(
                    params=params,
                    network=net,
                    datamgr_test=dmgr_test,
                )

                exp.kfold(2, dmgr_train, num_splits=2)
                exp.stratified_kfold(2, dmgr_train, num_splits=2)
                exp.stratified_kfold_predict(2, dmgr_train, num_splits=2)
Beispiel #14
0
    def test_experiment(self):

        from delira.training import TfExperiment, Parameters
        from delira.models.classification import ClassificationNetworkBaseTf
        from delira.data_loading import AbstractDataset, BaseDataManager
        import tensorflow as tf

        test_cases = [(Parameters(
            fixed_params={
                "model": {
                    'in_channels': 32,
                    'n_outputs': 1
                },
                "training": {
                    "criterions": {
                        "CE": tf.losses.softmax_cross_entropy
                    },
                    "optimizer_cls": tf.train.AdamOptimizer,
                    "optimizer_params": {
                        "learning_rate": 1e-3
                    },
                    "num_epochs": 2,
                    "metrics": {},
                    "lr_sched_cls": None,
                    "lr_sched_params": {}
                }
            }), 500, 50)]

        class DummyNetwork(ClassificationNetworkBaseTf):
            def __init__(self):
                super().__init__(32, 1)
                self.model = self._build_model(1)

                images = tf.placeholder(shape=[None, 32], dtype=tf.float32)
                labels = tf.placeholder(shape=[None, 1], dtype=tf.float32)

                preds_train = self.model(images, training=True)
                preds_eval = self.model(images, training=False)

                self.inputs = [images, labels]
                self.outputs_train = [preds_train]
                self.outputs_eval = [preds_eval]

            @staticmethod
            def _build_model(n_outputs):
                return tf.keras.models.Sequential(layers=[
                    tf.keras.layers.Dense(64,
                                          input_shape=(32, ),
                                          bias_initializer='glorot_uniform'),
                    tf.keras.layers.ReLU(),
                    tf.keras.layers.Dense(n_outputs,
                                          bias_initializer='glorot_uniform')
                ])

        class DummyDataset(AbstractDataset):
            def __init__(self, length):
                super().__init__(None, None, None, None)
                self.length = length

            def __getitem__(self, index):
                return {
                    "data": np.random.rand(32),
                    "label": np.random.randint(0, 1, 1)
                }

            def __len__(self):
                return self.length

            def get_sample_from_index(self, index):
                return self.__getitem__(index)

        for case in test_cases:
            with self.subTest(case=case):

                params, dataset_length_train, dataset_length_test = case

                exp = TfExperiment(params, DummyNetwork)
                dset_train = DummyDataset(dataset_length_train)
                dset_test = DummyDataset(dataset_length_test)

                dmgr_train = BaseDataManager(dset_train, 16, 4, None)
                dmgr_test = BaseDataManager(dset_test, 16, 1, None)

                net = exp.run(dmgr_train, dmgr_test)
                exp.test(
                    params=params,
                    network=net,
                    datamgr_test=dmgr_test,
                )

                exp.kfold(2, dmgr_train, num_splits=2)
                exp.stratified_kfold(2, dmgr_train, num_splits=2)
                exp.stratified_kfold_predict(2, dmgr_train, num_splits=2)
Beispiel #15
0
    def kfold(self,
              paths,
              num_splits=5,
              shuffle=True,
              random_seed=None,
              valid_size=0.1,
              **kwargs):
        """
        Runs K-Fold Crossvalidation

        Parameters
        ----------
        num_epochs: int
            number of epochs to train the model
        data: str
            path to root dir
        num_splits: None or int
            number of splits for kfold
            if None: len(data) splits will be validated
        shuffle: bool
            whether or not to shuffle indices for kfold
        random_seed: None or int
            random seed used to seed the kfold (if shuffle is true),
            pytorch and numpy
        valid_size : float, default: 0.1
            relative size of validation dataset in relation to training set
        """

        if random_seed is not None:
            torch.manual_seed(random_seed)
        np.random.seed(random_seed)

        if "dataset_type" in kwargs and kwargs["dataset_type"] is not None:
            dataset_type = kwargs["dataset_type"]
            if dataset_type != "INbreast" and dataset_type != "DDSM":
                raise ValueError("Unknown dataset!")
        else:
            raise ValueError("No dataset type!")

        train_splits, _ = utils.kfold_patientwise(paths,
                                                  dataset_type=dataset_type,
                                                  num_splits=num_splits,
                                                  shuffle=shuffle,
                                                  random_state=random_seed)

        for i in range(len(train_splits)):
            train_paths, val_paths, _ = \
                utils.split_paths_patientwise(train_splits[i],
                                              dataset_type=dataset_type,
                                              train_size= 1 - valid_size)

            dataset_train = self.dataset_cls(path_list=train_paths,
                                             **self.dataset_train_kwargs)

            dataset_valid = self.dataset_cls(path_list=val_paths,
                                             **self.dataset_val_kwargs)

            mgr_train = BaseDataManager(dataset_train,
                                        **self.datamgr_train_kwargs)

            mgr_valid = BaseDataManager(dataset_valid,
                                        **self.datamgr_val_kwargs)

            super().run(mgr_train, mgr_valid, fold=i, **kwargs)
Beispiel #16
0
    def kfold(self,
              data: BaseDataManager,
              metrics: dict,
              num_epochs=None,
              num_splits=None,
              shuffle=False,
              random_seed=None,
              split_type="random",
              val_split=0.2,
              label_key="label",
              train_kwargs: dict = None,
              metric_keys: dict = None,
              test_kwargs: dict = None,
              config=None,
              verbose=False,
              **kwargs):
        """
        Performs a k-Fold cross-validation

        Parameters
        ----------
        data : :class:`BaseDataManager`
            the data to use for training(, validation) and testing. Will be
            split based on ``split_type`` and ``val_split``
        metrics : dict
            dictionary containing the metrics to evaluate during k-fold
        num_epochs : int or None
            number of epochs to train (if not given, will either be extracted
            from ``config``, ``self.config`` or ``self.n_epochs``)
        num_splits : int or None
            the number of splits to extract from ``data``.
            If None: uses a default of 10
        shuffle : bool
            whether to shuffle the data before splitting or not (implemented by
            index-shuffling rather than actual data-shuffling to retain
            potentially lazy-behavior of datasets)
        random_seed : None
            seed to seed numpy, the splitting functions and the used
            backend-framework
        split_type : str
            must be one of ['random', 'stratified']
            if 'random': uses random data splitting
            if 'stratified': uses stratified data splitting. Stratification
            will be based on ``label_key``
        val_split : float or None
            the fraction of the train data to use as validation set. If None:
            No validation will be done during training; only testing for each
            fold after the training is complete
        label_key : str
            the label to use for stratification. Will be ignored unless
            ``split_type`` is 'stratified'. Default: 'label'
        train_kwargs : dict or None
            kwargs to update the behavior of the :class:`BaseDataManager`
            containing the train data. If None: empty dict will be passed
        metric_keys : dict of tuples
            the batch_dict keys to use for each metric to calculate.
            Should contain a value for each key in ``metrics``.
            If no values are given for a key, per default ``pred`` and
            ``label`` will be used for metric calculation
        test_kwargs : dict or None
            kwargs to update the behavior of the :class:`BaseDataManager`
            containing the test and validation data.
            If None: empty dict will be passed
        config : :class:`DeliraConfig`or None
            the training and model parameters
            (will be merged with ``self.config``)
        verbose : bool
            verbosity
        **kwargs :
            additional keyword arguments

        Returns
        -------
        dict
            all predictions from all folds
        dict
            all metric values from all folds

        Raises
        ------
        ValueError
            if ``split_type`` is neither 'random', nor 'stratified'

        See Also
        --------

        * :class:`sklearn.model_selection.KFold`
        and :class:`sklearn.model_selection.ShuffleSplit`
        for random data-splitting

        * :class:`sklearn.model_selection.StratifiedKFold`
        and :class:`sklearn.model_selection.StratifiedShuffleSplit`
        for stratified data-splitting

        * :meth:`BaseDataManager.update_from_state_dict` for updating the
        data managers by kwargs

        * :meth:`BaseExperiment.run` for the training

        * :meth:`BaseExperiment.test` for the testing

        Notes
        -----
        using stratified splits may be slow during split-calculation, since
        each item must be loaded once to obtain the labels necessary for
        stratification.

        """

        # set number of splits if not specified
        if num_splits is None:
            num_splits = 10
            logger.warning("num_splits not defined, using default value of \
                                    10 splits instead ")

        metrics_test, outputs = {}, {}
        split_idxs = list(range(len(data.dataset)))

        if train_kwargs is None:
            train_kwargs = {}
        if test_kwargs is None:
            test_kwargs = {}

        # switch between differnt kfold types
        if split_type == "random":
            split_cls = KFold
            val_split_cls = ShuffleSplit
            # split_labels are ignored for random splitting, set them to
            # split_idxs just ensures same length
            split_labels = split_idxs
        elif split_type == "stratified":
            split_cls = StratifiedKFold
            val_split_cls = StratifiedShuffleSplit
            # iterate over dataset to get labels for stratified splitting
            split_labels = [
                data.dataset[_idx][label_key] for _idx in split_idxs
            ]
        else:
            raise ValueError("split_type must be one of "
                             "['random', 'stratified'], but got: %s" %
                             str(split_type))

        fold = split_cls(n_splits=num_splits,
                         shuffle=shuffle,
                         random_state=random_seed)

        if random_seed is not None:
            np.random.seed(random_seed)

        # iterate over folds
        for idx, (train_idxs,
                  test_idxs) in enumerate(fold.split(split_idxs,
                                                     split_labels)):

            # extract data from single manager
            train_data = data.get_subset(train_idxs)
            test_data = data.get_subset(test_idxs)

            train_data.update_state_from_dict(copy.deepcopy(train_kwargs))
            test_data.update_state_from_dict(copy.deepcopy(test_kwargs))

            val_data = None
            if val_split is not None:
                if split_type == "random":
                    # split_labels are ignored for random splitting, set them
                    # to split_idxs just ensures same length
                    train_labels = train_idxs
                elif split_type == "stratified":
                    # iterate over dataset to get labels for stratified
                    # splitting
                    train_labels = [
                        train_data.dataset[_idx][label_key]
                        for _idx in train_idxs
                    ]
                else:
                    raise ValueError("split_type must be one of "
                                     "['random', 'stratified'], but got: %s" %
                                     str(split_type))

                _val_split = val_split_cls(n_splits=1,
                                           test_size=val_split,
                                           random_state=random_seed)

                for _train_idxs, _val_idxs in _val_split.split(
                        train_idxs, train_labels):
                    val_data = train_data.get_subset(_val_idxs)
                    val_data.update_state_from_dict(copy.deepcopy(test_kwargs))

                    train_data = train_data.get_subset(_train_idxs)

            model = self.run(train_data=train_data,
                             val_data=val_data,
                             config=config,
                             num_epochs=num_epochs,
                             fold=idx,
                             **kwargs)

            _outputs, _metrics_test = self.test(model,
                                                test_data,
                                                metrics=metrics,
                                                metric_keys=metric_keys,
                                                verbose=verbose)

            outputs[str(idx)] = _outputs
            metrics_test[str(idx)] = _metrics_test

        return outputs, metrics_test