Ejemplo n.º 1
0
 def test_interaction_duplicate_column_name(self):
     """Test we don't crash when seeing ambiguous column names."""
     config = synthesize_entablement.SynthesizationConfig(attempts=10)
     interaction = interaction_pb2.Interaction(
         id='i_id',
         table=interaction_pb2.Table(
             table_id='t_id',
             columns=[
                 interaction_pb2.Cell(text='Name'),
                 interaction_pb2.Cell(text='Name'),
                 interaction_pb2.Cell(text='Height')
             ],
             rows=[
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='Peter'),
                     interaction_pb2.Cell(text='Peter'),
                     interaction_pb2.Cell(text='100')
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='Bob'),
                     interaction_pb2.Cell(text='Bob'),
                     interaction_pb2.Cell(text='150')
                 ]),
                 interaction_pb2.Cells(cells=[
                     interaction_pb2.Cell(text='Tina'),
                     interaction_pb2.Cell(text='Tina'),
                     interaction_pb2.Cell(text='200')
                 ]),
             ]),
         questions=[])
     for i in range(20):
         rng = np.random.RandomState(i)
         synthesize_entablement.synthesize_from_interaction(
             config, rng, interaction, synthesize_entablement.Counter())
Ejemplo n.º 2
0
    def test_end_to_end(self, runner_type, add_example_conversion):
        mode = intermediate_pretrain_utils.Mode.ALL
        prob_count_aggregation = 0.2
        use_fake_table = False
        add_opposite_table = False
        drop_without_support_rate = 0.0

        with tempfile.TemporaryDirectory() as temp_dir:
            config = None
            if add_example_conversion:
                vocab_path = os.path.join(temp_dir, "vocab.txt")
                _create_vocab(
                    list(_RESERVED_SYMBOLS) + ["released"], vocab_path)
                config = tf_example_utils.ClassifierConversionConfig(
                    vocab_file=vocab_path,
                    max_seq_length=32,
                    max_column_id=32,
                    max_row_id=32,
                    strip_column_names=False,
                )

            pipeline = intermediate_pretrain_utils.build_pipeline(
                mode=mode,
                config=synthesize_entablement.SynthesizationConfig(
                    prob_count_aggregation=prob_count_aggregation),
                use_fake_table=use_fake_table,
                add_opposite_table=add_opposite_table,
                drop_without_support_rate=drop_without_support_rate,
                input_file=os.path.join(self._test_dir,
                                        "pretrain_interactions.txtpb"),
                output_dir=temp_dir,
                output_suffix=".tfrecord",
                num_splits=3,
                conversion_config=config,
            )

            beam_runner.run_type(pipeline, runner_type).wait_until_finish()

            message_type = interaction_pb2.Interaction
            if add_example_conversion:
                message_type = tf.train.Example

            for name in [("train"), ("test")]:
                self.assertNotEmpty(
                    list(
                        _read_record(
                            os.path.join(temp_dir, f"{name}.tfrecord"),
                            message_type,
                        )))

            if add_example_conversion:
                self.assertNotEmpty(
                    list(
                        _read_record(
                            os.path.join(temp_dir, "interactions.tfrecord"),
                            interaction_pb2.Interaction,
                        ), ))
Ejemplo n.º 3
0
def main(unused_argv):
    del unused_argv
    config = synthesize_entablement.SynthesizationConfig(
        prob_count_aggregation=FLAGS.prob_count_aggregation, )
    conversion_config = None
    if FLAGS.convert_to_examples:
        conversion_config = tf_example_utils.ClassifierConversionConfig(
            vocab_file=FLAGS.vocab_file,
            max_seq_length=FLAGS.max_seq_length,
            max_column_id=FLAGS.max_seq_length,
            max_row_id=FLAGS.max_seq_length,
            strip_column_names=False,
        )
    pipeline = intermediate_pretrain_utils.build_pipeline(
        mode=FLAGS.mode,
        config=config,
        use_fake_table=FLAGS.use_fake_table,
        add_opposite_table=FLAGS.add_opposite_table,
        drop_without_support_rate=FLAGS.drop_without_support_rate,
        input_file=FLAGS.input_file,
        output_dir=FLAGS.output_dir,
        output_suffix=FLAGS.output_suffix,
        conversion_config=conversion_config)
    beam_runner.run(pipeline).wait_until_finish()
Ejemplo n.º 4
0
    def test_interaction(self, prob_count_aggregation):
        config = synthesize_entablement.SynthesizationConfig(
            prob_count_aggregation=prob_count_aggregation, attempts=10)
        interaction = interaction_pb2.Interaction(
            id='i_id',
            table=interaction_pb2.Table(
                table_id='t_id',
                columns=[
                    interaction_pb2.Cell(text='Name'),
                    interaction_pb2.Cell(text='Height'),
                    interaction_pb2.Cell(text='Age')
                ],
                rows=[
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Peter'),
                        interaction_pb2.Cell(text='100'),
                        interaction_pb2.Cell(text='15')
                    ]),
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Bob'),
                        interaction_pb2.Cell(text='150'),
                        interaction_pb2.Cell(text='15')
                    ]),
                    interaction_pb2.Cells(cells=[
                        interaction_pb2.Cell(text='Tina'),
                        interaction_pb2.Cell(text='200'),
                        interaction_pb2.Cell(text='17')
                    ]),
                ]),
            questions=[])

        pos_statements = set()
        neg_statements = set()

        counter = TestCounter()

        for i in range(100):
            rng = np.random.RandomState(i)
            interactions = synthesize_entablement.synthesize_from_interaction(
                config, rng, interaction, counter)
            for new_interaction in interactions:
                question = new_interaction.questions[0]
                if question.answer.class_index == 1:
                    pos_statements.add(question.text)
                else:
                    assert question.answer.class_index == 0
                    neg_statements.add(question.text)
        self.assertEqual(neg_statements, pos_statements)
        logging.info('pos_statements: %s', pos_statements)

        counts = counter.get_counts()
        logging.info('counts: %s', counts)

        is_count_test = prob_count_aggregation == 1.0

        if is_count_test:
            self.assertGreater(len(pos_statements), 10)
            expected_statements = {
                '1 is less than the count when age is 15 and height is greater than 100',
                '1 is less than the count when height is less than 200 and age is 15',
                '1 is the count when height is greater than 100 and age is less than 17',
                '2 is the count when age is less than 17 and height is less than 200',
            }
        else:
            self.assertGreater(len(pos_statements), 100)
            expected_statements = {
                '0 is the range of age when height is greater than 100',
                '100 is less than the last height when height is less than 200',
                '125 is greater than height when name is peter',
                '15 is age when height is less than 150',
                '15 is the last age when height is less than 200',
                '150 is the last height when age is 15',
                '175 is the average height when age is less than 17',
                '200 is greater than the greatest height when age is less than 17',
                '250 is less than the total height when age is 15',
                '30 is less than the total age when height is greater than 100',
                'bob is name when age is greater than 15',
                'bob is the first name when age is 15 and height is less than 200',
                'peter is name when age is 15 and height is less than 150',
                'the average height when age is 15 is less than 175',
                'the first height when height is greater than 100 is 150',
                'the first height when height is less than 200 is 150',
                'the first name when age is 15 is name when name is peter',
                'the greatest height when age is 15 is less than 200',
                'the last age when height is greater than 100 is greater than 15',
                'the last name when age is 15 is bob',
                'the last name when age is less than 17 is peter',
                'the last name when height is greater than 100 is bob',
                'the last name when height is less than 200 is bob',
                'the lowest height when age is 15 is 150',
                'the range of age when height is greater than 100 is greater than 0',
                'the range of height when age is 15 is 100',
                'tina is name when age is greater than 15 and height is 200',
                'tina is the first name when age is 15',
                'tina is the last name when age is 15',
                'tina is the last name when height is greater than 100',
            }

        for statement in expected_statements:
            self.assertIn(statement, pos_statements)

        for name in ['pos', 'neg']:
            if is_count_test:
                self.assertGreater(counts[f'{name}: Synthesization success'],
                                   10)
                self.assertGreater(counts[f'{name}: Select: COUNT'], 10)
            else:
                self.assertEqual(counts[f'{name}: Synthesization success'],
                                 100)
                for aggregation in Aggregation:
                    self.assertGreater(
                        counts[f'{name}: Select: {aggregation.name}'], 0)
            for comparator in Comparator:
                min_count = 1 if prob_count_aggregation == 1.0 else 10
                self.assertGreater(
                    counts[f'{name}: Comparator {comparator.name}'], min_count)
                self.assertGreater(
                    counts[f'{name}: where: Comparator {comparator.name}'],
                    min_count)