Esempio n. 1
0
def vocabulary(ace_dataset):
    vocab_collate_fn = EventVocabularyCollate()
    data_loader = DataLoader(ace_dataset, collate_fn=vocab_collate_fn)

    event_types_list: List[List[str]] = list()
    tokens_list: List[List[str]] = list()
    entity_tags_list: List[List[str]] = list()

    for collate_dict in data_loader:
        event_types_list.extend(collate_dict["event_types"])
        tokens_list.extend(collate_dict["tokens"])
        entity_tags_list.extend(collate_dict["entity_tags"])

    negative_event_type = "Negative"
    event_type_vocab = Vocabulary(tokens=event_types_list,
                                  unk=negative_event_type,
                                  padding="",
                                  special_first=True)

    word_vocab = Vocabulary(tokens=tokens_list,
                            unk=Vocabulary.UNK,
                            padding=Vocabulary.PADDING,
                            special_first=True)

    entity_tag_vocab = LabelVocabulary(entity_tags_list,
                                       padding=LabelVocabulary.PADDING)
    return {
        "event_type_vocab": event_type_vocab,
        "word_vocab": word_vocab,
        "entity_tag_vocab": entity_tag_vocab
    }
Esempio n. 2
0
def event_type_vocabulary():
    event_types = [["A", "B", "C"], ["A", "B"], ["A"]]

    vocabulary = Vocabulary(tokens=event_types,
                            padding="",
                            unk="Negative",
                            special_first=True)

    ASSERT.assertEqual(4, vocabulary.size)
    ASSERT.assertEqual(0, vocabulary.index(vocabulary.unk))
    ASSERT.assertEqual(1, vocabulary.index("A"))
    ASSERT.assertEqual(2, vocabulary.index("B"))
    ASSERT.assertEqual(3, vocabulary.index("C"))

    return vocabulary
Esempio n. 3
0
def vocabulary(
        conll2003_dataset
) -> Dict[str, Union[Vocabulary, PretrainedVocabulary]]:
    data_loader = DataLoader(dataset=conll2003_dataset,
                             batch_size=2,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=VocabularyCollate())

    batch_tokens = list()
    batch_sequence_labels = list()

    for collate_dict in data_loader:
        batch_tokens.extend(collate_dict["tokens"])
        batch_sequence_labels.extend(collate_dict["sequence_labels"])

    token_vocabulary = Vocabulary(tokens=batch_tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                       padding=LabelVocabulary.PADDING)
    return {
        "token_vocabulary": token_vocabulary,
        "label_vocabulary": label_vocabulary
    }
    def __init__(self,
                 character_pretrained_vocabulary: PretrainedVocabulary,
                 gaz_word_pretrained_vocabulary: PretrainedVocabulary):
        """
        初始化
        :param character_pretrained_vocabulary:
        :param gaz_word_pretrained_vocabulary:
        """

        assert character_pretrained_vocabulary.embedding_dim == gaz_word_pretrained_vocabulary.embedding_dim, \
            f"character_pretrained_vocabulary 与 gaz_word_pretrained_vocabulary embedding 维度必须相同"

        char_embedding_dict = self.__token_embedding_dict(character_pretrained_vocabulary)
        gaz_word_embedding_dict = self.__token_embedding_dict(gaz_word_pretrained_vocabulary)

        tokens = [char_embedding_dict.keys(), gaz_word_embedding_dict.keys()]
        char_embedding_dict.update(gaz_word_embedding_dict)

        embedding_dict = char_embedding_dict

        vocabulary = Vocabulary(tokens=tokens,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

        super().__init__(vocabulary=vocabulary, pretrained_word_embedding_loader=None)

        self._embedding_dim = character_pretrained_vocabulary.embedding_dim
        self._init_embedding_matrix(vocabulary=self._vocabulary,
                                    embedding_dict=embedding_dict,
                                    embedding_dim=self._embedding_dim)
Esempio n. 5
0
def test_vocabulary_speical_first():
    """
    测试 vocabulary speical first
    :return:
    """
    batch_tokens = [["我", "和", "你"], ["在", "我"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding=Vocabulary.PADDING,
                            unk=Vocabulary.UNK,
                            special_first=True,
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 6)

    ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK)
    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0)
    ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1)
Esempio n. 6
0
def test_vocabulary():
    """

    :return:
    """

    batch_tokens = [["我", "和", "你"], ["在", "我"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding="",
                            unk="",
                            special_first=True,
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 4)

    ASSERT.assertTrue(not vocabulary.padding)
    ASSERT.assertTrue(not vocabulary.unk)

    ASSERT.assertEqual(vocabulary.index("我"), 0)
    ASSERT.assertEqual(vocabulary.index("和"), 1)
Esempio n. 7
0
    def __init__(self, dataset_file_path: str,
                 event_type_vocabulary: Vocabulary):
        """
        初始化 ACE Event Dataset
        :param dataset_file_path: 数据集的文件路基
        """
        super().__init__()
        self._ace_dataset = ACEDataset(dataset_file_path=dataset_file_path)

        self._instances: List[Instance] = list()

        for ori_instance in self._ace_dataset:

            ori_event_types = ori_instance["event_types"]

            ori_event_type_set = None

            if ori_event_types is not None:  # 实际预测的时候 ori_event_types is None
                # 针对 training 和 validation 设置,因为 对于 pair<sentence, unk>, label = 1
                ori_event_type_set = set(ori_event_types)

                if len(ori_event_type_set) == 0:
                    ori_event_type_set.add(event_type_vocabulary.unk)

            for index in range(event_type_vocabulary.size):
                # 遍历所有的label, 形成 pair<句子,事件类型>,作为样本
                event_type = event_type_vocabulary.token(index)

                instance = Instance()

                instance["sentence"] = ori_instance["sentence"]

                instance["entity_tag"] = ori_instance["entity_tag"]

                instance["event_type"] = event_type
                instance["metadata"] = ori_instance["metadata"]

                if ori_event_type_set is not None:
                    if event_type in ori_event_type_set:
                        instance["label"] = 1
                    else:
                        instance["label"] = 0
                else:
                    # 是针对实际的 prediction 设置的
                    pass

                self._instances.append(instance)
Esempio n. 8
0
    def build_vocabulary(self):
        training_dataset_file_path = self.config["training_dataset_file_path"]

        dataset = ACSASemEvalDataset(
            dataset_file_path=training_dataset_file_path)

        collate_fn = VocabularyCollate()
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=10,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=collate_fn)

        tokens = list()
        categories = list()
        labels = list()
        for collate_dict in data_loader:
            tokens.append(collate_dict["tokens"])
            categories.append(collate_dict["categories"])
            labels.append(collate_dict["labels"])

        token_vocabulary = Vocabulary(tokens=tokens,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)

        if not self.config["debug"]:
            pretrained_file_path = self.config["pretrained_file_path"]
            pretrained_loader = GloveLoader(
                embedding_dim=300, pretrained_file_path=pretrained_file_path)
            pretrained_token_vocabulary = PretrainedVocabulary(
                vocabulary=token_vocabulary,
                pretrained_word_embedding_loader=pretrained_loader)

            token_vocabulary = pretrained_token_vocabulary

        category_vocabulary = LabelVocabulary(labels=categories, padding=None)
        label_vocabulary = LabelVocabulary(labels=labels, padding=None)

        return {
            "token_vocabulary": token_vocabulary,
            "category_vocabulary": category_vocabulary,
            "label_vocabulary": label_vocabulary
        }
Esempio n. 9
0
    def __init__(self, event_type_vocabulary: Vocabulary):
        """
        初始化
        :param event_type_vocabulary: event type vocabulary
        """
        super().__init__()

        self._event_type_f1: Dict[str, LabelF1Metric] = dict()

        for index in range(0, event_type_vocabulary.size):
            event_type = event_type_vocabulary.token(index)

            if event_type != event_type_vocabulary.unk:
                self._event_type_f1[event_type] = LabelF1Metric(
                    labels=[1], label_vocabulary=None)

        self._event_type_f1[EventF1MetricAdapter.__OVERALL] = LabelF1Metric(
            labels=[1], label_vocabulary=None)
        self._event_type_vocabulary = event_type_vocabulary
def pretrained_vocabulary():
    """
    生成 预训练词汇表
    """
    pretrained_file_path = "data/easytext/tests/pretrained/word_embedding_sample.3d.txt"
    pretrained_file_path = os.path.join(ROOT_PATH, pretrained_file_path)

    glove_loader = GloveLoader(embedding_dim=3,
                               pretrained_file_path=pretrained_file_path)

    tokens = [["我"], ["美丽"]]

    vocab = Vocabulary(tokens=tokens,
                       padding=Vocabulary.PADDING,
                       unk=Vocabulary.UNK,
                       special_first=True)

    pretrained_vocab = PretrainedVocabulary(
        vocabulary=vocab, pretrained_word_embedding_loader=glove_loader)
    return pretrained_vocab
    def __init__(
            self, is_training: bool, dataset: Dataset, gaz_vocabulary_dir: str,
            gaz_pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader
    ):
        """
        构建 Gaz 词汇表
        :param is_training: 当前是否 Training 状态
        :param dataset: 数据集
        :param gaz_vocabulary_dir: gaz 词汇表存放目录
        :param gaz_pretrained_word_embedding_loader: gaz 预训练 word embedding 载入器
        """

        super().__init__(is_training=is_training)

        # gazetter 理论上来说,应该支持持久化的,这里并没有做
        gazetteer = Gazetteer(gaz_pretrained_word_embedding_loader=
                              gaz_pretrained_word_embedding_loader)

        if is_training:
            gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetteer)

            data_loader = DataLoader(dataset=dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     num_workers=0,
                                     collate_fn=gaz_vocabulary_collate)

            gaz_words = list()

            for batch_gaz_words in data_loader:
                gaz_words.extend(batch_gaz_words)

            gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                        padding=Vocabulary.PADDING,
                                        unk=Vocabulary.UNK,
                                        special_first=True)

            gaz_vocabulary = PretrainedVocabulary(
                vocabulary=gaz_vocabulary,
                pretrained_word_embedding_loader=
                gaz_pretrained_word_embedding_loader)

            gaz_vocabulary.save_to_file(gaz_vocabulary_dir)
        else:

            gaz_vocabulary = Vocabulary.from_file(gaz_vocabulary_dir)

        self.gaz_vocabulary = gaz_vocabulary
        self.gazetteer = gazetteer
Esempio n. 12
0
    def build_vocabulary(self, dataset: Dataset):

        data_loader = DataLoader(dataset=dataset,
                                 batch_size=100,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=VocabularyCollate())
        batch_tokens = list()
        batch_sequence_labels = list()

        for collate_dict in data_loader:
            batch_tokens.extend(collate_dict["tokens"])
            batch_sequence_labels.extend(collate_dict["sequence_labels"])

        token_vocabulary = Vocabulary(tokens=batch_tokens,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)
        model_name = self.config["model_name"]

        if model_name in {ModelName.NER_V2, ModelName.NER_V3}:
            pretrained_word_embedding_file_path = self.config[
                "pretrained_word_embedding_file_path"]
            glove_loader = GloveLoader(
                embedding_dim=100,
                pretrained_file_path=pretrained_word_embedding_file_path)

            token_vocabulary = PretrainedVocabulary(
                vocabulary=token_vocabulary,
                pretrained_word_embedding_loader=glove_loader)

        label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                           padding=LabelVocabulary.PADDING)

        return {
            "token_vocabulary": token_vocabulary,
            "label_vocabulary": label_vocabulary
        }
Esempio n. 13
0
def test_speical_last():
    batch_tokens = [["我", "和", "你"], ["在", "我"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding=Vocabulary.PADDING,
                            unk=Vocabulary.UNK,
                            special_first=False,
                            other_special_tokens=["<Start>", "<End>"],
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 8)

    ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK)
    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 3 + 1)
    ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 3 + 2)
    ASSERT.assertEqual(vocabulary.index("<Start>"), 3 + 3)
    ASSERT.assertEqual(vocabulary.index("<End>"), 3 + 4)
Esempio n. 14
0
def test_save_and_load():
    """
    测试存储和载入 vocabulary
    :return:
    """
    batch_tokens = [["我", "和", "你"], ["在", "我"], ["newline\nnewline"]]
    vocabulary = Vocabulary(batch_tokens,
                            padding=Vocabulary.PADDING,
                            unk=Vocabulary.UNK,
                            special_first=True,
                            other_special_tokens=["<Start>", "<End>"],
                            min_frequency=1,
                            max_size=None)

    ASSERT.assertEqual(vocabulary.size, 9)

    ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING)
    ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK)
    ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0)
    ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1)
    ASSERT.assertEqual(vocabulary.index("<Start>"), 2)
    ASSERT.assertEqual(vocabulary.index("<End>"), 3)
    ASSERT.assertEqual(vocabulary.index("我"), 4)
    ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8)
    ASSERT.assertEqual(vocabulary.index("哈哈"),
                       vocabulary.index(vocabulary.unk))

    vocab_dir = os.path.join(ROOT_PATH, "data/easytext/tests")

    if not os.path.isdir(vocab_dir):
        os.makedirs(vocab_dir, exist_ok=True)

    vocabulary.save_to_file(vocab_dir)

    loaded_vocab = Vocabulary.from_file(directory=vocab_dir)

    ASSERT.assertEqual(vocabulary.size, 9)

    ASSERT.assertEqual(loaded_vocab.padding, vocabulary.PADDING)
    ASSERT.assertEqual(loaded_vocab.unk, vocabulary.UNK)
    ASSERT.assertEqual(loaded_vocab.index(vocabulary.padding), 0)
    ASSERT.assertEqual(loaded_vocab.index(vocabulary.unk), 1)
    ASSERT.assertEqual(loaded_vocab.index("<Start>"), 2)
    ASSERT.assertEqual(loaded_vocab.index("<End>"), 3)
    ASSERT.assertEqual(vocabulary.index("我"), 4)
    ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8)
    ASSERT.assertEqual(vocabulary.index("哈哈"),
                       vocabulary.index(vocabulary.unk))
Esempio n. 15
0
    def __init__(self,
                 is_training: bool,
                 dataset: Dataset,
                 vocabulary_collate,
                 token_vocabulary_dir: str,
                 label_vocabulary_dir: str,
                 is_build_token_vocabulary: bool,
                 pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader):
        """
        词汇表构建器
        :param is_training: 因为在 train 和 非 train, 词汇表的构建行为有所不同;
                            如果是 train, 则一般需要重新构建; 而对于 非train, 使用先前构建好的即可。
        :param dataset: 数据集
        :param vocabulary_collate: 词汇表 collate
        :param token_vocabulary_dir: token vocabulary 存放目录
        :param label_vocabulary_dir: label vocabulary 存放目录
        :param is_build_token_vocabulary: 是否构建 token vocabulary, 因为在使用 Bert 或者 其他模型作为预训练的 embedding,
                                          则没有必要构建 token vocabulary.
        :param pretrained_word_embedding_loader: 预训练词汇表
        """

        super().__init__(is_training=is_training)

        token_vocabulary = None
        label_vocabulary = None

        if is_training:
            data_loader = DataLoader(dataset=dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     num_workers=0,
                                     collate_fn=vocabulary_collate)
            batch_tokens = list()
            batch_sequence_labels = list()

            for collate_dict in data_loader:
                batch_tokens.extend(collate_dict["tokens"])
                batch_sequence_labels.extend(collate_dict["sequence_labels"])

            if is_build_token_vocabulary:
                token_vocabulary = Vocabulary(tokens=batch_tokens,
                                              padding=Vocabulary.PADDING,
                                              unk=Vocabulary.UNK,
                                              special_first=True)

                if pretrained_word_embedding_loader is not None:
                    token_vocabulary = \
                        PretrainedVocabulary(vocabulary=token_vocabulary,
                                             pretrained_word_embedding_loader=pretrained_word_embedding_loader)

                if token_vocabulary_dir:
                    token_vocabulary.save_to_file(token_vocabulary_dir)

            label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                               padding=LabelVocabulary.PADDING)

            if label_vocabulary_dir:
                label_vocabulary.save_to_file(label_vocabulary_dir)
        else:
            if is_build_token_vocabulary and token_vocabulary_dir:
                token_vocabulary = Vocabulary.from_file(token_vocabulary_dir)

            if label_vocabulary_dir:
                label_vocabulary = LabelVocabulary.from_file(label_vocabulary_dir)

        self.token_vocabulary = token_vocabulary
        self.label_vocabulary = label_vocabulary
Esempio n. 16
0
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset,
                                  gaz_pretrained_embedding_loader):
    """
    测试 bilstm gat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    bilstm_gat_model_collate = BiLstmGATModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = bilstm_gat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    t_graph_0 = model_inputs.model_inputs["t_graph"][0]
    c_graph_0 = model_inputs.model_inputs["c_graph"][0]
    l_graph_0 = model_inputs.model_inputs["l_graph"][0]

    expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0))

    expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0))

    expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8)
    ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0))

    gaz_words_indices = model_inputs.model_inputs["gaz_words"]

    ASSERT.assertEqual((2, 11), gaz_words_indices.size())

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0)

    expect_squeeze_gaz_words_indices_0 = torch.tensor(
        [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0],
        dtype=torch.long)

    ASSERT.assertTrue(
        is_tensor_equal(expect_squeeze_gaz_words_indices_0,
                        gaz_words_indices[0]))
Esempio n. 17
0
def test_flat_model_collate(lattice_ner_demo_dataset,
                            character_pretrained_embedding_loader,
                            gaz_pretrained_embedding_loader):
    """
    测试 flat model collate
    :return:
    """
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    characters = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    character_vocabulary = Vocabulary(tokens=characters,
                                      padding=Vocabulary.PADDING,
                                      unk=Vocabulary.UNK,
                                      special_first=True)
    character_vocabulary = PretrainedVocabulary(
        vocabulary=character_vocabulary,
        pretrained_word_embedding_loader=character_pretrained_embedding_loader)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    flat_vocabulary = FlatPretrainedVocabulary(
        character_pretrained_vocabulary=character_vocabulary,
        gaz_word_pretrained_vocabulary=gaz_vocabulary)

    flat_model_collate = FLATModelCollate(token_vocabulary=flat_vocabulary,
                                          gazetter=gazetter,
                                          label_vocabulary=label_vocabulary)

    model_inputs = flat_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    sentence = "陈元呼吁加强国际合作推动世界经济发展"

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_squeeze_gaz_words_0 = [
        "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展"
    ]

    squeeze_gaz_words_0 = metadata_0["squeeze_gaz_words"]

    ASSERT.assertListEqual(expect_squeeze_gaz_words_0, squeeze_gaz_words_0)

    expect_tokens = [character
                     for character in sentence] + expect_squeeze_gaz_words_0

    tokens = metadata_0["tokens"]

    ASSERT.assertListEqual(expect_tokens, tokens)

    character_pos_begin = [index for index in range(len(sentence))]
    character_pos_end = [index for index in range(len(sentence))]

    squeeze_gaz_words_begin = list()
    squeeze_gaz_words_end = list()

    for squeeze_gaz_word in squeeze_gaz_words_0:
        index = sentence.find(squeeze_gaz_word)

        squeeze_gaz_words_begin.append(index)
        squeeze_gaz_words_end.append(index + len(squeeze_gaz_word) - 1)

    pos_begin = model_inputs.model_inputs["pos_begin"][0]
    pos_end = model_inputs.model_inputs["pos_end"][0]

    expect_pos_begin = character_pos_begin + squeeze_gaz_words_begin
    expect_pos_begin += [0] * (pos_begin.size(0) - len(expect_pos_begin))
    expect_pos_begin = torch.tensor(expect_pos_begin)

    expect_pos_end = character_pos_end + squeeze_gaz_words_end
    expect_pos_end += [0] * (pos_end.size(0) - len(expect_pos_end))
    expect_pos_end = torch.tensor(expect_pos_end)

    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_begin, pos_begin))
    ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_end, pos_end))

    expect_character_length = len(sentence)
    expect_squeeze_gaz_word_length = len(expect_squeeze_gaz_words_0)

    character_length = model_inputs.model_inputs["sequence_length"][0]
    squeeze_word_length = model_inputs.model_inputs["squeeze_gaz_word_length"][
        0]

    ASSERT.assertEqual(expect_character_length, character_length.item())
    ASSERT.assertEqual(expect_squeeze_gaz_word_length,
                       squeeze_word_length.item())
Esempio n. 18
0
    def __call__(self, config: Dict, train_type: int):
        serialize_dir = config["serialize_dir"]
        vocabulary_dir = config["vocabulary_dir"]
        pretrained_embedding_file_path = config["pretrained_embedding_file_path"]
        word_embedding_dim = config["word_embedding_dim"]
        pretrained_embedding_max_size = config["pretrained_embedding_max_size"]
        is_fine_tuning = config["fine_tuning"]

        word_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "word_vocabulary")
        event_type_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "event_type_vocabulary")
        entity_tag_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "entity_tag_vocabulary")

        if train_type == Train.NEW_TRAIN:

            if os.path.isdir(serialize_dir):
                shutil.rmtree(serialize_dir)

            os.makedirs(serialize_dir)

            if os.path.isdir(vocabulary_dir):
                shutil.rmtree(vocabulary_dir)
            os.makedirs(vocabulary_dir)
            os.makedirs(word_vocab_dir)
            os.makedirs(event_type_vocab_dir)
            os.makedirs(entity_tag_vocab_dir)

        elif train_type == Train.RECOVERY_TRAIN:
            pass
        else:
            assert False, f"train_type: {train_type} error!"

        train_dataset_file_path = config["train_dataset_file_path"]
        validation_dataset_file_path = config["validation_dataset_file_path"]

        num_epoch = config["epoch"]
        batch_size = config["batch_size"]

        if train_type == Train.NEW_TRAIN:
            # 构建词汇表
            ace_dataset = ACEDataset(train_dataset_file_path)
            vocab_data_loader = DataLoader(dataset=ace_dataset,
                                           batch_size=10,
                                           shuffle=False, num_workers=0,
                                           collate_fn=EventVocabularyCollate())

            tokens: List[List[str]] = list()
            event_types: List[List[str]] = list()
            entity_tags: List[List[str]] = list()

            for colleta_dict in vocab_data_loader:
                tokens.extend(colleta_dict["tokens"])
                event_types.extend(colleta_dict["event_types"])
                entity_tags.extend(colleta_dict["entity_tags"])

            word_vocabulary = Vocabulary(tokens=tokens,
                                         padding=Vocabulary.PADDING,
                                         unk=Vocabulary.UNK,
                                         special_first=True)

            glove_loader = GloveLoader(embedding_dim=word_embedding_dim,
                                       pretrained_file_path=pretrained_embedding_file_path,
                                       max_size=pretrained_embedding_max_size)

            pretrained_word_vocabulary = PretrainedVocabulary(vocabulary=word_vocabulary,
                                                              pretrained_word_embedding_loader=glove_loader)

            pretrained_word_vocabulary.save_to_file(word_vocab_dir)

            event_type_vocabulary = Vocabulary(tokens=event_types,
                                               padding="",
                                               unk="Negative",
                                               special_first=True)
            event_type_vocabulary.save_to_file(event_type_vocab_dir)

            entity_tag_vocabulary = LabelVocabulary(labels=entity_tags,
                                                    padding=LabelVocabulary.PADDING)
            entity_tag_vocabulary.save_to_file(entity_tag_vocab_dir)
        else:
            pretrained_word_vocabulary = PretrainedVocabulary.from_file(word_vocab_dir)
            event_type_vocabulary = Vocabulary.from_file(event_type_vocab_dir)
            entity_tag_vocabulary = Vocabulary.from_file(entity_tag_vocab_dir)

        model = EventModel(alpha=0.5,
                           activate_score=True,
                           sentence_vocab=pretrained_word_vocabulary,
                           sentence_embedding_dim=word_embedding_dim,
                           entity_tag_vocab=entity_tag_vocabulary,
                           entity_tag_embedding_dim=50,
                           event_type_vocab=event_type_vocabulary,
                           event_type_embedding_dim=300,
                           lstm_hidden_size=300,
                           lstm_encoder_num_layer=1,
                           lstm_encoder_droupout=0.4)

        trainer = Trainer(
            serialize_dir=serialize_dir,
            num_epoch=num_epoch,
            model=model,
            loss=EventLoss(),
            optimizer_factory=EventOptimizerFactory(is_fine_tuning=is_fine_tuning),
            metrics=EventF1MetricAdapter(event_type_vocabulary=event_type_vocabulary),
            patient=10,
            num_check_point_keep=5,
            devices=None
        )

        train_dataset = EventDataset(dataset_file_path=train_dataset_file_path,
                                     event_type_vocabulary=event_type_vocabulary)
        validation_dataset = EventDataset(dataset_file_path=validation_dataset_file_path,
                                          event_type_vocabulary=event_type_vocabulary)

        event_collate = EventCollate(word_vocabulary=pretrained_word_vocabulary,
                                     event_type_vocabulary=event_type_vocabulary,
                                     entity_tag_vocabulary=entity_tag_vocabulary,
                                     sentence_max_len=512)
        train_data_loader = DataLoader(dataset=train_dataset,
                                       batch_size=batch_size,
                                       num_workers=0,
                                       collate_fn=event_collate)

        validation_data_loader = DataLoader(dataset=validation_dataset,
                                            batch_size=batch_size,
                                            num_workers=0,
                                            collate_fn=event_collate)

        if train_type == Train.NEW_TRAIN:
            trainer.train(train_data_loader=train_data_loader,
                          validation_data_loader=validation_data_loader)
        else:
            trainer.recovery_train(train_data_loader=train_data_loader,
                          validation_data_loader=validation_data_loader)
def test_gaz_model_collate(lattice_ner_demo_dataset,
                           gaz_pretrained_embedding_loader):
    # 仅仅取前两个作为测试
    batch_instances = lattice_ner_demo_dataset[0:2]

    vocabulary_collate = VocabularyCollate()

    collate_result = vocabulary_collate(batch_instances)

    tokens = collate_result["tokens"]
    sequence_label = collate_result["sequence_labels"]

    token_vocabulary = Vocabulary(tokens=tokens,
                                  padding=Vocabulary.PADDING,
                                  unk=Vocabulary.UNK,
                                  special_first=True)

    label_vocabulary = LabelVocabulary(labels=sequence_label,
                                       padding=LabelVocabulary.PADDING)

    gazetter = Gazetteer(
        gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter)

    gaz_words = gaz_vocabulary_collate(batch_instances)

    gaz_vocabulary = Vocabulary(tokens=gaz_words,
                                padding=Vocabulary.PADDING,
                                unk=Vocabulary.UNK,
                                special_first=True)

    gaz_vocabulary = PretrainedVocabulary(
        vocabulary=gaz_vocabulary,
        pretrained_word_embedding_loader=gaz_pretrained_embedding_loader)

    lattice_model_collate = LatticeModelCollate(
        token_vocabulary=token_vocabulary,
        gazetter=gazetter,
        gaz_vocabulary=gaz_vocabulary,
        label_vocabulary=label_vocabulary)

    model_inputs = lattice_model_collate(batch_instances)

    logging.debug(json2str(model_inputs.model_inputs["metadata"]))

    metadata_0 = model_inputs.model_inputs["metadata"][0]

    # 陈元呼吁加强国际合作推动世界经济发展
    expect_gaz_words_0 = [["陈元"], [], ["呼吁"], ["吁加"], ["加强"], ["强国"], ["国际"],
                          [], ["合作"], [], ["推动"], [], ["世界"], [], ["经济"], [],
                          ["发展"], []]

    gaz_words_0 = metadata_0["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_words_0, gaz_words_0)

    gaz_list_0 = model_inputs.model_inputs["gaz_list"][0]

    expect_gaz_list_0 = list()

    for expect_gaz_word in expect_gaz_words_0:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_0.append([indices, lengthes])

        else:
            expect_gaz_list_0.append([])

    logging.debug(
        f"expect_gaz_list_0: {json2str(expect_gaz_list_0)}\n gaz_list_0:{json2str(gaz_list_0)}"
    )
    ASSERT.assertListEqual(expect_gaz_list_0, gaz_list_0)

    tokens_0 = model_inputs.model_inputs["tokens"]
    ASSERT.assertEqual((2, 19), tokens_0.size())
    sequence_label_0 = model_inputs.labels
    ASSERT.assertEqual((2, 19), sequence_label_0.size())

    # 新华社华盛顿4月28日电(记者翟景升)
    expect_gaz_word_1 = [
        ["新华社", "新华"],  # 新
        ["华社"],  # 华
        ["社华"],  # 社
        ["华盛顿", "华盛"],  # 华
        ["盛顿"],  # 盛
        [],  # 顿
        [],  # 4
        [],  # 月
        [],  # 2
        [],  # 8
        [],  # 日
        [],  # 电
        [],  # (
        ["记者"],  # 记
        [],  # 者
        ["翟景升", "翟景"],  # 翟
        ["景升"],  # 景
        [],  # 升
        []
    ]  # )

    metadata_1 = model_inputs.model_inputs["metadata"][1]
    gaz_words_1 = metadata_1["gaz_words"]

    ASSERT.assertListEqual(expect_gaz_word_1, gaz_words_1)

    expect_gaz_list_1 = list()

    for expect_gaz_word in expect_gaz_word_1:

        if len(expect_gaz_word) > 0:
            indices = [gaz_vocabulary.index(word) for word in expect_gaz_word]
            lengthes = [len(word) for word in expect_gaz_word]

            expect_gaz_list_1.append([indices, lengthes])

        else:
            expect_gaz_list_1.append([])

    gaz_list_1 = model_inputs.model_inputs["gaz_list"][1]
    ASSERT.assertListEqual(expect_gaz_list_1, gaz_list_1)