예제 #1
0
    def test(cls, cfg, model):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
            logger.info("Prepare testing set")
            try:
                data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
            except NotImplementedError:
                logger.warn(
                    "No evaluator found. implement its `build_evaluator` method."
                )
                results[dataset_name] = {}
                continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i

        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results)
            print_csv_format(results)

        if len(results) == 1: results = list(results.values())[0]

        return results
예제 #2
0
def do_test(cfg, model):
    results = OrderedDict()
    for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
        logger.info("Prepare testing set")
        try:
            data_loader, evaluator = get_evaluator(cfg, dataset_name)
        except NotImplementedError:
            logger.warn(
                "No evaluator found. implement its `build_evaluator` method.")
            results[dataset_name] = {}
            continue
        results_i = inference_on_dataset(model,
                                         data_loader,
                                         evaluator,
                                         flip_test=cfg.TEST.FLIP_ENABLED)
        results[dataset_name] = results_i

    if comm.is_main_process():
        assert isinstance(
            results, dict
        ), "Evaluator must return a dict on the main process. Got {} instead.".format(
            results)
        print_csv_format(results)

    if len(results) == 1: results = list(results.values())[0]

    return results
예제 #3
0
    def test(cls, cfg, model, evaluators=None):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.
        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]

        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
            logger.info("Prepare testing set")
            data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, num_query)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset(model, data_loader, evaluator)
            results[dataset_name] = results_i

        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results
            )
            print_csv_format(results)

        if len(results) == 1: results = list(results.values())[0]

        return results
예제 #4
0
 def test(self, cfg, model):
     results = OrderedDict()
     for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
         results_i = inference_on_dataset(
             model, self.test_data_loader[dataset_name],
             self.evaluator[dataset_name])
         results[dataset_name] = results_i
     self.eval_results = results
     if comm.is_main_process():
         assert isinstance(
             results, dict
         ), "Evaluator must return a dict on the main process. Got {} instead.".format(
             results)
         print_csv_format(results)
     if len(results) == 1: results = list(results.values())[0]
     return results
예제 #5
0
    def test(cls, cfg, model):
        """
        Args:
            cfg (CfgNode):
            model (nn.Module):
        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger('fastreid')

        results = OrderedDict()
        dataset_name = cfg.DATASETS.TGT

        logger.info("Prepare testing set")
        try:
            data_loader, evaluator = cls.build_evaluator(cfg, dataset_name)
        except NotImplementedError:
            logger.warn(
                "No evaluator found. implement its `build_evaluator` method.")
            results[dataset_name] = {}

        results_i = inference_on_dataset(model,
                                         data_loader,
                                         evaluator,
                                         flip_test=cfg.TEST.FLIP.ENABLED)
        results[dataset_name] = results_i

        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results)
            logger.info("Evaluation results for {} in csv format:".format(
                dataset_name))
            results_i['dataset'] = dataset_name
            print_csv_format(results_i)

        # if len(results) == 1:
        #     results = list(results.values())[0]

        return results