예제 #1
0
    def predict_dataset(self, dataset, progress_bar=True, apply_preproc=True):
        """ Predict a complete dataset

        Parameters
        ----------
        dataset : Dataset
            Dataset to predict
        progress_bar : bool, optional
            hide or show a progress bar

        Yields
        -------
        PredictionResult
            Single PredictionResult
        dict
            Dataset entry of the prediction result
        """
        if isinstance(dataset, RawDataSet):
            input_dataset = StreamingInputDataset(dataset, self.data_preproc if apply_preproc else None, self.text_postproc if apply_preproc else None)
        else:
            input_dataset = RawInputDataset(dataset, self.data_preproc if apply_preproc else None, self.text_postproc if apply_preproc else None)

        prediction_results = self.predict_input_dataset(input_dataset, progress_bar)

        for prediction, sample in zip(prediction_results, dataset.samples()):
            yield prediction, sample
예제 #2
0
    def predict_dataset(self, dataset, progress_bar=True):
        start_time = time.time()
        with StreamingInputDataset(dataset, self.predictors[0].data_preproc, self.predictors[0].text_postproc, None,
                                   processes=self.processes,
                                   ) as input_dataset:
            def progress_bar_wrapper(l):
                if progress_bar:
                    return tqdm(l, total=int(np.ceil(len(dataset) / self.batch_size)), desc="Prediction")
                else:
                    return l

            def batched_data_params():
                batch = []
                for data_idx, (image, _, params) in enumerate(input_dataset.generator(epochs=1)):
                    batch.append((data_idx, image, params))
                    if len(batch) == self.batch_size:
                        yield batch
                        batch = []

                if len(batch) > 0:
                    yield batch

            for batch in progress_bar_wrapper(batched_data_params()):
                sample_ids, batch_images, batch_params = zip(*batch)
                samples = [dataset.samples()[i] for i in sample_ids]
                prediction = self.predict_raw(batch_images, params=batch_params, progress_bar=False, apply_preproc=False)
                for result, sample in zip(prediction, samples):
                    yield result, sample

        print("Prediction of {} models took {}s".format(len(self.predictors), time.time() - start_time))
예제 #3
0
    def preload_gt(self, gt_dataset, progress_bar=False):
        """ Preload gt to be used for several experiments

        Use this method to specify ground truth data to be tested versus many predictions

        Parameters
        ----------
        gt_dataset : Dataset
            the ground truth
        progress_bar : bool, optional
            show a progress bar

        """
        with StreamingInputDataset(gt_dataset,
                                   None,
                                   self.text_preprocessor,
                                   processes=1) as gt_input_dataset:
            self.preloaded_gt = [
                txt for _, txt, _ in tqdm_wrapper(
                    gt_input_dataset.generator(text_only=True),
                    total=len(gt_dataset),
                    progress_bar=progress_bar,
                    desc="Loading GT",
                )
            ]
예제 #4
0
    def predict_dataset(self, dataset, progress_bar=True):
        start_time = time.time()
        with StreamingInputDataset(
                dataset,
                self.predictors[0].data_preproc,
                self.predictors[0].text_postproc,
                None,
                processes=self.processes,
        ) as input_dataset:

            def progress_bar_wrapper(l):
                if progress_bar:
                    return tqdm(l,
                                total=int(
                                    np.ceil(len(dataset) / self.batch_size)),
                                desc="Prediction")
                else:
                    return l

            def batched_data_params():
                batch = []
                for data_idx, (image, _, params) in enumerate(
                        input_dataset.generator(epochs=1)):
                    batch.append((data_idx, image, params))
                    if len(batch) == self.batch_size:
                        yield batch
                        batch = []

                if len(batch) > 0:
                    yield batch

            for batch in progress_bar_wrapper(batched_data_params()):
                sample_ids, batch_images, batch_params = zip(*batch)
                samples = [dataset.samples()[i] for i in sample_ids]
                current_mode = dataset.mode
                with ExitStack() as stack:
                    raw_dataset = [
                        stack.enter_context(
                            RawInputDataset(
                                current_mode,
                                batch_images,
                                [None] * len(batch_images),
                                batch_params,
                            )) for _ in self.predictors
                    ]

                    # predict_raw returns list of prediction objects
                    prediction = [
                        predictor.predict_input_dataset(ds, progress_bar=False)
                        for ds, predictor in zip(raw_dataset, self.predictors)
                    ]

                    for result, sample in zip(zip(*prediction), samples):
                        yield result, sample

        print("Prediction of {} models took {}s".format(
            len(self.predictors),
            time.time() - start_time))
예제 #5
0
    def run(self,
            _sentinel=None,
            gt_dataset=None,
            pred_dataset=None,
            processes=1,
            progress_bar=False):
        """ evaluate on the given dataset

        Parameters
        ----------
        _sentinel : do not use
            Forcing the use of `gt_dataset` and `pred_dataset` fore safety
        gt_dataset : Dataset, optional
            the ground truth
        pred_dataset : Dataset
            the prediction dataset
        processes : int, optional
            the processes to use for preprocessing and evaluation
        progress_bar : bool, optional
            show a progress bar

        Returns
        -------
        evaluation dictionary
        """
        if _sentinel:
            raise Exception("You must call run by using parameter names.")

        if self.preloaded_gt:
            gt_data = self.preloaded_gt
        else:
            # gt_dataset.load_samples(progress_bar=progress_bar)
            # gt_data = self.text_preprocessor.apply(gt_dataset.text_samples(), progress_bar=progress_bar)
            with StreamingInputDataset(
                    gt_dataset, None, self.text_preprocessor,
                    processes=processes) as gt_input_dataset:
                gt_data = [
                    txt for _, txt, _ in tqdm_wrapper(
                        gt_input_dataset.generator(text_only=True),
                        total=len(gt_dataset),
                        progress_bar=progress_bar,
                        desc="Loading GT",
                    )
                ]

        with StreamingInputDataset(pred_dataset,
                                   None,
                                   self.text_preprocessor,
                                   processes=processes) as pred_input_dataset:
            pred_data = [
                txt for _, txt, _ in tqdm_wrapper(pred_input_dataset.generator(
                    text_only=True),
                                                  total=len(pred_dataset),
                                                  progress_bar=progress_bar,
                                                  desc="Loading Prediction")
            ]

        return self.evaluate(gt_data=gt_data,
                             pred_data=pred_data,
                             processes=processes,
                             progress_bar=progress_bar,
                             skip_empty_gt=self.skip_empty_gt)