def vocabulary(ace_dataset): vocab_collate_fn = EventVocabularyCollate() data_loader = DataLoader(ace_dataset, collate_fn=vocab_collate_fn) event_types_list: List[List[str]] = list() tokens_list: List[List[str]] = list() entity_tags_list: List[List[str]] = list() for collate_dict in data_loader: event_types_list.extend(collate_dict["event_types"]) tokens_list.extend(collate_dict["tokens"]) entity_tags_list.extend(collate_dict["entity_tags"]) negative_event_type = "Negative" event_type_vocab = Vocabulary(tokens=event_types_list, unk=negative_event_type, padding="", special_first=True) word_vocab = Vocabulary(tokens=tokens_list, unk=Vocabulary.UNK, padding=Vocabulary.PADDING, special_first=True) entity_tag_vocab = LabelVocabulary(entity_tags_list, padding=LabelVocabulary.PADDING) return { "event_type_vocab": event_type_vocab, "word_vocab": word_vocab, "entity_tag_vocab": entity_tag_vocab }
def event_type_vocabulary(): event_types = [["A", "B", "C"], ["A", "B"], ["A"]] vocabulary = Vocabulary(tokens=event_types, padding="", unk="Negative", special_first=True) ASSERT.assertEqual(4, vocabulary.size) ASSERT.assertEqual(0, vocabulary.index(vocabulary.unk)) ASSERT.assertEqual(1, vocabulary.index("A")) ASSERT.assertEqual(2, vocabulary.index("B")) ASSERT.assertEqual(3, vocabulary.index("C")) return vocabulary
def vocabulary( conll2003_dataset ) -> Dict[str, Union[Vocabulary, PretrainedVocabulary]]: data_loader = DataLoader(dataset=conll2003_dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=VocabularyCollate()) batch_tokens = list() batch_sequence_labels = list() for collate_dict in data_loader: batch_tokens.extend(collate_dict["tokens"]) batch_sequence_labels.extend(collate_dict["sequence_labels"]) token_vocabulary = Vocabulary(tokens=batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) label_vocabulary = LabelVocabulary(labels=batch_sequence_labels, padding=LabelVocabulary.PADDING) return { "token_vocabulary": token_vocabulary, "label_vocabulary": label_vocabulary }
def __init__(self, character_pretrained_vocabulary: PretrainedVocabulary, gaz_word_pretrained_vocabulary: PretrainedVocabulary): """ 初始化 :param character_pretrained_vocabulary: :param gaz_word_pretrained_vocabulary: """ assert character_pretrained_vocabulary.embedding_dim == gaz_word_pretrained_vocabulary.embedding_dim, \ f"character_pretrained_vocabulary 与 gaz_word_pretrained_vocabulary embedding 维度必须相同" char_embedding_dict = self.__token_embedding_dict(character_pretrained_vocabulary) gaz_word_embedding_dict = self.__token_embedding_dict(gaz_word_pretrained_vocabulary) tokens = [char_embedding_dict.keys(), gaz_word_embedding_dict.keys()] char_embedding_dict.update(gaz_word_embedding_dict) embedding_dict = char_embedding_dict vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) super().__init__(vocabulary=vocabulary, pretrained_word_embedding_loader=None) self._embedding_dim = character_pretrained_vocabulary.embedding_dim self._init_embedding_matrix(vocabulary=self._vocabulary, embedding_dict=embedding_dict, embedding_dim=self._embedding_dim)
def test_vocabulary_speical_first(): """ 测试 vocabulary speical first :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True, min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 6) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1)
def test_vocabulary(): """ :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"]] vocabulary = Vocabulary(batch_tokens, padding="", unk="", special_first=True, min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 4) ASSERT.assertTrue(not vocabulary.padding) ASSERT.assertTrue(not vocabulary.unk) ASSERT.assertEqual(vocabulary.index("我"), 0) ASSERT.assertEqual(vocabulary.index("和"), 1)
def __init__(self, dataset_file_path: str, event_type_vocabulary: Vocabulary): """ 初始化 ACE Event Dataset :param dataset_file_path: 数据集的文件路基 """ super().__init__() self._ace_dataset = ACEDataset(dataset_file_path=dataset_file_path) self._instances: List[Instance] = list() for ori_instance in self._ace_dataset: ori_event_types = ori_instance["event_types"] ori_event_type_set = None if ori_event_types is not None: # 实际预测的时候 ori_event_types is None # 针对 training 和 validation 设置,因为 对于 pair<sentence, unk>, label = 1 ori_event_type_set = set(ori_event_types) if len(ori_event_type_set) == 0: ori_event_type_set.add(event_type_vocabulary.unk) for index in range(event_type_vocabulary.size): # 遍历所有的label, 形成 pair<句子,事件类型>,作为样本 event_type = event_type_vocabulary.token(index) instance = Instance() instance["sentence"] = ori_instance["sentence"] instance["entity_tag"] = ori_instance["entity_tag"] instance["event_type"] = event_type instance["metadata"] = ori_instance["metadata"] if ori_event_type_set is not None: if event_type in ori_event_type_set: instance["label"] = 1 else: instance["label"] = 0 else: # 是针对实际的 prediction 设置的 pass self._instances.append(instance)
def build_vocabulary(self): training_dataset_file_path = self.config["training_dataset_file_path"] dataset = ACSASemEvalDataset( dataset_file_path=training_dataset_file_path) collate_fn = VocabularyCollate() data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=False, num_workers=0, collate_fn=collate_fn) tokens = list() categories = list() labels = list() for collate_dict in data_loader: tokens.append(collate_dict["tokens"]) categories.append(collate_dict["categories"]) labels.append(collate_dict["labels"]) token_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) if not self.config["debug"]: pretrained_file_path = self.config["pretrained_file_path"] pretrained_loader = GloveLoader( embedding_dim=300, pretrained_file_path=pretrained_file_path) pretrained_token_vocabulary = PretrainedVocabulary( vocabulary=token_vocabulary, pretrained_word_embedding_loader=pretrained_loader) token_vocabulary = pretrained_token_vocabulary category_vocabulary = LabelVocabulary(labels=categories, padding=None) label_vocabulary = LabelVocabulary(labels=labels, padding=None) return { "token_vocabulary": token_vocabulary, "category_vocabulary": category_vocabulary, "label_vocabulary": label_vocabulary }
def __init__(self, event_type_vocabulary: Vocabulary): """ 初始化 :param event_type_vocabulary: event type vocabulary """ super().__init__() self._event_type_f1: Dict[str, LabelF1Metric] = dict() for index in range(0, event_type_vocabulary.size): event_type = event_type_vocabulary.token(index) if event_type != event_type_vocabulary.unk: self._event_type_f1[event_type] = LabelF1Metric( labels=[1], label_vocabulary=None) self._event_type_f1[EventF1MetricAdapter.__OVERALL] = LabelF1Metric( labels=[1], label_vocabulary=None) self._event_type_vocabulary = event_type_vocabulary
def pretrained_vocabulary(): """ 生成 预训练词汇表 """ pretrained_file_path = "data/easytext/tests/pretrained/word_embedding_sample.3d.txt" pretrained_file_path = os.path.join(ROOT_PATH, pretrained_file_path) glove_loader = GloveLoader(embedding_dim=3, pretrained_file_path=pretrained_file_path) tokens = [["我"], ["美丽"]] vocab = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) pretrained_vocab = PretrainedVocabulary( vocabulary=vocab, pretrained_word_embedding_loader=glove_loader) return pretrained_vocab
def __init__( self, is_training: bool, dataset: Dataset, gaz_vocabulary_dir: str, gaz_pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader ): """ 构建 Gaz 词汇表 :param is_training: 当前是否 Training 状态 :param dataset: 数据集 :param gaz_vocabulary_dir: gaz 词汇表存放目录 :param gaz_pretrained_word_embedding_loader: gaz 预训练 word embedding 载入器 """ super().__init__(is_training=is_training) # gazetter 理论上来说,应该支持持久化的,这里并没有做 gazetteer = Gazetteer(gaz_pretrained_word_embedding_loader= gaz_pretrained_word_embedding_loader) if is_training: gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetteer) data_loader = DataLoader(dataset=dataset, batch_size=100, shuffle=False, num_workers=0, collate_fn=gaz_vocabulary_collate) gaz_words = list() for batch_gaz_words in data_loader: gaz_words.extend(batch_gaz_words) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader= gaz_pretrained_word_embedding_loader) gaz_vocabulary.save_to_file(gaz_vocabulary_dir) else: gaz_vocabulary = Vocabulary.from_file(gaz_vocabulary_dir) self.gaz_vocabulary = gaz_vocabulary self.gazetteer = gazetteer
def build_vocabulary(self, dataset: Dataset): data_loader = DataLoader(dataset=dataset, batch_size=100, shuffle=False, num_workers=0, collate_fn=VocabularyCollate()) batch_tokens = list() batch_sequence_labels = list() for collate_dict in data_loader: batch_tokens.extend(collate_dict["tokens"]) batch_sequence_labels.extend(collate_dict["sequence_labels"]) token_vocabulary = Vocabulary(tokens=batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) model_name = self.config["model_name"] if model_name in {ModelName.NER_V2, ModelName.NER_V3}: pretrained_word_embedding_file_path = self.config[ "pretrained_word_embedding_file_path"] glove_loader = GloveLoader( embedding_dim=100, pretrained_file_path=pretrained_word_embedding_file_path) token_vocabulary = PretrainedVocabulary( vocabulary=token_vocabulary, pretrained_word_embedding_loader=glove_loader) label_vocabulary = LabelVocabulary(labels=batch_sequence_labels, padding=LabelVocabulary.PADDING) return { "token_vocabulary": token_vocabulary, "label_vocabulary": label_vocabulary }
def test_speical_last(): batch_tokens = [["我", "和", "你"], ["在", "我"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=False, other_special_tokens=["<Start>", "<End>"], min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 8) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 3 + 1) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 3 + 2) ASSERT.assertEqual(vocabulary.index("<Start>"), 3 + 3) ASSERT.assertEqual(vocabulary.index("<End>"), 3 + 4)
def test_save_and_load(): """ 测试存储和载入 vocabulary :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"], ["newline\nnewline"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True, other_special_tokens=["<Start>", "<End>"], min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 9) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1) ASSERT.assertEqual(vocabulary.index("<Start>"), 2) ASSERT.assertEqual(vocabulary.index("<End>"), 3) ASSERT.assertEqual(vocabulary.index("我"), 4) ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8) ASSERT.assertEqual(vocabulary.index("哈哈"), vocabulary.index(vocabulary.unk)) vocab_dir = os.path.join(ROOT_PATH, "data/easytext/tests") if not os.path.isdir(vocab_dir): os.makedirs(vocab_dir, exist_ok=True) vocabulary.save_to_file(vocab_dir) loaded_vocab = Vocabulary.from_file(directory=vocab_dir) ASSERT.assertEqual(vocabulary.size, 9) ASSERT.assertEqual(loaded_vocab.padding, vocabulary.PADDING) ASSERT.assertEqual(loaded_vocab.unk, vocabulary.UNK) ASSERT.assertEqual(loaded_vocab.index(vocabulary.padding), 0) ASSERT.assertEqual(loaded_vocab.index(vocabulary.unk), 1) ASSERT.assertEqual(loaded_vocab.index("<Start>"), 2) ASSERT.assertEqual(loaded_vocab.index("<End>"), 3) ASSERT.assertEqual(vocabulary.index("我"), 4) ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8) ASSERT.assertEqual(vocabulary.index("哈哈"), vocabulary.index(vocabulary.unk))
def __init__(self, is_training: bool, dataset: Dataset, vocabulary_collate, token_vocabulary_dir: str, label_vocabulary_dir: str, is_build_token_vocabulary: bool, pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader): """ 词汇表构建器 :param is_training: 因为在 train 和 非 train, 词汇表的构建行为有所不同; 如果是 train, 则一般需要重新构建; 而对于 非train, 使用先前构建好的即可。 :param dataset: 数据集 :param vocabulary_collate: 词汇表 collate :param token_vocabulary_dir: token vocabulary 存放目录 :param label_vocabulary_dir: label vocabulary 存放目录 :param is_build_token_vocabulary: 是否构建 token vocabulary, 因为在使用 Bert 或者 其他模型作为预训练的 embedding, 则没有必要构建 token vocabulary. :param pretrained_word_embedding_loader: 预训练词汇表 """ super().__init__(is_training=is_training) token_vocabulary = None label_vocabulary = None if is_training: data_loader = DataLoader(dataset=dataset, batch_size=100, shuffle=False, num_workers=0, collate_fn=vocabulary_collate) batch_tokens = list() batch_sequence_labels = list() for collate_dict in data_loader: batch_tokens.extend(collate_dict["tokens"]) batch_sequence_labels.extend(collate_dict["sequence_labels"]) if is_build_token_vocabulary: token_vocabulary = Vocabulary(tokens=batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) if pretrained_word_embedding_loader is not None: token_vocabulary = \ PretrainedVocabulary(vocabulary=token_vocabulary, pretrained_word_embedding_loader=pretrained_word_embedding_loader) if token_vocabulary_dir: token_vocabulary.save_to_file(token_vocabulary_dir) label_vocabulary = LabelVocabulary(labels=batch_sequence_labels, padding=LabelVocabulary.PADDING) if label_vocabulary_dir: label_vocabulary.save_to_file(label_vocabulary_dir) else: if is_build_token_vocabulary and token_vocabulary_dir: token_vocabulary = Vocabulary.from_file(token_vocabulary_dir) if label_vocabulary_dir: label_vocabulary = LabelVocabulary.from_file(label_vocabulary_dir) self.token_vocabulary = token_vocabulary self.label_vocabulary = label_vocabulary
def test_bilstm_gat_model_collate(lattice_ner_demo_dataset, gaz_pretrained_embedding_loader): """ 测试 bilstm gat model collate :return: """ # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) tokens = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] token_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) bilstm_gat_model_collate = BiLstmGATModelCollate( token_vocabulary=token_vocabulary, gazetter=gazetter, gaz_vocabulary=gaz_vocabulary, label_vocabulary=label_vocabulary) model_inputs = bilstm_gat_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) t_graph_0 = model_inputs.model_inputs["t_graph"][0] c_graph_0 = model_inputs.model_inputs["c_graph"][0] l_graph_0 = model_inputs.model_inputs["l_graph"][0] expect_t_graph_tensor = torch.tensor(expect_t_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_t_graph_tensor, t_graph_0)) expect_c_graph_tensor = torch.tensor(expect_c_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_c_graph_tensor, c_graph_0)) expect_l_graph_tensor = torch.tensor(expect_l_graph, dtype=torch.uint8) ASSERT.assertTrue(is_tensor_equal(expect_l_graph_tensor, l_graph_0)) gaz_words_indices = model_inputs.model_inputs["gaz_words"] ASSERT.assertEqual((2, 11), gaz_words_indices.size()) metadata_0 = model_inputs.model_inputs["metadata"][0] # 陈元呼吁加强国际合作推动世界经济发展 expect_squeeze_gaz_words_0 = [ "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展" ] sequeeze_gaz_words_0 = metadata_0["sequeeze_gaz_words"] ASSERT.assertListEqual(expect_squeeze_gaz_words_0, sequeeze_gaz_words_0) expect_squeeze_gaz_words_indices_0 = torch.tensor( [gaz_vocabulary.index(word) for word in expect_squeeze_gaz_words_0], dtype=torch.long) ASSERT.assertTrue( is_tensor_equal(expect_squeeze_gaz_words_indices_0, gaz_words_indices[0]))
def test_flat_model_collate(lattice_ner_demo_dataset, character_pretrained_embedding_loader, gaz_pretrained_embedding_loader): """ 测试 flat model collate :return: """ # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) characters = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] character_vocabulary = Vocabulary(tokens=characters, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) character_vocabulary = PretrainedVocabulary( vocabulary=character_vocabulary, pretrained_word_embedding_loader=character_pretrained_embedding_loader) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) flat_vocabulary = FlatPretrainedVocabulary( character_pretrained_vocabulary=character_vocabulary, gaz_word_pretrained_vocabulary=gaz_vocabulary) flat_model_collate = FLATModelCollate(token_vocabulary=flat_vocabulary, gazetter=gazetter, label_vocabulary=label_vocabulary) model_inputs = flat_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) metadata_0 = model_inputs.model_inputs["metadata"][0] sentence = "陈元呼吁加强国际合作推动世界经济发展" # 陈元呼吁加强国际合作推动世界经济发展 expect_squeeze_gaz_words_0 = [ "陈元", "呼吁", "吁加", "加强", "强国", "国际", "合作", "推动", "世界", "经济", "发展" ] squeeze_gaz_words_0 = metadata_0["squeeze_gaz_words"] ASSERT.assertListEqual(expect_squeeze_gaz_words_0, squeeze_gaz_words_0) expect_tokens = [character for character in sentence] + expect_squeeze_gaz_words_0 tokens = metadata_0["tokens"] ASSERT.assertListEqual(expect_tokens, tokens) character_pos_begin = [index for index in range(len(sentence))] character_pos_end = [index for index in range(len(sentence))] squeeze_gaz_words_begin = list() squeeze_gaz_words_end = list() for squeeze_gaz_word in squeeze_gaz_words_0: index = sentence.find(squeeze_gaz_word) squeeze_gaz_words_begin.append(index) squeeze_gaz_words_end.append(index + len(squeeze_gaz_word) - 1) pos_begin = model_inputs.model_inputs["pos_begin"][0] pos_end = model_inputs.model_inputs["pos_end"][0] expect_pos_begin = character_pos_begin + squeeze_gaz_words_begin expect_pos_begin += [0] * (pos_begin.size(0) - len(expect_pos_begin)) expect_pos_begin = torch.tensor(expect_pos_begin) expect_pos_end = character_pos_end + squeeze_gaz_words_end expect_pos_end += [0] * (pos_end.size(0) - len(expect_pos_end)) expect_pos_end = torch.tensor(expect_pos_end) ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_begin, pos_begin)) ASSERT.assertTrue(tensor_util.is_tensor_equal(expect_pos_end, pos_end)) expect_character_length = len(sentence) expect_squeeze_gaz_word_length = len(expect_squeeze_gaz_words_0) character_length = model_inputs.model_inputs["sequence_length"][0] squeeze_word_length = model_inputs.model_inputs["squeeze_gaz_word_length"][ 0] ASSERT.assertEqual(expect_character_length, character_length.item()) ASSERT.assertEqual(expect_squeeze_gaz_word_length, squeeze_word_length.item())
def __call__(self, config: Dict, train_type: int): serialize_dir = config["serialize_dir"] vocabulary_dir = config["vocabulary_dir"] pretrained_embedding_file_path = config["pretrained_embedding_file_path"] word_embedding_dim = config["word_embedding_dim"] pretrained_embedding_max_size = config["pretrained_embedding_max_size"] is_fine_tuning = config["fine_tuning"] word_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "word_vocabulary") event_type_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "event_type_vocabulary") entity_tag_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "entity_tag_vocabulary") if train_type == Train.NEW_TRAIN: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) if os.path.isdir(vocabulary_dir): shutil.rmtree(vocabulary_dir) os.makedirs(vocabulary_dir) os.makedirs(word_vocab_dir) os.makedirs(event_type_vocab_dir) os.makedirs(entity_tag_vocab_dir) elif train_type == Train.RECOVERY_TRAIN: pass else: assert False, f"train_type: {train_type} error!" train_dataset_file_path = config["train_dataset_file_path"] validation_dataset_file_path = config["validation_dataset_file_path"] num_epoch = config["epoch"] batch_size = config["batch_size"] if train_type == Train.NEW_TRAIN: # 构建词汇表 ace_dataset = ACEDataset(train_dataset_file_path) vocab_data_loader = DataLoader(dataset=ace_dataset, batch_size=10, shuffle=False, num_workers=0, collate_fn=EventVocabularyCollate()) tokens: List[List[str]] = list() event_types: List[List[str]] = list() entity_tags: List[List[str]] = list() for colleta_dict in vocab_data_loader: tokens.extend(colleta_dict["tokens"]) event_types.extend(colleta_dict["event_types"]) entity_tags.extend(colleta_dict["entity_tags"]) word_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) glove_loader = GloveLoader(embedding_dim=word_embedding_dim, pretrained_file_path=pretrained_embedding_file_path, max_size=pretrained_embedding_max_size) pretrained_word_vocabulary = PretrainedVocabulary(vocabulary=word_vocabulary, pretrained_word_embedding_loader=glove_loader) pretrained_word_vocabulary.save_to_file(word_vocab_dir) event_type_vocabulary = Vocabulary(tokens=event_types, padding="", unk="Negative", special_first=True) event_type_vocabulary.save_to_file(event_type_vocab_dir) entity_tag_vocabulary = LabelVocabulary(labels=entity_tags, padding=LabelVocabulary.PADDING) entity_tag_vocabulary.save_to_file(entity_tag_vocab_dir) else: pretrained_word_vocabulary = PretrainedVocabulary.from_file(word_vocab_dir) event_type_vocabulary = Vocabulary.from_file(event_type_vocab_dir) entity_tag_vocabulary = Vocabulary.from_file(entity_tag_vocab_dir) model = EventModel(alpha=0.5, activate_score=True, sentence_vocab=pretrained_word_vocabulary, sentence_embedding_dim=word_embedding_dim, entity_tag_vocab=entity_tag_vocabulary, entity_tag_embedding_dim=50, event_type_vocab=event_type_vocabulary, event_type_embedding_dim=300, lstm_hidden_size=300, lstm_encoder_num_layer=1, lstm_encoder_droupout=0.4) trainer = Trainer( serialize_dir=serialize_dir, num_epoch=num_epoch, model=model, loss=EventLoss(), optimizer_factory=EventOptimizerFactory(is_fine_tuning=is_fine_tuning), metrics=EventF1MetricAdapter(event_type_vocabulary=event_type_vocabulary), patient=10, num_check_point_keep=5, devices=None ) train_dataset = EventDataset(dataset_file_path=train_dataset_file_path, event_type_vocabulary=event_type_vocabulary) validation_dataset = EventDataset(dataset_file_path=validation_dataset_file_path, event_type_vocabulary=event_type_vocabulary) event_collate = EventCollate(word_vocabulary=pretrained_word_vocabulary, event_type_vocabulary=event_type_vocabulary, entity_tag_vocabulary=entity_tag_vocabulary, sentence_max_len=512) train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=0, collate_fn=event_collate) validation_data_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, num_workers=0, collate_fn=event_collate) if train_type == Train.NEW_TRAIN: trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) else: trainer.recovery_train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader)
def test_gaz_model_collate(lattice_ner_demo_dataset, gaz_pretrained_embedding_loader): # 仅仅取前两个作为测试 batch_instances = lattice_ner_demo_dataset[0:2] vocabulary_collate = VocabularyCollate() collate_result = vocabulary_collate(batch_instances) tokens = collate_result["tokens"] sequence_label = collate_result["sequence_labels"] token_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) label_vocabulary = LabelVocabulary(labels=sequence_label, padding=LabelVocabulary.PADDING) gazetter = Gazetteer( gaz_pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetter) gaz_words = gaz_vocabulary_collate(batch_instances) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader=gaz_pretrained_embedding_loader) lattice_model_collate = LatticeModelCollate( token_vocabulary=token_vocabulary, gazetter=gazetter, gaz_vocabulary=gaz_vocabulary, label_vocabulary=label_vocabulary) model_inputs = lattice_model_collate(batch_instances) logging.debug(json2str(model_inputs.model_inputs["metadata"])) metadata_0 = model_inputs.model_inputs["metadata"][0] # 陈元呼吁加强国际合作推动世界经济发展 expect_gaz_words_0 = [["陈元"], [], ["呼吁"], ["吁加"], ["加强"], ["强国"], ["国际"], [], ["合作"], [], ["推动"], [], ["世界"], [], ["经济"], [], ["发展"], []] gaz_words_0 = metadata_0["gaz_words"] ASSERT.assertListEqual(expect_gaz_words_0, gaz_words_0) gaz_list_0 = model_inputs.model_inputs["gaz_list"][0] expect_gaz_list_0 = list() for expect_gaz_word in expect_gaz_words_0: if len(expect_gaz_word) > 0: indices = [gaz_vocabulary.index(word) for word in expect_gaz_word] lengthes = [len(word) for word in expect_gaz_word] expect_gaz_list_0.append([indices, lengthes]) else: expect_gaz_list_0.append([]) logging.debug( f"expect_gaz_list_0: {json2str(expect_gaz_list_0)}\n gaz_list_0:{json2str(gaz_list_0)}" ) ASSERT.assertListEqual(expect_gaz_list_0, gaz_list_0) tokens_0 = model_inputs.model_inputs["tokens"] ASSERT.assertEqual((2, 19), tokens_0.size()) sequence_label_0 = model_inputs.labels ASSERT.assertEqual((2, 19), sequence_label_0.size()) # 新华社华盛顿4月28日电(记者翟景升) expect_gaz_word_1 = [ ["新华社", "新华"], # 新 ["华社"], # 华 ["社华"], # 社 ["华盛顿", "华盛"], # 华 ["盛顿"], # 盛 [], # 顿 [], # 4 [], # 月 [], # 2 [], # 8 [], # 日 [], # 电 [], # ( ["记者"], # 记 [], # 者 ["翟景升", "翟景"], # 翟 ["景升"], # 景 [], # 升 [] ] # ) metadata_1 = model_inputs.model_inputs["metadata"][1] gaz_words_1 = metadata_1["gaz_words"] ASSERT.assertListEqual(expect_gaz_word_1, gaz_words_1) expect_gaz_list_1 = list() for expect_gaz_word in expect_gaz_word_1: if len(expect_gaz_word) > 0: indices = [gaz_vocabulary.index(word) for word in expect_gaz_word] lengthes = [len(word) for word in expect_gaz_word] expect_gaz_list_1.append([indices, lengthes]) else: expect_gaz_list_1.append([]) gaz_list_1 = model_inputs.model_inputs["gaz_list"][1] ASSERT.assertListEqual(expect_gaz_list_1, gaz_list_1)