def verify_task_matches_fake_datasets(task, use_cached, token_preprocessed=False, splits=("train", "validation"), num_shards=None): """Assert all splits for both tokenized datasets are correct.""" for split in splits: get_dataset = functools.partial(task.get_dataset, _SEQUENCE_LENGTH, split, use_cached=use_cached, shuffle=False) if num_shards: ds = get_dataset( shard_info=dataset_providers.ShardInfo(0, num_shards)) for i in range(1, num_shards): ds = ds.concatenate( get_dataset( shard_info=dataset_providers.ShardInfo(i, num_shards))) else: ds = get_dataset() _assert_compare_to_fake_dataset( ds, split, task.output_features, token_preprocessed=token_preprocessed, )
def test_get_dataset_enc_dec_sharded_and_packed(self): mixture_or_task_name = "enc_dec_sharded_and_packed" x = [{"inputs": [7, 8], "targets": [3, 9]}, {"inputs": [8, 4], "targets": [4]}, {"inputs": [5, 6, 7], "targets": [6]}] dtypes = {"inputs": tf.int64, "targets": tf.int64} ds = create_default_dataset(x, output_types=dtypes) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=True) shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) output_ds = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split="train", shuffle=False, feature_converter=converter, shard_info=shard_info) # Packing should be done after the sharding. expected = { "encoder_input_token": [7, 8, 1, 5, 6, 7, 1], "encoder_segment_id": [1, 1, 1, 2, 2, 2, 2], "encoder_position": [0, 1, 2, 0, 1, 2, 3], "decoder_target_token": [3, 9, 1, 6, 1], "decoder_input_token": [0, 3, 9, 0, 6], "decoder_loss_weight": [1, 1, 1, 1, 1], "decoder_segment_id": [1, 1, 1, 2, 2], "decoder_position": [0, 1, 2, 0, 1], } expected_dtypes = {feat: tf.int32 for feat in expected.keys()} assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)
def test_get_dataset_enc_dec_sharded(self): mixture_or_task_name = "enc_dec_sharded" x = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]}, {"inputs": [8, 4], "targets": [4]}, {"inputs": [5, 6, 7], "targets": [6, 5]}] dtypes = {"inputs": tf.int64, "targets": tf.int64} ds = create_default_dataset(x, output_types=dtypes) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=False) shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) output_ds = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split="train", shuffle=False, feature_converter=converter, shard_info=shard_info) # Example index 1 should not be present in the sharded dataset. expected = [{ "encoder_input_token": [7, 8, 5, 6, 9, 4, 1], "decoder_target_token": [3, 9, 1, 0, 0], "decoder_input_token": [0, 3, 9, 1, 0], "decoder_loss_weight": [1, 1, 1, 0, 0], }, { "encoder_input_token": [5, 6, 7, 1, 0, 0, 0], "decoder_target_token": [6, 5, 1, 0, 0], "decoder_input_token": [0, 6, 5, 1, 0], "decoder_loss_weight": [1, 1, 1, 0, 0], }] expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)