def testBatchAutotuneDatasetMultiSource(self):
        vocab_file = self._makeTextFile("vocab.txt", ["1", "2", "3", "4"])
        data_file = self._makeTextFile("data.txt", ["hello world !"])
        source_inputter = inputter.ParallelInputter([
            text_inputter.WordEmbedder(embedding_size=10),
            text_inputter.WordEmbedder(embedding_size=10),
        ])
        target_inputter = text_inputter.WordEmbedder(embedding_size=10)
        target_inputter.set_decoder_mode(mark_start=True, mark_end=True)
        example_inputter = inputter.ExampleInputter(source_inputter,
                                                    target_inputter)
        example_inputter.initialize({
            "source_1_vocabulary": vocab_file,
            "source_2_vocabulary": vocab_file,
            "target_vocabulary": vocab_file,
        })

        dataset = example_inputter.make_training_dataset(
            [data_file, data_file],
            data_file,
            batch_size=1024,
            batch_type="tokens",
            maximum_features_length=[100, 110],
            maximum_labels_length=120,
            batch_autotune_mode=True,
        )

        source, target = next(iter(dataset))
        self.assertListEqual(source["inputter_0_ids"].shape.as_list(),
                             [8, 100])
        self.assertListEqual(source["inputter_1_ids"].shape.as_list(),
                             [8, 110])
        self.assertListEqual(target["ids"].shape.as_list(), [8, 120])
        self.assertListEqual(target["ids_out"].shape.as_list(), [8, 120])
 def testWeightedDataset(self):
     vocab_file = self._makeTextFile("vocab.txt",
                                     ["the", "world", "hello", "toto"])
     data_file = self._makeTextFile("data.txt", ["hello world !"])
     source_inputter = text_inputter.WordEmbedder(embedding_size=10)
     target_inputter = text_inputter.WordEmbedder(embedding_size=10)
     example_inputter = inputter.ExampleInputter(source_inputter,
                                                 target_inputter)
     example_inputter.initialize({
         "source_vocabulary": vocab_file,
         "target_vocabulary": vocab_file
     })
     with self.assertRaisesRegex(ValueError, "same number"):
         example_inputter.make_training_dataset([data_file, data_file],
                                                [data_file],
                                                batch_size=16)
     with self.assertRaisesRegex(ValueError, "expected to match"):
         example_inputter.make_training_dataset(
             [data_file, data_file],
             [data_file, data_file],
             batch_size=16,
             weights=[0.5],
         )
     dataset = example_inputter.make_training_dataset(
         [data_file, data_file], [data_file, data_file], batch_size=16)
     self.assertIsInstance(dataset, tf.data.Dataset)
     dataset = example_inputter.make_training_dataset(
         [data_file, data_file],
         [data_file, data_file],
         batch_size=16,
         weights=[0.2, 0.8],
     )
     self.assertIsInstance(dataset, tf.data.Dataset)
    def testExampleInputterFiltering(self):
        vocab_file = self._makeTextFile("vocab.txt", ["a", "b", "c", "d"])
        features_file = self._makeTextFile("features.txt",
                                           ["a b c d", "a b c", "a a c", "a"])
        labels_file = self._makeTextFile("labels.txt",
                                         ["a b c d", "a", "a a c d d", ""])

        example_inputter = inputter.ExampleInputter(
            text_inputter.WordEmbedder(embedding_size=10),
            text_inputter.WordEmbedder(embedding_size=10),
        )
        example_inputter.initialize({
            "source_vocabulary": vocab_file,
            "target_vocabulary": vocab_file
        })

        dataset = example_inputter.make_training_dataset(
            features_file,
            labels_file,
            batch_size=1,
            maximum_features_length=3,
            maximum_labels_length=4,
            single_pass=True,
        )
        examples = list(iter(dataset))
        self.assertLen(examples, 1)
        self.assertAllEqual(examples[0][0]["ids"], [[0, 1, 2]])
        self.assertAllEqual(examples[0][1]["ids"], [[0]])
    def testExampleInputter(self):
        vocab_file = self._makeTextFile("vocab.txt",
                                        ["the", "world", "hello", "toto"])
        data_file = self._makeTextFile("data.txt", ["hello world !"])

        source_inputter = text_inputter.WordEmbedder(embedding_size=10)
        target_inputter = text_inputter.WordEmbedder(embedding_size=10)
        example_inputter = inputter.ExampleInputter(source_inputter,
                                                    target_inputter)
        self.assertEqual(example_inputter.num_outputs, 2)

        features, transformed = self._makeDataset(
            example_inputter,
            [data_file, data_file],
            data_config={
                "source_vocabulary": vocab_file,
                "target_vocabulary": vocab_file,
            },
        )

        self.assertIsInstance(features, tuple)
        self.assertEqual(len(features), 2)
        self.assertEqual(len(transformed), 2)
        features, labels = features
        for field in ("ids", "length", "tokens"):
            self.assertIn(field, features)
        for field in ("ids", "length", "tokens"):
            self.assertIn(field, labels)
Beispiel #5
0
 def testExampleInputterAsset(self):
   vocab_file = self._makeTextFile("vocab.txt", ["the", "world", "hello", "toto"])
   source_inputter = text_inputter.WordEmbedder("vocabulary_file_1", embedding_size=10)
   target_inputter = text_inputter.WordEmbedder("vocabulary_file_1", embedding_size=10)
   example_inputter = inputter.ExampleInputter(source_inputter, target_inputter)
   example_inputter.initialize({
       "vocabulary_file_1": vocab_file,
       "vocabulary_file_2": vocab_file,
       "source_tokenization": {"mode": "conservative"}
   })
   self.assertIsInstance(source_inputter.tokenizer, tokenizers.OpenNMTTokenizer)
Beispiel #6
0
 def testExampleInputterAsset(self):
     vocab_file = self._makeTextFile("vocab.txt", ["the", "world", "hello", "toto"])
     source_inputter = text_inputter.WordEmbedder(embedding_size=10)
     target_inputter = text_inputter.WordEmbedder(embedding_size=10)
     example_inputter = inputter.ExampleInputter(source_inputter, target_inputter)
     example_inputter.initialize(
         {
             "source_vocabulary": vocab_file,
             "target_vocabulary": vocab_file,
             "source_tokenization": {"mode": "conservative"},
         }
     )
     self.assertIsInstance(source_inputter.tokenizer, tokenizers.OpenNMTTokenizer)
     asset_dir = self.get_temp_dir()
     example_inputter.export_assets(asset_dir)
     self.assertIn("source_tokenizer_config.yml", set(os.listdir(asset_dir)))