コード例 #1
0
 def verify_task_matches_fake_datasets(  # pylint:disable=dangerous-default-value
         self,
         task_name,
         use_cached,
         token_preprocessed=False,
         splits=("train", "validation"),
         sequence_length=_DEFAULT_SEQUENCE_LENGTH,
         num_shards=None):
     """Assert all splits for both tokenized datasets are correct."""
     task = TaskRegistry.get(task_name)
     for split in splits:
         get_dataset = functools.partial(task.get_dataset,
                                         sequence_length,
                                         split,
                                         use_cached=use_cached,
                                         shuffle=False)
         if num_shards:
             ds = get_dataset(
                 shard_info=dataset_providers.ShardInfo(0, num_shards))
             for i in range(1, num_shards):
                 ds = ds.concatenate(
                     get_dataset(shard_info=dataset_providers.ShardInfo(
                         i, num_shards)))
         else:
             ds = get_dataset()
         _assert_compare_to_fake_dataset(
             ds,
             split,
             task.output_features,
             sequence_length,
             token_preprocessed=token_preprocessed,
         )
  def test_get_dataset_enc_dec_sharded_and_packed(self):
    mixture_or_task_name = "enc_dec_sharded_and_packed"
    x = [{"inputs": [7, 8], "targets": [3, 9]},
         {"inputs": [8, 4], "targets": [4]},
         {"inputs": [5, 6, 7], "targets": [6]}]
    ds = create_default_dataset(x)
    dataset_fn = lambda split, shuffle_files: ds
    register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn)

    task_feature_lengths = {"inputs": 7, "targets": 5}
    converter = feature_converters.EncDecFeatureConverter(pack=True)
    shard_info = dataset_providers.ShardInfo(index=0, num_shards=2)
    output_ds = dataset_providers.get_dataset(
        mixture_or_task_name=mixture_or_task_name,
        task_feature_lengths=task_feature_lengths,
        dataset_split="train",
        shuffle=False,
        feature_converter=converter,
        shard_info=shard_info)

    # Packing should be done after the sharding.
    expected = {
        "encoder_input_tokens": [7, 8, 1, 5, 6, 7, 1],
        "encoder_segment_ids": [1, 1, 1, 2, 2, 2, 2],
        "encoder_positions": [0, 1, 2, 0, 1, 2, 3],
        "decoder_target_tokens": [3, 9, 1, 6, 1],
        "decoder_input_tokens": [0, 3, 9, 0, 6],
        "decoder_loss_weights": [1, 1, 1, 1, 1],
        "decoder_segment_ids": [1, 1, 1, 2, 2],
        "decoder_positions": [0, 1, 2, 0, 1],
    }
    expected_dtypes = {feat: tf.int32 for feat in expected.keys()}
    assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)
  def test_get_dataset_enc_dec_sharded(self):
    mixture_or_task_name = "enc_dec_sharded"
    x = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]},
         {"inputs": [8, 4], "targets": [4]},
         {"inputs": [5, 6, 7], "targets": [6, 5]}]
    ds = create_default_dataset(x)
    dataset_fn = lambda split, shuffle_files: ds
    register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn)

    task_feature_lengths = {"inputs": 7, "targets": 5}
    converter = feature_converters.EncDecFeatureConverter(pack=False)
    shard_info = dataset_providers.ShardInfo(index=0, num_shards=2)
    output_ds = dataset_providers.get_dataset(
        mixture_or_task_name=mixture_or_task_name,
        task_feature_lengths=task_feature_lengths,
        dataset_split="train",
        shuffle=False,
        feature_converter=converter,
        shard_info=shard_info)

    # Example index 1 should not be present in the sharded dataset.
    expected = [{
        "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1],
        "decoder_target_tokens": [3, 9, 1, 0, 0],
        "decoder_input_tokens": [0, 3, 9, 1, 0],
        "decoder_loss_weights": [1, 1, 1, 0, 0],
    }, {
        "encoder_input_tokens": [5, 6, 7, 1, 0, 0, 0],
        "decoder_target_tokens": [6, 5, 1, 0, 0],
        "decoder_input_tokens": [0, 6, 5, 1, 0],
        "decoder_loss_weights": [1, 1, 1, 0, 0],
    }]
    expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()}
    assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)