def load_dpr_model(self): model = DPRReader( DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) print("Loading DPR reader from {}".format(self.src_file)) saved_state = load_states_from_checkpoint(self.src_file) state_dict = {} for key, value in saved_state.model_dict.items(): if key.startswith( "encoder.") and not key.startswith("encoder.encode_proj"): key = "encoder.bert_model." + key[len("encoder."):] state_dict[key] = value model.span_predictor.load_state_dict(state_dict) return model
def load_dpr_model(self): model = DPRQuestionEncoder( DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) print("Loading DPR biencoder from {}".format(self.src_file)) saved_state = load_states_from_checkpoint(self.src_file) encoder, prefix = model.question_encoder, "question_model." state_dict = {} for key, value in saved_state.model_dict.items(): if key.startswith(prefix): key = key[len(prefix):] if not key.startswith("encode_proj."): key = "bert_model." + key state_dict[key] = value encoder.load_state_dict(state_dict) return model
def load_dpr_model(self): model = BertModel( BertConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) print("Loading DPR biencoder from {}".format(self.src_file)) saved_state = load_states_from_checkpoint(self.src_file) encoder, prefix = model, "model." # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 state_dict = {"embeddings.position_ids": model.embeddings.position_ids} for key, value in saved_state.model_dict.items(): if key.startswith(prefix): key = key[len(prefix):] if not key.startswith("encode_proj."): key = "" + key state_dict[key] = value encoder.load_state_dict(state_dict) return model
def load_dpr_model(self): model = DPRReader( DPRConfig(**BertConfig.get_config_dict("bert-base-uncased")[0])) print("Loading DPR reader from {}".format(self.src_file)) saved_state = load_states_from_checkpoint(self.src_file) # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 state_dict = { "encoder.bert_model.embeddings.position_ids": model.span_predictor.encoder.bert_model.embeddings.position_ids } for key, value in saved_state.model_dict.items(): if key.startswith( "encoder.") and not key.startswith("encoder.encode_proj"): key = "encoder.bert_model." + key[len("encoder."):] state_dict[key] = value model.span_predictor.load_state_dict(state_dict) return model
train_dataset, # The training samples. sampler=RandomSampler(train_dataset), # Select batches randomly batch_size=batch_size # Trains with this batch size. ) # For validation the order doesn't matter, so we'll just read them sequentially. validation_dataloader = DataLoader( val_dataset, # The validation samples. sampler=SequentialSampler(val_dataset), # Pull out batches sequentially. batch_size=batch_size # Evaluate with this batch size. ) """setup model, optimizer""" # configuration = BertConfig.from_pretrained('bert-base-cased') # print(configuration) # quit() pretrain_config = BertConfig.get_config_dict('bert-base-cased')[0] pretrain_config['entity_type_size'] = len(entity_type_dict) pretrain_config['role_type_size'] = len(role_type_dict) pretrain_config['class_size'] = len(event_type_dict) pretrain_config['chunk_size_feed_forward'] = 0 pretrain_config['add_cross_attention'] = False pretrain_config['use_return_dict'] = True pretrain_config['output_hidden_states'] = True pretrain_config['num_labels'] = len(event_type_dict) print(len(event_type_dict)) configuration = BertConfig.from_dict(pretrain_config) configuration.update(pretrain_config) # print(configuration) # quit()
batch_size = batch_size # Trains with this batch size. ) # For validation the order doesn't matter, so we'll just read them sequentially. validation_dataloader = DataLoader( val_dataset, # The validation samples. sampler = SequentialSampler(val_dataset), # Pull out batches sequentially. batch_size = batch_size # Evaluate with this batch size. ) """setup model, optimizer""" # configuration = BertConfig.from_pretrained('bert-base-cased') # print(configuration) # quit() pretrain_config = BertConfig.get_config_dict(pretrained_model_name)[0] # pretrain_config = BertConfig.get_config_dict('bert-large-cased-whole-word-masking')[0] pretrain_config['entity_type_size'] = len(entity_type_dict) # pretrain_config['event_type_size'] = len(event_type_dict) pretrain_config['role_type_size'] = len(role_type_dict) pretrain_config['class_size'] = len(event_type_dict) pretrain_config['chunk_size_feed_forward'] = 0 pretrain_config['add_cross_attention'] = False pretrain_config['use_return_dict'] = True pretrain_config['output_hidden_states'] = True pretrain_config['num_labels'] = len(event_type_dict) print('event_type_dict length:',len(event_type_dict)) configuration = BertConfig.from_dict(pretrain_config) configuration.update(pretrain_config)