def load_dpr_model(self):
     model = DPRReader(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR reader from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     state_dict = {}
     for key, value in saved_state.model_dict.items():
         if key.startswith(
                 "encoder.") and not key.startswith("encoder.encode_proj"):
             key = "encoder.bert_model." + key[len("encoder."):]
         state_dict[key] = value
     model.span_predictor.load_state_dict(state_dict)
     return model
 def load_dpr_model(self):
     model = DPRQuestionEncoder(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR biencoder from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     encoder, prefix = model.question_encoder, "question_model."
     state_dict = {}
     for key, value in saved_state.model_dict.items():
         if key.startswith(prefix):
             key = key[len(prefix):]
             if not key.startswith("encode_proj."):
                 key = "bert_model." + key
             state_dict[key] = value
     encoder.load_state_dict(state_dict)
     return model
 def load_dpr_model(self):
     model = BertModel(
         BertConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR biencoder from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     encoder, prefix = model, "model."
     # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
     state_dict = {"embeddings.position_ids": model.embeddings.position_ids}
     for key, value in saved_state.model_dict.items():
         if key.startswith(prefix):
             key = key[len(prefix):]
             if not key.startswith("encode_proj."):
                 key = "" + key
             state_dict[key] = value
     encoder.load_state_dict(state_dict)
     return model
 def load_dpr_model(self):
     model = DPRReader(
         DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0]))
     print("Loading DPR reader from {}".format(self.src_file))
     saved_state = load_states_from_checkpoint(self.src_file)
     # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
     state_dict = {
         "encoder.bert_model.embeddings.position_ids":
         model.span_predictor.encoder.bert_model.embeddings.position_ids
     }
     for key, value in saved_state.model_dict.items():
         if key.startswith(
                 "encoder.") and not key.startswith("encoder.encode_proj"):
             key = "encoder.bert_model." + key[len("encoder."):]
         state_dict[key] = value
     model.span_predictor.load_state_dict(state_dict)
     return model
Пример #5
0
    train_dataset,  # The training samples.
    sampler=RandomSampler(train_dataset),  # Select batches randomly
    batch_size=batch_size  # Trains with this batch size.
)

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
    val_dataset,  # The validation samples.
    sampler=SequentialSampler(val_dataset),  # Pull out batches sequentially.
    batch_size=batch_size  # Evaluate with this batch size.
)
"""setup model, optimizer"""
# configuration = BertConfig.from_pretrained('bert-base-cased')
# print(configuration)
# quit()
pretrain_config = BertConfig.get_config_dict('bert-base-cased')[0]
pretrain_config['entity_type_size'] = len(entity_type_dict)
pretrain_config['role_type_size'] = len(role_type_dict)
pretrain_config['class_size'] = len(event_type_dict)
pretrain_config['chunk_size_feed_forward'] = 0
pretrain_config['add_cross_attention'] = False
pretrain_config['use_return_dict'] = True
pretrain_config['output_hidden_states'] = True
pretrain_config['num_labels'] = len(event_type_dict)
print(len(event_type_dict))

configuration = BertConfig.from_dict(pretrain_config)
configuration.update(pretrain_config)

# print(configuration)
# quit()
Пример #6
0
            batch_size = batch_size # Trains with this batch size.
        )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
            val_dataset, # The validation samples.
            sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )


"""setup model, optimizer"""
# configuration = BertConfig.from_pretrained('bert-base-cased')
# print(configuration)
# quit()
pretrain_config = BertConfig.get_config_dict(pretrained_model_name)[0]
# pretrain_config = BertConfig.get_config_dict('bert-large-cased-whole-word-masking')[0]
pretrain_config['entity_type_size'] = len(entity_type_dict)
# pretrain_config['event_type_size'] = len(event_type_dict)
pretrain_config['role_type_size'] = len(role_type_dict)
pretrain_config['class_size'] = len(event_type_dict)
pretrain_config['chunk_size_feed_forward'] = 0
pretrain_config['add_cross_attention'] = False
pretrain_config['use_return_dict'] = True
pretrain_config['output_hidden_states'] = True
pretrain_config['num_labels'] = len(event_type_dict)
print('event_type_dict length:',len(event_type_dict))


configuration = BertConfig.from_dict(pretrain_config)
configuration.update(pretrain_config)