def dev(config, bert_config, dev_path, id2rel, tokenizer, output_path=None): dev_data = json.load(open(dev_path)) for sent in dev_data: data.to_tuple(sent) with torch.no_grad(): Bert_model = BertModel(bert_config).to(device).eval() submodel = sub_model(config).to(device).eval() objmodel = obj_model(config).to(device).eval() state = torch.load( os.path.join(config.output_dir, config.load_model_name)) Bert_model.load_state_dict(state['bert_state_dict']) submodel.load_state_dict(state['subject_state_dict']) objmodel.load_state_dict(state['object_state_dict']) precision, recall, f1 = utils.metric(Bert_model, submodel, objmodel, dev_data, id2rel, tokenizer, output_path=output_path) logger.info('precision: %.4f' % precision) logger.info('recall: %.4f' % recall) logger.info('F1: %.4f' % f1)
def init_bert_model_with_teacher( student: BertModel, teacher: BertModel, layers_to_transfer: List[int] = None, ) -> BertModel: """Initialize student model with teacher layers. Args: student (BertModel): Student model. teacher (BertModel): Teacher model. layers_to_transfer (List[int], optional): Defines which layers will be transfered. If None then will transfer last layers. Defaults to None. Returns: BertModel: [description] """ teacher_hidden_size = teacher.config.hidden_size student_hidden_size = student.config.hidden_size if teacher_hidden_size != student_hidden_size: raise Exception("Teacher and student hidden size should be the same") teacher_layers_num = teacher.config.num_hidden_layers student_layers_num = student.config.num_hidden_layers if layers_to_transfer is None: layers_to_transfer = list( range(teacher_layers_num - student_layers_num, teacher_layers_num)) prefix_teacher = list(teacher.state_dict().keys())[0].split(".")[0] prefix_student = list(student.state_dict().keys())[0].split(".")[0] student_sd = _extract_layers( teacher_model=teacher, layers=layers_to_transfer, ) student.load_state_dict(student_sd) return student
def get_kobert_model(model_file, vocab_file, ctx="cpu"): bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) bertmodel.load_state_dict(torch.load(model_file), strict=False) device = torch.device(ctx) bertmodel.to(device) bertmodel.eval() vocab_b_obj = nlp.vocab.BERTVocab.from_json(open(vocab_file, 'rt').read()) return bertmodel, vocab_b_obj
def get_kobert_model(model_file, vocab_file, ctx="cpu"): bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) bertmodel.load_state_dict(torch.load(model_file)) device = torch.device(ctx) bertmodel.to(device) bertmodel.eval() vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file, padding_token='[PAD]') return bertmodel, vocab_b_obj
def run(pretrained_model, out_dir, num_layers=3): os.makedirs(out_dir, exist_ok=True) tokenizer = AutoTokenizer.from_pretrained(pretrained_model) model = BertModel.from_pretrained(pretrained_model, return_dict=True) small_config = copy.deepcopy(model.config) small_config.num_hidden_layers = num_layers small_model = BertModel(small_config) small_model.load_state_dict(model.state_dict(), strict=False) tokenizer.save_pretrained(out_dir) small_model.save_pretrained(out_dir)
class BioBert(nn.Module): def __init__(self, num_labels, config, state_dict): super().__init__() self.bert = BertModel(config) self.bert.load_state_dict(state_dict, strict=False) self.dropout = nn.Dropout(p=0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) self.softmax = nn.Softmax(dim=1) def forward(self, input_ids, attention_mask): #https://huggingface.co/transformers/model_doc/bert.html#bertmodel # last_hidden_state: Sequence of hidden-states at the output of the last layer of the model. # pooler_output: Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. last_hidden_state, pooler_output = self.bert( input_ids=input_ids, attention_mask=attention_mask) output = self.dropout(pooler_output) out = self.classifier(output) return out
class BioBertNER(nn.Module): def __init__(self, num_labels, config, state_dict): super().__init__() self.bert = BertModel(config) self.bert.load_state_dict(state_dict, strict=False) self.dropout = nn.Dropout(p=0.3) self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) self.softmax = nn.Softmax(dim=1) def forward(self, input_ids, attention_mask): encoded_layer, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) enlayer = encoded_layer[-1] enlayer = self.dropout(enlayer) outlayer = self.classifier(enlayer) pooled_output = self.dropout(pooled_output) out = self.classifier(pooled_output) return out, outlayer
def get_bert(BERT_PT_PATH, bert_type, do_lower_case, no_pretraining): bert_config_file = os.path.join(BERT_PT_PATH, f'bert_config_{bert_type}.json') vocab_file = os.path.join(BERT_PT_PATH, f'vocab_{bert_type}.txt') init_checkpoint = os.path.join(BERT_PT_PATH, f'pytorch_model_{bert_type}.bin') bert_config = BertConfig.from_json_file(bert_config_file) tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) bert_config.print_status() model_bert = BertModel(bert_config) if no_pretraining: pass else: model_bert.load_state_dict( torch.load(init_checkpoint, map_location='cpu')) print("Load pre-trained parameters.") model_bert.to(device) return model_bert, tokenizer, bert_config
def _load_bert(self, bert_config_path: str, bert_model_path: str): bert_config = BertConfig.from_json_file(bert_config_path) model = BertModel(bert_config) if self.cuda: model_states = torch.load(bert_model_path) else: model_states = torch.load(bert_model_path, map_location='cpu') # fix model_states for k in list(model_states.keys()): if k.startswith("bert."): model_states[k[5:]] = model_states.pop(k) elif k.startswith("cls"): _ = model_states.pop(k) if k[-4:] == "beta": model_states[k[:-4]+"bias"] = model_states.pop(k) if k[-5:] == "gamma": model_states[k[:-5]+"weight"] = model_states.pop(k) model.load_state_dict(model_states) if self.cuda: model.cuda() model.eval() return model
def get_pretrained_model(path, logger, args=None): logger.info('load pretrained model in {}'.format(path)) bert_tokenizer = BertTokenizer.from_pretrained(path) if args is None or args.hidden_layers == 12: bert_config = BertConfig.from_pretrained(path) bert_model = BertModel.from_pretrained(path) else: logger.info('load {} layers bert'.format(args.hidden_layers)) bert_config = BertConfig.from_pretrained(path, num_hidden_layers=args.hidden_layers) bert_model = BertModel(bert_config) model_param_list = [p[0] for p in bert_model.named_parameters()] load_dict = torch.load(os.path.join(path, 'pytorch_model.bin')) new_load_dict = {} for k, v in load_dict.items(): k = k.replace('bert.', '') if k in model_param_list: new_load_dict[k] = v new_load_dict['embeddings.position_ids'] = torch.tensor([i for i in range(512)]).unsqueeze(dim=0) bert_model.load_state_dict(new_load_dict) logger.info('load complete') return bert_config, bert_tokenizer, bert_model
class RenamingModelHybrid(nn.Module): def __init__(self, vocab, top_k, config, device): super(RenamingModelHybrid, self).__init__() self.vocab = vocab self.top_k = top_k self.source_vocab_size = len(self.vocab.source_tokens) + 1 self.graph_encoder = GraphASTEncoder.build( config['encoder']['graph_encoder']) self.graph_emb_size = config['encoder']['graph_encoder']['gnn'][ 'hidden_size'] self.emb_size = 256 state_dict = torch.load( 'saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth', map_location=device) keys_to_delete = [ "cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.weight", "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.weight", "cls.predictions.decoder.bias", "cls.seq_relationship.weight", "cls.seq_relationship.bias" ] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue name = k[5:] # remove `bert.` new_state_dict[name] = v bert_config = BertConfig(vocab_size=self.source_vocab_size, max_position_embeddings=512, num_hidden_layers=6, hidden_size=self.emb_size, num_attention_heads=4) self.bert_encoder = BertModel(bert_config) self.bert_encoder.load_state_dict(new_state_dict) self.target_vocab_size = len(self.vocab.all_subtokens) + 1 bert_config = BertConfig(vocab_size=self.target_vocab_size, max_position_embeddings=1000, num_hidden_layers=6, hidden_size=self.emb_size, num_attention_heads=4, is_decoder=True) self.bert_decoder = BertModel(bert_config) state_dict = torch.load( 'saved_checkpoints/bert_0905/bert_decoder_epoch_19_batch_220000.pth', map_location=device) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue if 'crossattention' in k: continue name = k[5:] # remove `bert.` new_state_dict[name] = v for key in new_state_dict: self.bert_decoder.state_dict()[key].copy_(new_state_dict[key]) self.enc_graph_map = nn.Linear(self.emb_size + self.graph_emb_size, self.emb_size) self.fc_final = nn.Linear(self.emb_size, self.target_vocab_size) self.fc_final.weight.data = state_dict['model'][ 'cls.predictions.decoder.weight'] def forward(self, src_tokens, src_mask, variable_ids, target_tokens, graph_input): encoder_attention_mask = torch.ones_like(src_tokens).float().to( src_tokens.device) encoder_attention_mask[src_tokens == PAD_ID] = 0.0 assert torch.max(src_tokens) < self.source_vocab_size assert torch.min(src_tokens) >= 0 assert torch.max(target_tokens) < self.target_vocab_size assert torch.min(target_tokens) >= 0 encoder_output = self.bert_encoder( input_ids=src_tokens, attention_mask=encoder_attention_mask)[0] graph_output = self.graph_encoder(graph_input) variable_emb = graph_output['variable_encoding'] graph_embedding = torch.gather( variable_emb, 1, variable_ids.unsqueeze(2).repeat( 1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2) full_enc_output = self.enc_graph_map( torch.cat((encoder_output, graph_embedding), dim=2)) decoder_attention_mask = torch.ones_like(target_tokens).float().to( target_tokens.device) decoder_attention_mask[target_tokens == PAD_ID] = 0.0 decoder_output = self.bert_decoder( input_ids=target_tokens, attention_mask=decoder_attention_mask, encoder_hidden_states=full_enc_output, encoder_attention_mask=encoder_attention_mask)[0] predictions = self.fc_final(decoder_output) return predictions def predict(self, src_tokens, src_mask, variable_ids, graph_input, approx=False): end_token = self.vocab.all_subtokens.word2id['</s>'] start_token = self.vocab.all_subtokens.word2id['<s>'] batch_size = src_tokens.shape[0] encoder_attention_mask = torch.ones_like(src_tokens).float().to( src_tokens.device) encoder_attention_mask[src_tokens == PAD_ID] = 0.0 assert torch.max(src_tokens) < self.source_vocab_size assert torch.min(src_tokens) >= 0 encoder_output = self.bert_encoder( input_ids=src_tokens, attention_mask=encoder_attention_mask)[0] graph_output = self.graph_encoder(graph_input) variable_emb = graph_output['variable_encoding'] graph_embedding = torch.gather( variable_emb, 1, variable_ids.unsqueeze(2).repeat( 1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2) full_enc_output = self.enc_graph_map( torch.cat((encoder_output, graph_embedding), dim=2)) source_vocab_to_target = { self.vocab.source_tokens.word2id[t]: self.vocab.all_subtokens.word2id[t] for t in self.vocab.source_tokens.word2id.keys() } src_target_maps = [] confidences = [] for i in range(batch_size): if src_tokens[i][0] != start_token: input_sequence = torch.zeros(src_tokens.shape[1] + 1).to( src_tokens.device) input_mask = torch.zeros(src_mask.shape[1] + 1).to( src_mask.device) input_sequence[1:] = src_tokens[i] input_mask[1:] = src_mask[i] else: input_sequence = src_tokens[i] input_mask = src_mask[i] num_vars = int(input_mask.sum()) seq_len = torch.sum((input_sequence != PAD_ID).long()) generated_seqs = torch.zeros(1, min( seq_len + 10 * num_vars, 1000)).long().to(src_tokens.device) source_marker = 0 gen_markers = torch.LongTensor([0]).to(generated_seqs.device) prior_probs = torch.FloatTensor([0]).to(generated_seqs.device) candidate_maps = [{}] for _ in range(num_vars): # Filling up the known (non-identifier) tokens while source_marker < seq_len and input_mask[ source_marker] != 1: token = input_sequence[source_marker] values = source_vocab_to_target[token.item( )] * torch.ones_like(gen_markers).to(generated_seqs.device) generated_seqs = torch.scatter(generated_seqs, 1, gen_markers.unsqueeze(1), values.unsqueeze(1)) source_marker += 1 gen_markers += 1 if source_marker >= seq_len: break curr_var = input_sequence[source_marker].item() if curr_var in candidate_maps[0]: if approx is True: source_marker += 1 continue # If we've seen this variable before, just use the previous predictions and update the scores # Note - it's enough to check candidate_maps[0] because if it is in the first map, it is in all of them orig_markers = gen_markers.clone() for j in range(len(candidate_maps)): pred = candidate_maps[j][curr_var] generated_seqs[j][gen_markers[j]:gen_markers[j] + len(pred)] = torch.LongTensor( pred).to(generated_seqs.device) gen_markers[j] += len(pred) decoder_attention_mask = torch.ones_like( generated_seqs).float().to(generated_seqs.device) decoder_attention_mask[generated_seqs == PAD_ID] = 0.0 decoder_output = self.bert_decoder( input_ids=generated_seqs, attention_mask=decoder_attention_mask, encoder_hidden_states=full_enc_output[i].unsqueeze(0), encoder_attention_mask=encoder_attention_mask[i]. unsqueeze(0))[0] probabilities = F.log_softmax( self.fc_final(decoder_output), dim=-1) # Add up the scores of the token at the __next__ time step scores = torch.zeros(generated_seqs.shape[0]).to( generated_seqs.device) active = torch.ones(generated_seqs.shape[0]).long().to( generated_seqs.device) temp_markers = orig_markers while torch.sum(active) != 0: position_probs = torch.gather( probabilities, 1, (temp_markers - 1).reshape(-1, 1, 1).repeat( 1, 1, probabilities.shape[2])).squeeze(1) curr_tokens = torch.gather(generated_seqs, 1, temp_markers.unsqueeze(1)) tok_probs = torch.gather(position_probs, 1, curr_tokens).squeeze(1) tok_probs *= active scores += tok_probs active *= (temp_markers != (gen_markers - 1)).long() temp_markers += active # Update the prior probabilities prior_probs = prior_probs + scores else: # You encounter a new variable which hasn't been seen before # Generate <beam_width> possibilities for its name generated_seqs, gen_markers, prior_probs, candidate_maps = self.beam_search( generated_seqs, gen_markers, prior_probs, candidate_maps, curr_var, full_enc_output[i].unsqueeze(0), encoder_attention_mask[i].unsqueeze(0), beam_width=5, top_k=self.top_k) source_marker += 1 final_ind = torch.argmax(prior_probs) confidence = torch.max(prior_probs).item() src_target_map = candidate_maps[final_ind] src_target_maps.append(src_target_map) confidences.append(confidence) return src_target_maps, confidences def beam_search(self, generated_seqs, gen_markers, prior_probs, candidate_maps, curr_var, full_enc_output, encoder_attention_mask, beam_width=5, top_k=10): if generated_seqs.shape[0] * beam_width < top_k: beam_width = top_k active = torch.ones_like(gen_markers).to(gen_markers.device) beam_alpha = 0.7 end_token = self.vocab.all_subtokens.word2id['</s>'] candidate_maps = candidate_maps orig_markers = gen_markers.clone() for _ in range(10): # Predict at most 10 subtokens decoder_attention_mask = torch.ones_like( generated_seqs).float().to(generated_seqs.device) decoder_attention_mask[generated_seqs == PAD_ID] = 0.0 decoder_output = self.bert_decoder( input_ids=generated_seqs, attention_mask=decoder_attention_mask, encoder_hidden_states=full_enc_output, encoder_attention_mask=encoder_attention_mask)[0] probabilities = F.log_softmax(self.fc_final(decoder_output), dim=-1) # Gather the predictions at the current markers # (gen_marker - 1) because prediction happens one step ahead probabilities = torch.gather( probabilities, 1, (gen_markers - 1).reshape(-1, 1, 1).repeat( 1, 1, probabilities.shape[2])).squeeze(1) probs, preds = probabilities.sort(dim=-1, descending=True) probs *= active.unsqueeze( 1) # Set log prob of non-active ones to 0 preds[ active == 0] = end_token # Set preds of non-active ones to the end token (ie, remain unchanged) # Repeat active ones only once. Repeat the rest beam_width no. of times. filter_mask = torch.ones( (preds.shape[0], beam_width)).long().to(preds.device) filter_mask *= active.unsqueeze(1) filter_mask[:, 0][active == 0] = 1 filter_mask = filter_mask.reshape(-1) preds = preds[:, :beam_width].reshape(-1)[filter_mask == 1] probs = probs[:, :beam_width].reshape(-1)[filter_mask == 1] generated_seqs = torch.repeat_interleave(generated_seqs, beam_width, dim=0)[filter_mask == 1] orig_markers = torch.repeat_interleave(orig_markers, beam_width, dim=0)[filter_mask == 1] gen_markers = torch.repeat_interleave(gen_markers, beam_width, dim=0)[filter_mask == 1] active = torch.repeat_interleave(active, beam_width, dim=0)[filter_mask == 1] prior_probs = torch.repeat_interleave(prior_probs, beam_width, dim=0)[filter_mask == 1] candidate_maps = [ item.copy() for item in candidate_maps for _ in range(beam_width) ] candidate_maps = [ candidate_maps[i] for i in range(len(candidate_maps)) if filter_mask[i] == 1 ] generated_seqs.scatter_(1, gen_markers.unsqueeze(1), preds.unsqueeze(1)) # lengths = (gen_markers - gen_marker + 1).float() # penalties = torch.pow(5 + lengths, beam_alpha) / math.pow(6, beam_alpha) penalties = torch.ones_like(probs).to(probs.device) updated_probs = probs + prior_probs sort_inds = (updated_probs / penalties).argsort(descending=True) updated_probs = updated_probs[sort_inds] prior_probs = updated_probs[:top_k] new_preds = preds[sort_inds[:top_k]] generated_seqs = generated_seqs[sort_inds[:top_k]] gen_markers = gen_markers[sort_inds[:top_k]] active = active[sort_inds[:top_k]] orig_markers = orig_markers[sort_inds[:top_k]] candidate_maps = [ candidate_maps[ind.item()] for ind in sort_inds[:top_k] ] active = active * (new_preds != end_token).long() gen_markers += active if torch.sum(active) == 0: break # gen_markers are pointing at the end_token. Move them one ahead gen_markers += 1 assert generated_seqs.shape[0] == top_k for i in range(top_k): candidate_maps[i][curr_var] = generated_seqs[i][ orig_markers[i]:gen_markers[i]].cpu().tolist() return generated_seqs, gen_markers, prior_probs, candidate_maps
class SequentialEncoder(Encoder): def __init__(self, config): super().__init__() self.vocab = vocab = Vocab.load(config['vocab_file']) self.src_word_embed = nn.Embedding(len(vocab.source_tokens), config['source_embedding_size']) self.config = config self.decoder_cell_init = nn.Linear(config['source_encoding_size'], config['decoder_hidden_size']) if self.config['transformer'] == 'none': dropout = config['dropout'] self.lstm_encoder = nn.LSTM(input_size=self.src_word_embed.embedding_dim, hidden_size=config['source_encoding_size'] // 2, num_layers=config['num_layers'], batch_first=True, bidirectional=True, dropout=dropout) self.dropout = nn.Dropout(dropout) elif self.config['transformer'] == 'bert': self.vocab_size = len(self.vocab.source_tokens) + 1 state_dict = torch.load('saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth') keys_to_delete = ["cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.weight", "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.weight", "cls.predictions.decoder.bias", "cls.seq_relationship.weight", "cls.seq_relationship.bias"] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue name = k[5:] # remove `bert.` new_state_dict[name] = v bert_config = BertConfig(vocab_size=self.vocab_size, max_position_embeddings=512, num_hidden_layers=6, hidden_size=256, num_attention_heads=4) self.bert_model = BertModel(bert_config) self.bert_model.load_state_dict(new_state_dict) elif self.config['transformer'] == 'xlnet': self.vocab_size = len(self.vocab.source_tokens) + 1 state_dict = torch.load('saved_checkpoints/xlnet_2704/xlnet1_pretrained_epoch_13_iter_500000.pth') keys_to_delete = ["lm_loss.weight", "lm_loss.bias"] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue if k[:12] == 'transformer.': name = k[12:] else: name = k new_state_dict[name] = v xlnet_config = XLNetConfig(vocab_size=self.vocab_size, d_model=256, n_layer=12) self.xlnet_model = XLNetModel(xlnet_config) self.xlnet_model.load_state_dict(new_state_dict) else: print("Error! Unknown transformer type '{}'".format(self.config['transformer'])) @property def device(self): return self.src_word_embed.weight.device @classmethod def default_params(cls): return { 'source_encoding_size': 256, 'decoder_hidden_size': 128, 'source_embedding_size': 128, 'vocab_file': None, 'num_layers': 1 } @classmethod def build(cls, config): params = util.update(SequentialEncoder.default_params(), config) return cls(params) def forward(self, tensor_dict: Dict[str, torch.Tensor]): if self.config['transformer'] == 'bert': code_token_encoding, code_token_mask = self.encode_bert(tensor_dict['src_code_tokens']) elif self.config['transformer'] == 'xlnet': code_token_encoding, code_token_mask = self.encode_xlnet(tensor_dict['src_code_tokens']) elif self.config['transformer'] == 'none': code_token_encoding, code_token_mask, (last_states, last_cells) = self.encode_sequence(tensor_dict['src_code_tokens']) else: print("Error! Unknown transformer type '{}'".format(self.config['transformer'])) # (batch_size, max_variable_mention_num) # variable_mention_positions = tensor_dict['variable_position'] variable_mention_mask = tensor_dict['variable_mention_mask'] variable_mention_to_variable_id = tensor_dict['variable_mention_to_variable_id'] # (batch_size, max_variable_num) variable_encoding_mask = tensor_dict['variable_encoding_mask'] variable_mention_num = tensor_dict['variable_mention_num'] # # (batch_size, max_variable_mention_num, encoding_size) # variable_mention_encoding = torch.gather(code_token_encoding, 1, variable_mention_positions.unsqueeze(-1).expand(-1, -1, code_token_encoding.size(-1))) * variable_mention_positions_mask max_time_step = variable_mention_to_variable_id.size(1) variable_num = variable_mention_num.size(1) encoding_size = code_token_encoding.size(-1) variable_mention_encoding = code_token_encoding * variable_mention_mask.unsqueeze(-1) variable_encoding = torch.zeros(tensor_dict['batch_size'], variable_num, encoding_size, device=self.device) variable_encoding.scatter_add_(1, variable_mention_to_variable_id.unsqueeze(-1).expand(-1, -1, encoding_size), variable_mention_encoding) * variable_encoding_mask.unsqueeze(-1) variable_encoding = variable_encoding / (variable_mention_num + (1. - variable_encoding_mask) * nn_util.SMALL_NUMBER).unsqueeze(-1) if self.config['transformer'] == 'bert' or self.config['transformer'] == 'xlnet': context_encoding = dict( variable_encoding=variable_encoding, code_token_encoding=code_token_encoding, code_token_mask=code_token_mask ) else: context_encoding = dict( variable_encoding=variable_encoding, code_token_encoding=code_token_encoding, code_token_mask=code_token_mask, last_states=last_states, last_cells=last_cells ) context_encoding.update(tensor_dict) return context_encoding def encode_xlnet(self, input_ids): attention_mask = torch.ones_like(input_ids).float() attention_mask[input_ids == PAD_ID] = 0.0 assert torch.max(input_ids) < self.vocab_size assert torch.min(input_ids) >= 0 if torch.cuda.is_available(): input_ids = input_ids.cuda() attention_mask = attention_mask.cuda() outputs = self.xlnet_model(input_ids=input_ids, attention_mask=attention_mask) return outputs[0], attention_mask def encode_bert(self, input_ids): attention_mask = torch.ones_like(input_ids).float() attention_mask[input_ids == PAD_ID] = 0.0 assert torch.max(input_ids) < self.vocab_size assert torch.min(input_ids) >= 0 if torch.cuda.is_available(): input_ids = input_ids.cuda() attention_mask = attention_mask.cuda() outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask) return outputs[0], attention_mask def encode_sequence(self, code_sequence): # (batch_size, max_code_length) # code_sequence = tensor_dict['src_code_tokens'] # (batch_size, max_code_length, embed_size) code_token_embedding = self.src_word_embed(code_sequence) # (batch_size, max_code_length) code_token_mask = torch.ne(code_sequence, PAD_ID).float() # (batch_size) code_sequence_length = code_token_mask.sum(dim=-1).long() sorted_seqs, sorted_seq_lens, restoration_indices, sorting_indices = nn_util.sort_batch_by_length(code_token_embedding, code_sequence_length) packed_question_embedding = pack_padded_sequence(sorted_seqs, sorted_seq_lens.data.tolist(), batch_first=True) sorted_encodings, (last_states, last_cells) = self.lstm_encoder(packed_question_embedding) sorted_encodings, _ = pad_packed_sequence(sorted_encodings, batch_first=True) # apply dropout to the last layer # (batch_size, seq_len, hidden_size * 2) sorted_encodings = self.dropout(sorted_encodings) # (batch_size, question_len, hidden_size * 2) restored_encodings = sorted_encodings.index_select(dim=0, index=restoration_indices) # (num_layers, direction_num, batch_size, hidden_size) last_states = last_states.view(self.lstm_encoder.num_layers, 2, -1, self.lstm_encoder.hidden_size) last_states = last_states.index_select(dim=2, index=restoration_indices) last_cells = last_cells.view(self.lstm_encoder.num_layers, 2, -1, self.lstm_encoder.hidden_size) last_cells = last_cells.index_select(dim=2, index=restoration_indices) return restored_encodings, code_token_mask, (last_states, last_cells) @classmethod def to_tensor_dict(cls, examples: List[Example], next_examples=None, flips=None) -> Dict[str, torch.Tensor]: if next_examples is not None: max_time_step = max(e.source_seq_length + n.source_seq_length for e,n in zip(examples, next_examples)) else: max_time_step = max(e.source_seq_length for e in examples) input = np.zeros((len(examples), max_time_step), dtype=np.int64) if next_examples is not None: seq_mask = torch.zeros((len(examples), max_time_step), dtype=torch.long) else: seq_mask = None variable_mention_to_variable_id = torch.zeros(len(examples), max_time_step, dtype=torch.long) variable_mention_mask = torch.zeros(len(examples), max_time_step) variable_mention_num = torch.zeros(len(examples), max(len(e.ast.variables) for e in examples)) variable_encoding_mask = torch.zeros(variable_mention_num.size()) for e_id, example in enumerate(examples): sub_tokens = example.sub_tokens input[e_id, :len(sub_tokens)] = example.sub_token_ids if next_examples is not None: next_example = next_examples[e_id] next_tokens = next_example.sub_tokens input[e_id, len(sub_tokens):len(sub_tokens)+len(next_tokens)] = next_example.sub_token_ids seq_mask[e_id, len(sub_tokens):] = 1 # seq_mask[e_id, len(sub_tokens):len(sub_tokens)+len(next_tokens)] = 1 variable_position_map = dict() var_name_to_id = {name: i for i, name in enumerate(example.ast.variables)} for i, sub_token in enumerate(sub_tokens): if sub_token.startswith('@@') and sub_token.endswith('@@'): old_var_name = sub_token[2: -2] if old_var_name in var_name_to_id: # sometimes there are strings like `@@@@` var_id = var_name_to_id[old_var_name] variable_mention_to_variable_id[e_id, i] = var_id variable_mention_mask[e_id, i] = 1. variable_position_map.setdefault(old_var_name, []).append(i) for var_id, var_name in enumerate(example.ast.variables): try: var_pos = variable_position_map[var_name] variable_mention_num[e_id, var_id] = len(var_pos) except KeyError: variable_mention_num[e_id, var_id] = 1 print(example.binary_file, f'variable [{var_name}] not found', file=sys.stderr) variable_encoding_mask[e_id, :len(example.ast.variables)] = 1. batch_dict = dict(src_code_tokens=torch.from_numpy(input), variable_mention_to_variable_id=variable_mention_to_variable_id, variable_mention_mask=variable_mention_mask, variable_mention_num=variable_mention_num, variable_encoding_mask=variable_encoding_mask, batch_size=len(examples)) if next_examples is not None: batch_dict['next_seq_mask'] = seq_mask, batch_dict['next_sentence_label'] = torch.LongTensor(flips) return batch_dict def get_decoder_init_state(self, context_encoder, config=None): if 'last_cells' not in context_encoder: if self.config['init_decoder']: dec_init_cell = self.decoder_cell_init(torch.mean(context_encoder['code_token_encoding'], dim=1)) dec_init_state = torch.tanh(dec_init_cell) else: dec_init_cell = dec_init_state = None elif 'last_cells' in context_encoder: fwd_last_layer_cell = context_encoder['last_cells'][-1, 0] bak_last_layer_cell = context_encoder['last_cells'][-1, 1] dec_init_cell = self.decoder_cell_init(torch.cat([fwd_last_layer_cell, bak_last_layer_cell], dim=-1)) dec_init_state = torch.tanh(dec_init_cell) return dec_init_state, dec_init_cell def get_attention_memory(self, context_encoding, att_target='terminal_nodes'): assert att_target == 'terminal_nodes' memory = context_encoding['code_token_encoding'] mask = context_encoding['code_token_mask'] return memory, mask
def load_bert(bert_path, device): bert_config_path = os.path.join(bert_path, 'config.json') bert = BertModel(BertConfig(**load_json(bert_config_path))).to(device) bert_model_path = os.path.join(bert_path, 'model.bin') bert.load_state_dict(clean_state_dict(torch.load(bert_model_path))) return bert
def main(): parser = argparse.ArgumentParser() # 1. 训练和测试数据路径 parser.add_argument("--data_dir", default='./data/cluener', type=str, help="Path to data.") parser.add_argument("--type_description", default='./data/cluener/type_des.json', type=str, help="Path to data.") # 2. 预训练模型路径 parser.add_argument("--vocab_file", default="./data/pretrain/vocab.txt", type=str, help="Init vocab to resume training from.") parser.add_argument("--config_path", default="./data/pretrain/config.json", type=str, help="Init config to resume training from.") parser.add_argument("--init_checkpoint", default="./data/pretrain/pytorch_model.bin", type=str, help="Init checkpoint to resume training from.") # 3. 保存模型 parser.add_argument("--save_path", default="./check_points/", type=str, help="Path to save checkpoints.") parser.add_argument("--load_path", default=None, type=str, help="Path to load checkpoints.") # 训练和测试参数 parser.add_argument("--do_train", default=True, type=bool, help="Whether to perform training.") parser.add_argument("--do_eval", default=True, type=bool, help="Whether to perform evaluation on test data set.") parser.add_argument("--do_predict", default=False, type=bool, help="Whether to perform evaluation on test data set.") parser.add_argument("--do_adv", default=True, type=bool) parser.add_argument("--epochs", default=10, type=int, help="Number of epoches for fine-tuning.") parser.add_argument("--train_batch_size", default=8, type=int, help="Total examples' number in batch for training.") parser.add_argument("--eval_batch_size", default=1, type=int, help="Total examples' number in batch for eval.") parser.add_argument("--max_seq_len", default=300, type=int, help="Number of words of the longest seqence.") parser.add_argument("--learning_rate", default=1e-5, type=float, help="Learning rate used to train with warmup.") parser.add_argument( "--warmup_proportion", default=0.01, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10% of training.") parser.add_argument("--use_cuda", type=bool, default=True, help="whether to use cuda") parser.add_argument("--log_steps", type=int, default=20, help="The steps interval to print loss.") parser.add_argument("--eval_step", type=int, default=1000, help="The steps interval to print loss.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") args = parser.parse_args() if args.use_cuda: device = torch.device("cuda") n_gpu = torch.cuda.device_count() else: device = torch.device("cpu") n_gpu = 0 logger.info("device: {}, n_gpu: {}".format(device, n_gpu)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not os.path.exists(args.save_path): os.mkdir(args.save_path) model_path_postfix = '' if args.do_adv: model_path_postfix += '_adv' args.save_path = os.path.join(args.save_path, 'ner' + model_path_postfix) if not os.path.exists(args.save_path): os.mkdir(args.save_path) bert_tokenizer = util.CNerTokenizer.from_pretrained(args.vocab_file) bert_config = BertConfig.from_pretrained(args.config_path) type2description = json.load(open(args.type_description)) # 获取数据 train_dataset = None eval_dataset = None if args.do_train: logger.info("loading train dataset") train_dataset = data_helper.NER_dataset( os.path.join(args.data_dir, 'train.json'), bert_tokenizer, args.max_seq_len, type2description) if args.do_eval: logger.info("loading eval dataset") eval_dataset = data_helper.NER_dataset(os.path.join( args.data_dir, 'dev.json'), bert_tokenizer, args.max_seq_len, type2description, shuffle=False) if args.do_predict: logger.info("loading test dataset") test_dataset = data_helper.NER_dataset(os.path.join( args.data_dir, 'test.json'), bert_tokenizer, args.max_seq_len, type2description, shuffle=False) if args.do_train: logging.info("Start training !") train_helper.train(bert_tokenizer, bert_config, args, train_dataset, eval_dataset) if not args.do_train and args.do_eval: logging.info("Start evaluating !") bert_model = BertModel(config=bert_config) span_model = span_type.EntitySpan(config=bert_config) state = torch.load(args.load_path) bert_model.load_state_dict(state['bert_state_dict']) span_model.load_state_dict(state['span_state_dict']) logging.info("Checkpoint: %s have been loaded!" % (args.load_path)) if args.use_cuda: bert_model.cuda() span_model.cuda() model_list = [bert_model, span_model] train_helper.evaluate(args, eval_dataset, model_list) if args.do_predict: logging.info("Start predicting !") bert_model = BertModel(config=bert_config) span_model = span_type.EntitySpan(config=bert_config) state = torch.load(args.load_path) bert_model.load_state_dict(state['bert_state_dict']) span_model.load_state_dict(state['span_state_dict']) logging.info("Checkpoint: %s have been loaded!" % (args.load_path)) if args.use_cuda: bert_model.cuda() span_model.cuda() model_list = [bert_model, span_model] predict_res = train_helper.predict(args, test_dataset, model_list)
'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'intermediate_size': 3072, 'max_position_embeddings': 512, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'type_vocab_size': 2, 'vocab_size': 8002 } if __name__ == "__main__": ctx = "cpu" # kobert kobert_model_file = "./kobert_resources/pytorch_kobert_2439f391a6.params" kobert_vocab_file = "./kobert_resources/kobert_news_wiki_ko_cased-ae5711deb3.spiece" bertmodel = BertModel(config=BertConfig.from_dict(bert_config)) bertmodel.load_state_dict(torch.load(kobert_model_file)) device = torch.device(ctx) bertmodel.to(device) # bertmodel.eval() # for name, param in bertmodel.named_parameters(): # print(name, param.shape) for name, param in bertmodel.named_parameters(): if param.requires_grad: print(name, param.shape)
class NERPredict(IPredict): ''' 构造函数, 初始化预测器 use_gpu: 使用GPU bert_config_file_name: Bert模型配置文件路径 vocab_file_name: 单词表文件路径 tags_file_name: Tag表文件路径 bert_model_path: Bert模型装载路径 lstm_crf_model_path: CRF模型装载路径 hidden_dim: CRF隐藏层 ''' def __init__(self, use_gpu, bert_config_file_name, vocab_file_name, tags_file_name, bert_model_path, lstm_crf_model_path, hidden_dim): self.use_gpu = use_gpu self.data_manager_init(vocab_file_name, tags_file_name) self.tokenizer = BertTokenizer.from_pretrained(vocab_file_name) self.model_init(hidden_dim, bert_config_file_name, bert_model_path, lstm_crf_model_path) def data_manager_init(self, vocab_file_name, tags_file_name): tags_list = BERTDataManager.ReadTagsList(tags_file_name) tags_list = [tags_list] self.dm = BERTDataManager(tags_list=tags_list, vocab_file_name=vocab_file_name) def model_init(self, hidden_dim, bert_config_file_name, bert_model_path, lstm_crf_model_path): config = BertConfig.from_json_file(bert_config_file_name) self.model = BertModel(config) bert_dict = torch.load(bert_model_path).module.state_dict() self.model.load_state_dict(bert_dict) self.birnncrf = torch.load(lstm_crf_model_path) self.model.eval() self.birnncrf.eval() def data_process(self, sentences): result = [] pad_tag = '[PAD]' if type(sentences) == str: sentences = [sentences] max_len = 0 for sentence in sentences: encode = self.tokenizer.encode(sentence, add_special_tokens=True) result.append(encode) if max_len < len(encode): max_len = len(encode) for i, sentence in enumerate(result): remain = max_len - len(sentence) for _ in range(remain): result[i].append(self.dm.wordToIdx(pad_tag)) return torch.tensor(result) def pred(self, sentences): sentences = self.data_process(sentences) if torch.cuda.is_available() and self.use_gpu: self.model.cuda() self.birnncrf.cuda() sentences = sentences.cuda() outputs = self.model(input_ids=sentences, attention_mask=sentences.gt(0)) hidden_states = outputs[0] scores, tags = self.birnncrf(hidden_states, sentences.gt(0)) final_tags = [] decode_sentences = [] for item in tags: final_tags.append([self.dm.idx_to_tag[tag] for tag in item]) for item in sentences.tolist(): decode_sentences.append(self.tokenizer.decode(item)) return (scores, tags, final_tags, decode_sentences) def __call__(self, sentences): return self.pred(sentences)
def __init__(self): self.src_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased') self.tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.tgt_tokenizer.bos_token = '<s>' self.tgt_tokenizer.eos_token = '</s>' #hidden_size and intermediate_size are both wrt all the attention heads. #Should be divisible by num_attention_heads encoder_config = BertConfig(vocab_size=self.src_tokenizer.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_dropout_prob=config.dropout_prob, attention_probs_dropout_prob=config.dropout_prob, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12) decoder_config = BertConfig(vocab_size=self.tgt_tokenizer.vocab_size, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_dropout_prob=config.dropout_prob, attention_probs_dropout_prob=config.dropout_prob, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, is_decoder=True) #Create encoder and decoder embedding layers. encoder_embeddings = torch.nn.Embedding(self.src_tokenizer.vocab_size, config.hidden_size, padding_idx=self.src_tokenizer.pad_token_id) decoder_embeddings = torch.nn.Embedding(self.tgt_tokenizer.vocab_size, config.hidden_size, padding_idx=self.tgt_tokenizer.pad_token_id) encoder = BertModel(encoder_config) encoder.set_input_embeddings(encoder_embeddings.cpu()) decoder = BertForMaskedLM(decoder_config) decoder.set_input_embeddings(decoder_embeddings.cpu()) input_dirs = config.model_output_dirs suffix = "pytorch_model.bin" decoderPath = os.path.join(input_dirs['decoder'], suffix) encoderPath = os.path.join(input_dirs['encoder'], suffix) decoder_state_dict = torch.load(decoderPath) encoder_state_dict = torch.load(encoderPath) decoder.load_state_dict(decoder_state_dict) encoder.load_state_dict(encoder_state_dict) self.model = TranslationModel(encoder, decoder, None, None, self.tgt_tokenizer, config) self.model.cpu() #model.eval() self.model.encoder.eval() self.model.decoder.eval()
# get paths for model weights store/load if args.exp_name is not None: args.checkpoint_dir = '%s/%s/%s' % (SAVE_DIR, args.dataset, args.exp_name) else: args.checkpoint_dir = '%s/%s/%s' % (SAVE_DIR, args.dataset, args.config_name) print('checkpoints dir : ', args.checkpoint_dir) if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) args.start_epoch = 0 if args.resume: # for testing, one has to load models from certain path if args.iter != -1: resume_file = get_assigned_file(args.checkpoint_dir, args.iter) else: resume_file = get_resume_file(args.checkpoint_dir) if resume_file is not None: print('Resume file is: ', resume_file) tmp = torch.load(resume_file) start_epoch = tmp['epoch'] + 1 projection.load_state_dict(tmp['projection']) model.load_state_dict(tmp['feature']) else: raise Exception('Resume file not found') train(data_loader, model, projection, args)