def test_experiment(experiment_cls, config, network_cls, len_test, **kwargs): assert issubclass(experiment_cls, BaseExperiment) exp = experiment_cls(config, network_cls, **kwargs) dset_test = DummyDataset(len_test) dmgr_test = BaseDataManager(dset_test, 16, 1, None) model = network_cls() return exp.test(model, dmgr_test, config.nested_get("metrics", {}), kwargs.get("metric_keys", None))
def test_experiment_run_torch(self): from delira.training import PyTorchExperiment from delira.data_loading import BaseDataManager for case in self._test_cases_torch: with self.subTest(case=case): (params, dataset_length_train, dataset_length_test, val_score_key, val_score_mode, network_cls) = case exp = PyTorchExperiment(params, network_cls, key_mapping={"x": "data"}, val_score_key=val_score_key, val_score_mode=val_score_mode) dset_train = DummyDataset(dataset_length_train) dset_test = DummyDataset(dataset_length_test) dmgr_train = BaseDataManager(dset_train, 16, 2, None) dmgr_test = BaseDataManager(dset_test, 16, 1, None) exp.run(dmgr_train, dmgr_test)
def test_experiment(params, dataset_length_train, dataset_length_test): class DummyNetwork(ClassificationNetworkBasePyTorch): def __init__(self): super().__init__(32, 1) def forward(self, x): return self.module(x) @staticmethod def _build_model(in_channels, n_outputs): return torch.nn.Sequential(torch.nn.Linear(in_channels, 64), torch.nn.ReLU(), torch.nn.Linear(64, n_outputs)) class DummyDataset(AbstractDataset): def __init__(self, length): super().__init__(None, None, None, None) self.length = length def __getitem__(self, index): return { "data": np.random.rand(1, 32), "label": np.random.rand(1, 1) } def __len__(self): return self.length exp = PyTorchExperiment(params, DummyNetwork) dset_train = DummyDataset(dataset_length_train) dset_test = DummyDataset(dataset_length_test) dmgr_train = BaseDataManager(dset_train, 16, 4, None) dmgr_test = BaseDataManager(dset_test, 16, 1, None) exp.run(dmgr_train, dmgr_test)
def test_experiment_test_tf(self): from delira.training import TfExperiment from delira.data_loading import BaseDataManager for case in self._test_cases_tf: with self.subTest(case=case): (params, dataset_length_train, dataset_length_test, network_cls) = case exp = TfExperiment(params, network_cls, key_mapping={"images": "data"}, ) model = network_cls() dset_test = DummyDataset(dataset_length_test) dmgr_test = BaseDataManager(dset_test, 16, 1, None) exp.test(model, dmgr_test, params.nested_get("val_metrics"))
def t_experiment(self, raise_error=False, expected_msg=None): dummy_exp = DummyExperiment() dummy_exp = self.messenger_cls(dummy_exp, **self.messenger_kwargs) dset_test = DummyDataset(10) dmgr_test = BaseDataManager(dset_test, 2, 1, None) model = DummyNetwork() with self.assertLogs(logger, level='INFO') as cm: if raise_error: with self.assertRaises(RuntimeError): dummy_exp.test(model, dmgr_test, {}, raise_error=True) else: dummy_exp.test(model, dmgr_test, {}, raise_error=False) if expected_msg is None or not expected_msg: logger.info("NoExpectedMessage") if expected_msg is None or not expected_msg: self.assertEqual(cm.output, ["INFO:UnitTestMessenger:NoExpectedMessage"]) else: self.assertEqual(cm.output, expected_msg)
def setUp(self): self.dset = DummyDataset(20) self.dmgr = BaseDataManager(self.dset, 4, 1, transforms=None)
def train_shapenet(): """ Trains a single shapenet with config file from comandline arguments See Also -------- :class:`delira.training.PyTorchNetworkTrainer` """ import logging import numpy as np import torch from shapedata.single_shape import SingleShapeDataset from delira.training import PyTorchNetworkTrainer from ..utils import Config from ..layer import HomogeneousShapeLayer from ..networks import SingleShapeNetwork from delira.logging import TrixiHandler from trixi.logger import PytorchVisdomLogger from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch from delira.data_loading import BaseDataManager, RandomSampler, \ SequentialSampler import os import argparse parser = argparse.ArgumentParser() parser.add_argument("-c", "--config", type=str, help="Path to configuration file") parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() config = Config() config_dict = config(os.path.abspath(args.config)) shapes = np.load(os.path.abspath(config_dict["layer"].pop( "pca_path")))["shapes"][:config_dict["layer"].pop("num_shape_params") + 1] # layer_cls = HomogeneousShapeLayer net = SingleShapeNetwork(HomogeneousShapeLayer, { "shapes": shapes, **config_dict["layer"] }, img_size=config_dict["data"]["img_size"], **config_dict["network"]) num_params = 0 for param in net.parameters(): num_params += param.numel() if args.verbose: print("Number of Parameters: %d" % num_params) criterions = {"L1": torch.nn.L1Loss()} metrics = {"MSE": torch.nn.MSELoss()} mixed_prec = config_dict["training"].pop("mixed_prec", False) config_dict["training"]["save_path"] = os.path.abspath( config_dict["training"]["save_path"]) trainer = PyTorchNetworkTrainer( net, criterions=criterions, metrics=metrics, lr_scheduler_cls=ReduceLROnPlateauCallbackPyTorch, lr_scheduler_params=config_dict["scheduler"], optimizer_cls=torch.optim.Adam, optimizer_params=config_dict["optimizer"], mixed_precision=mixed_prec, **config_dict["training"]) if args.verbose: print(trainer.input_device) print("Load Data") dset_train = SingleShapeDataset( os.path.abspath(config_dict["data"]["train_path"]), config_dict["data"]["img_size"], config_dict["data"]["crop"], config_dict["data"]["landmark_extension_train"], cached=config_dict["data"]["cached"], rotate=config_dict["data"]["rotate_train"], random_offset=config_dict["data"]["offset_train"]) if config_dict["data"]["test_path"]: dset_val = SingleShapeDataset( os.path.abspath(config_dict["data"]["test_path"]), config_dict["data"]["img_size"], config_dict["data"]["crop"], config_dict["data"]["landmark_extension_test"], cached=config_dict["data"]["cached"], rotate=config_dict["data"]["rotate_test"], random_offset=config_dict["data"]["offset_test"]) else: dset_val = None mgr_train = BaseDataManager( dset_train, batch_size=config_dict["data"]["batch_size"], n_process_augmentation=config_dict["data"]["num_workers"], transforms=None, sampler_cls=RandomSampler) mgr_val = BaseDataManager( dset_val, batch_size=config_dict["data"]["batch_size"], n_process_augmentation=config_dict["data"]["num_workers"], transforms=None, sampler_cls=SequentialSampler) if args.verbose: print("Data loaded") if config_dict["logging"].pop("enable", False): logger_cls = PytorchVisdomLogger logging.basicConfig( level=logging.INFO, handlers=[TrixiHandler(logger_cls, **config_dict["logging"])]) else: logging.basicConfig(level=logging.INFO, handlers=[logging.NullHandler()]) logger = logging.getLogger("Test Logger") logger.info("Start Training") if args.verbose: print("Start Training") trainer.train(config_dict["training"]["num_epochs"], mgr_train, mgr_val, config_dict["training"]["val_score_key"], val_score_mode='lowest')
def test_experiment_kfold_torch(self): from delira.training import PyTorchExperiment from delira.data_loading import BaseDataManager from copy import deepcopy # all test cases for case in self._test_cases_torch: with self.subTest(case=case): (params, dataset_length_train, dataset_length_test, val_score_key, val_score_mode, network_cls) = case # both split_types for split_type in ["random", "stratified", "error"]: with self.subTest(split_type=split_type): if split_type == "error": # must raise ValueError with self.assertRaises(ValueError): exp = PyTorchExperiment( params, network_cls, key_mapping={"x": "data"}, val_score_key=val_score_key, val_score_mode=val_score_mode) dset = DummyDataset( dataset_length_test + dataset_length_train) dmgr = BaseDataManager(dset, 16, 1, None) exp.kfold( dmgr, params.nested_get("val_metrics"), shuffle=True, split_type=split_type, num_splits=2) continue # check all types of validation data for val_split in [0.2, None]: with self.subTest(val_split=val_split): # disable lr scheduling if no validation data # is present _params = deepcopy(params) if val_split is None: _params["fixed"]["training" ]["lr_sched_cls"] = None exp = PyTorchExperiment( _params, network_cls, key_mapping={"x": "data"}, val_score_key=val_score_key, val_score_mode=val_score_mode) dset = DummyDataset( dataset_length_test + dataset_length_train) dmgr = BaseDataManager(dset, 16, 1, None) exp.kfold( dmgr, params.nested_get("val_metrics"), shuffle=True, split_type=split_type, val_split=val_split, num_splits=2)
from batchgenerators.transforms import RandomCropTransform, Compose from batchgenerators.transforms.spatial_transforms import ResizeTransform,SpatialTransform transforms = Compose([ #SpatialTransform(patch_size=(1024, 1024), do_rotation=True, patch_center_dist_from_border=1024, border_mode_data='reflect', # border_mode_seg='reflect', angle_x=(args.rot_angle, args.rot_angle), angle_y=(0, 0), angle_z=(0, 0), # do_elastic_deform=False, order_data=1, order_seg=1) ResizeTransform((int(args.resize_size), int(args.resize_size)), order=1), RandomCropTransform((params.nested_get("image_size"), params.nested_get("image_size"))), ]) from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler manager_test = BaseDataManager(dataset_test, params.nested_get("batch_size"), transforms=transforms, sampler_cls=SequentialSampler, n_process_augmentation=1) import warnings warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code ### Setting path for loading best checkpoint from delira.training.tf_trainer import tf_load_checkpoint test_path = '/../checkpoint_best'
def run_experiment(cp: str, test=True) -> str: """ Run classification experiment on patches Imports moved inside because of logging setups Parameters ---------- ch : str path to config file test : bool test best model on test set Returns ------- str path to experiment folder """ # setup config ch = ConfigHandlerPyTorchDelira(cp) ch = feature_map_params(ch) if 'mixed_precision' not in ch or ch['mixed_precision'] is None: ch['mixed_precision'] = True if 'debug_delira' in ch and ch['debug_delira'] is not None: delira.set_debug_mode(ch['debug_delira']) print("Debug mode active: settings n_process_augmentation to 1!") ch['augment.n_process'] = 1 dset_keys = ['train', 'val', 'test'] losses = {'class_ce': torch.nn.CrossEntropyLoss()} train_metrics = {} val_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())} test_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())} ######################### # Setup Parameters # ######################### params_dict = ch.get_params(losses=losses, train_metrics=train_metrics, val_metrics=val_metrics, add_self=ch['add_config_to_params']) params = Parameters(**params_dict) ################# # Setup IO # ################# # setup io load_sample = load_pickle load_fn = LoadPatches(load_fn=load_sample, patch_size=ch['patch_size'], **ch['data.load_patch']) datasets = {} for key in dset_keys: p = os.path.join(ch["data.path"], str(key)) datasets[key] = BaseExtendCacheDataset(p, load_fn=load_fn, **ch['data.kwargs']) ############################# # Setup Transformations # ############################# base_transforms = [] base_transforms.append(PopKeys("mapping")) train_transforms = [] if ch['augment.mode']: logger.info("Training augmentation enabled.") train_transforms.append( SpatialTransform(patch_size=ch['patch_size'], **ch['augment.kwargs'])) train_transforms.append(MirrorTransform(axes=(0, 1))) process = ch['augment.n_process'] if 'augment.n_process' in ch else 1 ######################### # Setup Datamanagers # ######################### datamanagers = {} for key in dset_keys: if key == 'train': trafos = base_transforms + train_transforms sampler = WeightedPrevalenceRandomSampler else: trafos = base_transforms sampler = SequentialSampler datamanagers[key] = BaseDataManager( data=datasets[key], batch_size=params.nested_get('batch_size'), n_process_augmentation=process, transforms=Compose(trafos), sampler_cls=sampler, ) ############################# # Initialize Experiment # ############################# experiment = \ PyTorchExperiment( params=params, model_cls=ClassNetwork, name=ch['exp.name'], save_path=ch['exp.dir'], optim_builder=create_optims_default_pytorch, trainer_cls=PyTorchNetworkTrainer, mixed_precision=ch['mixed_precision'], mixed_precision_kwargs={'verbose': False}, key_mapping={"input_batch": "data"}, **ch['exp.kwargs'], ) # save configurations ch.dump(os.path.join(experiment.save_path, 'config.json')) ################# # Training # ################# model = experiment.run(datamanagers['train'], datamanagers['val'], save_path_exp=experiment.save_path, ch=ch, metric_keys={'val_CE': ['pred', 'label']}, val_freq=1, verbose=True) ################ # Testing # ################ if test and datamanagers['test'] is not None: # metrics and metric_keys are used differently than in original # Delira implementation in order to support Evaluator # see mscl.training.predictor preds = experiment.test( network=model, test_data=datamanagers['test'], metrics=test_metrics, metric_keys={'CE': ['pred', 'label']}, verbose=True, ) softmax_fn = metric_wrapper_pytorch( partial(torch.nn.functional.softmax, dim=1)) preds = softmax_fn(preds[0]['pred']) labels = [d['label'] for d in datasets['test']] fpr, tpr, thresholds = roc_curve(labels, preds[:, 1]) roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") plt.savefig(os.path.join(experiment.save_path, 'test_roc.pdf')) plt.close() preds = experiment.test( network=model, test_data=datamanagers['val'], metrics=test_metrics, metric_keys={'CE': ['pred', 'label']}, verbose=True, ) preds = softmax_fn(preds[0]['pred']) labels = [d['label'] for d in datasets['val']] fpr, tpr, thresholds = roc_curve(labels, preds[:, 1]) roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") plt.savefig(os.path.join(experiment.save_path, 'best_val_roc.pdf')) plt.close() return experiment.save_path
transforms = Compose([ #SpatialTransform(patch_size=(1024, 1024), do_rotation=True, patch_center_dist_from_border=1024, border_mode_data='reflect', # border_mode_seg='reflect', angle_x=(args.rot_angle, args.rot_angle), angle_y=(0, 0), angle_z=(0, 0), # do_elastic_deform=False, order_data=1, order_seg=1) ResizeTransform((int(args.resize_size), int(args.resize_size)), order=1), RandomCropTransform( (params.nested_get("image_size"), params.nested_get("image_size"))), ]) ### Data manager from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler manager_train = BaseDataManager(dataset_train, params.nested_get("batch_size"), transforms=transforms, sampler_cls=RandomSampler, n_process_augmentation=4) manager_val = BaseDataManager(dataset_val, params.nested_get("val_batch_size"), transforms=transforms, sampler_cls=SequentialSampler, n_process_augmentation=4) import warnings warnings.simplefilter( "ignore", UserWarning) # ignore UserWarnings raised by dependency code warnings.simplefilter( "ignore", FutureWarning) # ignore FutureWarnings raised by dependency code
##################### # Augmentation # ##################### base_transforms = [ZeroMeanUnitVarianceTransform(), ] train_transforms = [SpatialTransform(patch_size=(200, 200), random_crop=False, ), ] ##################### # Datamanagers # ##################### manager_train = BaseDataManager(dataset_train, params.nested_get("batch_size"), transforms=Compose( base_transforms + train_transforms), sampler_cls=RandomSampler, n_process_augmentation=n_process_augmentation) manager_val = BaseDataManager(dataset_val, params.nested_get("batch_size"), transforms=Compose(base_transforms), sampler_cls=SequentialSampler, n_process_augmentation=n_process_augmentation) manager_test = BaseDataManager(dataset_test, 1, transforms=Compose(base_transforms), sampler_cls=SequentialSampler, n_process_augmentation=n_process_augmentation) logger.info("Init Experiment") experiment = PyTorchExperiment(params,
def test_experiment(self): from delira.training import PyTorchExperiment, Parameters from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch from delira.models.classification import ClassificationNetworkBasePyTorch from delira.data_loading import AbstractDataset, BaseDataManager import torch test_cases = [(Parameters( fixed_params={ "model": {}, "training": { "criterions": { "CE": torch.nn.CrossEntropyLoss() }, "optimizer_cls": torch.optim.Adam, "optimizer_params": { "lr": 1e-3 }, "num_epochs": 2, "metrics": {}, "lr_sched_cls": ReduceLROnPlateauCallbackPyTorch, "lr_sched_params": {} } }), 500, 50)] class DummyNetwork(ClassificationNetworkBasePyTorch): def __init__(self): super().__init__(32, 1) def forward(self, x): return self.module(x) @staticmethod def _build_model(in_channels, n_outputs): return torch.nn.Sequential(torch.nn.Linear(in_channels, 64), torch.nn.ReLU(), torch.nn.Linear(64, n_outputs)) @staticmethod def prepare_batch(batch_dict, input_device, output_device): return { "data": torch.from_numpy(batch_dict["data"]).to( input_device, torch.float), "label": torch.from_numpy(batch_dict["label"]).to( output_device, torch.long) } class DummyDataset(AbstractDataset): def __init__(self, length): super().__init__(None, None, None, None) self.length = length def __getitem__(self, index): return { "data": np.random.rand(32), "label": np.random.randint(0, 1, 1) } def __len__(self): return self.length def get_sample_from_index(self, index): return self.__getitem__(index) for case in test_cases: with self.subTest(case=case): params, dataset_length_train, dataset_length_test = case exp = PyTorchExperiment(params, DummyNetwork) dset_train = DummyDataset(dataset_length_train) dset_test = DummyDataset(dataset_length_test) dmgr_train = BaseDataManager(dset_train, 16, 4, None) dmgr_test = BaseDataManager(dset_test, 16, 1, None) net = exp.run(dmgr_train, dmgr_test) exp.test( params=params, network=net, datamgr_test=dmgr_test, ) exp.kfold(2, dmgr_train, num_splits=2) exp.stratified_kfold(2, dmgr_train, num_splits=2) exp.stratified_kfold_predict(2, dmgr_train, num_splits=2)
def test_experiment(self): from delira.training import TfExperiment, Parameters from delira.models.classification import ClassificationNetworkBaseTf from delira.data_loading import AbstractDataset, BaseDataManager import tensorflow as tf test_cases = [(Parameters( fixed_params={ "model": { 'in_channels': 32, 'n_outputs': 1 }, "training": { "criterions": { "CE": tf.losses.softmax_cross_entropy }, "optimizer_cls": tf.train.AdamOptimizer, "optimizer_params": { "learning_rate": 1e-3 }, "num_epochs": 2, "metrics": {}, "lr_sched_cls": None, "lr_sched_params": {} } }), 500, 50)] class DummyNetwork(ClassificationNetworkBaseTf): def __init__(self): super().__init__(32, 1) self.model = self._build_model(1) images = tf.placeholder(shape=[None, 32], dtype=tf.float32) labels = tf.placeholder(shape=[None, 1], dtype=tf.float32) preds_train = self.model(images, training=True) preds_eval = self.model(images, training=False) self.inputs = [images, labels] self.outputs_train = [preds_train] self.outputs_eval = [preds_eval] @staticmethod def _build_model(n_outputs): return tf.keras.models.Sequential(layers=[ tf.keras.layers.Dense(64, input_shape=(32, ), bias_initializer='glorot_uniform'), tf.keras.layers.ReLU(), tf.keras.layers.Dense(n_outputs, bias_initializer='glorot_uniform') ]) class DummyDataset(AbstractDataset): def __init__(self, length): super().__init__(None, None, None, None) self.length = length def __getitem__(self, index): return { "data": np.random.rand(32), "label": np.random.randint(0, 1, 1) } def __len__(self): return self.length def get_sample_from_index(self, index): return self.__getitem__(index) for case in test_cases: with self.subTest(case=case): params, dataset_length_train, dataset_length_test = case exp = TfExperiment(params, DummyNetwork) dset_train = DummyDataset(dataset_length_train) dset_test = DummyDataset(dataset_length_test) dmgr_train = BaseDataManager(dset_train, 16, 4, None) dmgr_test = BaseDataManager(dset_test, 16, 1, None) net = exp.run(dmgr_train, dmgr_test) exp.test( params=params, network=net, datamgr_test=dmgr_test, ) exp.kfold(2, dmgr_train, num_splits=2) exp.stratified_kfold(2, dmgr_train, num_splits=2) exp.stratified_kfold_predict(2, dmgr_train, num_splits=2)
def kfold(self, paths, num_splits=5, shuffle=True, random_seed=None, valid_size=0.1, **kwargs): """ Runs K-Fold Crossvalidation Parameters ---------- num_epochs: int number of epochs to train the model data: str path to root dir num_splits: None or int number of splits for kfold if None: len(data) splits will be validated shuffle: bool whether or not to shuffle indices for kfold random_seed: None or int random seed used to seed the kfold (if shuffle is true), pytorch and numpy valid_size : float, default: 0.1 relative size of validation dataset in relation to training set """ if random_seed is not None: torch.manual_seed(random_seed) np.random.seed(random_seed) if "dataset_type" in kwargs and kwargs["dataset_type"] is not None: dataset_type = kwargs["dataset_type"] if dataset_type != "INbreast" and dataset_type != "DDSM": raise ValueError("Unknown dataset!") else: raise ValueError("No dataset type!") train_splits, _ = utils.kfold_patientwise(paths, dataset_type=dataset_type, num_splits=num_splits, shuffle=shuffle, random_state=random_seed) for i in range(len(train_splits)): train_paths, val_paths, _ = \ utils.split_paths_patientwise(train_splits[i], dataset_type=dataset_type, train_size= 1 - valid_size) dataset_train = self.dataset_cls(path_list=train_paths, **self.dataset_train_kwargs) dataset_valid = self.dataset_cls(path_list=val_paths, **self.dataset_val_kwargs) mgr_train = BaseDataManager(dataset_train, **self.datamgr_train_kwargs) mgr_valid = BaseDataManager(dataset_valid, **self.datamgr_val_kwargs) super().run(mgr_train, mgr_valid, fold=i, **kwargs)
def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None, num_splits=None, shuffle=False, random_seed=None, split_type="random", val_split=0.2, label_key="label", train_kwargs: dict = None, metric_keys: dict = None, test_kwargs: dict = None, config=None, verbose=False, **kwargs): """ Performs a k-Fold cross-validation Parameters ---------- data : :class:`BaseDataManager` the data to use for training(, validation) and testing. Will be split based on ``split_type`` and ``val_split`` metrics : dict dictionary containing the metrics to evaluate during k-fold num_epochs : int or None number of epochs to train (if not given, will either be extracted from ``config``, ``self.config`` or ``self.n_epochs``) num_splits : int or None the number of splits to extract from ``data``. If None: uses a default of 10 shuffle : bool whether to shuffle the data before splitting or not (implemented by index-shuffling rather than actual data-shuffling to retain potentially lazy-behavior of datasets) random_seed : None seed to seed numpy, the splitting functions and the used backend-framework split_type : str must be one of ['random', 'stratified'] if 'random': uses random data splitting if 'stratified': uses stratified data splitting. Stratification will be based on ``label_key`` val_split : float or None the fraction of the train data to use as validation set. If None: No validation will be done during training; only testing for each fold after the training is complete label_key : str the label to use for stratification. Will be ignored unless ``split_type`` is 'stratified'. Default: 'label' train_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the train data. If None: empty dict will be passed metric_keys : dict of tuples the batch_dict keys to use for each metric to calculate. Should contain a value for each key in ``metrics``. If no values are given for a key, per default ``pred`` and ``label`` will be used for metric calculation test_kwargs : dict or None kwargs to update the behavior of the :class:`BaseDataManager` containing the test and validation data. If None: empty dict will be passed config : :class:`DeliraConfig`or None the training and model parameters (will be merged with ``self.config``) verbose : bool verbosity **kwargs : additional keyword arguments Returns ------- dict all predictions from all folds dict all metric values from all folds Raises ------ ValueError if ``split_type`` is neither 'random', nor 'stratified' See Also -------- * :class:`sklearn.model_selection.KFold` and :class:`sklearn.model_selection.ShuffleSplit` for random data-splitting * :class:`sklearn.model_selection.StratifiedKFold` and :class:`sklearn.model_selection.StratifiedShuffleSplit` for stratified data-splitting * :meth:`BaseDataManager.update_from_state_dict` for updating the data managers by kwargs * :meth:`BaseExperiment.run` for the training * :meth:`BaseExperiment.test` for the testing Notes ----- using stratified splits may be slow during split-calculation, since each item must be loaded once to obtain the labels necessary for stratification. """ # set number of splits if not specified if num_splits is None: num_splits = 10 logger.warning("num_splits not defined, using default value of \ 10 splits instead ") metrics_test, outputs = {}, {} split_idxs = list(range(len(data.dataset))) if train_kwargs is None: train_kwargs = {} if test_kwargs is None: test_kwargs = {} # switch between differnt kfold types if split_type == "random": split_cls = KFold val_split_cls = ShuffleSplit # split_labels are ignored for random splitting, set them to # split_idxs just ensures same length split_labels = split_idxs elif split_type == "stratified": split_cls = StratifiedKFold val_split_cls = StratifiedShuffleSplit # iterate over dataset to get labels for stratified splitting split_labels = [ data.dataset[_idx][label_key] for _idx in split_idxs ] else: raise ValueError("split_type must be one of " "['random', 'stratified'], but got: %s" % str(split_type)) fold = split_cls(n_splits=num_splits, shuffle=shuffle, random_state=random_seed) if random_seed is not None: np.random.seed(random_seed) # iterate over folds for idx, (train_idxs, test_idxs) in enumerate(fold.split(split_idxs, split_labels)): # extract data from single manager train_data = data.get_subset(train_idxs) test_data = data.get_subset(test_idxs) train_data.update_state_from_dict(copy.deepcopy(train_kwargs)) test_data.update_state_from_dict(copy.deepcopy(test_kwargs)) val_data = None if val_split is not None: if split_type == "random": # split_labels are ignored for random splitting, set them # to split_idxs just ensures same length train_labels = train_idxs elif split_type == "stratified": # iterate over dataset to get labels for stratified # splitting train_labels = [ train_data.dataset[_idx][label_key] for _idx in train_idxs ] else: raise ValueError("split_type must be one of " "['random', 'stratified'], but got: %s" % str(split_type)) _val_split = val_split_cls(n_splits=1, test_size=val_split, random_state=random_seed) for _train_idxs, _val_idxs in _val_split.split( train_idxs, train_labels): val_data = train_data.get_subset(_val_idxs) val_data.update_state_from_dict(copy.deepcopy(test_kwargs)) train_data = train_data.get_subset(_train_idxs) model = self.run(train_data=train_data, val_data=val_data, config=config, num_epochs=num_epochs, fold=idx, **kwargs) _outputs, _metrics_test = self.test(model, test_data, metrics=metrics, metric_keys=metric_keys, verbose=verbose) outputs[str(idx)] = _outputs metrics_test[str(idx)] = _metrics_test return outputs, metrics_test