def evaluate_ppl(model: RenamingModel, dataset: Dataset, config: Dict, predicate: Any = None): if predicate is None: def predicate(_): return True eval_batch_size = config['train']['batch_size'] num_readers = config['train']['num_readers'] num_batchers = config['train']['num_batchers'] data_iter = dataset.batch_iterator(batch_size=eval_batch_size, train=False, progress=True, return_examples=False, return_prediction_target=True, config=model.config, num_readers=num_readers, num_batchers=num_batchers) was_training = model.training model.eval() cum_log_probs = 0. cum_num_examples = 0 with torch.no_grad(): for batch in data_iter: td = batch.tensor_dict nn_util.to(td, model.device) result = model(td, td['prediction_target']) log_probs = result['batch_log_prob'].cpu().tolist() for e_id, test_meta in enumerate(td['test_meta']): if predicate(test_meta): log_prob = log_probs[e_id] cum_log_probs += log_prob cum_num_examples += 1 ppl = np.exp(-cum_log_probs / cum_num_examples) if was_training: model.train() return ppl
def predict(self, source_asts: List[AbstractSyntaxTree]): """ Given a batch of ASTs, predict their new variable names """ tensor_dict = self.batcher.to_tensor_dict(source_asts=source_asts) nn_util.to(tensor_dict, self.device) context_encoding = self.encoder(tensor_dict) # (prediction_size, tgt_vocab_size) packed_var_name_log_probs = self.decoder(context_encoding) best_var_name_log_probs, best_var_name_ids = torch.max( packed_var_name_log_probs, dim=-1) variable_rename_results = [] pred_node_ptr = 0 for ast_id, ast in enumerate(source_asts): variable_rename_result = dict() for var_name in ast.variables: var_name_prob = best_var_name_log_probs[pred_node_ptr].item() token_id = best_var_name_ids[pred_node_ptr].item() new_var_name = self.vocab.target.id2word[token_id] if new_var_name == SAME_VARIABLE_TOKEN: new_var_name = var_name variable_rename_result[var_name] = { 'new_name': new_var_name, 'prob': var_name_prob } pred_node_ptr += 1 variable_rename_results.append(variable_rename_result) assert pred_node_ptr == packed_var_name_log_probs.size(0) return variable_rename_results
def process_batch(self, batch): if self.p.hybrid: X, Y, src_mask, variable_ids, target_mask, src_target_maps, body_in_train, graph_input = batch else: X, Y, src_mask, target_mask, src_target_maps, body_in_train = batch X = X.long().to(self.device) Y = Y.long().to(self.device) src_mask = src_mask.float().to(self.device) target_mask = target_mask.float().to(self.device) if self.p.hybrid: variable_ids = variable_ids.long().to(self.device) graph_input = nn_util.to(graph_input, self.device) return (X, Y, src_mask, variable_ids, target_mask, src_target_maps, body_in_train, graph_input) else: return (X, Y, src_mask, target_mask, src_target_maps, body_in_train)
def predict(self, examples: List[Example], encoder: Encoder) -> List[Dict]: batch_size = len(examples) beam_size = self.config['beam_size'] same_variable_id = self.vocab.target[SAME_VARIABLE_TOKEN] end_of_variable_id = self.vocab.target[END_OF_VARIABLE_TOKEN] variable_nums = [] for ast_id, example in enumerate(examples): variable_nums.append(len(example.ast.variables)) beams = OrderedDict((ast_id, [self.Hypothesis([], 0, 0.)]) for ast_id in range(batch_size)) hyp_scores_tm1 = torch.zeros(len(beams), device=self.device) completed_hyps = [[] for _ in range(batch_size)] tgt_vocab_size = len(self.vocab.target) tensor_dict = self.batcher.to_tensor_dict(examples) nn_util.to(tensor_dict, self.device) context_encoding = encoder(tensor_dict) h_tm1 = h_0 = self.get_init_state(context_encoding) # Note that we are using the `restoration_indices` from `context_encoding`, which is the word-level restoration index # (batch_size, variable_master_node_num, encoding_size) variable_encoding = context_encoding['variable_encoding'] # (batch_size, encoding_size) variable_name_embed_tm1 = att_tm1 = torch.zeros( batch_size, self.lstm_cell.hidden_size, device=self.device) max_prediction_time_step = self.config['max_prediction_time_step'] for t in range(0, max_prediction_time_step): # (total_live_hyp_num, encoding_size) if t > 0: variable_encoding_t = variable_encoding[hyp_ast_ids_t, hyp_variable_ptrs_t] else: variable_encoding_t = variable_encoding[:, 0] if self.config['input_feed']: x = torch.cat( [variable_encoding_t, variable_name_embed_tm1, att_tm1], dim=-1) else: x = torch.cat([variable_encoding_t, variable_name_embed_tm1], dim=-1) h_t, q_t, alpha_t = self.rnn_step(x, h_tm1, context_encoding) # (total_live_hyp_num, vocab_size) hyp_var_name_scores_t = torch.log_softmax(self.state2names(q_t), dim=-1) cont_cand_hyp_scores = hyp_scores_tm1.unsqueeze( -1) + hyp_var_name_scores_t new_beams = OrderedDict() live_beam_ids = [] new_hyp_scores = [] live_prev_hyp_ids = [] new_hyp_var_name_ids = [] new_hyp_ast_ids = [] new_hyp_variable_ptrs = [] is_same_variable_mask = [] beam_start_hyp_pos = 0 for beam_id, (ast_id, beam) in enumerate(beams.items()): beam_end_hyp_pos = beam_start_hyp_pos + len(beam) # (live_beam_size, vocab_size) beam_cont_cand_hyp_scores = cont_cand_hyp_scores[ beam_start_hyp_pos:beam_end_hyp_pos] cont_beam_size = beam_size - len(completed_hyps[ast_id]) beam_new_hyp_scores, beam_new_hyp_positions = torch.topk( beam_cont_cand_hyp_scores.view(-1), k=cont_beam_size, dim=-1) # (cont_beam_size) beam_prev_hyp_ids = beam_new_hyp_positions / tgt_vocab_size beam_hyp_var_name_ids = beam_new_hyp_positions % tgt_vocab_size _prev_hyp_ids = beam_prev_hyp_ids.cpu() _hyp_var_name_ids = beam_hyp_var_name_ids.cpu() _new_hyp_scores = beam_new_hyp_scores.cpu() for i in range(cont_beam_size): prev_hyp_id = _prev_hyp_ids[i].item() prev_hyp = beam[prev_hyp_id] hyp_var_name_id = _hyp_var_name_ids[i].item() new_hyp_score = _new_hyp_scores[i].item() variable_ptr = prev_hyp.variable_ptr if hyp_var_name_id == end_of_variable_id: variable_ptr += 1 # remove empty cases if len(prev_hyp.variable_list ) == 0 or prev_hyp.variable_list[ -1] == end_of_variable_id: continue new_hyp = self.Hypothesis( variable_list=list(prev_hyp.variable_list) + [hyp_var_name_id], variable_ptr=variable_ptr, score=new_hyp_score) if variable_ptr == variable_nums[ast_id]: completed_hyps[ast_id].append(new_hyp) else: new_beams.setdefault(ast_id, []).append(new_hyp) live_beam_ids.append(beam_id) new_hyp_scores.append(new_hyp_score) live_prev_hyp_ids.append(beam_start_hyp_pos + prev_hyp_id) new_hyp_var_name_ids.append(hyp_var_name_id) new_hyp_ast_ids.append(ast_id) new_hyp_variable_ptrs.append(variable_ptr) is_same_variable_mask.append( 1. if prev_hyp.variable_ptr == variable_ptr else 0.) beam_start_hyp_pos = beam_end_hyp_pos if live_beam_ids: hyp_scores_tm1 = torch.tensor(new_hyp_scores, device=self.device) h_tm1 = (h_t[0][live_prev_hyp_ids], h_t[1][live_prev_hyp_ids]) att_tm1 = q_t[live_prev_hyp_ids] variable_name_embed_tm1 = self.state2names.weight[ new_hyp_var_name_ids] hyp_ast_ids_t = new_hyp_ast_ids hyp_variable_ptrs_t = new_hyp_variable_ptrs beams = new_beams if self.independent_prediction_for_each_variable: is_same_variable_mask = torch.tensor( is_same_variable_mask, device=self.device, dtype=torch.float).unsqueeze(-1) h_tm1 = (h_tm1[0] * is_same_variable_mask, h_tm1[1] * is_same_variable_mask) att_tm1 = att_tm1 * is_same_variable_mask variable_name_embed_tm1 = variable_name_embed_tm1 * is_same_variable_mask else: break variable_rename_results = [] for i, hyps in enumerate(completed_hyps): variable_rename_result = dict() ast = examples[i].ast hyps = sorted(hyps, key=lambda hyp: -hyp.score) if not hyps: # return identity renamings print( f'Failed to found a hypothesis for function {ast.compilation_unit}', file=sys.stderr) for old_name in ast.variables: variable_rename_result[old_name] = { 'new_name': old_name, 'prob': 0. } else: top_hyp = hyps[0] sub_token_ptr = 0 for old_name in ast.variables: sub_token_begin = sub_token_ptr while top_hyp.variable_list[ sub_token_ptr] != end_of_variable_id: sub_token_ptr += 1 sub_token_ptr += 1 # point to first sub-token of next variable sub_token_end = sub_token_ptr var_name_token_ids = top_hyp.variable_list[ sub_token_begin:sub_token_end] # include ending </s> if var_name_token_ids == [ same_variable_id, end_of_variable_id ]: new_var_name = old_name else: new_var_name = self.vocab.target.subtoken_model.decode_ids( var_name_token_ids) variable_rename_result[old_name] = { 'new_name': new_var_name, 'prob': top_hyp.score } variable_rename_results.append(variable_rename_result) return variable_rename_results
def predict(self, examples: List[Example], encoder: Encoder) -> List[Any]: batch_size = len(examples) beam_size = self.config['beam_size'] same_variable_id = self.vocab.target[SAME_VARIABLE_TOKEN] end_of_variable_id = self.vocab.target[END_OF_VARIABLE_TOKEN] remove_duplicate = self.config['remove_duplicates_in_prediction'] variable_nums = [] for ast_id, example in enumerate(examples): variable_nums.append(len(example.ast.variables)) beams = OrderedDict((ast_id, [self.Hypothesis([[]], 0, 0.)]) for ast_id in range(batch_size)) hyp_scores_tm1 = torch.zeros(len(beams), device=self.device) completed_hyps = [[] for _ in range(batch_size)] tgt_vocab_size = len(self.vocab.target) tensor_dict = self.batcher.to_tensor_dict(examples) nn_util.to(tensor_dict, self.device) context_encoding = encoder(tensor_dict) # prepare tensors for attention attention_memory = self.get_attention_memory(context_encoding) context_encoding_t = attention_memory h_tm1 = h_0 = self.get_init_state(context_encoding) # Note that we are using the `restoration_indices` from `context_encoding`, which is the word-level restoration index # (batch_size, variable_master_node_num, encoding_size) variable_encoding = context_encoding['variable_encoding'] # (batch_size, encoding_size) variable_name_embed_tm1 = att_tm1 = torch.zeros( batch_size, self.lstm_cell.hidden_size, device=self.device) # cum_variable_name_embed = variable_name_embed_tm1.unsqueeze(1) max_prediction_time_step = self.config['max_prediction_time_step'] for t in range(0, max_prediction_time_step): # (total_live_hyp_num, encoding_size) if t > 0: variable_encoding_t = variable_encoding[hyp_ast_ids_t, hyp_variable_ptrs_t] else: variable_encoding_t = variable_encoding[:, 0] if self.config['input_feed']: x = torch.cat( [variable_encoding_t, variable_name_embed_tm1, att_tm1], dim=-1) else: x = torch.cat([variable_encoding_t, variable_name_embed_tm1], dim=-1) h_t, q_t, alpha_t = self.rnn_step(x, h_tm1, context_encoding_t) # attention_mask = context_encoding['seq_encoding_result']['code_token_mask'] # encoder_hidden_states = context_encoding['seq_encoding_result']['code_token_encoding'] # branch_out = int(cum_variable_name_embed.shape[0]/attention_mask.shape[0]) # attention_mask = attention_mask.unsqueeze(1).repeat(1, branch_out, 1).view(-1, attention_mask.shape[1]) # encoder_hidden_states = encoder_hidden_states.unsqueeze(1).repeat(1, branch_out, 1, 1).view(-1, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2]) # query_vecs = self.bert_model(inputs_embeds=cum_variable_name_embed, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask)[0] # q_t = query_vecs[:, -1, :] # (total_live_hyp_num, vocab_size) hyp_var_name_scores_t = torch.log_softmax(self.state2names(q_t), dim=-1) cont_cand_hyp_scores = hyp_scores_tm1.unsqueeze( -1) + hyp_var_name_scores_t new_beams = OrderedDict() live_beam_ids = [] new_hyp_scores = [] live_prev_hyp_ids = [] new_hyp_var_name_ids = [] new_hyp_ast_ids = [] new_hyp_variable_ptrs = [] is_same_variable_mask = [] beam_start_hyp_pos = 0 for beam_id, (ast_id, beam) in enumerate(beams.items()): beam_end_hyp_pos = beam_start_hyp_pos + len(beam) # (live_beam_size, vocab_size) beam_cont_cand_hyp_scores = cont_cand_hyp_scores[ beam_start_hyp_pos:beam_end_hyp_pos] cont_beam_size = beam_size - len(completed_hyps[ast_id]) # Take `len(beam)` more candidates to account for possible duplicate k = min(beam_cont_cand_hyp_scores.numel(), cont_beam_size + len(beam)) beam_new_hyp_scores, beam_new_hyp_positions = torch.topk( beam_cont_cand_hyp_scores.view(-1), k=k, dim=-1) # (cont_beam_size) beam_prev_hyp_ids = beam_new_hyp_positions / tgt_vocab_size beam_hyp_var_name_ids = beam_new_hyp_positions % tgt_vocab_size _prev_hyp_ids = beam_prev_hyp_ids.cpu() _hyp_var_name_ids = beam_hyp_var_name_ids.cpu() _new_hyp_scores = beam_new_hyp_scores.cpu() beam_cnt = 0 for i in range(len(beam_new_hyp_positions)): prev_hyp_id = _prev_hyp_ids[i].item() prev_hyp = beam[prev_hyp_id] hyp_var_name_id = _hyp_var_name_ids[i].item() new_hyp_score = _new_hyp_scores[i].item() variable_ptr = prev_hyp.variable_ptr new_variable_list = list(prev_hyp.variable_list) new_variable_list[-1] = list(new_variable_list[-1] + [hyp_var_name_id]) if hyp_var_name_id == end_of_variable_id: # remove empty cases if new_variable_list[-1] == [end_of_variable_id]: continue if remove_duplicate: last_pred = new_variable_list[-1] if any(x == last_pred for x in new_variable_list[:-1] if x != [same_variable_id, end_of_variable_id]): # print('found a duplicate!', ', '.join([str(x) for x in last_pred])) continue variable_ptr += 1 new_variable_list.append([]) beam_cnt += 1 new_hyp = self.Hypothesis(variable_list=new_variable_list, variable_ptr=variable_ptr, score=new_hyp_score) if variable_ptr == variable_nums[ast_id]: completed_hyps[ast_id].append(new_hyp) else: new_beams.setdefault(ast_id, []).append(new_hyp) live_beam_ids.append(beam_id) new_hyp_scores.append(new_hyp_score) live_prev_hyp_ids.append(beam_start_hyp_pos + prev_hyp_id) new_hyp_var_name_ids.append(hyp_var_name_id) new_hyp_ast_ids.append(ast_id) new_hyp_variable_ptrs.append(variable_ptr) is_same_variable_mask.append( 1. if prev_hyp.variable_ptr == variable_ptr else 0.) if beam_cnt >= cont_beam_size: break beam_start_hyp_pos = beam_end_hyp_pos if live_beam_ids: hyp_scores_tm1 = torch.tensor(new_hyp_scores, device=self.device) h_tm1 = (h_t[0][live_prev_hyp_ids], h_t[1][live_prev_hyp_ids]) att_tm1 = q_t[live_prev_hyp_ids] if self.config['tie_embedding']: variable_name_embed_tm1 = self.state2names.weight[ new_hyp_var_name_ids] else: variable_name_embed_tm1 = self.var_name_embed.weight[ new_hyp_var_name_ids] # branch_out = int((variable_name_embed_tm1.shape[0] + 1)/cum_variable_name_embed.shape[0]) # if branch_out == 0: branch_out = 1 # tiled = cum_variable_name_embed.unsqueeze(1).repeat(1, branch_out, 1, 1).view(-1, cum_variable_name_embed.shape[1], cum_variable_name_embed.shape[2]) # try: # cum_variable_name_embed = torch.cat([tiled, variable_name_embed_tm1.unsqueeze(1)], axis=1) # except: # pass hyp_ast_ids_t = new_hyp_ast_ids hyp_variable_ptrs_t = new_hyp_variable_ptrs beams = new_beams # (total_hyp_num, max_tree_size, node_encoding_size) context_encoding_t = dict( attention_key=attention_memory['attention_key'] [hyp_ast_ids_t], attention_value=attention_memory['attention_value'] [hyp_ast_ids_t], attention_value_mask=attention_memory[ 'attention_value_mask'][hyp_ast_ids_t]) if self.independent_prediction_for_each_variable: is_same_variable_mask = torch.tensor( is_same_variable_mask, device=self.device, dtype=torch.float).unsqueeze(-1) h_tm1 = (h_tm1[0] * is_same_variable_mask, h_tm1[1] * is_same_variable_mask) att_tm1 = att_tm1 * is_same_variable_mask variable_name_embed_tm1 = variable_name_embed_tm1 * is_same_variable_mask else: break variable_rename_results = [] for i, hyps in enumerate(completed_hyps): ast = examples[i].ast hyps = sorted(hyps, key=lambda hyp: -hyp.score) if not hyps: # return identity renamings print( f'Failed to find a hypothesis for function {ast.compilation_unit}', file=sys.stderr) variable_rename_result = dict() for old_name in ast.variables: variable_rename_result[old_name] = { 'new_name': old_name, 'prob': 0. } example_rename_results = [variable_rename_result] else: # top_hyp = hyps[0] # sub_token_ptr = 0 # for old_name in ast.variables: # sub_token_begin = sub_token_ptr # while top_hyp.variable_list[sub_token_ptr] != end_of_variable_id: # sub_token_ptr += 1 # sub_token_ptr += 1 # point to first sub-token of next variable # sub_token_end = sub_token_ptr # # var_name_token_ids = top_hyp.variable_list[sub_token_begin: sub_token_end] # include ending </s> # if var_name_token_ids == [same_variable_id, end_of_variable_id]: # new_var_name = old_name # else: # new_var_name = self.vocab.target.subtoken_model.decode_ids(var_name_token_ids) # # variable_rename_result[old_name] = {'new_name': new_var_name, # 'prob': top_hyp.score} example_rename_results = [] for hyp in hyps: variable_rename_result = dict() for var_id, old_name in enumerate(ast.variables): var_name_token_ids = hyp.variable_list[var_id] if var_name_token_ids == [ same_variable_id, end_of_variable_id ]: new_var_name = old_name else: new_var_name = self.vocab.target.subtoken_model.decode_ids( var_name_token_ids) variable_rename_result[old_name] = { 'new_name': new_var_name, 'prob': hyp.score } example_rename_results.append(variable_rename_result) variable_rename_results.append(example_rename_results) return variable_rename_results
def predict(self, source_asts: List[AbstractSyntaxTree], encoder: Encoder) -> List[Dict]: beam_size = self.config['beam_size'] unk_replace = self.config['unk_replace'] variable_nums = [] for ast_id, ast in enumerate(source_asts): variable_nums.append(len(ast.variables)) beams = OrderedDict((ast_id, [self.Hypothesis([], 0.)]) for ast_id, ast in enumerate(source_asts)) hyp_scores_tm1 = torch.zeros(len(beams), 1, device=self.device) completed_hyps = [[] for _ in source_asts] tgt_vocab_size = len(self.vocab.target) tensor_dict = self.batcher.to_tensor_dict(source_asts=source_asts) nn_util.to(tensor_dict, self.device) context_encoding = encoder(tensor_dict) h_tm1 = h_0 = self.get_init_state(context_encoding) # (prediction_node_num, encoding_size) variable_master_node_encoding = context_encoding[ 'variable_master_node_encoding'] encoding_size = variable_master_node_encoding.size(1) # (batch_size, max_prediction_node_num) variable_master_node_restoration_indices = context_encoding[ 'variable_encoding_restoration_indices'] # (batch_size, max_prediction_node_num, encoding_size) variable_master_node_encoding = variable_master_node_encoding[ variable_master_node_restoration_indices] variable_encoding_t = variable_master_node_encoding[:, 0] # (batch_size, encoding_size) variable_name_embed_tm1 = att_tm1 = torch.zeros(len(source_asts), encoding_size, device=self.device) max_prediction_node_num = variable_master_node_encoding.size(1) for t in range(0, max_prediction_node_num): live_beam_size = beam_size if t > 0 else 1 live_tree_ids = [ast_id for ast_id in beams] if t > 0: variable_encoding_t = variable_master_node_encoding[live_tree_ids][:, t]\ .unsqueeze(1).expand(-1, beam_size, -1).contiguous().view(-1, encoding_size) if self.config['input_feed']: x = torch.cat( [variable_encoding_t, variable_name_embed_tm1, att_tm1], dim=-1) else: x = torch.cat([variable_encoding_t, variable_name_embed_tm1], dim=-1) h_t, q_t, alpha_t = self.rnn_step(x, h_tm1, context_encoding) # (live_beam_num, beam_size, encoding_size) q_t_by_beam = q_t.view(len(beams), -1, q_t.size(-1)) # (live_beam_num, beam_size, vocab_size) hyp_var_name_scores_t = torch.log_softmax( self.state2names(q_t_by_beam), dim=-1) if unk_replace: hyp_var_name_scores_t[:, :, self.vocab.target['<unk>']] = float( '-inf') cont_cand_hyp_scores = hyp_scores_tm1.unsqueeze( -1) + hyp_var_name_scores_t # (live_beam_num, beam_size) new_hyp_scores, new_hyp_position_list = torch.topk( cont_cand_hyp_scores.view(len(beams), -1), k=beam_size, dim=-1) # (live_beam_num, beam_size) prev_hyp_ids = (new_hyp_position_list / tgt_vocab_size) hyp_var_name_ids = (new_hyp_position_list % tgt_vocab_size) new_hyp_scores = new_hyp_scores # move this tensor to cpu for fast indexing _prev_hyp_ids = prev_hyp_ids.cpu() _hyp_var_name_ids = hyp_var_name_ids.cpu() _new_hyp_scores = new_hyp_scores.cpu() new_beams = OrderedDict() live_beam_ids = [] for beam_id, (ast_id, beam) in enumerate(beams.items()): new_hyps = [] for i in range(beam_size): prev_hyp_id = _prev_hyp_ids[beam_id, i].item() prev_hyp = beam[prev_hyp_id] hyp_var_name_id = _hyp_var_name_ids[beam_id, i].item() new_hyp_score = _new_hyp_scores[beam_id, i].item() new_hyp = self.Hypothesis( variable_list=list(prev_hyp.variable_list) + [hyp_var_name_id], score=new_hyp_score) new_hyps.append(new_hyp) if t + 1 == variable_nums[ast_id]: completed_hyps[ast_id] = new_hyps else: new_beams[ast_id] = new_hyps live_beam_ids.append(beam_id) if t < max_prediction_node_num - 1: # (live_beam_num, beam_size, *) prev_hyp_ids = (torch.arange(len(beams)).to(self.device) * live_beam_size).unsqueeze(-1) + prev_hyp_ids live_prev_hyp_ids = prev_hyp_ids[live_beam_ids].view(-1) live_beam_ids = torch.tensor(live_beam_ids, device=self.device) hyp_scores_tm1 = new_hyp_scores[live_beam_ids] h_tm1 = (h_t[0][live_prev_hyp_ids], h_t[1][live_prev_hyp_ids]) att_tm1 = q_t[live_prev_hyp_ids] # (live_beam_num * beam_size) live_hyp_var_name_ids = hyp_var_name_ids[live_beam_ids].view( -1) # (live_beam_num * beam_size, embed_size) variable_name_embed_tm1 = self.state2names.weight[ live_hyp_var_name_ids] beams = new_beams variable_rename_results = [] for i, hyps in enumerate(completed_hyps): ast = source_asts[i] hyps = sorted(hyps, key=lambda hyp: -hyp.score) top_hyp = hyps[0] variable_rename_result = dict() for old_name, var_name_id in zip(ast.variables, top_hyp.variable_list): new_var_name = self.vocab.target.id2word[var_name_id] if new_var_name == SAME_VARIABLE_TOKEN: new_var_name = old_name variable_rename_result[old_name] = { 'new_name': new_var_name, 'prob': top_hyp.score } variable_rename_results.append(variable_rename_result) return variable_rename_results
def train(args): work_dir = args['--work-dir'] config = json.loads(_jsonnet.evaluate_file(args['CONFIG_FILE'])) config['work_dir'] = work_dir if not os.path.exists(work_dir): print(f'creating work dir [{work_dir}]', file=sys.stderr) os.makedirs(work_dir) if args['--extra-config']: extra_config = args['--extra-config'] extra_config = json.loads(extra_config) config = util.update(config, extra_config) json.dump(config, open(os.path.join(work_dir, 'config.json'), 'w'), indent=2) model = RenamingModel.build(config) config = model.config model.train() if args['--cuda']: model = model.cuda() params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.Adam(params, lr=0.001) nn_util.glorot_init(params) # set the padding index for embedding layers to zeros # model.encoder.var_node_name_embedding.weight[0].fill_(0.) train_set = Dataset(config['data']['train_file']) dev_set = Dataset(config['data']['dev_file']) batch_size = config['train']['batch_size'] print(f'Training set size {len(train_set)}, dev set size {len(dev_set)}', file=sys.stderr) # training loop train_iter = epoch = cum_examples = 0 log_every = config['train']['log_every'] evaluate_every_nepoch = config['train']['evaluate_every_nepoch'] max_epoch = config['train']['max_epoch'] max_patience = config['train']['patience'] cum_loss = 0. patience = 0. t_log = time.time() history_accs = [] while True: # load training dataset, which is a collection of ASTs and maps of gold-standard renamings train_set_iter = train_set.batch_iterator( batch_size=batch_size, return_examples=False, config=config, progress=True, train=True, num_readers=config['train']['num_readers'], num_batchers=config['train']['num_batchers']) epoch += 1 for batch in train_set_iter: train_iter += 1 optimizer.zero_grad() # t1 = time.time() nn_util.to(batch.tensor_dict, model.device) # print(f'[Learner] {time.time() - t1}s took for moving tensors to device', file=sys.stderr) # t1 = time.time() result = model(batch.tensor_dict, batch.tensor_dict['prediction_target']) # print(f'[Learner] batch {train_iter}, {batch.size} examples took {time.time() - t1:4f}s', file=sys.stderr) loss = -result['batch_log_prob'].mean() cum_loss += loss.item() * batch.size cum_examples += batch.size loss.backward() # clip gradient grad_norm = torch.nn.utils.clip_grad_norm_(params, 5.) optimizer.step() del loss if train_iter % log_every == 0: print( f'[Learner] train_iter={train_iter} avg. loss={cum_loss / cum_examples}, ' f'{cum_examples} examples ({cum_examples / (time.time() - t_log)} examples/s)', file=sys.stderr) cum_loss = cum_examples = 0. t_log = time.time() print(f'[Learner] Epoch {epoch} finished', file=sys.stderr) if epoch % evaluate_every_nepoch == 0: print(f'[Learner] Perform evaluation', file=sys.stderr) t1 = time.time() # ppl = Evaluator.evaluate_ppl(model, dev_set, config, predicate=lambda e: not e['function_body_in_train']) eval_results = Evaluator.decode_and_evaluate( model, dev_set, config) # print(f'[Learner] Evaluation result ppl={ppl} (took {time.time() - t1}s)', file=sys.stderr) print( f'[Learner] Evaluation result {eval_results} (took {time.time() - t1}s)', file=sys.stderr) dev_metric = eval_results['func_body_not_in_train_acc']['accuracy'] # dev_metric = -ppl if len(history_accs) == 0 or dev_metric > max(history_accs): patience = 0 model_save_path = os.path.join(work_dir, f'model.bin') model.save(model_save_path) print( f'[Learner] Saved currently the best model to {model_save_path}', file=sys.stderr) else: patience += 1 if patience == max_patience: print( f'[Learner] Reached max patience {max_patience}, exiting...', file=sys.stderr) patience = 0 exit() history_accs.append(dev_metric) if epoch == max_epoch: print(f'[Learner] Reached max epoch', file=sys.stderr) exit() t1 = time.time()