Example #1
0
def input_fn(
    name,
    file_patterns,
    data_format,
    compression_type,
    is_training,
    max_seq_length,
    max_predictions_per_seq,
    params,
):
    """Returns an input_fn compatible with the tf.estimator API."""
    parse_example_fn = table_dataset.parse_table_examples(
        max_seq_length=max_seq_length,
        max_predictions_per_seq=max_predictions_per_seq,
        task_type=table_dataset.TableTask.PRETRAINING,
        add_aggregation_function_id=False,
        add_classification_labels=False,
        add_answer=False,
        include_id=False,
        add_candidate_answers=False,
        max_num_candidates=0,
        params=params)
    ds = dataset.read_dataset(parse_example_fn,
                              name=name,
                              file_patterns=file_patterns,
                              data_format=data_format,
                              compression_type=compression_type,
                              is_training=is_training,
                              params=params)
    return ds
def input_fn(name: Text, file_patterns: Iterable[Text], data_format: Text,
             is_training: bool, max_seq_length: int, include_id: bool,
             compression_type: Text, use_mined_negatives: bool, params):
    """Returns an input_fn compatible with the tf.estimator API."""
    task_type = (table_dataset.TableTask.RETRIEVAL_NEGATIVES
                 if use_mined_negatives else table_dataset.TableTask.RETRIEVAL)
    parse_example_fn = table_dataset.parse_table_examples(
        max_seq_length=max_seq_length,
        max_predictions_per_seq=None,
        task_type=task_type,
        add_aggregation_function_id=False,
        add_classification_labels=False,
        add_answer=False,
        include_id=include_id,
        add_candidate_answers=False,
        max_num_candidates=0,
        params=params)
    ds = dataset.read_dataset(
        parse_example_fn,
        name=name,
        file_patterns=file_patterns,
        data_format=data_format,
        is_training=is_training,
        compression_type=compression_type,
        params=dict(params, max_eval_count=None),
    )
    return ds
Example #3
0
def create_random_dataset(num_examples, batch_size, repeat, generator_kwargs):
    """Creates a dataset out of random examples.

  Args:
    num_examples: Number of examples to generate.
    batch_size: Batch size.
    repeat: Whether to repeat the examples forever.
    generator_kwargs: dict of arguments for create_random_example.

  Returns:
    A tf.data.Dataset with parsed examples.
  """
    examples = []
    for _ in range(num_examples):
        example = make_tf_example(create_random_example(**generator_kwargs))
        examples.append(example.SerializeToString())

    dataset = tf.data.Dataset.from_tensor_slices(examples)
    if repeat:
        dataset = dataset.repeat()

    parse_fn = table_dataset.parse_table_examples(
        max_seq_length=generator_kwargs["max_seq_length"],
        max_predictions_per_seq=generator_kwargs["max_predictions_per_seq"],
        task_type=generator_kwargs["task_type"],
        add_aggregation_function_id=generator_kwargs[
            "add_aggregation_function_id"],
        add_classification_labels=generator_kwargs[
            "add_classification_labels"],
        add_answer=generator_kwargs["add_answer"],
        include_id=generator_kwargs["include_id"],
        add_candidate_answers=generator_kwargs["add_candidate_answers"],
        max_num_candidates=generator_kwargs["max_num_candidates"],
        params={"batch_size": batch_size})
    dataset = dataset.map(parse_fn)
    dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
    return dataset
    def test_parse_table_examples(self, max_seq_length,
                                  max_predictions_per_seq, task_type,
                                  add_aggregation_function_id, add_answer,
                                  include_id, add_candidate_answers,
                                  add_classification_labels):
        logging.info("Setting random seed to 42")
        np.random.seed(42)
        max_num_candidates = 10
        values = table_dataset_test_utils.create_random_example(
            max_seq_length,
            max_predictions_per_seq,
            task_type,
            add_aggregation_function_id,
            add_classification_labels,
            add_answer,
            include_id,
            vocab_size=10,
            segment_vocab_size=3,
            num_columns=3,
            num_rows=2,
            add_candidate_answers=add_candidate_answers,
            max_num_candidates=max_num_candidates)
        example = table_dataset_test_utils.make_tf_example(values)

        params = {}
        parse_fn = table_dataset.parse_table_examples(
            max_seq_length=max_seq_length,
            max_predictions_per_seq=max_predictions_per_seq,
            task_type=task_type,
            add_aggregation_function_id=add_aggregation_function_id,
            add_classification_labels=add_classification_labels,
            add_answer=add_answer,
            include_id=include_id,
            add_candidate_answers=add_candidate_answers,
            max_num_candidates=max_num_candidates,
            params=params,
        )
        features = parse_fn(example.SerializeToString())

        with self.cached_session() as sess:
            features_vals = sess.run(features)

        for value in values:
            if value == "can_indexes":
                continue
            if values[value].dtype == np.float32 or values[
                    value].dtype == np.int32:
                np.testing.assert_almost_equal(features_vals[value],
                                               values[value])
            else:  # Handle feature as string.
                np.testing.assert_equal(features_vals[value], values[value])

        if add_candidate_answers:
            self.assertEqual(features_vals["can_label_ids"].dtype, np.int32)
            self.assertAllEqual(features_vals["can_label_ids"].shape,
                                [max_num_candidates, max_seq_length])

            # The total number of label_ids set to 1 must match the total number
            # of indices.
            num_indices = len(values["can_indexes"])
            self.assertEqual(features_vals["can_label_ids"].sum(), num_indices)

            # Check that the correct indices are set to 1.
            cand_id = 0
            cand_start = 0
            for i in range(len(values["can_indexes"])):
                while i - cand_start >= values["can_sizes"][cand_id]:
                    cand_id += 1
                    cand_start = i
                token_id = values["can_indexes"][i]
                self.assertEqual(
                    features_vals["can_label_ids"][cand_id, token_id], 1)