def __init__( self, is_training: bool, dataset: Dataset, gaz_vocabulary_dir: str, gaz_pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader ): """ 构建 Gaz 词汇表 :param is_training: 当前是否 Training 状态 :param dataset: 数据集 :param gaz_vocabulary_dir: gaz 词汇表存放目录 :param gaz_pretrained_word_embedding_loader: gaz 预训练 word embedding 载入器 """ super().__init__(is_training=is_training) # gazetter 理论上来说,应该支持持久化的,这里并没有做 gazetteer = Gazetteer(gaz_pretrained_word_embedding_loader= gaz_pretrained_word_embedding_loader) if is_training: gaz_vocabulary_collate = GazVocabularyCollate(gazetteer=gazetteer) data_loader = DataLoader(dataset=dataset, batch_size=100, shuffle=False, num_workers=0, collate_fn=gaz_vocabulary_collate) gaz_words = list() for batch_gaz_words in data_loader: gaz_words.extend(batch_gaz_words) gaz_vocabulary = Vocabulary(tokens=gaz_words, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) gaz_vocabulary = PretrainedVocabulary( vocabulary=gaz_vocabulary, pretrained_word_embedding_loader= gaz_pretrained_word_embedding_loader) gaz_vocabulary.save_to_file(gaz_vocabulary_dir) else: gaz_vocabulary = Vocabulary.from_file(gaz_vocabulary_dir) self.gaz_vocabulary = gaz_vocabulary self.gazetteer = gazetteer
def test_save_and_load(): """ 测试存储和载入 vocabulary :return: """ batch_tokens = [["我", "和", "你"], ["在", "我"], ["newline\nnewline"]] vocabulary = Vocabulary(batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True, other_special_tokens=["<Start>", "<End>"], min_frequency=1, max_size=None) ASSERT.assertEqual(vocabulary.size, 9) ASSERT.assertEqual(vocabulary.padding, vocabulary.PADDING) ASSERT.assertEqual(vocabulary.unk, vocabulary.UNK) ASSERT.assertEqual(vocabulary.index(vocabulary.padding), 0) ASSERT.assertEqual(vocabulary.index(vocabulary.unk), 1) ASSERT.assertEqual(vocabulary.index("<Start>"), 2) ASSERT.assertEqual(vocabulary.index("<End>"), 3) ASSERT.assertEqual(vocabulary.index("我"), 4) ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8) ASSERT.assertEqual(vocabulary.index("哈哈"), vocabulary.index(vocabulary.unk)) vocab_dir = os.path.join(ROOT_PATH, "data/easytext/tests") if not os.path.isdir(vocab_dir): os.makedirs(vocab_dir, exist_ok=True) vocabulary.save_to_file(vocab_dir) loaded_vocab = Vocabulary.from_file(directory=vocab_dir) ASSERT.assertEqual(vocabulary.size, 9) ASSERT.assertEqual(loaded_vocab.padding, vocabulary.PADDING) ASSERT.assertEqual(loaded_vocab.unk, vocabulary.UNK) ASSERT.assertEqual(loaded_vocab.index(vocabulary.padding), 0) ASSERT.assertEqual(loaded_vocab.index(vocabulary.unk), 1) ASSERT.assertEqual(loaded_vocab.index("<Start>"), 2) ASSERT.assertEqual(loaded_vocab.index("<End>"), 3) ASSERT.assertEqual(vocabulary.index("我"), 4) ASSERT.assertEqual(vocabulary.index("newline\nnewline"), 8) ASSERT.assertEqual(vocabulary.index("哈哈"), vocabulary.index(vocabulary.unk))
def __init__(self, is_training: bool, dataset: Dataset, vocabulary_collate, token_vocabulary_dir: str, label_vocabulary_dir: str, is_build_token_vocabulary: bool, pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader): """ 词汇表构建器 :param is_training: 因为在 train 和 非 train, 词汇表的构建行为有所不同; 如果是 train, 则一般需要重新构建; 而对于 非train, 使用先前构建好的即可。 :param dataset: 数据集 :param vocabulary_collate: 词汇表 collate :param token_vocabulary_dir: token vocabulary 存放目录 :param label_vocabulary_dir: label vocabulary 存放目录 :param is_build_token_vocabulary: 是否构建 token vocabulary, 因为在使用 Bert 或者 其他模型作为预训练的 embedding, 则没有必要构建 token vocabulary. :param pretrained_word_embedding_loader: 预训练词汇表 """ super().__init__(is_training=is_training) token_vocabulary = None label_vocabulary = None if is_training: data_loader = DataLoader(dataset=dataset, batch_size=100, shuffle=False, num_workers=0, collate_fn=vocabulary_collate) batch_tokens = list() batch_sequence_labels = list() for collate_dict in data_loader: batch_tokens.extend(collate_dict["tokens"]) batch_sequence_labels.extend(collate_dict["sequence_labels"]) if is_build_token_vocabulary: token_vocabulary = Vocabulary(tokens=batch_tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) if pretrained_word_embedding_loader is not None: token_vocabulary = \ PretrainedVocabulary(vocabulary=token_vocabulary, pretrained_word_embedding_loader=pretrained_word_embedding_loader) if token_vocabulary_dir: token_vocabulary.save_to_file(token_vocabulary_dir) label_vocabulary = LabelVocabulary(labels=batch_sequence_labels, padding=LabelVocabulary.PADDING) if label_vocabulary_dir: label_vocabulary.save_to_file(label_vocabulary_dir) else: if is_build_token_vocabulary and token_vocabulary_dir: token_vocabulary = Vocabulary.from_file(token_vocabulary_dir) if label_vocabulary_dir: label_vocabulary = LabelVocabulary.from_file(label_vocabulary_dir) self.token_vocabulary = token_vocabulary self.label_vocabulary = label_vocabulary
def __call__(self, config: Dict, train_type: int): serialize_dir = config["serialize_dir"] vocabulary_dir = config["vocabulary_dir"] pretrained_embedding_file_path = config["pretrained_embedding_file_path"] word_embedding_dim = config["word_embedding_dim"] pretrained_embedding_max_size = config["pretrained_embedding_max_size"] is_fine_tuning = config["fine_tuning"] word_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "word_vocabulary") event_type_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "event_type_vocabulary") entity_tag_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "entity_tag_vocabulary") if train_type == Train.NEW_TRAIN: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) if os.path.isdir(vocabulary_dir): shutil.rmtree(vocabulary_dir) os.makedirs(vocabulary_dir) os.makedirs(word_vocab_dir) os.makedirs(event_type_vocab_dir) os.makedirs(entity_tag_vocab_dir) elif train_type == Train.RECOVERY_TRAIN: pass else: assert False, f"train_type: {train_type} error!" train_dataset_file_path = config["train_dataset_file_path"] validation_dataset_file_path = config["validation_dataset_file_path"] num_epoch = config["epoch"] batch_size = config["batch_size"] if train_type == Train.NEW_TRAIN: # 构建词汇表 ace_dataset = ACEDataset(train_dataset_file_path) vocab_data_loader = DataLoader(dataset=ace_dataset, batch_size=10, shuffle=False, num_workers=0, collate_fn=EventVocabularyCollate()) tokens: List[List[str]] = list() event_types: List[List[str]] = list() entity_tags: List[List[str]] = list() for colleta_dict in vocab_data_loader: tokens.extend(colleta_dict["tokens"]) event_types.extend(colleta_dict["event_types"]) entity_tags.extend(colleta_dict["entity_tags"]) word_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) glove_loader = GloveLoader(embedding_dim=word_embedding_dim, pretrained_file_path=pretrained_embedding_file_path, max_size=pretrained_embedding_max_size) pretrained_word_vocabulary = PretrainedVocabulary(vocabulary=word_vocabulary, pretrained_word_embedding_loader=glove_loader) pretrained_word_vocabulary.save_to_file(word_vocab_dir) event_type_vocabulary = Vocabulary(tokens=event_types, padding="", unk="Negative", special_first=True) event_type_vocabulary.save_to_file(event_type_vocab_dir) entity_tag_vocabulary = LabelVocabulary(labels=entity_tags, padding=LabelVocabulary.PADDING) entity_tag_vocabulary.save_to_file(entity_tag_vocab_dir) else: pretrained_word_vocabulary = PretrainedVocabulary.from_file(word_vocab_dir) event_type_vocabulary = Vocabulary.from_file(event_type_vocab_dir) entity_tag_vocabulary = Vocabulary.from_file(entity_tag_vocab_dir) model = EventModel(alpha=0.5, activate_score=True, sentence_vocab=pretrained_word_vocabulary, sentence_embedding_dim=word_embedding_dim, entity_tag_vocab=entity_tag_vocabulary, entity_tag_embedding_dim=50, event_type_vocab=event_type_vocabulary, event_type_embedding_dim=300, lstm_hidden_size=300, lstm_encoder_num_layer=1, lstm_encoder_droupout=0.4) trainer = Trainer( serialize_dir=serialize_dir, num_epoch=num_epoch, model=model, loss=EventLoss(), optimizer_factory=EventOptimizerFactory(is_fine_tuning=is_fine_tuning), metrics=EventF1MetricAdapter(event_type_vocabulary=event_type_vocabulary), patient=10, num_check_point_keep=5, devices=None ) train_dataset = EventDataset(dataset_file_path=train_dataset_file_path, event_type_vocabulary=event_type_vocabulary) validation_dataset = EventDataset(dataset_file_path=validation_dataset_file_path, event_type_vocabulary=event_type_vocabulary) event_collate = EventCollate(word_vocabulary=pretrained_word_vocabulary, event_type_vocabulary=event_type_vocabulary, entity_tag_vocabulary=entity_tag_vocabulary, sentence_max_len=512) train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=0, collate_fn=event_collate) validation_data_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, num_workers=0, collate_fn=event_collate) if train_type == Train.NEW_TRAIN: trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) else: trainer.recovery_train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader)