def test_get_dataset_enc_dec_sharded_and_packed(self): mixture_or_task_name = "enc_dec_sharded_and_packed" x = [{"inputs": [7, 8], "targets": [3, 9]}, {"inputs": [8, 4], "targets": [4]}, {"inputs": [5, 6, 7], "targets": [6]}] dtypes = {"inputs": tf.int64, "targets": tf.int64} ds = create_default_dataset(x, output_types=dtypes) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=True) shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) output_ds = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split="train", shuffle=False, feature_converter=converter, shard_info=shard_info) # Packing should be done after the sharding. expected = { "encoder_input_token": [7, 8, 1, 5, 6, 7, 1], "encoder_segment_id": [1, 1, 1, 2, 2, 2, 2], "encoder_position": [0, 1, 2, 0, 1, 2, 3], "decoder_target_token": [3, 9, 1, 6, 1], "decoder_input_token": [0, 3, 9, 0, 6], "decoder_loss_weight": [1, 1, 1, 1, 1], "decoder_segment_id": [1, 1, 1, 2, 2], "decoder_position": [0, 1, 2, 0, 1], } expected_dtypes = {feat: tf.int32 for feat in expected.keys()} assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)
def test_encoder_decoder_pretokenized_field(self): x = [{ "inputs": [7, 8, 5, 1], "targets": [3, 9, 1], "targets_pretokenized": "abc" }, { "inputs": [8, 4, 9, 3, 1], "targets": [4, 1], "targets_pretokenized": "def" }] types = { "inputs": tf.int32, "targets": tf.int32, "targets_pretokenized": tf.string } shapes = { "inputs": [None], "targets": [None], "targets_pretokenized": [] } ds = tf.data.Dataset.from_generator(lambda: x, output_types=types, output_shapes=shapes) task_feature_lengths = {"inputs": 10, "targets": 7} converter = feature_converters.EncDecFeatureConverter(pack=True) # Check whether convert_features raise error because targets_pretokenized is # present in the ds but not in the task_feature_lengths converter(ds, task_feature_lengths)
def test_get_dataset_enc_dec_sharded(self): mixture_or_task_name = "enc_dec_sharded" x = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]}, {"inputs": [8, 4], "targets": [4]}, {"inputs": [5, 6, 7], "targets": [6, 5]}] dtypes = {"inputs": tf.int64, "targets": tf.int64} ds = create_default_dataset(x, output_types=dtypes) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=False) shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) output_ds = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split="train", shuffle=False, feature_converter=converter, shard_info=shard_info) # Example index 1 should not be present in the sharded dataset. expected = [{ "encoder_input_token": [7, 8, 5, 6, 9, 4, 1], "decoder_target_token": [3, 9, 1, 0, 0], "decoder_input_token": [0, 3, 9, 1, 0], "decoder_loss_weight": [1, 1, 1, 0, 0], }, { "encoder_input_token": [5, 6, 7, 1, 0, 0, 0], "decoder_target_token": [6, 5, 1, 0, 0], "decoder_input_token": [0, 6, 5, 1, 0], "decoder_loss_weight": [1, 1, 1, 0, 0], }] expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)
def test_encoder_decoder_packed_long_sequences(self): x = [{"inputs": [7, 8, 5, 6, 9, 4, 1], "targets": [3, 9, 1]}, {"inputs": [8, 4, 9, 3, 5, 1], "targets": [4, 1]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 7, "targets": 3} converter = feature_converters.EncDecFeatureConverter(pack=True) converted_ds = converter(ds, task_feature_lengths) # Corner case: packing is true but task_feature_lengths are too long for # packing to happen. We should still get the *_segment_id, *_position # fields. expected = [{ "encoder_input_token": [7, 8, 5, 6, 9, 4, 1], "encoder_segment_id": [1, 1, 1, 1, 1, 1, 1], "encoder_position": [0, 1, 2, 3, 4, 5, 6], "decoder_target_token": [3, 9, 1], "decoder_input_token": [0, 3, 9], "decoder_loss_weight": [1, 1, 1], "decoder_segment_id": [1, 1, 1], "decoder_position": [0, 1, 2], }, { "encoder_input_token": [8, 4, 9, 3, 5, 1, 0], "encoder_segment_id": [1, 1, 1, 1, 1, 1, 0], "encoder_position": [0, 1, 2, 3, 4, 5, 0], "decoder_target_token": [4, 1, 0], "decoder_input_token": [0, 4, 0], "decoder_loss_weight": [1, 1, 0], "decoder_segment_id": [1, 1, 0], "decoder_position": [0, 1, 0], }] assert_dataset(converted_ds, expected)
def test_encoder_decoder_extra_long_inputs(self): x = [{"inputs": [9, 4, 3, 8, 4, 5, 1], "targets": [3, 9, 4, 7, 8, 1]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 5, "targets": 8} expected_msg = ( r".*Feature \\'inputs\\' has length not less than or equal to the " r"expected length of 5 during input_validation.*") with self.assertRaisesRegex(tf.errors.InvalidArgumentError, expected_msg): converter = feature_converters.EncDecFeatureConverter(pack=False) converted_ds = converter(ds, task_feature_lengths) list(converted_ds.as_numpy_iterator())
def test_get_dataset_enc_dec_packed(self): mixture_or_task_name = "enc_dec_packed" x = [{ "inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9] }, { "inputs": [8, 4], "targets": [4] }, { "inputs": [5, 6, 7], "targets": [6, 5] }] ds = create_default_dataset(x) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=True) output_ds = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split="train", shuffle=False, feature_converter=converter) expected = [ { # Example 1 is trimmed "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], "encoder_segment_ids": [1, 1, 1, 1, 1, 1, 1], "encoder_positions": [0, 1, 2, 3, 4, 5, 6], "decoder_target_tokens": [3, 9, 1, 0, 0], "decoder_input_tokens": [0, 3, 9, 0, 0], "decoder_loss_weights": [1, 1, 1, 0, 0], "decoder_segment_ids": [1, 1, 1, 0, 0], "decoder_positions": [0, 1, 2, 0, 0], }, { # Example 2 and 3 are packed together "encoder_input_tokens": [8, 4, 1, 5, 6, 7, 1], "encoder_segment_ids": [1, 1, 1, 2, 2, 2, 2], "encoder_positions": [0, 1, 2, 0, 1, 2, 3], "decoder_target_tokens": [4, 1, 6, 5, 1], "decoder_input_tokens": [0, 4, 0, 6, 5], "decoder_loss_weights": [1, 1, 1, 1, 1], "decoder_segment_ids": [1, 1, 2, 2, 2], "decoder_positions": [0, 1, 0, 1, 2], } ] expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)
def test_encoder_decoder_targets_max_length(self): x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 5, 1]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 5, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=False) converted_ds = converter(ds, task_feature_lengths) expected = { "encoder_input_tokens": [9, 4, 3, 8, 1], "decoder_target_tokens": [3, 9, 4, 5, 1], "decoder_input_tokens": [0, 3, 9, 4, 5], "decoder_loss_weights": [1, 1, 1, 1, 1], } assert_dataset(converted_ds, expected)
def test_encoder_decoder_unpacked(self): x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 1]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=False) converted_ds = converter(ds, task_feature_lengths) expected = { "encoder_input_tokens": [9, 4, 3, 8, 1, 0, 0], "decoder_target_tokens": [3, 9, 4, 1, 0], # mtf.transformer.autoregressive_inputs does not zero out the last eos # when the data is not packed. This test mimic the behavior. "decoder_input_tokens": [0, 3, 9, 4, 1], "decoder_loss_weights": [1, 1, 1, 1, 0], } assert_dataset(converted_ds, expected)
def test_encoder_decoder_packed(self): x = [{"inputs": [7, 8, 5, 1], "targets": [3, 9, 1]}, {"inputs": [8, 4, 9, 3, 1], "targets": [4, 1]}] ds = create_default_dataset(x) task_feature_lengths = {"inputs": 10, "targets": 7} converter = feature_converters.EncDecFeatureConverter(pack=True) converted_ds = converter(ds, task_feature_lengths) expected = { "encoder_input_token": [7, 8, 5, 1, 8, 4, 9, 3, 1, 0], "encoder_segment_id": [1, 1, 1, 1, 2, 2, 2, 2, 2, 0], "encoder_position": [0, 1, 2, 3, 0, 1, 2, 3, 4, 0], "decoder_target_token": [3, 9, 1, 4, 1, 0, 0], "decoder_input_token": [0, 3, 9, 0, 4, 0, 0], "decoder_loss_weight": [1, 1, 1, 1, 1, 0, 0], "decoder_segment_id": [1, 1, 1, 2, 2, 0, 0], "decoder_position": [0, 1, 2, 0, 1, 0, 0], } assert_dataset(converted_ds, expected)
def test_get_dataset_both_train_and_validation_splits(self): mixture_or_task_name = "both_train_and_validation_splits" x_train = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]}] x_val = [{"inputs": [8, 4], "targets": [4]}] datasets = { "train": create_default_dataset(x_train), "validation": create_default_dataset(x_val) } dataset_fn = lambda split, shuffle_files: datasets[split] register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} output_ds = {} for split in ["train", "validation"]: converter = feature_converters.EncDecFeatureConverter(pack=False) output_ds[split] = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split=split, shuffle=False, feature_converter=converter) expected_train = { "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], "decoder_target_tokens": [3, 9, 1, 0, 0], "decoder_input_tokens": [0, 3, 9, 1, 0], "decoder_loss_weights": [1, 1, 1, 0, 0], } expected_val = { "encoder_input_tokens": [8, 4, 1, 0, 0, 0, 0], "decoder_target_tokens": [4, 1, 0, 0, 0], "decoder_input_tokens": [0, 4, 1, 0, 0], "decoder_loss_weights": [1, 1, 0, 0, 0], } expected_dtypes = {feat: tf.int32 for feat in expected_train.keys()} assert_dataset(output_ds["train"], expected_train, expected_dtypes=expected_dtypes) assert_dataset(output_ds["validation"], expected_val, expected_dtypes=expected_dtypes)
def test_get_dataset_enc_dec_unpacked(self): mixture_or_task_name = "enc_dec_unpacked" x = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]}, {"inputs": [8, 4], "targets": [4]}, {"inputs": [5, 6, 7], "targets": [6, 5]}] dtypes = {"inputs": tf.int64, "targets": tf.int64} ds = create_default_dataset(x, output_types=dtypes) dataset_fn = lambda split, shuffle_files: ds register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) task_feature_lengths = {"inputs": 7, "targets": 5} converter = feature_converters.EncDecFeatureConverter(pack=False) output_ds = feature_converters.get_dataset( mixture_or_task_name=mixture_or_task_name, task_feature_lengths=task_feature_lengths, dataset_split="train", shuffle=False, feature_converter=converter) expected = [{ "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], "decoder_target_tokens": [3, 9, 1, 0, 0], "decoder_input_tokens": [0, 3, 9, 1, 0], "decoder_loss_weights": [1, 1, 1, 0, 0], }, { "encoder_input_tokens": [8, 4, 1, 0, 0, 0, 0], "decoder_target_tokens": [4, 1, 0, 0, 0], "decoder_input_tokens": [0, 4, 1, 0, 0], "decoder_loss_weights": [1, 1, 0, 0, 0], }, { "encoder_input_tokens": [5, 6, 7, 1, 0, 0, 0], "decoder_target_tokens": [6, 5, 1, 0, 0], "decoder_input_tokens": [0, 6, 5, 1, 0], "decoder_loss_weights": [1, 1, 1, 0, 0], }] expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes)