Esempio n. 1
0
    def test_english(self):
        config = utils.load_config(self.config_file)
        max_len = config["model"]["net"]["structure"]["max_len"]
        class_num = config["data"]["task"]["classes"]["num_classes"]
        task = TextClsTask(config, utils.TRAIN)

        # test offline data
        data = task.dataset()
        self.assertTrue("input_x_dict" in data
                        and "input_x" in data["input_x_dict"])
        self.assertTrue("input_y_dict" in data
                        and "input_y" in data["input_y_dict"])
        with self.session() as sess:
            sess.run(data["iterator"].initializer)
            res = sess.run([
                data["input_x_dict"]["input_x"],
                data["input_y_dict"]["input_y"]
            ])
            logging.debug(res[0][0])
            logging.debug(res[1][0])
            self.assertEqual(np.shape(res[0]), (32, max_len))
            self.assertEqual(np.shape(res[1]), (32, class_num))

        # test online data
        export_inputs = task.export_inputs()
        self.assertTrue("export_inputs" in export_inputs
                        and "input_sentence" in export_inputs["export_inputs"])
        input_sentence = export_inputs["export_inputs"]["input_sentence"]
        input_x = export_inputs["model_inputs"]["input_x"]
        with self.session() as sess:
            res = sess.run(input_x,
                           feed_dict={input_sentence: ["All is well."]})
            logging.debug(res[0])
            self.assertEqual(np.shape(res[0]), (max_len, ))
Esempio n. 2
0
    def test_chinese_char(self):
        config = utils.load_config(self.config_file)
        max_len = config["model"]["net"]["structure"]["max_len"]
        class_num = config["data"]["task"]["classes"]["num_classes"]
        data_config = config["data"]
        task_config = data_config["task"]
        task_config["language"] = "chinese"
        task_config["split_by_space"] = False
        task_config["use_word"] = False
        data_config = config["data"]
        data_config["train"]["paths"] = [
            "egs/mock_text_cls_data/text_cls/v1/data/train.split_by_char.txt"
        ]
        data_config["eval"]["paths"] = [
            "egs/mock_text_cls_data/text_cls/v1/data/eval.split_by_char.txt"
        ]
        data_config["infer"]["paths"] = [
            "egs/mock_text_cls_data/text_cls/v1/data/test.split_by_char.txt"
        ]
        task_config[
            "text_vocab"] = "egs/mock_text_cls_data/text_cls/v1/data/text_vocab.split_by_char.txt"
        task_config["need_shuffle"] = False
        config["model"]["split_token"] = ""
        task_config["preparer"]["reuse"] = False

        task = TextClsTask(config, utils.TRAIN)

        # test offline data
        data = task.dataset()
        self.assertTrue("input_x_dict" in data
                        and "input_x" in data["input_x_dict"])
        self.assertTrue("input_y_dict" in data
                        and "input_y" in data["input_y_dict"])
        with self.session() as sess:
            sess.run(data["iterator"].initializer,
                     feed_dict=data["init_feed_dict"])
            res = sess.run([
                data["input_x_dict"]["input_x"],
                data["input_y_dict"]["input_y"], data["input_x_len"]
            ])
            logging.debug(res[0][0])
            logging.debug(res[1][0])
            self.assertAllEqual(res[0][0][:5], [2, 3, 4, 0, 0])
            self.assertEqual(np.shape(res[0]), (32, max_len))
            self.assertEqual(np.shape(res[1]), (32, class_num))
            self.assertEqual(np.shape(res[2]), (32, ))

        # test online data
        export_inputs = task.export_inputs()
        self.assertTrue("export_inputs" in export_inputs
                        and "input_sentence" in export_inputs["export_inputs"])
        input_sentence = export_inputs["export_inputs"]["input_sentence"]
        input_x = export_inputs["model_inputs"]["input_x"]

        with self.session() as sess:
            res = sess.run(input_x, feed_dict={input_sentence: ["都挺好"]})
            logging.debug(res[0][:5])
            logging.debug(np.shape(res[0]))
            self.assertEqual(np.shape(res[0]), (max_len, ))
            self.assertAllEqual(res[0][:5], [2, 3, 4, 0, 0])
Esempio n. 3
0
    def test_chinese_word(self):
        config = utils.load_config(self.config_file)
        class_num = config["data"]["task"]["classes"]["num_classes"]
        data_config = config["data"]
        task_config = data_config["task"]
        task_config["language"] = "chinese"
        task_config["split_by_space"] = False
        task_config["use_word"] = True
        data_config = config["data"]
        data_config["train"]["paths"] = \
          ["egs/mock_text_cls_data/text_cls/v1/data/train.chinese_word.txt"]
        data_config["eval"]["paths"] = \
          ["egs/mock_text_cls_data/text_cls/v1/data/eval.chinese_word.txt"]
        data_config["infer"]["paths"] = \
          ["egs/mock_text_cls_data/text_cls/v1/data/test.chinese_word.txt"]
        task_config[
            "text_vocab"] = "egs/mock_text_cls_data/text_cls/v1/data/text_vocab.chinese_word.txt"
        task_config["need_shuffle"] = False
        config["model"]["split_token"] = ""
        task_config["preparer"]["reuse"] = False

        task = TextClsTask(config, utils.TRAIN)

        # test offline data
        data = task.dataset()
        self.assertTrue("input_x_dict" in data
                        and "input_x" in data["input_x_dict"])
        self.assertTrue("input_y_dict" in data
                        and "input_y" in data["input_y_dict"])
        with self.cached_session(use_gpu=False, force_gpu=False) as sess:
            sess.run(data["iterator"].initializer)
            res = sess.run([
                data["input_x_dict"]["input_x"],
                data["input_y_dict"]["input_y"]
            ])
            logging.debug(res[0][0])
            logging.debug(res[1][0])
            self.assertAllEqual(res[0][0][:5], [2, 0, 0, 0, 0])
            self.assertEqual(np.shape(res[1]), (32, class_num))

        # test online data
        export_inputs = task.export_inputs()
        self.assertTrue("export_inputs" in export_inputs
                        and "input_sentence" in export_inputs["export_inputs"])
        input_sentence = export_inputs["export_inputs"]["input_sentence"]
        input_x = export_inputs["model_inputs"]["input_x"]
        shape_op = tf.shape(input_x)

        with self.cached_session(use_gpu=False, force_gpu=False) as sess:
            res, shape_res = sess.run([input_x, shape_op],
                                      feed_dict={input_sentence: ["我很愤怒"]})
            logging.debug(res[0])
            logging.debug(np.shape(res[0]))
            logging.debug(f"shape: {shape_res}")
            self.assertAllEqual(shape_res, [1, 1024])
            self.assertAllEqual(res[0][:5], [4, 5, 0, 0, 0])
Esempio n. 4
0
    def test_english(self):
        config = utils.load_config(self.config_file)
        class_num = config["data"]["task"]["classes"]["num_classes"]
        task_config = config["data"]["task"]
        task_config["language"] = "english"
        task_config["split_by_space"] = True
        task_config["clean_english"] = True
        data_config = config["data"]
        data_config["train"]["paths"] = [
            "egs/mock_text_cls_data/text_cls/v1/data/train.english.txt"
        ]
        data_config["eval"]["paths"] = [
            "egs/mock_text_cls_data/text_cls/v1/data/eval.english.txt"
        ]
        data_config["infer"]["paths"] = [
            "egs/mock_text_cls_data/text_cls/v1/data/test.english.txt"
        ]
        task_config[
            "text_vocab"] = "egs/mock_text_cls_data/text_cls/v1/data/text_vocab.english.txt"
        task_config["need_shuffle"] = False
        config["model"]["split_token"] = ""
        task_config["preparer"]["reuse"] = False

        task = TextClsTask(config, utils.TRAIN)

        # test offline data
        data = task.dataset()
        self.assertTrue("input_x_dict" in data
                        and "input_x" in data["input_x_dict"])
        self.assertTrue("input_y_dict" in data
                        and "input_y" in data["input_y_dict"])
        with self.cached_session(use_gpu=False, force_gpu=False) as sess:
            sess.run(data["iterator"].initializer)
            res = sess.run([
                data["input_x_dict"]["input_x"],
                data["input_y_dict"]["input_y"]
            ])
            logging.debug(res[0][0][:5])
            logging.debug(res[1][0][:5])
            self.assertAllEqual(res[0][0][:5], [3, 4, 5, 0, 0])
            self.assertEqual(np.shape(res[1]), (32, class_num))

        # test online data
        export_inputs = task.export_inputs()
        self.assertTrue("export_inputs" in export_inputs
                        and "input_sentence" in export_inputs["export_inputs"])
        input_sentence = export_inputs["export_inputs"]["input_sentence"]
        input_x = export_inputs["model_inputs"]["input_x"]
        with self.cached_session(use_gpu=False, force_gpu=False) as sess:
            res = sess.run(input_x,
                           feed_dict={input_sentence: ["All is well."]})
            logging.debug(res[0][:5])
            self.assertAllEqual(res[0][:5], [3, 4, 5, 0, 0])