Example #1
0
    def __init__(self,
                 output_folder,
                 list_file=None,
                 config_file=None,
                 tf_log_level=None,
                 discover_tomls=True):
        """
        Set the attributes of the Configuration object.

        Instead of using a config_file, the attributes of orga.cfg can
        also be changed directly, e.g. by calling orga.cfg.batchsize.

        Parameters
        ----------
        output_folder : str
            Name of the folder of this model in which everything will be saved,
            e.g., the summary.txt log file is located in here.
            Will be used to load saved files or to save new ones.
        list_file : str, optional
            Path to a toml list file with pathes to all the h5 files that should
            be used for training and validation.
            Will be used to extract samples and labels.
            Default: Look for a file called 'list.toml' in the given output_folder.
        config_file : str, optional
            Path to a toml config file with settings that are used instead of
            the default ones.
            Default: Look for a file called 'config.toml' in the given output_folder.
        tf_log_level : int/str
            Sets the TensorFlow CPP_MIN_LOG_LEVEL environment variable.
            0 = all messages are logged (default behavior).
            1 = INFO messages are not printed.
            2 = INFO and WARNING messages are not printed.
            3 = INFO, WARNING, and ERROR messages are not printed.
        discover_tomls : bool
            If False, do not try to look for toml files in the given
            output_folder if list_file or config_file is None [Default: True].

        """
        if tf_log_level is not None:
            os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(tf_log_level)

        if discover_tomls and list_file is None:
            list_file = orcanet.misc.find_file(output_folder, "list.toml")
        if discover_tomls and config_file is None:
            config_file = orcanet.misc.find_file(output_folder, "config.toml")

        self.cfg = Configuration(output_folder, list_file, config_file)
        self.io = IOHandler(self.cfg)
        self.history = HistoryHandler(output_folder)

        self.xs_mean = None
        self._auto_label_modifier = None
        self._stored_model = None
        self._strategy = None
Example #2
0
    def setUp(self):
        self.temp_dir = os.path.join(
            os.path.dirname(__file__), ".temp", "TestIOHandlerLR")
        os.makedirs(self.temp_dir)

        cfg = Configuration(self.temp_dir, None, None)
        self.io = IOHandler(cfg)
        self.get_learning_rate = self.io.get_learning_rate
Example #3
0
    def setUp(self):
        self.data_folder = os.path.join(os.path.dirname(__file__), "data")
        self.output_folder = self.data_folder + "/dummy_model"

        list_file = self.data_folder + "/in_out_test_list.toml"
        config_file = None

        cfg = Configuration(self.output_folder, list_file, config_file)
        self.batchsize = 3
        cfg.batchsize = self.batchsize
        self.io = IOHandler(cfg)

        # mock get_subfolder, but only in case of predictions argument
        original_get_subfolder = self.io.get_subfolder
        mocked_result = self.pred_dir

        def side_effect(key, create=False):
            if key == 'predictions':
                return mocked_result
            else:
                return original_get_subfolder(key, create)

        self.io.get_subfolder = MagicMock(side_effect=side_effect)
Example #4
0
class Organizer:
    """
    Core class for working with networks in OrcaNet.

    Attributes
    ----------
    cfg : orcanet.core.Configuration
        Contains all configurable options.
    io : orcanet.in_out.IOHandler
        Utility functions for accessing the info in cfg.
    history : orcanet.in_out.HistoryHandler
        For reading and plotting data from the log files created
        during training.

    """
    def __init__(self, output_folder,
                 list_file=None,
                 config_file=None,
                 tf_log_level=None,
                 discover_tomls=True):
        """
        Set the attributes of the Configuration object.

        Instead of using a config_file, the attributes of orga.cfg can
        also be changed directly, e.g. by calling orga.cfg.batchsize.

        Parameters
        ----------
        output_folder : str
            Name of the folder of this model in which everything will be saved,
            e.g., the summary.txt log file is located in here.
            Will be used to load saved files or to save new ones.
        list_file : str, optional
            Path to a toml list file with pathes to all the h5 files that should
            be used for training and validation.
            Will be used to extract samples and labels.
            Default: Look for a file called 'list.toml' in the given output_folder.
        config_file : str, optional
            Path to a toml config file with settings that are used instead of
            the default ones.
            Default: Look for a file called 'config.toml' in the given output_folder.
        tf_log_level : int/str
            Sets the TensorFlow CPP_MIN_LOG_LEVEL environment variable.
            0 = all messages are logged (default behavior).
            1 = INFO messages are not printed.
            2 = INFO and WARNING messages are not printed.
            3 = INFO, WARNING, and ERROR messages are not printed.
        discover_tomls : bool
            If False, do not try to look for toml files in the given
            output_folder if list_file or config_file is None [Default: True].

        """
        if tf_log_level is not None:
            os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(tf_log_level)

        if discover_tomls and list_file is None:
            list_file = orcanet.misc.find_file(output_folder, "list.toml")
        if discover_tomls and config_file is None:
            config_file = orcanet.misc.find_file(output_folder, "config.toml")

        self.cfg = Configuration(output_folder, list_file, config_file)
        self.io = IOHandler(self.cfg)
        self.history = HistoryHandler(output_folder)

        self.xs_mean = None
        self._auto_label_modifier = None
        self._stored_model = None
        self._strategy = None

    def train_and_validate(self, model=None, epochs=None, to_epoch=None):
        """
        Train a model and validate according to schedule.

        The various settings of this process can be controlled with the
        attributes of orca.cfg.
        The model will be trained on the given data, saved and validated.
        Logfiles of the training are saved in the output folder.
        Plots showing the training and validation history, as well as
        the weights and activations of the network are generated in
        the plots subfolder after every validation.
        The training can be resumed by executing this function again.

        Parameters
        ----------
        model : ks.models.Model or str, optional
            Compiled keras model to use for training. Required for the first
            epoch (the start of training).
            Can also be the path to a saved keras model, which will be laoded.
            If model is None, the most recent saved model will be
            loaded automatically to continue the training.
        epochs : int, optional
            How many epochs should be trained by running this function.
            None for infinite. This includes the current epoch in case it
            is not finished yet, i.e. 1 means complete the epoch if there
            are files left, otherwise do the next epoch.
        to_epoch : int, optional
            Train up to and including this epoch. Can not be used together with
            epochs.

        Returns
        -------
        model : ks.models.Model
            The trained keras model.

        """
        latest_epoch = self.io.get_latest_epoch()

        model = self._get_model(model, logging=False)
        self._stored_model = model

        # check if the validation is missing for the latest fileno
        if latest_epoch is not None:
            state = self.history.get_state()[-1]
            if state["is_validated"] is False and self.val_is_due(latest_epoch):
                self.validate()

        next_epoch = self.io.get_next_epoch(latest_epoch)
        n_train_files = self.io.get_no_of_files("train")

        if to_epoch is None:
            epochs_left = epochs
        else:
            if epochs is not None:
                raise ValueError("Can not give both 'epochs' and 'to_epoch'")
            if latest_epoch is None:
                epochs_left = to_epoch
            else:
                epochs_left = max(
                    0, to_epoch - self.io.get_next_epoch(latest_epoch)[0] + 1)

        trained_epochs = 0
        while epochs_left is None or trained_epochs < epochs_left:
            # Train on remaining files
            for file_no in range(next_epoch[1], n_train_files + 1):
                curr_epoch = (next_epoch[0], file_no)
                self.train(model)
                if self.val_is_due(curr_epoch):
                    self.validate()

            next_epoch = (next_epoch[0] + 1, 1)
            trained_epochs += 1

        self._stored_model = None
        return model

    def train(self, model=None):
        """
        Trains a model on the next file.

        The progress of the training is also logged and plotted.

        Parameters
        ----------
        model : ks.models.Model or str, optional
            Compiled keras model to use for training. Required for the first
            epoch (the start of training).
            Can also be the path to a saved keras model, which will be laoded.
            If model is None, the most recent saved model will be
            loaded automatically to continue the training.

        Returns
        -------
        history : dict
            The history of the training on this file. A record of training
            loss values and metrics values.

        """
        # Create folder structure
        self.io.get_subfolder(create=True)
        latest_epoch = self.io.get_latest_epoch()

        model = self._get_model(model, logging=True)

        self._set_up(model, logging=True)

        # epoch about to be trained
        next_epoch = self.io.get_next_epoch(latest_epoch)
        next_epoch_float = self.io.get_epoch_float(*next_epoch)

        if latest_epoch is None:
            self.io.check_connections(model)
            logging.log_start_training(self)

        model_path = self.io.get_model_path(*next_epoch)
        model_path_local = self.io.get_model_path(*next_epoch, local=True)
        if os.path.isfile(model_path):
            raise FileExistsError(
                "Can not train model in epoch {} file {}, this model has "
                "already been saved!".format(*next_epoch))

        smry_logger = logging.SummaryLogger(self, model)

        if self.cfg.learning_rate is not None:
            tf.keras.backend.set_value(
                model.optimizer.lr, self.io.get_learning_rate(next_epoch)
            )

        files_dict = self.io.get_file("train", next_epoch[1])

        line = "Training in epoch {} on file {}/{}".format(
            next_epoch[0], next_epoch[1], self.io.get_no_of_files("train"))
        self.io.print_log(line)
        self.io.print_log("-" * len(line))
        self.io.print_log("Learning rate is at {}".format(
            tf.keras.backend.get_value(model.optimizer.lr)))
        self.io.print_log('Inputs and files:')
        for input_name, input_file in files_dict.items():
            self.io.print_log("   {}: \t{}".format(input_name,
                                                   os.path.basename(
                                                       input_file)))

        start_time = time.time()
        history = backend.train_model(self, model, next_epoch, batch_logger=True)
        elapsed_s = int(time.time() - start_time)

        model.save(model_path)
        smry_logger.write_line(
            next_epoch_float,
            tf.keras.backend.get_value(model.optimizer.lr),
            history_train=history,
        )

        self.io.print_log('Training results:')
        for metric_name, loss in history.items():
            self.io.print_log(f"   {metric_name}: \t{loss}")
        self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}")
        self.io.print_log(f"Saved model to: {model_path_local}\n")

        update_summary_plot(self)
        if self.cfg.cleanup_models:
            self.cleanup_models()

        return history

    def validate(self):
        """
        Validate the most recent saved model on all validation files.

        Will also log the progress, as well as update the summary plot and
        plot weights and activations of the model.

        Returns
        -------
        history : dict
            The history of the validation on all files. A record of validation
            loss values and metrics values.

        """
        latest_epoch = self.io.get_latest_epoch()
        if latest_epoch is None:
            raise ValueError("Can not validate: No saved model found")
        if self.history.get_state()[-1]["is_validated"] is True:
            raise ValueError("Can not validate in epoch {} file {}: "
                             "Has already been validated".format(*latest_epoch))

        if self._stored_model is None:
            model = self.load_saved_model(*latest_epoch)
        else:
            model = self._stored_model

        self._set_up(model, logging=True)

        epoch_float = self.io.get_epoch_float(*latest_epoch)
        smry_logger = logging.SummaryLogger(self, model)

        logging.log_start_validation(self)

        start_time = time.time()
        history = backend.validate_model(self, model)
        elapsed_s = int(time.time() - start_time)

        self.io.print_log('Validation results:')
        for metric_name, loss in history.items():
            self.io.print_log(f"   {metric_name}: \t{loss}")
        self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}\n")
        smry_logger.write_line(epoch_float, "n/a", history_val=history)

        update_summary_plot(self)

        if self.cfg.cleanup_models:
            self.cleanup_models()

        return history

    def predict(self, epoch=None, fileno=None, samples=None):
        """
        Make a prediction if it does not exist yet, and return its filepath.

        Load the model with the lowest validation loss, let it predict on
        all samples of the validation set
        in the toml list, and save this prediction together with all the
        y_values as h5 file(s) in the predictions subfolder.

        Parameters
        ----------
        epoch : int, optional
            Epoch of a model to load. Default: lowest val loss.
        fileno : int, optional
            File number of a model to load. Default: lowest val loss.
        samples : int, optional
            Don't use the full validation files, but just the given number
            of samples.

        Returns
        -------
        pred_filename : List
            List to the paths of all the prediction files.

        """
        if fileno is None and epoch is None:
            epoch, fileno = self.history.get_best_epoch_fileno()
            print(f"Automatically set epoch to epoch {epoch} file {fileno}.")
        elif fileno is None or epoch is None:
            raise ValueError(
                "Either both or none of epoch and fileno must be None")

        if self._check_if_pred_already_done(epoch, fileno):
            print("Prediction has already been done.")
            pred_filepaths = self.io.get_pred_files_list(epoch, fileno)

        else:
            if self._stored_model is None:
                model = self.load_saved_model(epoch, fileno, logging=False)
            else:
                model = self._stored_model
            self._set_up(model)

            start_time = time.time()
            backend.make_model_prediction(
                self, model, epoch, fileno, samples=samples)
            elapsed_s = int(time.time() - start_time)
            print('Finished predicting on all validation files.')
            print("Elapsed time: {}\n".format(timedelta(seconds=elapsed_s)))

            pred_filepaths = self.io.get_pred_files_list(epoch, fileno)

        return pred_filepaths

    def inference(self, epoch=None, fileno=None, as_generator=False):
        """
        Make an inference and return the filepaths.

        Load the model with the lowest validation loss, let
        it predict on all samples of all inference files
        in the toml list, and save these predictions as h5 files in the
        predictions subfolder. y values will only be added if they are in
        the input file, so this can be used on un-labeled data as well.

        Parameters
        ----------
        epoch : int, optional
            Epoch of a model to load [default: lowest val loss].
        fileno : int, optional
            File number of a model to load [default: lowest val loss].
        as_generator : bool
            If true, return a generator, which yields the output filename
            after the inference of each file.
            If false (default), do all files back to back.

        Returns
        -------
        filenames : list
            List to the paths of all created output files.

        """
        if fileno is None and epoch is None:
            epoch, fileno = self.history.get_best_epoch_fileno()
            print("Automatically set epoch to epoch {} file {}.".format(epoch, fileno))
        elif fileno is None or epoch is None:
            raise ValueError(
                "Either both or none of epoch and fileno must be None")

        if self._stored_model is None:
            model = self.load_saved_model(epoch, fileno, logging=False)
        else:
            model = self._stored_model
        self._set_up(model)

        filenames = []
        for files_dict in self.io.yield_files("inference"):
            # output filename is based on name of file in first input
            first_filename = os.path.basename(list(files_dict.values())[0])
            output_filename = "model_epoch_{}_file_{}_on_{}".format(
                epoch, fileno, first_filename)

            output_path = os.path.join(self.io.get_subfolder("inference"),
                                       output_filename)
            filenames.append(output_path)
            if os.path.exists(output_path):
                warnings.warn("Warning: {} exists already, skipping "
                              "file".format(output_filename))
                continue

            print(f'Working on file {first_filename}')
            start_time = time.time()
            backend.h5_inference(
                self, model, files_dict, output_path, use_def_label=False)
            elapsed_s = int(time.time() - start_time)
            print(f'Finished on file {first_filename} in {elapsed_s/60} min')
            if as_generator:
                yield output_path

        return filenames

    def cleanup_models(self):
        """
        Delete all models except for the the most recent one (to continue
        training), and the ones with the highest and lowest loss/metrics.

        """
        all_epochs = self.io.get_all_epochs()
        epochs_to_keep = {self.io.get_latest_epoch(), }
        try:
            for metric in self.history.get_metrics():
                epochs_to_keep.add(
                    self.history.get_best_epoch_fileno(
                        metric=f"val_{metric}", mini=True))
                epochs_to_keep.add(
                    self.history.get_best_epoch_fileno(
                        metric=f"val_{metric}", mini=False))
        except ValueError:
            # no best epoch exists
            pass

        for epoch in epochs_to_keep:
            if epoch not in all_epochs:
                warnings.warn(
                    f"ERROR: keeping_epoch {epoch} not in available epochs {all_epochs}, "
                    f"skipping clean-up of models!")
                return

        print("\nClean-up saved models:")
        for epoch in all_epochs:
            model_path = self.io.get_model_path(epoch[0], epoch[1])
            model_name = os.path.basename(model_path)
            if epoch in epochs_to_keep:
                print("Keeping model {}".format(model_name))
            else:
                print("Deleting model {}".format(model_name))
                os.remove(model_path)

    def _check_if_pred_already_done(self, epoch, fileno):
        """
        Checks if the prediction has already been done before.
        (-> predicted on all validation files)

        Returns
        -------
        pred_done : bool
            Boolean flag to specify if the prediction has
            already been fully done or not.

        """
        latest_pred_file_no = self.io.get_latest_prediction_file_no(epoch, fileno)
        total_no_of_val_files = self.io.get_no_of_files('val')

        if latest_pred_file_no is None:
            pred_done = False
        elif latest_pred_file_no == total_no_of_val_files:
            return True
        else:
            pred_done = False

        return pred_done

    def get_xs_mean(self, logging=False):
        """
        Set and return the zero center image for each list input.

        Requires the cfg.zero_center_folder to be set. If no existing
        image for the given input files is found in the folder, it will
        be calculated and saved by averaging over all samples in the
        train dataset.

        Parameters
        ----------
        logging : bool
            If true, the execution of this function will be logged into the
            full summary in the output folder if called for the first time.

        Returns
        -------
        dict
            Dict of numpy arrays that contains the mean_image of the x dataset
            (1 array per list input).
            Example format:
            { "input_A" : ndarray, "input_B" : ndarray }

        """
        if self.xs_mean is None:
            if self.cfg.zero_center_folder is None:
                raise ValueError("Can not calculate zero center: "
                                 "No zero center folder given")
            self.xs_mean = load_zero_center_data(self, logging=logging)
        return self.xs_mean

    def load_saved_model(self, epoch, fileno, logging=False):
        """
        Load a saved model.

        Parameters
        ----------
        epoch : int
            Epoch of the saved model. If both this and fileno are -1,
            load the most recent model.
        fileno : int
            Fileno of the saved model.
        logging : bool
            If True, will log this function call into the log.txt file.

        Returns
        -------
        model : keras model

        """
        path_of_model = self.io.get_model_path(epoch, fileno)
        path_loc = self.io.get_model_path(epoch, fileno, local=True)
        self.io.print_log("Loading saved model: " + path_loc, logging=logging)
        return self._load_model(path_of_model)

    def _get_model(self, model, logging=False):
        """ Load most recent saved model or use user model. """
        latest_epoch = self.io.get_latest_epoch()

        if latest_epoch is None:
            # new training, log info about model
            if model is None:
                raise ValueError("You need to provide a compiled keras model "
                                 "for the start of the training! (You gave None)")

            elif isinstance(model, str):
                # path to a saved model
                self.io.print_log("Loading model from " + model, logging=logging)
                model = self._load_model(model)

            if logging:
                self._save_as_json(model)
                model.summary(print_fn=self.io.print_log)

                try:
                    plots_folder = self.io.get_subfolder("plots", create=True)
                    tf.keras.utils.plot_model(
                        model, plots_folder + "/model_plot.png", show_shapes=True)
                except (ImportError, AttributeError) as e:
                    # TODO remove AttributeError once https://github.com/tensorflow/tensorflow/issues/38988 is fixed
                    warnings.warn("Can not plot model: " + str(e))

        else:
            # resuming training, load model if it is not given
            if model is None:
                model = self.load_saved_model(*latest_epoch, logging=logging)

            elif isinstance(model, str):
                # path to a saved model
                self.io.print_log("Loading model from " + model, logging=logging)
                model = self._load_model(model)

        return model

    def _load_model(self, filepath):
        """ Load from path, with custom objects and parallized. """
        with self.get_strategy().scope():
            model = tf.keras.models.load_model(
                filepath, custom_objects=self.cfg.get_custom_objects())
        return model

    def _save_as_json(self, model):
        """ Save the architecture of a model as json to fixed path. """
        json_filename = "model_arch.json"

        json_string = model.to_json(indent=1)
        model_folder = self.io.get_subfolder("saved_models", create=True)
        with open(os.path.join(model_folder, json_filename), "w") as f:
            f.write(json_string)

    def _set_up(self, model, logging=False):
        """ Necessary setup for training, validating and predicting. """
        if self.cfg.get_list_file() is None:
            raise ValueError("No files specified. Need to load a toml "
                             "list file with pathes to h5 files first.")

        if self.cfg.label_modifier is None:
            self._auto_label_modifier = lib.label_modifiers.ColumnLabels(model)

        if self.cfg.zero_center_folder is not None:
            self.get_xs_mean(logging)

    def val_is_due(self, epoch=None):
        """
        True if validation is due on given epoch according to schedule.
        Does not check if it has been done already.

        """
        if epoch is None:
            epoch = self.io.get_latest_epoch()
        n_train_files = self.io.get_no_of_files("train")
        val_sched = (epoch[1] == n_train_files) or \
                    (self.cfg.validate_interval is not None and
                     epoch[1] % self.cfg.validate_interval == 0)
        return val_sched

    def get_strategy(self):
        """ Get the strategy for distributed training. """
        if self._strategy is None:
            if self.cfg.multi_gpu and len(
                    tf.config.list_physical_devices('GPU')) > 1:
                self._strategy = tf.distribute.MirroredStrategy()
                print(f'Number of GPUs: {self._strategy.num_replicas_in_sync}')
            else:
                self._strategy = tf.distribute.get_strategy()
        return self._strategy
Example #5
0
class TestIOHandler(TestCase):
    @classmethod
    def setUpClass(cls):
        # super(TestIOHandler, cls).setUpClass()
        cls.temp_dir = os.path.join(os.path.dirname(__file__), ".temp",
                                    "test_in_out")
        cls.pred_dir = os.path.join(os.path.dirname(__file__), ".temp",
                                    "test_in_out", "predictions")
        os.mkdir(cls.temp_dir)
        os.mkdir(cls.pred_dir)

        # make dummy pred files
        Path(cls.pred_dir + '/pred_model_epoch_2_file_2_on_listname_val_file_1.h5').touch()
        Path(cls.pred_dir + '/pred_model_epoch_2_file_2_on_listname_val_file_2.h5').touch()

        cls.pred_filepaths = [cls.pred_dir + '/pred_model_epoch_2_file_2_on_listname_val_file_1.h5',
                              cls.pred_dir + '/pred_model_epoch_2_file_2_on_listname_val_file_2.h5']

        cls.init_dir = os.getcwd()
        os.chdir(cls.temp_dir)
        # make some dummy data
        cls.n_bins = {'input_A': (2, 3), 'input_B': (2, 3)}
        cls.train_sizes = [30, 50]
        cls.val_sizes = [40, 60]
        cls.file_names = (
            "/input_A_train_1.h5",
            "/input_A_train_2.h5",
            "/input_B_train_1.h5",
            "/input_B_train_2.h5",
            "/input_A_val_1.h5",
            "/input_A_val_2.h5",
            "/input_B_val_1.h5",
            "/input_B_val_2.h5",
        )
        cls.train_A_file_1 = {
            "path": cls.temp_dir + cls.file_names[0],
            "shape": cls.n_bins["input_A"],
            "value_xs": 1.1,
            "value_ys": 1.2,
            "size": cls.train_sizes[0],
        }
        cls.train_A_file_2 = {
            "path": cls.temp_dir + cls.file_names[1],
            "shape": cls.n_bins["input_A"],
            "value_xs": 1.3,
            "value_ys": 1.4,
            "size": cls.train_sizes[1],
        }
        cls.train_B_file_1 = {
            "path": cls.temp_dir + cls.file_names[2],
            "shape": cls.n_bins["input_B"],
            "value_xs": 2.1,
            "value_ys": 2.2,
            "size": cls.train_sizes[0],
        }
        cls.train_B_file_2 = {
            "path": cls.temp_dir + cls.file_names[3],
            "shape": cls.n_bins["input_B"],
            "value_xs": 2.3,
            "value_ys": 2.4,
            "size": cls.train_sizes[1],
        }
        cls.val_A_file_1 = {
            "path": cls.temp_dir + cls.file_names[4],
            "shape": cls.n_bins["input_A"],
            "value_xs": 3.1,
            "value_ys": 3.2,
            "size": cls.val_sizes[0],
        }
        cls.val_A_file_2 = {
            "path": cls.temp_dir + cls.file_names[5],
            "shape": cls.n_bins["input_A"],
            "value_xs": 3.1,
            "value_ys": 3.2,
            "size": cls.val_sizes[0],
        }
        cls.val_B_file_1 = {
            "path": cls.temp_dir + cls.file_names[6],
            "shape": cls.n_bins["input_B"],
            "value_xs": 4.1,
            "value_ys": 4.2,
            "size": cls.val_sizes[1],
        }
        cls.val_B_file_2 = {
            "path": cls.temp_dir + cls.file_names[7],
            "shape": cls.n_bins["input_B"],
            "value_xs": 4.1,
            "value_ys": 4.2,
            "size": cls.val_sizes[1],
        }
        cls.train_A_file_1_ctnt = save_dummy_h5py(**cls.train_A_file_1)
        cls.train_A_file_2_ctnt = save_dummy_h5py(**cls.train_A_file_2)
        cls.train_B_file_1_ctnt = save_dummy_h5py(**cls.train_B_file_1)
        cls.train_B_file_2_ctnt = save_dummy_h5py(**cls.train_B_file_2)
        cls.val_A_file_1_ctnt = save_dummy_h5py(**cls.val_A_file_1)
        cls.val_A_file_2_ctnt = save_dummy_h5py(**cls.val_A_file_2)
        cls.val_B_file_1_ctnt = save_dummy_h5py(**cls.val_B_file_1)
        cls.val_B_file_2_ctnt = save_dummy_h5py(**cls.val_B_file_2)

    def setUp(self):
        self.data_folder = os.path.join(os.path.dirname(__file__), "data")
        self.output_folder = self.data_folder + "/dummy_model"

        list_file = self.data_folder + "/in_out_test_list.toml"
        config_file = None

        cfg = Configuration(self.output_folder, list_file, config_file)
        self.batchsize = 3
        cfg.batchsize = self.batchsize
        self.io = IOHandler(cfg)

        # mock get_subfolder, but only in case of predictions argument
        original_get_subfolder = self.io.get_subfolder
        mocked_result = self.pred_dir

        def side_effect(key, create=False):
            if key == 'predictions':
                return mocked_result
            else:
                return original_get_subfolder(key, create)

        self.io.get_subfolder = MagicMock(side_effect=side_effect)

    @classmethod
    def tearDownClass(cls):
        os.remove(cls.train_A_file_1["path"])
        os.remove(cls.train_A_file_2["path"])
        os.remove(cls.train_B_file_1["path"])
        os.remove(cls.train_B_file_2["path"])
        os.remove(cls.val_A_file_1["path"])
        os.remove(cls.val_A_file_2["path"])
        os.remove(cls.val_B_file_1["path"])
        os.remove(cls.val_B_file_2["path"])

        os.chdir(cls.init_dir)
        shutil.rmtree(cls.temp_dir)

    def test_copy_to_ssd(self):
        self.io.cfg.use_scratch_ssd = True
        # make temporary directory
        temp_temp_dir = self.temp_dir + "/scratch"
        os.mkdir(temp_temp_dir)

        # change env variable TMPDIR to this dir (TMPDIR not defined in gitrunner)
        if "TMPDIR" in os.environ:
            tempdir_environ = os.environ["TMPDIR"]
        else:
            tempdir_environ = None

        try:
            os.environ["TMPDIR"] = temp_temp_dir
            scratch_dir = temp_temp_dir

            target_dirs_train = {
                 "input_A": (scratch_dir + self.file_names[0], scratch_dir + self.file_names[1]),
                 "input_B": (scratch_dir + self.file_names[2], scratch_dir + self.file_names[3]),
            }
            target_dirs_val = {
                 "input_A": (scratch_dir + self.file_names[4], scratch_dir + self.file_names[5], ),
                 "input_B": (scratch_dir + self.file_names[6], scratch_dir + self.file_names[7], ),
            }

            value = self.io.get_local_files("train")
            self.assertDictEqual(target_dirs_train, value)

            value = self.io.get_local_files("val")
            self.assertDictEqual(target_dirs_val, value)

        finally:
            # reset the env variable
            if tempdir_environ is not None:
                os.environ["TMPDIR"] = tempdir_environ
            else:
                os.environ.pop("TMPDIR")

            shutil.rmtree(temp_temp_dir)

    def test_check_connections_no_sample(self):
        input_shapes = self.n_bins
        output_shapes = {
            "out_A": 1,
            "out_B": 1,
        }

        self.io.cfg.label_modifier = get_dummy_label_modifier(output_shapes.keys())
        model = build_dummy_model(input_shapes, output_shapes)

        self.io.check_connections(model)

    def test_check_connections_ok_sample(self):
        input_shapes = self.n_bins
        output_shapes = {
            "out_A": 1,
            "out_B": 1,
        }

        def sample_modifier(info_blob):
            x_values = info_blob["x_values"]
            return {'input_A': x_values["input_A"], 'input_B': x_values["input_B"]}

        self.io.cfg.label_modifier = get_dummy_label_modifier(
            output_shapes.keys())
        self.io.cfg.sample_modifier = sample_modifier

        model = build_dummy_model(input_shapes, output_shapes)
        self.io.check_connections(model)

    def test_check_connections_wrong_sample(self):
        input_shapes = self.n_bins
        output_shapes = {
            "out_A": 1,
            "out_B": 1,
        }

        def sample_modifier(info_blob):
            x_values = info_blob["x_values"]
            return {'input_A': x_values["input_A"]}

        self.io.cfg.label_modifier = get_dummy_label_modifier(
            output_shapes.keys())
        self.io.cfg.sample_modifier = sample_modifier

        model = build_dummy_model(input_shapes, output_shapes)
        with self.assertRaises(ValueError):
            self.io.check_connections(model)

    def test_check_connections_wrong_label(self):
        input_shapes = self.n_bins
        output_shapes = {
            "out_A": 1,
            "out_B": 1,
        }

        self.io.cfg.label_modifier = get_dummy_label_modifier(["out_A"])
        model = build_dummy_model(input_shapes, output_shapes)

        with self.assertRaises(ValueError):
            self.io.check_connections(model)

    def test_check_connections_no_label(self):
        input_shapes = self.n_bins
        output_shapes = {
            "out_A": 1,
            "out_B": 1,
        }

        model = build_dummy_model(input_shapes, output_shapes)

        with self.assertRaises(ValueError):
            self.io.check_connections(model)

    def test_check_connections_auto_label(self):
        input_shapes = self.n_bins
        output_shapes = {
            "mc_A": 1,
            "mc_B": 1,
        }

        model = build_dummy_model(input_shapes, output_shapes)

        self.io.check_connections(model)

    def test_get_n_bins(self):
        value = self.io.get_n_bins()
        self.assertSequenceEqual(value, self.n_bins)

    def test_get_input_shape(self):
        value = self.io.get_input_shapes()
        self.assertSequenceEqual(value, self.n_bins)

        def sample_modifier(info_blob):
            x_values = info_blob["x_values"]
            return {'input_A': x_values["input_A"], }
        self.io.cfg.sample_modifier = sample_modifier

        value = self.io.get_input_shapes()
        self.assertEqual(value, {"input_A": self.n_bins["input_A"]})

    def test_get_file_sizes_train(self):
        value = self.io.get_file_sizes("train")
        self.assertSequenceEqual(value, self.train_sizes)

    def test_get_batch_xs(self):
        value = self.io.get_batch()
        target = {
            "input_A": self.train_A_file_1_ctnt[0][:self.batchsize],
            "input_B": self.train_B_file_1_ctnt[0][:self.batchsize],
        }
        assert_dict_arrays_equal(value["x_values"], target)

    def test_get_batch_ys(self):
        value = self.io.get_batch()
        target = {
            "input_A": self.train_A_file_1_ctnt[1][:self.batchsize],
            "input_B": self.train_B_file_1_ctnt[1][:self.batchsize],
        }
        assert_equal_struc_array(value["y_values"], target["input_A"])

    def test_get_all_epochs(self):
        epochs = self.io.get_all_epochs()
        target = [
            (1, 1), (1, 2), (2, 1),
        ]
        self.assertSequenceEqual(epochs, target)

    def test_get_latest_epoch(self):
        value = self.io.get_latest_epoch()
        target = (2, 1)
        self.assertSequenceEqual(value, target)

    def test_get_latest_epoch_no_epochs(self):
        self.io.cfg.output_folder = "./missing/"
        value = self.io.get_latest_epoch()
        self.assertEqual(value, None)

    def test_get_next_epoch_none(self):
        value = self.io.get_next_epoch(None)
        target = (1, 1)
        self.assertSequenceEqual(value, target)

    def test_get_next_epoch_1_1(self):
        value = self.io.get_next_epoch((1, 1))
        target = (1, 2)
        self.assertSequenceEqual(value, target)

    def test_get_next_epoch_1_2(self):
        value = self.io.get_next_epoch((1, 2))
        target = (2, 1)
        self.assertSequenceEqual(value, target)

    def test_get_previous_epoch_2_1(self):
        value = self.io.get_previous_epoch((2, 1))
        target = (1, 2)
        self.assertSequenceEqual(value, target)

    def test_get_previous_epoch_1_2(self):
        value = self.io.get_previous_epoch((1, 2))
        target = (1, 1)
        self.assertSequenceEqual(value, target)

    def test_get_model_path(self):
        value = self.io.get_model_path(1, 1)
        target = self.output_folder + '/saved_models/model_epoch_1_file_1.h5'
        self.assertEqual(value, target)

    def test_get_model_path_local(self):
        value = self.io.get_model_path(1, 1, local=True)
        target = 'saved_models/model_epoch_1_file_1.h5'
        self.assertEqual(value, target)

    def test_get_model_path_latest(self):
        value = self.io.get_model_path(-1, -1)
        target = self.output_folder + '/saved_models/model_epoch_2_file_1.h5'
        self.assertEqual(value, target)

    def test_get_model_path_latest_invalid(self):
        with self.assertRaises(ValueError):
            self.io.get_model_path(1, -1)

        with self.assertRaises(ValueError):
            self.io.get_model_path(-1, 1)

    def test_get_pred_files_list(self):
        value = self.io.get_pred_files_list()
        target = self.pred_filepaths
        self.assertEqual(value, target)

    def test_get_pred_files_list_epoch_given(self):
        value = self.io.get_pred_files_list(epoch=2)
        target = self.pred_filepaths
        self.assertEqual(value, target)

    def test_get_pred_files_list_fileno_given(self):
        value = self.io.get_pred_files_list(fileno=2)
        target = self.pred_filepaths
        self.assertEqual(value, target)

    def test_get_pred_files_list_epoch_fileno_given(self):
        value = self.io.get_pred_files_list(epoch=2, fileno=2)
        target = self.pred_filepaths
        self.assertEqual(value, target)

    def test_get_pred_files_list_no_files(self):
        value = self.io.get_pred_files_list(epoch=3)
        target = []
        self.assertEqual(value, target)

    def test_get_latest_prediction_file_no(self):
        value = self.io.get_latest_prediction_file_no(2, 2)
        target = 2
        self.assertEqual(value, target)

    def test_get_pred_path(self):
        value = self.io.get_pred_path(1, 2, 3)
        target = self.io.get_subfolder("predictions") + '/pred_model_epoch_1_file_2_on_in_out_test_list_val_file_3.h5'
        self.assertEqual(value, target)

    def test_get_local_files_train(self):
        value = self.io.get_local_files("train")
        target = {
            'input_A': ('input_A_train_1.h5', 'input_A_train_2.h5'),
            'input_B': ('input_B_train_1.h5', 'input_B_train_2.h5'),
        }
        self.assertDictEqual(value, target)

    def test_get_local_files_val(self):
        value = self.io.get_local_files("val")
        target = {
            'input_A': ('input_A_val_1.h5', 'input_A_val_2.h5',),
            'input_B': ('input_B_val_1.h5', 'input_B_val_2.h5',)
        }
        self.assertDictEqual(value, target)

    def test_get_no_of_files_train(self):
        value = self.io.get_no_of_files("train")
        target = 2
        self.assertEqual(value, target)

    def test_get_no_of_files_val(self):
        value = self.io.get_no_of_files("val")
        target = 2
        self.assertEqual(value, target)

    def test_yield_files_train(self):
        file_paths = self.io.yield_files("train")
        target = (
            {
                'input_A': 'input_A_train_1.h5',
                'input_B': 'input_B_train_1.h5',
            },
            {
                'input_A': 'input_A_train_2.h5',
                'input_B': 'input_B_train_2.h5',
            },
        )
        for i, value in enumerate(file_paths):
            self.assertDictEqual(value, target[i])

    def test_yield_files_val(self):
        file_paths = self.io.yield_files("val")
        target = (
            {
                'input_A': 'input_A_val_1.h5',
                'input_B': 'input_B_val_1.h5',
            },
            {
                'input_A': 'input_A_val_2.h5',
                'input_B': 'input_B_val_2.h5',
            },
        )
        for i, value in enumerate(file_paths):
            self.assertDictEqual(value, target[i])

    def test_get_file_train(self):
        value = self.io.get_file("train", 2)
        target = {
                'input_A': 'input_A_train_2.h5',
                'input_B': 'input_B_train_2.h5',
        }
        self.assertDictEqual(value, target)

    def test_get_file_val(self):
        value = self.io.get_file("val", 1)
        target = {
                'input_A': 'input_A_val_1.h5',
                'input_B': 'input_B_val_1.h5',
        }
        self.assertDictEqual(value, target)