def test_english(self): config = utils.load_config(self.config_file) max_len = config["model"]["net"]["structure"]["max_len"] class_num = config["data"]["task"]["classes"]["num_classes"] task = TextClsTask(config, utils.TRAIN) # test offline data data = task.dataset() self.assertTrue("input_x_dict" in data and "input_x" in data["input_x_dict"]) self.assertTrue("input_y_dict" in data and "input_y" in data["input_y_dict"]) with self.session() as sess: sess.run(data["iterator"].initializer) res = sess.run([ data["input_x_dict"]["input_x"], data["input_y_dict"]["input_y"] ]) logging.debug(res[0][0]) logging.debug(res[1][0]) self.assertEqual(np.shape(res[0]), (32, max_len)) self.assertEqual(np.shape(res[1]), (32, class_num)) # test online data export_inputs = task.export_inputs() self.assertTrue("export_inputs" in export_inputs and "input_sentence" in export_inputs["export_inputs"]) input_sentence = export_inputs["export_inputs"]["input_sentence"] input_x = export_inputs["model_inputs"]["input_x"] with self.session() as sess: res = sess.run(input_x, feed_dict={input_sentence: ["All is well."]}) logging.debug(res[0]) self.assertEqual(np.shape(res[0]), (max_len, ))
def test_chinese_char(self): config = utils.load_config(self.config_file) max_len = config["model"]["net"]["structure"]["max_len"] class_num = config["data"]["task"]["classes"]["num_classes"] data_config = config["data"] task_config = data_config["task"] task_config["language"] = "chinese" task_config["split_by_space"] = False task_config["use_word"] = False data_config = config["data"] data_config["train"]["paths"] = [ "egs/mock_text_cls_data/text_cls/v1/data/train.split_by_char.txt" ] data_config["eval"]["paths"] = [ "egs/mock_text_cls_data/text_cls/v1/data/eval.split_by_char.txt" ] data_config["infer"]["paths"] = [ "egs/mock_text_cls_data/text_cls/v1/data/test.split_by_char.txt" ] task_config[ "text_vocab"] = "egs/mock_text_cls_data/text_cls/v1/data/text_vocab.split_by_char.txt" task_config["need_shuffle"] = False config["model"]["split_token"] = "" task_config["preparer"]["reuse"] = False task = TextClsTask(config, utils.TRAIN) # test offline data data = task.dataset() self.assertTrue("input_x_dict" in data and "input_x" in data["input_x_dict"]) self.assertTrue("input_y_dict" in data and "input_y" in data["input_y_dict"]) with self.session() as sess: sess.run(data["iterator"].initializer, feed_dict=data["init_feed_dict"]) res = sess.run([ data["input_x_dict"]["input_x"], data["input_y_dict"]["input_y"], data["input_x_len"] ]) logging.debug(res[0][0]) logging.debug(res[1][0]) self.assertAllEqual(res[0][0][:5], [2, 3, 4, 0, 0]) self.assertEqual(np.shape(res[0]), (32, max_len)) self.assertEqual(np.shape(res[1]), (32, class_num)) self.assertEqual(np.shape(res[2]), (32, )) # test online data export_inputs = task.export_inputs() self.assertTrue("export_inputs" in export_inputs and "input_sentence" in export_inputs["export_inputs"]) input_sentence = export_inputs["export_inputs"]["input_sentence"] input_x = export_inputs["model_inputs"]["input_x"] with self.session() as sess: res = sess.run(input_x, feed_dict={input_sentence: ["都挺好"]}) logging.debug(res[0][:5]) logging.debug(np.shape(res[0])) self.assertEqual(np.shape(res[0]), (max_len, )) self.assertAllEqual(res[0][:5], [2, 3, 4, 0, 0])
def test_chinese_word(self): config = utils.load_config(self.config_file) class_num = config["data"]["task"]["classes"]["num_classes"] data_config = config["data"] task_config = data_config["task"] task_config["language"] = "chinese" task_config["split_by_space"] = False task_config["use_word"] = True data_config = config["data"] data_config["train"]["paths"] = \ ["egs/mock_text_cls_data/text_cls/v1/data/train.chinese_word.txt"] data_config["eval"]["paths"] = \ ["egs/mock_text_cls_data/text_cls/v1/data/eval.chinese_word.txt"] data_config["infer"]["paths"] = \ ["egs/mock_text_cls_data/text_cls/v1/data/test.chinese_word.txt"] task_config[ "text_vocab"] = "egs/mock_text_cls_data/text_cls/v1/data/text_vocab.chinese_word.txt" task_config["need_shuffle"] = False config["model"]["split_token"] = "" task_config["preparer"]["reuse"] = False task = TextClsTask(config, utils.TRAIN) # test offline data data = task.dataset() self.assertTrue("input_x_dict" in data and "input_x" in data["input_x_dict"]) self.assertTrue("input_y_dict" in data and "input_y" in data["input_y_dict"]) with self.cached_session(use_gpu=False, force_gpu=False) as sess: sess.run(data["iterator"].initializer) res = sess.run([ data["input_x_dict"]["input_x"], data["input_y_dict"]["input_y"] ]) logging.debug(res[0][0]) logging.debug(res[1][0]) self.assertAllEqual(res[0][0][:5], [2, 0, 0, 0, 0]) self.assertEqual(np.shape(res[1]), (32, class_num)) # test online data export_inputs = task.export_inputs() self.assertTrue("export_inputs" in export_inputs and "input_sentence" in export_inputs["export_inputs"]) input_sentence = export_inputs["export_inputs"]["input_sentence"] input_x = export_inputs["model_inputs"]["input_x"] shape_op = tf.shape(input_x) with self.cached_session(use_gpu=False, force_gpu=False) as sess: res, shape_res = sess.run([input_x, shape_op], feed_dict={input_sentence: ["我很愤怒"]}) logging.debug(res[0]) logging.debug(np.shape(res[0])) logging.debug(f"shape: {shape_res}") self.assertAllEqual(shape_res, [1, 1024]) self.assertAllEqual(res[0][:5], [4, 5, 0, 0, 0])
def test_english(self): config = utils.load_config(self.config_file) class_num = config["data"]["task"]["classes"]["num_classes"] task_config = config["data"]["task"] task_config["language"] = "english" task_config["split_by_space"] = True task_config["clean_english"] = True data_config = config["data"] data_config["train"]["paths"] = [ "egs/mock_text_cls_data/text_cls/v1/data/train.english.txt" ] data_config["eval"]["paths"] = [ "egs/mock_text_cls_data/text_cls/v1/data/eval.english.txt" ] data_config["infer"]["paths"] = [ "egs/mock_text_cls_data/text_cls/v1/data/test.english.txt" ] task_config[ "text_vocab"] = "egs/mock_text_cls_data/text_cls/v1/data/text_vocab.english.txt" task_config["need_shuffle"] = False config["model"]["split_token"] = "" task_config["preparer"]["reuse"] = False task = TextClsTask(config, utils.TRAIN) # test offline data data = task.dataset() self.assertTrue("input_x_dict" in data and "input_x" in data["input_x_dict"]) self.assertTrue("input_y_dict" in data and "input_y" in data["input_y_dict"]) with self.cached_session(use_gpu=False, force_gpu=False) as sess: sess.run(data["iterator"].initializer) res = sess.run([ data["input_x_dict"]["input_x"], data["input_y_dict"]["input_y"] ]) logging.debug(res[0][0][:5]) logging.debug(res[1][0][:5]) self.assertAllEqual(res[0][0][:5], [3, 4, 5, 0, 0]) self.assertEqual(np.shape(res[1]), (32, class_num)) # test online data export_inputs = task.export_inputs() self.assertTrue("export_inputs" in export_inputs and "input_sentence" in export_inputs["export_inputs"]) input_sentence = export_inputs["export_inputs"]["input_sentence"] input_x = export_inputs["model_inputs"]["input_x"] with self.cached_session(use_gpu=False, force_gpu=False) as sess: res = sess.run(input_x, feed_dict={input_sentence: ["All is well."]}) logging.debug(res[0][:5]) self.assertAllEqual(res[0][:5], [3, 4, 5, 0, 0])