def testBatchAutotuneDatasetMultiSource(self): vocab_file = self._makeTextFile("vocab.txt", ["1", "2", "3", "4"]) data_file = self._makeTextFile("data.txt", ["hello world !"]) source_inputter = inputter.ParallelInputter([ text_inputter.WordEmbedder(embedding_size=10), text_inputter.WordEmbedder(embedding_size=10), ]) target_inputter = text_inputter.WordEmbedder(embedding_size=10) target_inputter.set_decoder_mode(mark_start=True, mark_end=True) example_inputter = inputter.ExampleInputter(source_inputter, target_inputter) example_inputter.initialize({ "source_1_vocabulary": vocab_file, "source_2_vocabulary": vocab_file, "target_vocabulary": vocab_file, }) dataset = example_inputter.make_training_dataset( [data_file, data_file], data_file, batch_size=1024, batch_type="tokens", maximum_features_length=[100, 110], maximum_labels_length=120, batch_autotune_mode=True, ) source, target = next(iter(dataset)) self.assertListEqual(source["inputter_0_ids"].shape.as_list(), [8, 100]) self.assertListEqual(source["inputter_1_ids"].shape.as_list(), [8, 110]) self.assertListEqual(target["ids"].shape.as_list(), [8, 120]) self.assertListEqual(target["ids_out"].shape.as_list(), [8, 120])
def testWeightedDataset(self): vocab_file = self._makeTextFile("vocab.txt", ["the", "world", "hello", "toto"]) data_file = self._makeTextFile("data.txt", ["hello world !"]) source_inputter = text_inputter.WordEmbedder(embedding_size=10) target_inputter = text_inputter.WordEmbedder(embedding_size=10) example_inputter = inputter.ExampleInputter(source_inputter, target_inputter) example_inputter.initialize({ "source_vocabulary": vocab_file, "target_vocabulary": vocab_file }) with self.assertRaisesRegex(ValueError, "same number"): example_inputter.make_training_dataset([data_file, data_file], [data_file], batch_size=16) with self.assertRaisesRegex(ValueError, "expected to match"): example_inputter.make_training_dataset( [data_file, data_file], [data_file, data_file], batch_size=16, weights=[0.5], ) dataset = example_inputter.make_training_dataset( [data_file, data_file], [data_file, data_file], batch_size=16) self.assertIsInstance(dataset, tf.data.Dataset) dataset = example_inputter.make_training_dataset( [data_file, data_file], [data_file, data_file], batch_size=16, weights=[0.2, 0.8], ) self.assertIsInstance(dataset, tf.data.Dataset)
def testExampleInputterFiltering(self): vocab_file = self._makeTextFile("vocab.txt", ["a", "b", "c", "d"]) features_file = self._makeTextFile("features.txt", ["a b c d", "a b c", "a a c", "a"]) labels_file = self._makeTextFile("labels.txt", ["a b c d", "a", "a a c d d", ""]) example_inputter = inputter.ExampleInputter( text_inputter.WordEmbedder(embedding_size=10), text_inputter.WordEmbedder(embedding_size=10), ) example_inputter.initialize({ "source_vocabulary": vocab_file, "target_vocabulary": vocab_file }) dataset = example_inputter.make_training_dataset( features_file, labels_file, batch_size=1, maximum_features_length=3, maximum_labels_length=4, single_pass=True, ) examples = list(iter(dataset)) self.assertLen(examples, 1) self.assertAllEqual(examples[0][0]["ids"], [[0, 1, 2]]) self.assertAllEqual(examples[0][1]["ids"], [[0]])
def testExampleInputter(self): vocab_file = self._makeTextFile("vocab.txt", ["the", "world", "hello", "toto"]) data_file = self._makeTextFile("data.txt", ["hello world !"]) source_inputter = text_inputter.WordEmbedder(embedding_size=10) target_inputter = text_inputter.WordEmbedder(embedding_size=10) example_inputter = inputter.ExampleInputter(source_inputter, target_inputter) self.assertEqual(example_inputter.num_outputs, 2) features, transformed = self._makeDataset( example_inputter, [data_file, data_file], data_config={ "source_vocabulary": vocab_file, "target_vocabulary": vocab_file, }, ) self.assertIsInstance(features, tuple) self.assertEqual(len(features), 2) self.assertEqual(len(transformed), 2) features, labels = features for field in ("ids", "length", "tokens"): self.assertIn(field, features) for field in ("ids", "length", "tokens"): self.assertIn(field, labels)
def testExampleInputterAsset(self): vocab_file = self._makeTextFile("vocab.txt", ["the", "world", "hello", "toto"]) source_inputter = text_inputter.WordEmbedder("vocabulary_file_1", embedding_size=10) target_inputter = text_inputter.WordEmbedder("vocabulary_file_1", embedding_size=10) example_inputter = inputter.ExampleInputter(source_inputter, target_inputter) example_inputter.initialize({ "vocabulary_file_1": vocab_file, "vocabulary_file_2": vocab_file, "source_tokenization": {"mode": "conservative"} }) self.assertIsInstance(source_inputter.tokenizer, tokenizers.OpenNMTTokenizer)
def testExampleInputterAsset(self): vocab_file = self._makeTextFile("vocab.txt", ["the", "world", "hello", "toto"]) source_inputter = text_inputter.WordEmbedder(embedding_size=10) target_inputter = text_inputter.WordEmbedder(embedding_size=10) example_inputter = inputter.ExampleInputter(source_inputter, target_inputter) example_inputter.initialize( { "source_vocabulary": vocab_file, "target_vocabulary": vocab_file, "source_tokenization": {"mode": "conservative"}, } ) self.assertIsInstance(source_inputter.tokenizer, tokenizers.OpenNMTTokenizer) asset_dir = self.get_temp_dir() example_inputter.export_assets(asset_dir) self.assertIn("source_tokenizer_config.yml", set(os.listdir(asset_dir)))