Exemplo n.º 1
0
    def __init__(self,
                 is_training: bool,
                 dataset: Dataset,
                 vocabulary_collate,
                 token_vocabulary_dir: str,
                 label_vocabulary_dir: str,
                 is_build_token_vocabulary: bool,
                 pretrained_word_embedding_loader: PretrainedWordEmbeddingLoader):
        """
        词汇表构建器
        :param is_training: 因为在 train 和 非 train, 词汇表的构建行为有所不同;
                            如果是 train, 则一般需要重新构建; 而对于 非train, 使用先前构建好的即可。
        :param dataset: 数据集
        :param vocabulary_collate: 词汇表 collate
        :param token_vocabulary_dir: token vocabulary 存放目录
        :param label_vocabulary_dir: label vocabulary 存放目录
        :param is_build_token_vocabulary: 是否构建 token vocabulary, 因为在使用 Bert 或者 其他模型作为预训练的 embedding,
                                          则没有必要构建 token vocabulary.
        :param pretrained_word_embedding_loader: 预训练词汇表
        """

        super().__init__(is_training=is_training)

        token_vocabulary = None
        label_vocabulary = None

        if is_training:
            data_loader = DataLoader(dataset=dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     num_workers=0,
                                     collate_fn=vocabulary_collate)
            batch_tokens = list()
            batch_sequence_labels = list()

            for collate_dict in data_loader:
                batch_tokens.extend(collate_dict["tokens"])
                batch_sequence_labels.extend(collate_dict["sequence_labels"])

            if is_build_token_vocabulary:
                token_vocabulary = Vocabulary(tokens=batch_tokens,
                                              padding=Vocabulary.PADDING,
                                              unk=Vocabulary.UNK,
                                              special_first=True)

                if pretrained_word_embedding_loader is not None:
                    token_vocabulary = \
                        PretrainedVocabulary(vocabulary=token_vocabulary,
                                             pretrained_word_embedding_loader=pretrained_word_embedding_loader)

                if token_vocabulary_dir:
                    token_vocabulary.save_to_file(token_vocabulary_dir)

            label_vocabulary = LabelVocabulary(labels=batch_sequence_labels,
                                               padding=LabelVocabulary.PADDING)

            if label_vocabulary_dir:
                label_vocabulary.save_to_file(label_vocabulary_dir)
        else:
            if is_build_token_vocabulary and token_vocabulary_dir:
                token_vocabulary = Vocabulary.from_file(token_vocabulary_dir)

            if label_vocabulary_dir:
                label_vocabulary = LabelVocabulary.from_file(label_vocabulary_dir)

        self.token_vocabulary = token_vocabulary
        self.label_vocabulary = label_vocabulary
Exemplo n.º 2
0
    def __call__(self, config: Dict, train_type: int):
        serialize_dir = config["serialize_dir"]
        vocabulary_dir = config["vocabulary_dir"]
        pretrained_embedding_file_path = config["pretrained_embedding_file_path"]
        word_embedding_dim = config["word_embedding_dim"]
        pretrained_embedding_max_size = config["pretrained_embedding_max_size"]
        is_fine_tuning = config["fine_tuning"]

        word_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "word_vocabulary")
        event_type_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "event_type_vocabulary")
        entity_tag_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "entity_tag_vocabulary")

        if train_type == Train.NEW_TRAIN:

            if os.path.isdir(serialize_dir):
                shutil.rmtree(serialize_dir)

            os.makedirs(serialize_dir)

            if os.path.isdir(vocabulary_dir):
                shutil.rmtree(vocabulary_dir)
            os.makedirs(vocabulary_dir)
            os.makedirs(word_vocab_dir)
            os.makedirs(event_type_vocab_dir)
            os.makedirs(entity_tag_vocab_dir)

        elif train_type == Train.RECOVERY_TRAIN:
            pass
        else:
            assert False, f"train_type: {train_type} error!"

        train_dataset_file_path = config["train_dataset_file_path"]
        validation_dataset_file_path = config["validation_dataset_file_path"]

        num_epoch = config["epoch"]
        batch_size = config["batch_size"]

        if train_type == Train.NEW_TRAIN:
            # 构建词汇表
            ace_dataset = ACEDataset(train_dataset_file_path)
            vocab_data_loader = DataLoader(dataset=ace_dataset,
                                           batch_size=10,
                                           shuffle=False, num_workers=0,
                                           collate_fn=EventVocabularyCollate())

            tokens: List[List[str]] = list()
            event_types: List[List[str]] = list()
            entity_tags: List[List[str]] = list()

            for colleta_dict in vocab_data_loader:
                tokens.extend(colleta_dict["tokens"])
                event_types.extend(colleta_dict["event_types"])
                entity_tags.extend(colleta_dict["entity_tags"])

            word_vocabulary = Vocabulary(tokens=tokens,
                                         padding=Vocabulary.PADDING,
                                         unk=Vocabulary.UNK,
                                         special_first=True)

            glove_loader = GloveLoader(embedding_dim=word_embedding_dim,
                                       pretrained_file_path=pretrained_embedding_file_path,
                                       max_size=pretrained_embedding_max_size)

            pretrained_word_vocabulary = PretrainedVocabulary(vocabulary=word_vocabulary,
                                                              pretrained_word_embedding_loader=glove_loader)

            pretrained_word_vocabulary.save_to_file(word_vocab_dir)

            event_type_vocabulary = Vocabulary(tokens=event_types,
                                               padding="",
                                               unk="Negative",
                                               special_first=True)
            event_type_vocabulary.save_to_file(event_type_vocab_dir)

            entity_tag_vocabulary = LabelVocabulary(labels=entity_tags,
                                                    padding=LabelVocabulary.PADDING)
            entity_tag_vocabulary.save_to_file(entity_tag_vocab_dir)
        else:
            pretrained_word_vocabulary = PretrainedVocabulary.from_file(word_vocab_dir)
            event_type_vocabulary = Vocabulary.from_file(event_type_vocab_dir)
            entity_tag_vocabulary = Vocabulary.from_file(entity_tag_vocab_dir)

        model = EventModel(alpha=0.5,
                           activate_score=True,
                           sentence_vocab=pretrained_word_vocabulary,
                           sentence_embedding_dim=word_embedding_dim,
                           entity_tag_vocab=entity_tag_vocabulary,
                           entity_tag_embedding_dim=50,
                           event_type_vocab=event_type_vocabulary,
                           event_type_embedding_dim=300,
                           lstm_hidden_size=300,
                           lstm_encoder_num_layer=1,
                           lstm_encoder_droupout=0.4)

        trainer = Trainer(
            serialize_dir=serialize_dir,
            num_epoch=num_epoch,
            model=model,
            loss=EventLoss(),
            optimizer_factory=EventOptimizerFactory(is_fine_tuning=is_fine_tuning),
            metrics=EventF1MetricAdapter(event_type_vocabulary=event_type_vocabulary),
            patient=10,
            num_check_point_keep=5,
            devices=None
        )

        train_dataset = EventDataset(dataset_file_path=train_dataset_file_path,
                                     event_type_vocabulary=event_type_vocabulary)
        validation_dataset = EventDataset(dataset_file_path=validation_dataset_file_path,
                                          event_type_vocabulary=event_type_vocabulary)

        event_collate = EventCollate(word_vocabulary=pretrained_word_vocabulary,
                                     event_type_vocabulary=event_type_vocabulary,
                                     entity_tag_vocabulary=entity_tag_vocabulary,
                                     sentence_max_len=512)
        train_data_loader = DataLoader(dataset=train_dataset,
                                       batch_size=batch_size,
                                       num_workers=0,
                                       collate_fn=event_collate)

        validation_data_loader = DataLoader(dataset=validation_dataset,
                                            batch_size=batch_size,
                                            num_workers=0,
                                            collate_fn=event_collate)

        if train_type == Train.NEW_TRAIN:
            trainer.train(train_data_loader=train_data_loader,
                          validation_data_loader=validation_data_loader)
        else:
            trainer.recovery_train(train_data_loader=train_data_loader,
                          validation_data_loader=validation_data_loader)