Exemple #1
0
    def __init__(self, boundaries, scheduler_settings, name=None):
        """MultiLearningRateScheduler

        The function returns a 1-arg callable to compute the multi learning
        rate schedule when passed the current optimizer step.


        Parameters
        ----------
        boundaries
            A list of `Tensor`s or `int`s or `float`s with strictly
            increasing entries, and with all elements having the same type as
            the optimizer step.
        scheduler_settings : list of dict
            A list of scheduler settings that specify the learning rate
            schedules to use for the intervals defined by `boundaries`.
            It should have one more element than `boundaries`, and all
            schedulers should return the same type.
            Each scheduler_setting dict should contain the following:
                'full_class_string': str
                    The full class string of the scheduler class
                'settings': dict
                    A dictionary of arguments that are passed on to
                    the scheduler class
        name
            A string. Optional name of the operation. Defaults to
            'MultiLearningRateScheduler'.
        """

        super(MultiLearningRateScheduler, self).__init__()

        if len(boundaries) != len(scheduler_settings) - 1:
            raise ValueError(
              "The length of boundaries should be 1 less than the length "
              "of scheduler_settings")

        # create schedulers
        schedulers = []
        for settings in scheduler_settings:
            scheduler_class = misc.load_class(settings['full_class_string'])
            scheduler = scheduler_class(**settings['settings'])
            schedulers.append(scheduler)

        if name is None:
            name = 'MultiLearningRateScheduler'

        self.boundaries = tf.convert_to_tensor(boundaries)
        self.scheduler_settings = scheduler_settings
        self.schedulers = schedulers
        self.name = name
Exemple #2
0
    def _get_label_meta_data(self):
        """Loads labels from a sample file to obtain label meta data.
        """
        class_string = 'dnn_reco.modules.data.labels.{}.{}'.format(
            self._config['data_handler_label_file'],
            self._config['data_handler_label_name'],
        )
        label_reader = misc.load_class(class_string)
        labels, label_names = label_reader(self.test_input_data[0],
                                           self._config)

        self.label_names = label_names
        self.label_name_dict = {n: i for i, n in enumerate(label_names)}
        self.label_shape = list(labels.shape[1:])
        self.num_labels = int(np.prod(self.label_shape))
Exemple #3
0
    def _build_model(self):
        """Build neural network architecture.
        """
        class_string = 'dnn_reco.modules.models.{}.{}'.format(
                                self.config['model_file'],
                                self.config['model_name'],
                                )
        nn_model = misc.load_class(class_string)

        print('\n----------------------')
        print('Now Building Model ...')
        print('----------------------\n')

        y_pred_trafo, y_unc_trafo, model_vars_pred, model_vars_unc = nn_model(
                                        is_training=self.is_training,
                                        config=self.config,
                                        data_handler=self.data_handler,
                                        data_transformer=self.data_transformer,
                                        shared_objects=self.shared_objects)

        # transform back
        y_pred = self.data_transformer.inverse_transform(y_pred_trafo,
                                                         data_type='label')
        y_unc = self.data_transformer.inverse_transform(y_unc_trafo,
                                                        data_type='label',
                                                        bias_correction=False)

        self.shared_objects['y_pred_trafo'] = y_pred_trafo
        self.shared_objects['y_unc_trafo'] = y_unc_trafo
        self.shared_objects['y_pred'] = y_pred
        self.shared_objects['y_unc'] = y_unc
        self.shared_objects['model_vars_pred'] = model_vars_pred
        self.shared_objects['model_vars_unc'] = model_vars_unc
        self.shared_objects['model_vars'] = model_vars_pred + model_vars_unc

        y_pred_list = tf.unstack(self.shared_objects['y_pred'], axis=1)
        for i, name in enumerate(self.data_handler.label_names):
            tf.compat.v1.summary.histogram('y_pred_' + name, y_pred_list[i])

        # count number of trainable parameters
        print('Number of free parameters in NN model: {}\n'.format(
                    self.count_parameters(self.shared_objects['model_vars'])))

        # create saver
        self.saver = tf.compat.v1.train.Saver(self.shared_objects['model_vars'])
Exemple #4
0
    def _get_misc_meta_data(self):
        """Loads misc data from a sample file to obtain misc meta data.
        """
        class_string = 'dnn_reco.modules.data.misc.{}.{}'.format(
            self._config['data_handler_misc_file'],
            self._config['data_handler_misc_name'],
        )
        misc_reader = misc.load_class(class_string)
        misc_data, misc_names = misc_reader(self.test_input_data[0],
                                            self._config)

        self.misc_names = misc_names
        self.misc_name_dict = {n: i for i, n in enumerate(misc_names)}
        if misc_data is None:
            self.misc_data_exists = False
            self.misc_shape = None
            self.num_misc = 0
        else:
            self.misc_data_exists = True
            self.misc_shape = list(misc_data.shape[1:])
            self.num_misc = int(np.prod(self.misc_shape))
Exemple #5
0
    def _create_event_weights(self):
        """Create event weights
        """
        if ('event_weight_file' in self.config and
                self.config['event_weight_file'] is not None):

            # get event weight function
            class_string = 'dnn_reco.modules.data.event_weights.{}.{}'.format(
                self.config['event_weight_file'],
                self.config['event_weight_name'])
            event_weight_function = misc.load_class(class_string)

            # compute loss
            self.shared_objects['event_weights'] = event_weight_function(
                                    config=self.config,
                                    data_handler=self.data_handler,
                                    data_transformer=self.data_transformer,
                                    shared_objects=self.shared_objects)

            shape = self.shared_objects['event_weights'].get_shape().as_list()
            assert len(shape) == 2 and shape[1] == 1, \
                'Expected shape [-1, 1] but got {!r}'.format(shape)
Exemple #6
0
    def fit(self, num_training_iterations, train_data_generator,
            val_data_generator,
            evaluation_methods=None,
            *args, **kwargs):
        """Trains the NN model with the data provided by the data iterators.

        Parameters
        ----------
        num_training_iterations : int
            The number of training iterations to perform.
        train_data_generator : generator object
            A python generator object which generates batches of training data.
        val_data_generator : generator object
            A python generator object which generates batches of validation
            data.
        evaluation_methods : None, optional
            Description
        *args
            Variable length argument list.
        **kwargs
            Arbitrary keyword arguments.

        Raises
        ------
        ValueError
            Description
        """
        if not self._model_is_compiled:
            raise ValueError(
                        'Model must be compiled prior to call of fit method')

        # training operations to run
        train_ops = {'optimizer_{:03d}'.format(i): opt for i, opt in
                     enumerate(self.shared_objects['optimizer_ops'])}

        # add parameters and op if label weights are to be updated
        if self.config['label_update_weights']:

            label_weight_n = 0.
            label_weight_mean = np.zeros(self.data_handler.label_shape)
            label_weight_M2 = np.zeros(self.data_handler.label_shape)

            train_ops['mse_values_trafo'] = \
                self.shared_objects['mse_values_trafo']

        # add op if tukey scaling is to be applied
        if self.config['label_scale_tukey']:
            train_ops['y_diff_trafo'] = self.shared_objects['y_diff_trafo']

        # ----------------
        # training loop
        # ----------------
        start_time = timeit.default_timer()
        for i in range(num_training_iterations):

            feed_dict = self._feed_placeholders(train_data_generator,
                                                is_validation=False)
            train_result = self.sess.run(train_ops,
                                         feed_dict=feed_dict)

            # -------------------------------------
            # calculate variabels for tukey scaling
            # -------------------------------------
            if self.config['label_scale_tukey']:
                batch_median_abs_dev = np.median(
                                np.abs(train_result['y_diff_trafo']), axis=0)

                # assign new label weight updates
                feed_dict_assign = {
                    self.shared_objects['new_median_abs_dev_values']:
                        np.clip(batch_median_abs_dev, 1e-6, float('inf'))}

                self.sess.run(
                    self.shared_objects['assign_new_median_abs_dev_values'],
                    feed_dict=feed_dict_assign)

            # --------------------------------------------
            # calculate online variabels for label weights
            # --------------------------------------------
            if self.config['label_update_weights']:
                mse_values_trafo = train_result['mse_values_trafo']
                mse_values_trafo[~self.shared_objects['non_zero_mask']] = 1.

                if np.isfinite(mse_values_trafo).all():
                    label_weight_n += 1
                    delta = mse_values_trafo - label_weight_mean
                    label_weight_mean += delta / label_weight_n
                    delta2 = mse_values_trafo - label_weight_mean
                    label_weight_M2 += delta * delta2
                else:
                    misc.print_warning('Found NaNs: {}'.format(
                                       mse_values_trafo))
                    for i, name in enumerate(self.data_handler.label_names):
                        print(name, mse_values_trafo[i])

                if not np.isfinite(label_weight_mean).all():
                    for i, name in enumerate(self.data_handler.label_names):
                        print('weight', name, label_weight_mean[i])

                if not np.isfinite(mse_values_trafo).all():
                    raise ValueError('FOUND NANS!')

                # every n steps: update label_weights
                if i % self.config['validation_frequency'] == 0:
                    new_weights = 1.0 / (np.sqrt(
                                    np.abs(label_weight_mean) + 1e-6) + 1e-3)
                    new_weights[new_weights < 1] = 1
                    new_weights *= self.shared_objects['label_weight_config']

                    # assign new label weight updates
                    feed_dict_assign = {
                        self.shared_objects['new_label_weight_values']:
                            new_weights}

                    self.sess.run(
                            self.shared_objects['assign_new_label_weights'],
                            feed_dict=feed_dict_assign)

                    # reset values
                    label_weight_n = 0.
                    label_weight_mean = np.zeros(self.data_handler.label_shape)
                    label_weight_M2 = np.zeros(self.data_handler.label_shape)

            # ----------------
            # validate performance
            # ----------------
            if i % self.config['validation_frequency'] == 0:

                updated_weights = self.sess.run(
                                        self.shared_objects['label_weights'])

                eval_dict = {
                    'merged_summary': self._merged_summary,
                    'weights': self.shared_objects['label_weights_benchmark'],
                    'rmse_trafo': self.shared_objects['rmse_values_trafo'],
                    'y_pred': self.shared_objects['y_pred'],
                    'y_unc': self.shared_objects['y_unc'],
                    'y_pred_trafo': self.shared_objects['y_pred_trafo'],
                    'y_unc_trafo': self.shared_objects['y_unc_trafo'],
                }
                result_msg = ''
                for k, loss in sorted(
                        self.shared_objects['label_loss_dict'].items()):
                    if k[:9] == 'loss_sum_':
                        eval_dict[k] = loss
                        result_msg += k + ': {' + k + ':2.3f} '

                # -------------------------------------
                # Test performance on training data
                # -------------------------------------
                feed_dict_train = self._feed_placeholders(train_data_generator,
                                                          is_validation=True)
                results_train = self.sess.run(eval_dict,
                                              feed_dict=feed_dict_train)

                # -------------------------------------
                # Test performance on validation data
                # -------------------------------------
                feed_dict_val = self._feed_placeholders(val_data_generator,
                                                        is_validation=True)
                results_val = self.sess.run(eval_dict, feed_dict=feed_dict_val)
                y_true_train = feed_dict_train[self.shared_objects['y_true']]
                y_true_val = feed_dict_val[self.shared_objects['y_true']]
                y_true_trafo_train = self.data_transformer.transform(
                                        y_true_train, data_type='label')
                y_true_trafo_val = self.data_transformer.transform(
                                        y_true_val, data_type='label')

                self._train_writer.add_summary(
                                        results_train['merged_summary'], i)
                self._val_writer.add_summary(results_val['merged_summary'], i)
                msg = 'Step: {:08d}, Runtime: {:2.2f}s, Benchmark: {:3.3f}'
                print(msg.format(i, timeit.default_timer() - start_time,
                                 np.sum(updated_weights)))
                print('\t[Train]      '+result_msg.format(**results_train))
                print('\t[Validation] '+result_msg.format(**results_val))

                # print info for each label
                for name, index in sorted(
                        self.data_handler.label_name_dict.items()):
                    if updated_weights[index] > 0:

                        unc_pull_train = np.std(
                            (results_train['y_pred_trafo'][:, index]
                             - y_true_trafo_train[:, index]) /
                            results_train['y_unc_trafo'][:, index], ddof=1)
                        unc_pull_val = np.std(
                            (results_val['y_pred_trafo'][:, index]
                             - y_true_trafo_val[:, index]) /
                            results_val['y_unc_trafo'][:, index], ddof=1)

                        msg = '\tweight: {weight:2.3f},'
                        msg += ' train: {train:2.3f} [{unc_pull_train:1.2f}],'
                        msg += 'val: {val:2.3f} [{unc_pull_val:2.2f}] [{name}'
                        msg += ', mean: {mean_train:2.3f} {mean_val:2.3f}]'
                        print(msg.format(
                            weight=updated_weights[index],
                            train=results_train['rmse_trafo'][index],
                            val=results_val['rmse_trafo'][index],
                            name=name,
                            mean_train=np.mean(y_true_train[:, index]),
                            mean_val=np.mean(y_true_val[:, index]),
                            unc_pull_train=unc_pull_train,
                            unc_pull_val=unc_pull_val,
                            ))

                # Call user defined evaluation method
                if self.config['evaluation_file'] is not None:
                    class_string = 'dnn_reco.modules.evaluation.{}.{}'.format(
                                self.config['evaluation_file'],
                                self.config['evaluation_name'],
                                )
                    eval_func = misc.load_class(class_string)
                    eval_func(feed_dict_train=feed_dict_train,
                              feed_dict_val=feed_dict_val,
                              results_train=results_train,
                              results_val=results_val,
                              config=self.config,
                              data_handler=self.data_handler,
                              data_transformer=self.data_transformer,
                              shared_objects=self.shared_objects)

            # ----------------
            # save models
            # ----------------
            if i % self.config['save_frequency'] == 0:
                if self.config['model_save_model']:
                    self._save_training_config(i)
                    self.saver.save(
                            sess=self.sess,
                            global_step=self._step_offset + i,
                            save_path=self.config['model_checkpoint_path'])
Exemple #7
0
    def _get_optimizers_and_loss(self):
        """Get optimizers and loss terms as defined in config.

        Raises
        ------
        ValueError
            Description
        """
        optimizer_dict = dict(self.config['model_optimizer_dict'])

        # create empty list to hold tensorflow optimizer operations
        optimizer_ops = []

        # create empty dictionary to hold loss values
        self.shared_objects['label_loss_dict'] = {}

        # create each optimizer
        for name, opt_config in sorted(optimizer_dict.items()):

            # sanity check: make sure loss file and name have same length
            if isinstance(opt_config['loss_file'], str):
                assert isinstance(opt_config['loss_name'], str)
                opt_config['loss_file'] = [opt_config['loss_file']]
                opt_config['loss_name'] = [opt_config['loss_name']]

            assert len(opt_config['loss_file']) == len(opt_config['loss_name'])

            # aggregate over all defined loss functions
            label_loss = None
            for file, name in zip(opt_config['loss_file'],
                                  opt_config['loss_name']):

                # get loss function
                class_string = 'dnn_reco.modules.loss.{}.{}'.format(file, name)
                loss_function = misc.load_class(class_string)

                # compute loss
                label_loss_i = loss_function(
                                        config=self.config,
                                        data_handler=self.data_handler,
                                        data_transformer=self.data_transformer,
                                        shared_objects=self.shared_objects)

                # sanity check: make sure loss has expected shape
                loss_shape = label_loss_i.get_shape().as_list()
                if loss_shape != self.data_handler.label_shape:
                    error_msg = 'Shape of label loss {!r} does not match {!r}'
                    raise ValueError(error_msg.format(
                                                loss_shape,
                                                self.data_handler.label_shape))

                # accumulate loss terms
                if label_loss is None:
                    label_loss = label_loss_i
                else:
                    label_loss += label_loss_i

            # weight label_losses
            # use nested where trick to avoid NaNs:
            # https://stackoverflow.com/questions/33712178/tensorflow-nan-bug
            label_loss_safe = tf.where(self.shared_objects['non_zero_mask'],
                                       label_loss, tf.zeros_like(label_loss))
            weighted_label_loss = tf.where(
                        self.shared_objects['non_zero_mask'],
                        label_loss_safe * self.shared_objects['label_weights'],
                        tf.zeros_like(label_loss))
            weighted_loss_sum = tf.reduce_sum(input_tensor=weighted_label_loss)

            # create learning rate schedule if learning rate is a dict
            optimizer_settings = dict(opt_config['optimizer_settings'])
            if 'learning_rate' in optimizer_settings:
                if isinstance(optimizer_settings['learning_rate'], dict):

                    # assume that the learning rate dictionary defines a
                    # schedule of learning rates
                    # In this case the dictionary must have the following keys:
                    #   full_class_string: str
                    #       The full class string of the scheduler class to use
                    #   settings: dict
                    #       keyword arguments that are passed on to the
                    #       scheduler class.
                    lr_cfg = optimizer_settings.pop('learning_rate')
                    scheduler_class = misc.load_class(
                        lr_cfg['full_class_string'])
                    scheduler = scheduler_class(**lr_cfg['settings'])
                    optimizer_settings['learning_rate'] = scheduler

            # get optimizer
            # check for old-style (tf < 2) optimizers in tf.train
            try:
                optimizer_cls = getattr(tf.train, opt_config['optimizer'])
            except AttributeError:
                optimizer_cls = getattr(tf.optimizers, opt_config['optimizer'])
            optimizer = optimizer_cls(**optimizer_settings)

            # get variable list
            if isinstance(opt_config['vars'], str):
                opt_config['vars'] = [opt_config['vars']]

            var_list = []
            for var_name in opt_config['vars']:
                var_list.extend(self.shared_objects['model_vars_' + var_name])

            # apply regularization
            if opt_config['l1_regularization'] > 0. or \
                    opt_config['l2_regularization'] > 0.:

                reg_loss = 0.

                # apply regularization
                if opt_config['l1_regularization'] > 0.:
                    reg_loss += tf.add_n(
                        [tf.reduce_sum(tf.abs(v)) for v in var_list])

                if opt_config['l2_regularization'] > 0.:
                    reg_loss += tf.add_n(
                        [tf.reduce_sum(v**2) for v in var_list])

                total_loss = weighted_loss_sum + reg_loss

            else:
                total_loss = weighted_loss_sum

            # logging
            self.shared_objects['label_loss_dict'].update({
                'loss_label_' + name: weighted_label_loss,
                'loss_sum_' + name: weighted_loss_sum,
                'loss_sum_total_' + name: total_loss,
            })

            tf.compat.v1.summary.histogram(
                'loss_label_' + name, weighted_label_loss)
            tf.compat.v1.summary.scalar('loss_sum_' + name, weighted_loss_sum)
            tf.compat.v1.summary.scalar('loss_sum_total_' + name, total_loss)

            # get gradients
            # compatibility mode for old and new tensorflow versions
            try:
                gvs = optimizer.compute_gradients(
                    total_loss, var_list=var_list)
            except AttributeError:
                gradients = tf.gradients(total_loss, var_list)
                gvs = zip(gradients, var_list)

            # remove nans in gradients and replace these with zeros
            if 'remove_nan_gradients' in opt_config:
                remove_nan_gradients = opt_config['remove_nan_gradients']
            else:
                remove_nan_gradients = False
            if remove_nan_gradients:
                gvs = [(tf.where(
                            tf.math.is_nan(grad), tf.zeros_like(grad), grad),
                        var) for grad, var in gvs if grad is not None]

            if 'clip_gradients_value' in opt_config:
                clip_gradients_value = opt_config['clip_gradients_value']
            else:
                clip_gradients_value = None
            if clip_gradients_value is not None:
                gradients, variables = zip(*gvs)
                gradients, _ = tf.clip_by_global_norm(gradients,
                                                      clip_gradients_value)
                capped_gvs = zip(gradients, variables)
            else:
                capped_gvs = gvs
            optimizer_ops.append(optimizer.apply_gradients(capped_gvs))

        self.shared_objects['optimizer_ops'] = optimizer_ops
Exemple #8
0
    def read_icecube_data(self,
                          input_data,
                          nan_fill_value=None,
                          init_values=0.,
                          verbose=False):
        """Read IceCube hdf5 data files

        Parameters
        ----------
        input_data : str
            Path to input data file.
        nan_fill_value : float, optional
            Fill value for nan values in loaded data.
            Entries with nan values will be replaced by this value.
            If None, no replacement will be performed.
        init_values : float, optional
            The x_ic78 array will be initalized with these values via:
            np.zeros_like(x_ic78) * np.array(init_values)
        verbose : bool, optional
            Print out additional information on runtimes for loading and
            processing of files.

        Returns
        -------
        x_ic78 : numpy.ndarray
            DOM input data of main IceCube array.
            shape: [batch_size, 10, 10, 60, num_bins]
        x_deepcore : numpy.ndarray
            DOM input data of DeepCore array.
            shape: [batch_size, 8, 60, num_bins]
        labels : numpy.ndarray
            Labels.
            shape: [batch_size] + label_shape
        misc : numpy.ndarray
            Misc variables.
            shape: [batch_size] + misc_shape

        Raises
        ------
        ValueError
            Description
        """
        if not self.is_setup:
            raise ValueError('DataHandler needs to be set up first!')

        start_time = timeit.default_timer()

        try:
            with pd.HDFStore(input_data, mode='r') as f:
                bin_values = f[self._config['data_handler_bin_values_name']]
                bin_indices = f[self._config['data_handler_bin_indices_name']]
                _time_range = f[self._config['data_handler_time_offset_name']]

        except Exception as e:
            print(e)
            print('Skipping file: {}'.format(input_data))
            return None

        time_range_start = _time_range['value']

        # create Dictionary with eventIDs
        size = len(_time_range['Event'])
        eventIDDict = {}
        for row in _time_range.iterrows():
            eventIDDict[(row[1][0], row[1][1], row[1][2], row[1][3])] = row[0]

        # Create arrays for input data
        x_ic78 = np.ones(
            [size, 10, 10, 60, self.num_bins],
            dtype=self._config['np_float_precision'],
        ) * np.array(init_values)
        x_deepcore = np.ones(
            [size, 8, 60, self.num_bins],
            dtype=self._config['np_float_precision'],
        ) * np.array(init_values)

        # ------------------
        # get DOM input data
        # ------------------
        for value_row, index_row in zip(bin_values.itertuples(),
                                        bin_indices.itertuples()):
            if value_row[1:5] != index_row[1:5]:
                raise ValueError(
                    'Event headers do not match! HDF5 version error?')
            string = index_row[6]
            dom = index_row[7] - 1
            index = eventIDDict[(index_row[1:5])]
            if string > 78:
                # deep core
                x_deepcore[index, string - 78 - 1, dom, index_row[10]] = \
                    value_row[10]
            else:
                # IC78
                a, b = self._get_indices_from_string(string)
                # Center of Detector is a,b = 0,0
                # a goes from -4 to 5
                # b goes from -5 to 4
                x_ic78[index, a + 4, b + 5, dom, index_row[10]] = value_row[10]

        # --------------
        # read in labels
        # --------------
        class_string = 'dnn_reco.modules.data.labels.{}.{}'.format(
            self._config['data_handler_label_file'],
            self._config['data_handler_label_name'],
        )
        label_reader = misc.load_class(class_string)
        labels, _ = label_reader(input_data,
                                 self._config,
                                 label_names=self.label_names)
        assert list(labels.shape) == [size] + self.label_shape

        # perform label smoothing if provided in config
        if 'label_pid_smooth_labels' in self._config:
            smoothing = self._config['label_pid_smooth_labels']
            if smoothing is not None:
                for key, i in self.label_name_dict.items():
                    if key in self._config['label_pid_keys']:
                        assert ((labels[:, i] >= 0.).all()
                                and (labels[:, i] <= 1.).all()), \
                            'Values outside of [0, 1] for {!r}'.format(key)
                        labels[:, i] = \
                            labels[:, i] * (1 - smoothing) + smoothing / 2.

        # -------------------
        # read in misc values
        # -------------------
        class_string = 'dnn_reco.modules.data.misc.{}.{}'.format(
            self._config['data_handler_misc_file'],
            self._config['data_handler_misc_name'],
        )
        misc_reader = misc.load_class(class_string)
        misc_data, _ = misc_reader(input_data,
                                   self._config,
                                   misc_names=self.misc_names)
        if self.misc_data_exists:
            assert list(misc_data.shape) == [size, self.num_misc]

        # -------------
        # filter events
        # -------------
        class_string = 'dnn_reco.modules.data.filter.{}.{}'.format(
            self._config['data_handler_filter_file'],
            self._config['data_handler_filter_name'],
        )
        filter_func = misc.load_class(class_string)
        mask = filter_func(self, input_data, self._config, x_ic78, x_deepcore,
                           labels, misc_data, time_range_start)

        # mask out events not passing filter:
        x_ic78 = x_ic78[mask]
        x_deepcore = x_deepcore[mask]
        labels = labels[mask]
        if self.misc_data_exists:
            misc_data = misc_data[mask]
        time_range_start = time_range_start[mask]

        # ---------------
        # Fix time offset
        # ---------------
        if self.relative_time_keys:

            # fix misc relative time variables
            for i, name in enumerate(self.misc_names):
                if name in self.relative_time_keys:
                    misc_data[:, i] -= time_range_start

            # fix relative time labels
            for i, name in enumerate(self.label_names):
                if name in self.relative_time_keys:
                    labels[:, i] -= time_range_start

        # --------------------------
        # fill nan values if desired
        # --------------------------
        if nan_fill_value is None:
            mask = np.isfinite(np.sum(x_ic78, axis=(1, 2, 3, 4)))
            mask = np.logical_and(
                mask, np.isfinite(np.sum(x_deepcore, axis=(1, 2, 3))))
            mask = np.logical_and(
                mask,
                np.isfinite(np.sum(labels, axis=tuple(range(1, labels.ndim)))))
            if not mask.all():
                misc.print_warning('Found NaNs. ' +
                                   'Removing {} events from batch of {}.'.
                                   format(len(mask) - np.sum(mask), len(mask)))
                misc.print_warning(
                    'NaN-free x_ic78: {}, x_deepcore: {}, labels: {}'.format(
                        np.isfinite(x_ic78).all(),
                        np.isfinite(x_deepcore).all(),
                        np.isfinite(labels).all(),
                    ))

                x_ic78 = x_ic78[mask]
                x_deepcore = x_deepcore[mask]
                labels = labels[mask]
                if self.misc_data_exists:
                    misc_data = misc_data[mask]
        else:

            # Raise Error if NaNs found in input data.
            # This should never be the case!
            mask = np.isfinite(np.sum(x_ic78, axis=(1, 2, 3, 4)))
            mask = np.logical_and(
                mask, np.isfinite(np.sum(x_deepcore, axis=(1, 2, 3))))
            if not mask.all():
                raise ValueError('Found NaN values in input data!')

            # Fixing NaNs in labels and misc data is ok, but warn about this
            mask = np.isfinite(
                np.sum(labels, axis=tuple(range(1, labels.ndim))))
            if self.misc_data_exists:
                mask = np.logical_and(
                    mask,
                    np.isfinite(
                        np.sum(misc_data, axis=tuple(range(1,
                                                           misc_data.ndim)))))
            if not mask.all():
                misc.print_warning('Found NaNs in labels and/or misc data. ' +
                                   'Replacing NaNs in {} events'.format(
                                       len(mask) - np.sum(mask)))
            labels[~np.isfinite(labels)] = nan_fill_value
            if self.misc_data_exists:
                misc_data[~np.isfinite(misc_data)] = nan_fill_value
        # --------------------------

        if verbose:
            final_time = timeit.default_timer() - start_time
            print("=== Time needed to process Data: {:5.3f} seconds ==".format(
                final_time))

        return x_ic78, x_deepcore, labels, misc_data