Пример #1
0
    def evaluate(_sentinel=None, gt_data=None, pred_data=None, processes=1, progress_bar=False, skip_empty_gt=False):
        """ evaluate on the given raw data

        Parameters
        ----------
        _sentinel : do not use
            Forcing the use of `gt_dataset` and `pred_dataset` fore safety
        gt_data : Dataset, optional
            the ground truth
        pred_data : Dataset
            the prediction dataset
        processes : int, optional
            the processes to use for preprocessing and evaluation
        progress_bar : bool, optional
            show a progress bar
        skip_empty_gt : bool
            skip gt text lines that are empty

        Returns
        -------
        evaluation dictionary
        """
        if len(gt_data) != len(pred_data):
            raise Exception("Mismatch in gt and pred files count: {} vs {}".format(len(gt_data), len(pred_data)))

        # evaluate single lines
        out = parallel_map(Evaluator.evaluate_single_args, [{'gt': gt, 'pred': pred, 'skip_empty_gt': skip_empty_gt} for gt, pred in zip(gt_data, pred_data)],
                           processes=processes, progress_bar=progress_bar, desc="Evaluation")

        return Evaluator.evaluate_single_list(out, True)
Пример #2
0
    def preload(
        self,
        samples: List[Sample],
        num_processes=1,
        drop_invalid=True,
        progress_bar=False,
    ) -> Iterable[Sample]:
        n_augmentation = self.data_aug_params.to_abs(
        )  # real number of augmentations
        if n_augmentation == 0:
            return samples

        apply_fn = partial(
            self.multi_augment,
            n_augmentations=n_augmentation,
            include_non_augmented=True,
        )
        augmented_samples = parallel_map(
            apply_fn,
            samples,
            desc="Augmenting data",
            processes=num_processes,
            progress_bar=progress_bar,
        )
        augmented_samples = sum(list(augmented_samples), [])  # Flatten
        return augmented_samples
Пример #3
0
    def augment_datas(self,
                      datas,
                      gt_txts,
                      n_augmentations,
                      processes=1,
                      progress_bar=False):
        if n_augmentations < 0 or not isinstance(n_augmentations, int):
            raise ValueError("Number of augmentation must be an integer >= 0")

        if n_augmentations == 0:
            return datas, gt_txts

        out = parallel_map(self.augment_data_tuple,
                           list(
                               zip(datas, gt_txts,
                                   [n_augmentations] * len(datas))),
                           desc="Augmentation",
                           processes=processes,
                           progress_bar=progress_bar)
        out_d, out_t = type(datas)(), type(datas)()
        for d, t in out:
            out_d += d
            out_t += t

        return datas + out_d, gt_txts + out_t
Пример #4
0
    def evaluate(self, *, gt_data: Dict[str, str], pred_data: Dict[str, str]):
        """evaluate on the given raw data

        Parameters
        ----------
        gt_data : Dataset, optional
            the ground truth
        pred_data : Dataset
            the prediction dataset

        Returns
        -------
        evaluation dictionary
        """
        if self.params.non_existing_pred_as_empty:
            n_empty = 0
            mapped_pred_data = {}
            for sample_id in gt_data.keys():
                if sample_id in pred_data:
                    mapped_pred_data[sample_id] = pred_data[sample_id]
                else:
                    mapped_pred_data[sample_id] = ""
                    n_empty += 1
            logger.info(
                f"{n_empty}/{len(gt_data)} lines could not be matched during the evaluation."
            )
            if n_empty == len(gt_data):
                raise ValueError(
                    f"No lines could be matched by their ID. First 10 gt ids "
                    f"{list(gt_data.keys())[:10]}, first 10 pred ids {list(pred_data.keys())[:100]}"
                )
            pred_data = mapped_pred_data

        gt_ids, pred_ids = set(gt_data.keys()), set(pred_data.keys())
        if len(gt_ids) != len(gt_data):
            raise ValueError(f"Non unique keys in ground truth data.")
        if gt_ids != pred_ids:
            raise Exception(
                f"Mismatch in gt and pred. Samples could not be matched by ID. "
                f"GT without PRED: {gt_ids.difference(pred_ids)}. "
                f"PRED without GT: {pred_ids.difference(gt_ids)}")

        gt_pred = [(gt_data[s_id], pred_data[s_id]) for s_id in gt_ids]
        # evaluate single lines
        out = parallel_map(
            Evaluator.evaluate_single_args,
            [{
                "gt": gt,
                "pred": pred,
                "skip_empty_gt": self.params.skip_empty_gt
            } for gt, pred in gt_pred],
            processes=self.params.setup.num_processes,
            progress_bar=self.params.progress_bar,
            desc="Evaluation",
        )

        return Evaluator.evaluate_single_list(out, True)
Пример #5
0
 def apply_on_samples(
     self,
     samples: Iterable[Sample],
     num_processes=1,
     drop_invalid=True,
     progress_bar=False,
 ) -> Iterator[Sample]:
     mapped = parallel_map(
         self.apply_on_sample,
         samples,
         processes=num_processes,
         progress_bar=progress_bar,
         desc=f"Applying data processor {self.__class__.__name__}",
     )
     if drop_invalid:
         mapped = filter(self.is_valid_sample, mapped)
     return mapped
Пример #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", type=str, required=True,
                        help="The base directory where to store all working files")
    parser.add_argument("--eval_files", type=str, nargs="+", required=True,
                        help="All files that shall be used for evaluation")
    parser.add_argument("--n_lines", type=int, default=[-1], nargs="+",
                        help="Optional argument to specify the number of lines (images) used for training. "
                             "On default, all available lines will be used.")
    parser.add_argument("--run", type=str, default=None,
                        help="An optional command that will receive the train calls. Useful e.g. when using a resource "
                             "manager such as slurm.")

    parser.add_argument("--skip_train", action="store_true",
                        help="Skip the cross fold training")
    parser.add_argument("--skip_eval", action="store_true",
                        help="Skip the cross fold evaluation")
    parser.add_argument("--verbose", action="store_true",
                        help="Verbose output")
    parser.add_argument("--n_confusions", type=int, default=0,
                        help="Only print n most common confusions. Defaults to 0, use -1 for all.")
    parser.add_argument("--xlsx_output", type=str,
                        help="Optionally write a xlsx file with the evaluation results")

    setup_train_args(parser, omit=["early_stopping_best_model_output_dir", "output_dir"])

    args = parser.parse_args()

    args.base_dir = os.path.abspath(os.path.expanduser(args.base_dir))

    np.random.seed(args.seed)
    random.seed(args.seed)

    # run for all lines
    single_args = [copy.copy(args) for _ in args.n_lines]
    for s_args, n_lines in zip(single_args, args.n_lines):
        s_args.n_lines = n_lines

    predictions = parallel_map(run_for_single_line, single_args, progress_bar=False, processes=len(single_args), use_thread_pool=True)
    predictions = list(predictions)


    # output predictions as csv:
    header = "lines," + ",".join([str(fold) for fold in range(len(predictions[0]["full"]) - 1)])\
             + ",avg,std,voted"

    print(header)

    for prediction_map, n_lines in zip(predictions, args.n_lines):
        prediction = prediction_map["full"]
        data = "{}".format(n_lines)
        folds_lers = []
        for fold, pred in prediction.items():
            if fold == 'voted':
                continue

            eval = pred["eval"]
            data += ",{}".format(eval['avg_ler'])
            folds_lers.append(eval['avg_ler'])

        data += ",{},{}".format(np.mean(folds_lers), np.std(folds_lers))
        eval = prediction['voted']["eval"]
        data += ",{}".format(eval['avg_ler'])

        print(data)

    if args.n_confusions != 0:
        for prediction_map, n_lines in zip(predictions, args.n_lines):
            prediction = prediction_map["full"]
            print("")
            print("CONFUSIONS (lines = {})".format(n_lines))
            print("==========")
            print()

            for fold, pred in prediction.items():
                print("FOLD {}".format(fold))
                print_confusions(pred['eval'], args.n_confusions)

    if args.xlsx_output:
        data_list = []
        for prediction_map, n_lines in zip(predictions, args.n_lines):
            prediction = prediction_map["full"]
            for fold, pred in prediction.items():
                data_list.append({
                    "prefix": "L{} - Fold{}".format(n_lines, fold),
                    "results": pred['eval'],
                    "gt_files": prediction_map['gt_txts'],
                    "gts": prediction_map['gt'],
                    "preds": pred['data']
                })

            for voter in ['sequence_voter', 'confidence_voter_default_ctc']:
                pred = prediction[voter]
                data_list.append({
                    "prefix": "L{} - {}".format(n_lines, voter[:3]),
                    "results": pred['eval'],
                    "gt_files": prediction_map['gt_txts'],
                    "gts": prediction_map['gt'],
                    "preds": pred['data']
                })

        write_xlsx(args.xlsx_output, data_list)