示例#1
0
    def setUpClass(cls):
        test_root = "./"
        TestErnieExtractFeature.test_output_dir = os.path.join(test_root, "output/test_dygraph_models/")
        if not os.path.isdir(TestErnieExtractFeature.test_output_dir):
            os.mkdir(TestErnieExtractFeature.test_output_dir)

        test_data_dir = os.path.join(test_root, "dataset/classification_data/toutiao_news")

        example_num = 5

        # 加载数据
        data_path = os.path.join(test_data_dir, "toutiao_news_shrink.txt")
        TestErnieExtractFeature.text_list, TestErnieExtractFeature.keywords_list = \
                get_attr_values(data_path, fetch_list=["text", "keywords"], encoding="utf-8")
        logging.info("data num = {}".format(len(TestErnieExtractFeature.text_list)))

        TestErnieExtractFeature.text_list = TestErnieExtractFeature.text_list[:100]
        TestErnieExtractFeature.keywords_list = TestErnieExtractFeature.keywords_list[:100]

        TestErnieExtractFeature.tokenizer = ErnieTokenizer.load("./dict/vocab.txt")
        TestErnieExtractFeature.text_ids = TestErnieExtractFeature.tokenizer.transform(TestErnieExtractFeature.text_list)
        #logging.info("text_ids: {}".format(TestErnieExtractFeature.text_ids))

        logging.info(u"数据样例")
        for index, (text, token_ids) in enumerate(zip(
                TestErnieExtractFeature.text_list[:example_num],
                TestErnieExtractFeature.text_ids[:example_num],
                )):
            logging.info("example #{}:".format(index))
            logging.info("text: {}".format(text.encode("utf-8")))
            logging.info("token_ids: {}".format(token_ids))
示例#2
0
    def setUpClass(cls):
        test_root = "./"
        TestCluster.test_output_dir = os.path.join(test_root, "output/test_clutser/")
        if not os.path.isdir(TestCluster.test_output_dir):
            os.mkdir(TestCluster.test_output_dir)

        test_data_dir = os.path.join(test_root, "dataset/classification_data/toutiao_news")

        example_num = 5

        # 加载数据
        data_path = os.path.join(test_data_dir, "toutiao_news_shrink.txt")
        TestCluster.text_list, TestCluster.keywords_list = \
                get_attr_values(data_path, fetch_list=["text", "keywords"], encoding="utf-8")
        logging.info("data num = {}".format(len(TestCluster.text_list)))

        TestCluster.text_list = TestCluster.text_list
        TestCluster.keywords_list = TestCluster.keywords_list

        logging.info(u"数据样例")
        for index, text in enumerate(TestCluster.text_list[:example_num]):
            logging.info("example #{}:".format(index))
            logging.info("text: {}".format(text.encode("utf-8")))

        TestCluster.cluster_num = 15
        max_no_improvement = 10000

        TestCluster.cluster_model = mini_batch_kmeans(n_clusters=TestCluster.cluster_num, max_no_improvement=max_no_improvement)
示例#3
0
    def setUpClass(cls):
        test_root = "./"
        TestLRModel.test_output_dir = os.path.join(test_root, "output")
        if not os.path.isdir(TestLRModel.test_output_dir):
            os.mkdir(TestLRModel.test_output_dir)

        test_data_dir = os.path.join(test_root, "dataset/classification_data/toutiao_news")
        TestLRModel.model_path = os.path.join(TestLRModel.test_output_dir, "lr_feature_weight.model")

        test_size = 0.2
        random_state = 1
        shuffle = True
        example_num = 5

        label_id_path = os.path.join(test_data_dir, "class_id.txt")
        TestLRModel.label_encoder = LabelEncoder(label_id_path, isFile=True)
        logging.info("label num = {}".format(TestLRModel.label_encoder.size()))

        # 加载数据
        data_path = os.path.join(test_data_dir, "toutiao_news_shrink.txt")
        label_list, text_list = get_attr_values(data_path, fetch_list=["label", "text"], encoding="utf-8")
        logging.info("data num = {}".format(len(label_list)))

        tokenizer_path = os.path.join(TestLRModel.test_output_dir, "lr_tokenizer.config")
        if os.path.exists(tokenizer_path):
            TestLRModel.tokenizer = LRTokenizer.load(tokenizer_path)
        else:
            TestLRModel.tokenizer = LRTokenizer(
                    stopword_path="./dict/stopword_shrink.txt",
                    jieba_tmp_dir="./dict/jieba_tmp",
                    )
            TestLRModel.tokenizer.save(tokenizer_path)
        feature_list = TestLRModel.tokenizer.transform(text_list)

        label_ids = [TestLRModel.label_encoder.transform(label_name) for label_name in label_list]

        TestLRModel.train_text, TestLRModel.test_text, TestLRModel.train_x, \
            TestLRModel.test_x, TestLRModel.train_y, TestLRModel.test_y = \
            train_test_split(text_list, feature_list, label_ids,
                             test_size=test_size, random_state=random_state, shuffle=shuffle)
        logging.info("train num = {}".format(len(TestLRModel.train_y)))
        logging.info("test num = {}".format(len(TestLRModel.test_y)))

        logging.info("数据样例")
        for index, (label_id, text, feature) in enumerate(zip(
                TestLRModel.train_y[:example_num],
                TestLRModel.train_text[:example_num],
                TestLRModel.train_x[:example_num],
                )):
            label_name = TestLRModel.label_encoder.inverse_transform(label_id)
            logging.info("example #{}:".format(index))
            logging.info("label: {}".format(label_name))
            logging.info("text: {}".format(text))
            logging.info("feature: {}".format(feature))
示例#4
0
    def make_pairwise_data(cls, src_data_path, output_path, num_each_label_pair=100):
        """
        """
        # 加载源数据
        text_list, label_list = \
                get_attr_values(src_data_path, fetch_list=["text", "label"], encoding="utf-8")
        logging.info("data num = {}".format(len(text_list)))

        # -------------------- 构建训练集 -------------------------
        label_text_dict = defaultdict(set)
        for text, label in zip(text_list, label_list):
            label_text_dict[label].add(text)

        label_text_dict = {k:list(v) for k, v in label_text_dict.items()}

        label_text_num_list = [(label, len(text_list)) for label, text_list in label_text_dict.items()]
        label_text_num_list = sorted(label_text_num_list,key=lambda x: x[1], reverse=True)
        logging.info(u"\n各类物料数:\n" + "\n".join(["%s = %d" % (label, text_num) for label, text_num in label_text_num_list]))

        def gen_pair_wise_data():
            # 头条物料有15类 每类样本量从340到4w左右
            # 建立pairwise训练物料
            for anchor_label, neg_label in permutations(label_text_dict.keys(), 2):
                # 每一类做anchor 其他的每一类做neg
                # neg类 随机抽num_each_label_pair个物料
                #neg_text_list = random.sample(label_text_dict[neg_label], min(num_each_label_pair, len(label_text_dict[neg_label])))

                # anchor类 随机抽num_each_label_pair个配对
                anchor_pair_list = list(combinations(label_text_dict[anchor_label], 2))
                random.shuffle(anchor_pair_list)
                for index, (anchor_text, pos_text) in enumerate(anchor_pair_list):
                    if index == num_each_label_pair:
                        break
                    neg_text = random.sample(label_text_dict[neg_label], 1)[0]

                    yield "\t".join([
                        anchor_text,
                        pos_text,
                        neg_text,
                        anchor_label,
                        neg_label,
                        ])

        with codecs.open(output_path, "w", "utf-8") as wf:
            wf.write("\t".join([
                "anchor",
                "pos",
                "neg",
                "pos_label",
                "neg_laebl",
                ]) + "\n")

            for text in gen_pair_wise_data():
                wf.write(text + "\n")
示例#5
0
    def setUpClass(cls):
        test_root = "./"
        TestDygraphModelsParallelized.test_output_dir = os.path.join(test_root, "output/test_dygraph_models/")
        if not os.path.isdir(TestDygraphModelsParallelized.test_output_dir):
            os.mkdir(TestDygraphModelsParallelized.test_output_dir)

        test_data_dir = os.path.join(test_root, "dataset/classification_data/toutiao_news")

        test_size = 0.15
        random_state = 1
        shuffle = True
        example_num = 5

        label_id_path = os.path.join(test_data_dir, "class_id.txt")
        TestDygraphModelsParallelized.label_encoder = LabelEncoder(label_id_path, isFile=True)
        logging.info("label num = {}".format(TestDygraphModelsParallelized.label_encoder.size()))

        # 加载数据
        data_path = os.path.join(test_data_dir, "toutiao_news_shrink.txt")
        text_list, label_list = \
                get_attr_values(data_path, fetch_list=["text", "label"], encoding="utf-8")
        logging.info("data num = {}".format(len(text_list)))

        TestDygraphModelsParallelized.tokenizer = ErnieTokenizer.load("./dict/vocab.txt")
        text_ids = TestDygraphModelsParallelized.tokenizer.transform(text_list)

        label_ids = [TestDygraphModelsParallelized.label_encoder.transform(label_name) for label_name in label_list]

        TestDygraphModelsParallelized.train_text, TestDygraphModelsParallelized.test_text, \
            train_x, test_x, train_y, test_y = \
            train_test_split(text_list, text_ids, label_ids,
                             test_size=test_size, random_state=random_state, shuffle=shuffle)
        logging.info("train num = {}".format(len(train_y)))
        logging.info("test num = {}".format(len(test_y)))

        TestDygraphModelsParallelized.train_data = list(zip(train_x, train_y))
        TestDygraphModelsParallelized.eval_data = list(zip(test_x, test_y))

        place = F.CUDAPlace(D.ParallelEnv().dev_id)
        with D.guard(place):
            TestDygraphModelsParallelized.strategy = D.prepare_context()

        logging.info(u"数据样例")
        for index, (text, (token_ids, label_id)) in enumerate(zip(
                TestDygraphModelsParallelized.train_text[:example_num],
                TestDygraphModelsParallelized.train_data[:example_num],
                )):
            label_name = TestDygraphModelsParallelized.label_encoder.inverse_transform(label_id)
            logging.info("example #{}:".format(index))
            logging.info("label: {}".format(label_name))
            logging.info("text: {}".format(text.encode("utf-8")))
            logging.info("token_ids: {}".format(token_ids))
示例#6
0
def create_train_test_dataset(data_dir, tokenizer, label_encoder,
        test_size=0.15, shuffle=True, random_state=1):
    # 加载数据
    text_list, label_list = \
            get_attr_values(data_dir, fetch_list=["text", "label"], encoding="utf-8")
    logging.info("data num = {}".format(len(text_list)))

    token_ids, token_type_ids = zip(*[tokenizer.encode(text) for text in text_list])

    label_ids = [label_encoder.transform(label_name) for label_name in label_list]

    data_list = list(zip(token_ids, token_type_ids, label_ids))

    train_text, test_text, train_data_list, test_data_list = \
        train_test_split(text_list, data_list,
                         test_size=test_size, random_state=random_state, shuffle=shuffle)
    logging.info("train num = {}".format(len(train_data_list)))
    logging.info("test num = {}".format(len(test_data_list)))

    train_dataset =  ClassificationDataset(train_data_list)
    test_dataset =  ClassificationDataset(test_data_list)

    train_dataloader =  DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=collate_fn,
        )

    test_dataloader =  DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        collate_fn=collate_fn,
        )

    example_num = 5
    logging.info(u"数据样例")
    for index, (text, (token_ids, _, label_id)) in enumerate(zip(
            train_text[:example_num],
            train_data_list[:example_num],
            )):
        label_name = label_encoder.inverse_transform(label_id)
        logging.info("example #{}:".format(index))
        logging.info("label: {}".format(label_name))
        logging.info("text: {}".format(text.encode("utf-8")))
        logging.info("token_ids: {}".format(token_ids))

    return train_dataloader, test_dataloader
示例#7
0
    def setUpClass(cls):
        test_root = "./"
        TestTokenizers.test_output_dir = os.path.join(test_root, "output/test_tokenizers/")
        if not os.path.isdir(TestTokenizers.test_output_dir):
            os.mkdir(TestTokenizers.test_output_dir)

        test_data_dir = os.path.join(test_root, "dataset/classification_data/toutiao_news")
        example_num = 5

        # 加载数据
        data_path = os.path.join(test_data_dir, "toutiao_news.txt")
        TestTokenizers.text_list, TestTokenizers.keywords_list = \
                get_attr_values(data_path, fetch_list=["text", "keywords"], encoding="utf-8")
        logging.info("text num = {}".format(len(TestTokenizers.text_list)))

        logging.info("数据样例")
        for index, text in enumerate(TestTokenizers.text_list[:example_num]):
            logging.info("example #{}:".format(index))
            logging.info("text: {}".format(text))
示例#8
0
    def setUpClass(cls):
        test_root = "./"
        TestTextSimilarity.test_output_dir = os.path.join(test_root, "output/test_text_similarity/")
        if not os.path.isdir(TestTextSimilarity.test_output_dir):
            os.mkdir(TestTextSimilarity.test_output_dir)

        test_data_dir = os.path.join(test_root, "dataset/text_similarity/")
        if not os.path.isdir(test_data_dir):
            os.mkdir(test_data_dir)

        pair_wise_data_path = os.path.join(test_data_dir, "pair_wise_data.txt")
        src_data_path = os.path.join(test_root, "dataset/classification_data/toutiao_news/toutiao_news_shrink.txt")

        if not os.path.exists(pair_wise_data_path):
            TestTextSimilarity.make_pairwise_data(src_data_path, pair_wise_data_path)

        src_text_list, src_label_list = \
                get_attr_values(src_data_path, fetch_list=["text", "label"], encoding="utf-8")

        src_data = zip(src_text_list, src_label_list)
        random.shuffle(src_data)
        TestTextSimilarity.src_text_list, TestTextSimilarity.src_label_list = zip(*src_data[:2000])

        anchor_list, pos_list, neg_list = \
                get_attr_values(pair_wise_data_path, fetch_list=["anchor", "pos", "neg"], encoding="utf-8")
        logging.info("data num = {}".format(len(anchor_list)))

        test_size = 0.2
        random_state = 1
        shuffle = True
        example_num = 5

        TestTextSimilarity.tokenizer = ErnieTokenizer.load("./dict/vocab.txt")

        TestTextSimilarity.src_text_ids = TestTextSimilarity.tokenizer.transform(TestTextSimilarity.src_text_list)

        #anchor_list = anchor_list[:500]
        #pos_list = pos_list[:500]
        #neg_list = neg_list[:500]

        anchor_ids = TestTextSimilarity.tokenizer.transform(anchor_list)
        pos_ids = TestTextSimilarity.tokenizer.transform(pos_list)
        neg_ids = TestTextSimilarity.tokenizer.transform(neg_list)

        text_list = zip(anchor_list, pos_list, neg_list)

        TestTextSimilarity.train_text, TestTextSimilarity.test_text, \
            train_anchor, test_anchor, train_pos, test_pos, train_neg, test_neg = \
            train_test_split(text_list, anchor_ids, pos_ids, neg_ids,
                             test_size=test_size, random_state=random_state, shuffle=shuffle)
        logging.info("train num = {}".format(len(train_anchor)))
        logging.info("test num = {}".format(len(test_anchor)))

        TestTextSimilarity.train_data = list(zip(train_anchor, train_pos, train_neg))
        TestTextSimilarity.eval_data = list(zip(test_anchor, test_pos, test_neg))

        logging.info(u"数据样例")
        for index, ((anchor_text, pos_text, neg_text), (anchor_ids, pos_ids, neg_ids)) in enumerate(zip(
                TestTextSimilarity.train_text[:example_num],
                TestTextSimilarity.train_data[:example_num],
                )):
            logging.info("example #{}:".format(index))
            logging.info("anchor_text: {}".format(anchor_text.encode("utf-8")))
            logging.info("anchor_ids: {}".format(anchor_ids))
            logging.info("anchor_ids type: {}".format(type(anchor_ids)))
            logging.info("anchor_ids dtype: {}".format(anchor_ids.dtype))
            logging.info("pos_text: {}".format(pos_text.encode("utf-8")))
            logging.info("pos_ids: {}".format(pos_ids))
            logging.info("pos_ids type: {}".format(type(pos_ids)))
            logging.info("pos_ids dtype: {}".format(pos_ids.dtype))
            logging.info("neg_text: {}".format(neg_text.encode("utf-8")))
            logging.info("neg_ids: {}".format(neg_ids))
            logging.info("neg_ids type: {}".format(type(neg_ids)))
            logging.info("neg_ids dtype: {}".format(neg_ids.dtype))

        logging.info(u"src数据样例")
        for index, (text_ids, text, label) in enumerate(zip(
            TestTextSimilarity.src_text_ids[:example_num],
            TestTextSimilarity.src_text_list[:example_num],
            TestTextSimilarity.src_label_list[:example_num],
            )):
            logging.info("example #{}:".format(index))
            logging.info("text: {}".format(text.encode("utf-8")))
            logging.info("label: {}".format(label.encode("utf-8")))
            logging.info("text_ids: {}".format(text_ids))