def input_fn( name, file_patterns, data_format, compression_type, is_training, max_seq_length, max_predictions_per_seq, params, ): """Returns an input_fn compatible with the tf.estimator API.""" parse_example_fn = table_dataset.parse_table_examples( max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, task_type=table_dataset.TableTask.PRETRAINING, add_aggregation_function_id=False, add_classification_labels=False, add_answer=False, include_id=False, add_candidate_answers=False, max_num_candidates=0, params=params) ds = dataset.read_dataset(parse_example_fn, name=name, file_patterns=file_patterns, data_format=data_format, compression_type=compression_type, is_training=is_training, params=params) return ds
def input_fn(name: Text, file_patterns: Iterable[Text], data_format: Text, is_training: bool, max_seq_length: int, include_id: bool, compression_type: Text, use_mined_negatives: bool, params): """Returns an input_fn compatible with the tf.estimator API.""" task_type = (table_dataset.TableTask.RETRIEVAL_NEGATIVES if use_mined_negatives else table_dataset.TableTask.RETRIEVAL) parse_example_fn = table_dataset.parse_table_examples( max_seq_length=max_seq_length, max_predictions_per_seq=None, task_type=task_type, add_aggregation_function_id=False, add_classification_labels=False, add_answer=False, include_id=include_id, add_candidate_answers=False, max_num_candidates=0, params=params) ds = dataset.read_dataset( parse_example_fn, name=name, file_patterns=file_patterns, data_format=data_format, is_training=is_training, compression_type=compression_type, params=dict(params, max_eval_count=None), ) return ds
def create_random_dataset(num_examples, batch_size, repeat, generator_kwargs): """Creates a dataset out of random examples. Args: num_examples: Number of examples to generate. batch_size: Batch size. repeat: Whether to repeat the examples forever. generator_kwargs: dict of arguments for create_random_example. Returns: A tf.data.Dataset with parsed examples. """ examples = [] for _ in range(num_examples): example = make_tf_example(create_random_example(**generator_kwargs)) examples.append(example.SerializeToString()) dataset = tf.data.Dataset.from_tensor_slices(examples) if repeat: dataset = dataset.repeat() parse_fn = table_dataset.parse_table_examples( max_seq_length=generator_kwargs["max_seq_length"], max_predictions_per_seq=generator_kwargs["max_predictions_per_seq"], task_type=generator_kwargs["task_type"], add_aggregation_function_id=generator_kwargs[ "add_aggregation_function_id"], add_classification_labels=generator_kwargs[ "add_classification_labels"], add_answer=generator_kwargs["add_answer"], include_id=generator_kwargs["include_id"], add_candidate_answers=generator_kwargs["add_candidate_answers"], max_num_candidates=generator_kwargs["max_num_candidates"], params={"batch_size": batch_size}) dataset = dataset.map(parse_fn) dataset = dataset.batch(batch_size=batch_size, drop_remainder=True) return dataset
def test_parse_table_examples(self, max_seq_length, max_predictions_per_seq, task_type, add_aggregation_function_id, add_answer, include_id, add_candidate_answers, add_classification_labels): logging.info("Setting random seed to 42") np.random.seed(42) max_num_candidates = 10 values = table_dataset_test_utils.create_random_example( max_seq_length, max_predictions_per_seq, task_type, add_aggregation_function_id, add_classification_labels, add_answer, include_id, vocab_size=10, segment_vocab_size=3, num_columns=3, num_rows=2, add_candidate_answers=add_candidate_answers, max_num_candidates=max_num_candidates) example = table_dataset_test_utils.make_tf_example(values) params = {} parse_fn = table_dataset.parse_table_examples( max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, task_type=task_type, add_aggregation_function_id=add_aggregation_function_id, add_classification_labels=add_classification_labels, add_answer=add_answer, include_id=include_id, add_candidate_answers=add_candidate_answers, max_num_candidates=max_num_candidates, params=params, ) features = parse_fn(example.SerializeToString()) with self.cached_session() as sess: features_vals = sess.run(features) for value in values: if value == "can_indexes": continue if values[value].dtype == np.float32 or values[ value].dtype == np.int32: np.testing.assert_almost_equal(features_vals[value], values[value]) else: # Handle feature as string. np.testing.assert_equal(features_vals[value], values[value]) if add_candidate_answers: self.assertEqual(features_vals["can_label_ids"].dtype, np.int32) self.assertAllEqual(features_vals["can_label_ids"].shape, [max_num_candidates, max_seq_length]) # The total number of label_ids set to 1 must match the total number # of indices. num_indices = len(values["can_indexes"]) self.assertEqual(features_vals["can_label_ids"].sum(), num_indices) # Check that the correct indices are set to 1. cand_id = 0 cand_start = 0 for i in range(len(values["can_indexes"])): while i - cand_start >= values["can_sizes"][cand_id]: cand_id += 1 cand_start = i token_id = values["can_indexes"][i] self.assertEqual( features_vals["can_label_ids"][cand_id, token_id], 1)