Exemplo n.º 1
0
    def test_eval_duplicate_query_ids(self):
        """Tests whether precision@k calculations are done properly."""
        eval_table_retriever_utils._NUM_NEIGHBORS = 6

        queries = [
            eval_table_retriever_utils.QueryExample(
                query_id='1',
                table_ids=['b', 'c'],
                query=np.array([1, 0, 0, 0]),
            ),
            eval_table_retriever_utils.QueryExample(
                query_id='2',
                table_ids=['b'],
                query=np.array([0, 1, 0, 0]),
            ),
        ]
        tables = [
            eval_table_retriever_utils.TableExample(table_id='a',
                                                    table=np.array(
                                                        [1, 0, 0, 0])),
            eval_table_retriever_utils.TableExample(table_id='b',
                                                    table=np.array(
                                                        [0.5, 1, 0, 0])),
            eval_table_retriever_utils.TableExample(table_id='c',
                                                    table=np.array(
                                                        [0, 0.5, 0.5, 1])),
            eval_table_retriever_utils.TableExample(table_id='d',
                                                    table=np.array(
                                                        [0, 0, 0, 0])),
            eval_table_retriever_utils.TableExample(table_id='e',
                                                    table=np.array(
                                                        [0, 0, 1, 1])),
            eval_table_retriever_utils.TableExample(table_id='f',
                                                    table=np.array(
                                                        [0, 0, 0, 1])),
            eval_table_retriever_utils.TableExample(table_id='g',
                                                    table=np.array(
                                                        [0, 0, 0, 1])),
            eval_table_retriever_utils.TableExample(table_id='h',
                                                    table=np.array(
                                                        [0, 0, 0, 1])),
        ]

        index = eval_table_retriever_utils.build_table_index(tables)
        results_file_path = tempfile.mktemp()
        precision_at_k = eval_table_retriever_utils.process_predictions(
            queries,
            tables,
            index,
            retrieval_results_file_path=results_file_path)
        self.assertEqual(precision_at_k, {
            'precision_at_1': 0.5,
            'precision_at_5': 1.0
        })
Exemplo n.º 2
0
 def test_retrieve(self):
     queries, tables = self._generate_sythetic_data()
     eval_table_retriever_utils._NUM_NEIGHBORS = 2
     index = eval_table_retriever_utils.build_table_index(tables)
     similarities, neighbors = eval_table_retriever_utils._retrieve(
         queries, index)
     expected_similarities = [1, 0.5, 1, 0.5, 1, 0.5, 1, 1]
     expected_neighbors = [0, 1, 1, 2, 4, 2, 4, 7]
     self.assertSequenceEqual(similarities.flatten().tolist(),
                              expected_similarities)
     self.assertSequenceEqual(neighbors.flatten().tolist(),
                              expected_neighbors)
Exemplo n.º 3
0
    def test_eval_process_predictions(self):
        """Tests whether precision@k calculations are done properly."""
        queries, tables = self._generate_sythetic_data()
        eval_table_retriever_utils._NUM_NEIGHBORS = 6
        index = eval_table_retriever_utils.build_table_index(tables)
        results_file_path = tempfile.mktemp()
        precision_at_k = eval_table_retriever_utils.process_predictions(
            queries,
            tables,
            index,
            retrieval_results_file_path=results_file_path)
        self.assertEqual(precision_at_k, {
            'precision_at_1': 0.5,
            'precision_at_5': 0.75
        })

        results = self._load_results_from_file(results_file_path)
        for result in results:
            if result['query_id'] == '1':
                for table in result['table_scores']:
                    if table['table_id'] == 'a':
                        self.assertEqual(table['score'], -1.0)