Ejemplo n.º 1
0
  def test_table_values(self, row_updates, expected_updates,
                        min_consolidation_fraction):

    expected_table = text_format.Parse(
        """
          columns {
            text: 'Name'
          }
          columns {
            text: 'Number'
          }
          columns {
            text: 'Date'
          }
          rows {
            cells {
              text: 'A'
            }
            cells {
              text: '1'
            }
            cells {
              text: 'August 2014'
            }
          }
          rows {
            cells {
              text: 'B'
            }
            cells {
              text: '2'
            }
            cells {
              text: 'July 7'
            }
          }
          rows {
            cells {
              text: 'C'
            }
            cells {
              text: '3'
            }
            cells {
              text: 'March 17, 2015'
            }
          }
    """, interaction_pb2.Table())

    for row_index, col_index, text in row_updates:
      expected_table.rows[row_index].cells[col_index].text = text

    actual_table = interaction_pb2.Table()
    actual_table.CopyFrom(expected_table)
    number_annotation_utils.add_numeric_table_values(
        actual_table, min_consolidation_fraction=min_consolidation_fraction)

    expected_table.rows[0].cells[1].numeric_value.CopyFrom(_number(1))
    expected_table.rows[1].cells[1].numeric_value.CopyFrom(_number(2))
    expected_table.rows[2].cells[1].numeric_value.CopyFrom(_number(3))
    expected_table.rows[0].cells[2].numeric_value.CopyFrom(
        _date(year=2014, month=8))
    expected_table.rows[1].cells[2].numeric_value.CopyFrom(
        _date(month=7, day=7))
    expected_table.rows[2].cells[2].numeric_value.CopyFrom(
        _date(year=2015, month=3, day=17))
    for col_index, new_dict in expected_updates:
      for row_index in range(len(expected_table.rows)):
        expected_table.rows[row_index].cells[col_index].ClearField(
            'numeric_value')
        if row_index in new_dict:
          expected_table.rows[row_index].cells[
              col_index].numeric_value.CopyFrom(new_dict[row_index])

    self.assertEqual(expected_table, actual_table)
Ejemplo n.º 2
0
 def test_convert_with_negative_tables(self):
     max_seq_length = 12
     with tempfile.TemporaryDirectory() as input_dir:
         vocab_file = os.path.join(input_dir, 'vocab.txt')
         _create_vocab(vocab_file, range(10))
         converter = tf_example_utils.ToRetrievalTensorflowExample(
             config=tf_example_utils.RetrievalConversionConfig(
                 vocab_file=vocab_file,
                 max_seq_length=max_seq_length,
                 max_column_id=max_seq_length,
                 max_row_id=max_seq_length,
                 strip_column_names=False,
             ))
         interaction = interaction_pb2.Interaction(
             table=interaction_pb2.Table(
                 columns=[
                     interaction_pb2.Cell(text='A'),
                     interaction_pb2.Cell(text='B'),
                     interaction_pb2.Cell(text='C'),
                 ],
                 rows=[
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='0 6'),
                         interaction_pb2.Cell(text='4 7'),
                         interaction_pb2.Cell(text='5 6'),
                     ]),
                     interaction_pb2.Cells(cells=[
                         interaction_pb2.Cell(text='1 7'),
                         interaction_pb2.Cell(text='3 6'),
                         interaction_pb2.Cell(text='5 5'),
                     ]),
                 ],
                 table_id='table_0',
             ),
             questions=[
                 interaction_pb2.Question(
                     id='id',
                     original_text='2',
                 )
             ],
         )
         number_annotation_utils.add_numeric_values(interaction)
         n_table = interaction_pb2.Table(
             columns=[
                 interaction_pb2.Cell(text='A'),
                 interaction_pb2.Cell(text='B'),
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='0 6'),
                     interaction_pb2.Cell(text='4 7'),
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='1 7'),
                     interaction_pb2.Cell(text='3 6'),
                 ]),
             ],
             table_id='table_1',
         )
         number_annotation_utils.add_numeric_table_values(n_table)
         n_example = _NegativeRetrievalExample()
         n_example.table.CopyFrom(n_table)
         n_example.score = -82.0
         n_example.rank = 1
         example = converter.convert(interaction, 0, n_example)
         logging.info(example)
         self.assertEqual(_get_int_feature(example, 'input_ids'), [
             2, 5, 3, 1, 1, 1, 6, 10, 11, 7, 9, 11, 2, 5, 3, 1, 1, 6, 10, 7,
             9, 0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'row_ids'), [
             0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ids'), [
             0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 0, 0, 0, 1, 2, 1, 2, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'segment_ids'), [
             0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'input_mask'), [
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'inv_column_ranks'), [
             0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 2, 1, 0, 0, 0, 0, 0, 2, 1, 1, 2,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'column_ranks'), [
             0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 1, 0, 0, 0, 0, 0, 1, 2, 2, 1,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'numeric_relations'), [
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0
         ])
         self.assertEqual(_get_int_feature(example, 'table_id_hash'),
                          [911224864, 1294380046])
         self.assertEqual(_get_float_feature(example, 'numeric_values'), [
             'nan', 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 5.0, 1.0,
             3.0, 5.0, 'nan', 'nan', 'nan', 'nan', 'nan', 0.0, 4.0, 1.0,
             3.0, 'nan', 'nan', 'nan'
         ])
         self.assertEqual(
             _get_float_feature(example, 'numeric_values_scale'), [
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
                 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
             ])
         self.assertEqual([
             i.decode('utf-8')
             for i in _get_byte_feature(example, 'table_id')
         ], ['table_0', 'table_1'])