def verify_task_matches_fake_datasets(task,
                                      use_cached,
                                      token_preprocessed=False,
                                      splits=("train", "validation"),
                                      num_shards=None):
    """Assert all splits for both tokenized datasets are correct."""
    for split in splits:
        get_dataset = functools.partial(task.get_dataset,
                                        _SEQUENCE_LENGTH,
                                        split,
                                        use_cached=use_cached,
                                        shuffle=False)
        if num_shards:
            ds = get_dataset(
                shard_info=dataset_providers.ShardInfo(0, num_shards))
            for i in range(1, num_shards):
                ds = ds.concatenate(
                    get_dataset(
                        shard_info=dataset_providers.ShardInfo(i, num_shards)))
        else:
            ds = get_dataset()
        _assert_compare_to_fake_dataset(
            ds,
            split,
            task.output_features,
            token_preprocessed=token_preprocessed,
        )
Exemple #2
0
  def test_get_dataset_enc_dec_sharded_and_packed(self):
    mixture_or_task_name = "enc_dec_sharded_and_packed"
    x = [{"inputs": [7, 8], "targets": [3, 9]},
         {"inputs": [8, 4], "targets": [4]},
         {"inputs": [5, 6, 7], "targets": [6]}]
    dtypes = {"inputs": tf.int64, "targets": tf.int64}
    ds = create_default_dataset(x, output_types=dtypes)
    dataset_fn = lambda split, shuffle_files: ds
    register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn)

    task_feature_lengths = {"inputs": 7, "targets": 5}
    converter = feature_converters.EncDecFeatureConverter(pack=True)
    shard_info = dataset_providers.ShardInfo(index=0, num_shards=2)
    output_ds = feature_converters.get_dataset(
        mixture_or_task_name=mixture_or_task_name,
        task_feature_lengths=task_feature_lengths,
        dataset_split="train",
        shuffle=False,
        feature_converter=converter,
        shard_info=shard_info)

    # Packing should be done after the sharding.
    expected = {
        "encoder_input_token": [7, 8, 1, 5, 6, 7, 1],
        "encoder_segment_id": [1, 1, 1, 2, 2, 2, 2],
        "encoder_position": [0, 1, 2, 0, 1, 2, 3],
        "decoder_target_token": [3, 9, 1, 6, 1],
        "decoder_input_token": [0, 3, 9, 0, 6],
        "decoder_loss_weight": [1, 1, 1, 1, 1],
        "decoder_segment_id": [1, 1, 1, 2, 2],
        "decoder_position": [0, 1, 2, 0, 1],
    }
    expected_dtypes = {feat: tf.int32 for feat in expected.keys()}
    assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)
Exemple #3
0
  def test_get_dataset_enc_dec_sharded(self):
    mixture_or_task_name = "enc_dec_sharded"
    x = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]},
         {"inputs": [8, 4], "targets": [4]},
         {"inputs": [5, 6, 7], "targets": [6, 5]}]
    dtypes = {"inputs": tf.int64, "targets": tf.int64}
    ds = create_default_dataset(x, output_types=dtypes)
    dataset_fn = lambda split, shuffle_files: ds
    register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn)

    task_feature_lengths = {"inputs": 7, "targets": 5}
    converter = feature_converters.EncDecFeatureConverter(pack=False)
    shard_info = dataset_providers.ShardInfo(index=0, num_shards=2)
    output_ds = feature_converters.get_dataset(
        mixture_or_task_name=mixture_or_task_name,
        task_feature_lengths=task_feature_lengths,
        dataset_split="train",
        shuffle=False,
        feature_converter=converter,
        shard_info=shard_info)

    # Example index 1 should not be present in the sharded dataset.
    expected = [{
        "encoder_input_token": [7, 8, 5, 6, 9, 4, 1],
        "decoder_target_token": [3, 9, 1, 0, 0],
        "decoder_input_token": [0, 3, 9, 1, 0],
        "decoder_loss_weight": [1, 1, 1, 0, 0],
    }, {
        "encoder_input_token": [5, 6, 7, 1, 0, 0, 0],
        "decoder_target_token": [6, 5, 1, 0, 0],
        "decoder_input_token": [0, 6, 5, 1, 0],
        "decoder_loss_weight": [1, 1, 1, 0, 0],
    }]
    expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()}
    assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)