Example #1
0
    def predict(self, epoch=None, fileno=None, samples=None):
        """
        Make a prediction if it does not exist yet, and return its filepath.

        Load the model with the lowest validation loss, let it predict on
        all samples of the validation set
        in the toml list, and save this prediction together with all the
        y_values as h5 file(s) in the predictions subfolder.

        Parameters
        ----------
        epoch : int, optional
            Epoch of a model to load. Default: lowest val loss.
        fileno : int, optional
            File number of a model to load. Default: lowest val loss.
        samples : int, optional
            Don't use the full validation files, but just the given number
            of samples.

        Returns
        -------
        pred_filename : List
            List to the paths of all the prediction files.

        """
        if fileno is None and epoch is None:
            epoch, fileno = self.history.get_best_epoch_fileno()
            print(f"Automatically set epoch to epoch {epoch} file {fileno}.")
        elif fileno is None or epoch is None:
            raise ValueError(
                "Either both or none of epoch and fileno must be None")

        if self._check_if_pred_already_done(epoch, fileno):
            print("Prediction has already been done.")
            pred_filepaths = self.io.get_pred_files_list(epoch, fileno)

        else:
            if self._stored_model is None:
                model = self.load_saved_model(epoch, fileno, logging=False)
            else:
                model = self._stored_model
            self._set_up(model)

            start_time = time.time()
            backend.make_model_prediction(self,
                                          model,
                                          epoch,
                                          fileno,
                                          samples=samples)
            elapsed_s = int(time.time() - start_time)
            print('Finished predicting on all validation files.')
            print("Elapsed time: {}\n".format(timedelta(seconds=elapsed_s)))

            pred_filepaths = self.io.get_pred_files_list(epoch, fileno)

        return pred_filepaths
Example #2
0
    def test_predict(self):
        # dummy values
        epoch, fileno = 1, 3
        # mock get_latest_prediction_file_no
        # self.orga.io.get_latest_prediction_file_no = MagicMock(return_value=None)

        try:
            make_model_prediction(self.orga, self.model, epoch, fileno)

            file_cntn = {}
            with h5py.File(self.pred_filepath, 'r') as file:
                for key in file.keys():
                    file_cntn[key] = np.array(file[key])
        finally:
            os.remove(self.pred_filepath)

        target_datasets = [
            'label_mc_A', 'label_mc_B', 'pred_mc_A', 'pred_mc_B', 'y_values'
        ]
        target_shapes = [
            (500,), (500,), (500, 1), (500, 1), (500,)
        ]
        target_contents = [
            np.zeros(target_shapes[0]),
            np.ones(target_shapes[1]),
            np.ones(target_shapes[2]) * 18,
            np.ones(target_shapes[3]) * 18,
            self.train_A_file_1_ctnt[1],
        ]
        shapes_dict = dict(zip(target_datasets, target_shapes))
        contents_dict = dict(zip(target_datasets, target_contents))

        target_mc_names = ('mc_A', 'mc_B')

        datasets = list(file_cntn.keys())
        shapes = [file_cntn[key].shape for key in datasets]
        mc_dtype_names = file_cntn["y_values"].dtype.names

        self.assertSequenceEqual(datasets, target_datasets)
        self.assertSequenceEqual(shapes, target_shapes)
        self.assertSequenceEqual(mc_dtype_names, target_mc_names)
        for i, key in enumerate(target_datasets):
            value = file_cntn[key]
            target = contents_dict[key]
            np.testing.assert_array_equal(value, target)