Exemplo n.º 1
0
    def testMixedInputter(self):
        vocab_file = self._makeTextFile("vocab.txt",
                                        ["the", "world", "hello", "toto"])
        vocab_alt_file = self._makeTextFile("vocab_alt.txt",
                                            ["h", "e", "l", "w", "o"])
        data_file = self._makeTextFile("data.txt", ["hello world !"])

        mixed_inputter = inputter.MixedInputter(
            [
                text_inputter.WordEmbedder(embedding_size=10),
                text_inputter.CharConvEmbedder(10, 5),
            ],
            reducer=reducer.ConcatReducer(),
        )
        self.assertEqual(mixed_inputter.num_outputs, 1)
        features, transformed = self._makeDataset(
            mixed_inputter,
            data_file,
            data_config={
                "1_vocabulary": vocab_file,
                "2_vocabulary": vocab_alt_file
            },
            shapes={
                "char_ids": [None, None, None],
                "ids": [None, None],
                "length": [None],
            },
        )
        self.assertAllEqual([1, 3, 15], transformed.shape)
Exemplo n.º 2
0
    def testMixedInputter(self):
        vocab_file = self._makeTextFile("vocab.txt",
                                        ["the", "world", "hello", "toto"])
        vocab_alt_file = self._makeTextFile("vocab_alt.txt",
                                            ["h", "e", "l", "w", "o"])
        data_file = self._makeTextFile("data.txt", ["hello world !"])

        mixed_inputter = inputter.MixedInputter(
            [
                text_inputter.WordEmbedder("vocabulary_file_1",
                                           embedding_size=10),
                text_inputter.CharConvEmbedder("vocabulary_file_2", 10, 5)
            ],
            reducer=reducer.ConcatReducer())
        self.assertEqual(mixed_inputter.num_outputs, 1)
        features, transformed = self._makeDataset(mixed_inputter,
                                                  data_file,
                                                  metadata={
                                                      "vocabulary_file_1":
                                                      vocab_file,
                                                      "vocabulary_file_2":
                                                      vocab_alt_file
                                                  },
                                                  shapes={
                                                      "char_ids":
                                                      [None, None, None],
                                                      "ids": [None, None],
                                                      "length": [None]
                                                  })

        with self.test_session() as sess:
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            features, transformed = sess.run([features, transformed])
            self.assertAllEqual([1, 3, 15], transformed.shape)
Exemplo n.º 3
0
    def testMixedInputter(self):
        with open(vocab_file, "w") as vocab:
            vocab.write("the\n" "world\n" "hello\n" "toto\n")
        with open(vocab_alt_file, "w") as vocab_alt:
            vocab_alt.write("h\n" "e\n" "l\n" "w\n" "o\n")
        with open(data_file, "w") as data:
            data.write("hello world !\n")

        mixed_inputter = inputter.MixedInputter(
            [
                text_inputter.WordEmbedder("vocabulary_file_1",
                                           embedding_size=10),
                text_inputter.CharConvEmbedder("vocabulary_file_2", 10, 5)
            ],
            reducer=reducer.ConcatReducer())

        data, transformed = _first_element(mixed_inputter, data_file, {
            "vocabulary_file_1": vocab_file,
            "vocabulary_file_2": vocab_alt_file
        })

        input_receiver = mixed_inputter.get_serving_input_receiver()
        self.assertIn("ids", input_receiver.features)
        self.assertIn("char_ids", input_receiver.features)

        with self.test_session() as sess:
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            data, transformed = sess.run([data, transformed])
            self.assertNotIn("raw", data)
            self.assertNotIn("tokens", data)
            self.assertIn("ids", data)
            self.assertIn("char_ids", data)
            self.assertAllEqual([1, 3, 15], transformed.shape)
Exemplo n.º 4
0
    def testMixedInputter(self):
        vocab_file = os.path.join(self.get_temp_dir(), "vocab.txt")
        vocab_alt_file = os.path.join(self.get_temp_dir(), "vocab_alt.txt")
        data_file = os.path.join(self.get_temp_dir(), "data.txt")

        with open(vocab_file, "w") as vocab:
            vocab.write("the\n" "world\n" "hello\n" "toto\n")
        with open(vocab_alt_file, "w") as vocab_alt:
            vocab_alt.write("h\n" "e\n" "l\n" "w\n" "o\n")
        with open(data_file, "w") as data:
            data.write("hello world !\n")

        mixed_inputter = inputter.MixedInputter(
            [
                text_inputter.WordEmbedder("vocabulary_file_1",
                                           embedding_size=10),
                text_inputter.CharConvEmbedder("vocabulary_file_2", 10, 5)
            ],
            reducer=reducer.ConcatReducer())
        features, transformed = self._makeDataset(mixed_inputter,
                                                  data_file,
                                                  metadata={
                                                      "vocabulary_file_1":
                                                      vocab_file,
                                                      "vocabulary_file_2":
                                                      vocab_alt_file
                                                  },
                                                  shapes={
                                                      "char_ids":
                                                      [None, None, None],
                                                      "ids": [None, None],
                                                      "length": [None]
                                                  })

        self.assertNotIn("tokens", features)

        with self.test_session() as sess:
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            features, transformed = sess.run([features, transformed])
            self.assertAllEqual([1, 3, 15], transformed.shape)