示例#1
0
    def evaluate_ppl(model: RenamingModel,
                     dataset: Dataset,
                     config: Dict,
                     predicate: Any = None):
        if predicate is None:

            def predicate(_):
                return True

        eval_batch_size = config['train']['batch_size']
        num_readers = config['train']['num_readers']
        num_batchers = config['train']['num_batchers']
        data_iter = dataset.batch_iterator(batch_size=eval_batch_size,
                                           train=False,
                                           progress=True,
                                           return_examples=False,
                                           return_prediction_target=True,
                                           config=model.config,
                                           num_readers=num_readers,
                                           num_batchers=num_batchers)

        was_training = model.training
        model.eval()
        cum_log_probs = 0.
        cum_num_examples = 0
        with torch.no_grad():
            for batch in data_iter:
                td = batch.tensor_dict
                nn_util.to(td, model.device)
                result = model(td, td['prediction_target'])
                log_probs = result['batch_log_prob'].cpu().tolist()
                for e_id, test_meta in enumerate(td['test_meta']):
                    if predicate(test_meta):
                        log_prob = log_probs[e_id]
                        cum_log_probs += log_prob
                        cum_num_examples += 1

        ppl = np.exp(-cum_log_probs / cum_num_examples)

        if was_training:
            model.train()

        return ppl
示例#2
0
    def predict(self, source_asts: List[AbstractSyntaxTree]):
        """
        Given a batch of ASTs, predict their new variable names
        """

        tensor_dict = self.batcher.to_tensor_dict(source_asts=source_asts)
        nn_util.to(tensor_dict, self.device)
        context_encoding = self.encoder(tensor_dict)
        # (prediction_size, tgt_vocab_size)
        packed_var_name_log_probs = self.decoder(context_encoding)
        best_var_name_log_probs, best_var_name_ids = torch.max(
            packed_var_name_log_probs, dim=-1)

        variable_rename_results = []
        pred_node_ptr = 0
        for ast_id, ast in enumerate(source_asts):
            variable_rename_result = dict()
            for var_name in ast.variables:
                var_name_prob = best_var_name_log_probs[pred_node_ptr].item()
                token_id = best_var_name_ids[pred_node_ptr].item()
                new_var_name = self.vocab.target.id2word[token_id]

                if new_var_name == SAME_VARIABLE_TOKEN:
                    new_var_name = var_name

                variable_rename_result[var_name] = {
                    'new_name': new_var_name,
                    'prob': var_name_prob
                }

                pred_node_ptr += 1

            variable_rename_results.append(variable_rename_result)

        assert pred_node_ptr == packed_var_name_log_probs.size(0)

        return variable_rename_results
示例#3
0
    def process_batch(self, batch):
        if self.p.hybrid:
            X, Y, src_mask, variable_ids, target_mask, src_target_maps, body_in_train, graph_input = batch
        else:
            X, Y, src_mask, target_mask, src_target_maps, body_in_train = batch

        X = X.long().to(self.device)
        Y = Y.long().to(self.device)
        src_mask = src_mask.float().to(self.device)
        target_mask = target_mask.float().to(self.device)

        if self.p.hybrid:
            variable_ids = variable_ids.long().to(self.device)
            graph_input = nn_util.to(graph_input, self.device)

            return (X, Y, src_mask, variable_ids, target_mask, src_target_maps,
                    body_in_train, graph_input)
        else:
            return (X, Y, src_mask, target_mask, src_target_maps,
                    body_in_train)
    def predict(self, examples: List[Example], encoder: Encoder) -> List[Dict]:
        batch_size = len(examples)
        beam_size = self.config['beam_size']
        same_variable_id = self.vocab.target[SAME_VARIABLE_TOKEN]
        end_of_variable_id = self.vocab.target[END_OF_VARIABLE_TOKEN]

        variable_nums = []
        for ast_id, example in enumerate(examples):
            variable_nums.append(len(example.ast.variables))

        beams = OrderedDict((ast_id, [self.Hypothesis([], 0, 0.)])
                            for ast_id in range(batch_size))
        hyp_scores_tm1 = torch.zeros(len(beams), device=self.device)
        completed_hyps = [[] for _ in range(batch_size)]
        tgt_vocab_size = len(self.vocab.target)

        tensor_dict = self.batcher.to_tensor_dict(examples)
        nn_util.to(tensor_dict, self.device)

        context_encoding = encoder(tensor_dict)
        h_tm1 = h_0 = self.get_init_state(context_encoding)

        # Note that we are using the `restoration_indices` from `context_encoding`, which is the word-level restoration index
        # (batch_size, variable_master_node_num, encoding_size)
        variable_encoding = context_encoding['variable_encoding']
        # (batch_size, encoding_size)
        variable_name_embed_tm1 = att_tm1 = torch.zeros(
            batch_size, self.lstm_cell.hidden_size, device=self.device)

        max_prediction_time_step = self.config['max_prediction_time_step']
        for t in range(0, max_prediction_time_step):
            # (total_live_hyp_num, encoding_size)
            if t > 0:
                variable_encoding_t = variable_encoding[hyp_ast_ids_t,
                                                        hyp_variable_ptrs_t]
            else:
                variable_encoding_t = variable_encoding[:, 0]

            if self.config['input_feed']:
                x = torch.cat(
                    [variable_encoding_t, variable_name_embed_tm1, att_tm1],
                    dim=-1)
            else:
                x = torch.cat([variable_encoding_t, variable_name_embed_tm1],
                              dim=-1)

            h_t, q_t, alpha_t = self.rnn_step(x, h_tm1, context_encoding)

            # (total_live_hyp_num, vocab_size)
            hyp_var_name_scores_t = torch.log_softmax(self.state2names(q_t),
                                                      dim=-1)

            cont_cand_hyp_scores = hyp_scores_tm1.unsqueeze(
                -1) + hyp_var_name_scores_t

            new_beams = OrderedDict()
            live_beam_ids = []
            new_hyp_scores = []
            live_prev_hyp_ids = []
            new_hyp_var_name_ids = []
            new_hyp_ast_ids = []
            new_hyp_variable_ptrs = []
            is_same_variable_mask = []
            beam_start_hyp_pos = 0
            for beam_id, (ast_id, beam) in enumerate(beams.items()):
                beam_end_hyp_pos = beam_start_hyp_pos + len(beam)
                # (live_beam_size, vocab_size)
                beam_cont_cand_hyp_scores = cont_cand_hyp_scores[
                    beam_start_hyp_pos:beam_end_hyp_pos]
                cont_beam_size = beam_size - len(completed_hyps[ast_id])
                beam_new_hyp_scores, beam_new_hyp_positions = torch.topk(
                    beam_cont_cand_hyp_scores.view(-1),
                    k=cont_beam_size,
                    dim=-1)

                # (cont_beam_size)
                beam_prev_hyp_ids = beam_new_hyp_positions / tgt_vocab_size
                beam_hyp_var_name_ids = beam_new_hyp_positions % tgt_vocab_size

                _prev_hyp_ids = beam_prev_hyp_ids.cpu()
                _hyp_var_name_ids = beam_hyp_var_name_ids.cpu()
                _new_hyp_scores = beam_new_hyp_scores.cpu()

                for i in range(cont_beam_size):
                    prev_hyp_id = _prev_hyp_ids[i].item()
                    prev_hyp = beam[prev_hyp_id]
                    hyp_var_name_id = _hyp_var_name_ids[i].item()
                    new_hyp_score = _new_hyp_scores[i].item()

                    variable_ptr = prev_hyp.variable_ptr
                    if hyp_var_name_id == end_of_variable_id:
                        variable_ptr += 1

                        # remove empty cases
                        if len(prev_hyp.variable_list
                               ) == 0 or prev_hyp.variable_list[
                                   -1] == end_of_variable_id:
                            continue

                    new_hyp = self.Hypothesis(
                        variable_list=list(prev_hyp.variable_list) +
                        [hyp_var_name_id],
                        variable_ptr=variable_ptr,
                        score=new_hyp_score)

                    if variable_ptr == variable_nums[ast_id]:
                        completed_hyps[ast_id].append(new_hyp)
                    else:
                        new_beams.setdefault(ast_id, []).append(new_hyp)
                        live_beam_ids.append(beam_id)
                        new_hyp_scores.append(new_hyp_score)
                        live_prev_hyp_ids.append(beam_start_hyp_pos +
                                                 prev_hyp_id)
                        new_hyp_var_name_ids.append(hyp_var_name_id)
                        new_hyp_ast_ids.append(ast_id)
                        new_hyp_variable_ptrs.append(variable_ptr)
                        is_same_variable_mask.append(
                            1. if prev_hyp.variable_ptr ==
                            variable_ptr else 0.)

                beam_start_hyp_pos = beam_end_hyp_pos

            if live_beam_ids:
                hyp_scores_tm1 = torch.tensor(new_hyp_scores,
                                              device=self.device)
                h_tm1 = (h_t[0][live_prev_hyp_ids], h_t[1][live_prev_hyp_ids])
                att_tm1 = q_t[live_prev_hyp_ids]

                variable_name_embed_tm1 = self.state2names.weight[
                    new_hyp_var_name_ids]
                hyp_ast_ids_t = new_hyp_ast_ids
                hyp_variable_ptrs_t = new_hyp_variable_ptrs

                beams = new_beams

                if self.independent_prediction_for_each_variable:
                    is_same_variable_mask = torch.tensor(
                        is_same_variable_mask,
                        device=self.device,
                        dtype=torch.float).unsqueeze(-1)
                    h_tm1 = (h_tm1[0] * is_same_variable_mask,
                             h_tm1[1] * is_same_variable_mask)
                    att_tm1 = att_tm1 * is_same_variable_mask
                    variable_name_embed_tm1 = variable_name_embed_tm1 * is_same_variable_mask
            else:
                break

        variable_rename_results = []
        for i, hyps in enumerate(completed_hyps):
            variable_rename_result = dict()
            ast = examples[i].ast
            hyps = sorted(hyps, key=lambda hyp: -hyp.score)

            if not hyps:
                # return identity renamings
                print(
                    f'Failed to found a hypothesis for function {ast.compilation_unit}',
                    file=sys.stderr)
                for old_name in ast.variables:
                    variable_rename_result[old_name] = {
                        'new_name': old_name,
                        'prob': 0.
                    }
            else:
                top_hyp = hyps[0]
                sub_token_ptr = 0
                for old_name in ast.variables:
                    sub_token_begin = sub_token_ptr
                    while top_hyp.variable_list[
                            sub_token_ptr] != end_of_variable_id:
                        sub_token_ptr += 1
                    sub_token_ptr += 1  # point to first sub-token of next variable
                    sub_token_end = sub_token_ptr

                    var_name_token_ids = top_hyp.variable_list[
                        sub_token_begin:sub_token_end]  # include ending </s>
                    if var_name_token_ids == [
                            same_variable_id, end_of_variable_id
                    ]:
                        new_var_name = old_name
                    else:
                        new_var_name = self.vocab.target.subtoken_model.decode_ids(
                            var_name_token_ids)

                    variable_rename_result[old_name] = {
                        'new_name': new_var_name,
                        'prob': top_hyp.score
                    }

            variable_rename_results.append(variable_rename_result)

        return variable_rename_results
    def predict(self, examples: List[Example], encoder: Encoder) -> List[Any]:
        batch_size = len(examples)
        beam_size = self.config['beam_size']
        same_variable_id = self.vocab.target[SAME_VARIABLE_TOKEN]
        end_of_variable_id = self.vocab.target[END_OF_VARIABLE_TOKEN]
        remove_duplicate = self.config['remove_duplicates_in_prediction']

        variable_nums = []
        for ast_id, example in enumerate(examples):
            variable_nums.append(len(example.ast.variables))

        beams = OrderedDict((ast_id, [self.Hypothesis([[]], 0, 0.)])
                            for ast_id in range(batch_size))
        hyp_scores_tm1 = torch.zeros(len(beams), device=self.device)
        completed_hyps = [[] for _ in range(batch_size)]
        tgt_vocab_size = len(self.vocab.target)

        tensor_dict = self.batcher.to_tensor_dict(examples)
        nn_util.to(tensor_dict, self.device)

        context_encoding = encoder(tensor_dict)
        # prepare tensors for attention
        attention_memory = self.get_attention_memory(context_encoding)
        context_encoding_t = attention_memory

        h_tm1 = h_0 = self.get_init_state(context_encoding)

        # Note that we are using the `restoration_indices` from `context_encoding`, which is the word-level restoration index
        # (batch_size, variable_master_node_num, encoding_size)
        variable_encoding = context_encoding['variable_encoding']
        # (batch_size, encoding_size)
        variable_name_embed_tm1 = att_tm1 = torch.zeros(
            batch_size, self.lstm_cell.hidden_size, device=self.device)
        # cum_variable_name_embed = variable_name_embed_tm1.unsqueeze(1)

        max_prediction_time_step = self.config['max_prediction_time_step']
        for t in range(0, max_prediction_time_step):
            # (total_live_hyp_num, encoding_size)
            if t > 0:
                variable_encoding_t = variable_encoding[hyp_ast_ids_t,
                                                        hyp_variable_ptrs_t]
            else:
                variable_encoding_t = variable_encoding[:, 0]

            if self.config['input_feed']:
                x = torch.cat(
                    [variable_encoding_t, variable_name_embed_tm1, att_tm1],
                    dim=-1)
            else:
                x = torch.cat([variable_encoding_t, variable_name_embed_tm1],
                              dim=-1)

            h_t, q_t, alpha_t = self.rnn_step(x, h_tm1, context_encoding_t)

            # attention_mask = context_encoding['seq_encoding_result']['code_token_mask']
            # encoder_hidden_states = context_encoding['seq_encoding_result']['code_token_encoding']

            # branch_out = int(cum_variable_name_embed.shape[0]/attention_mask.shape[0])

            # attention_mask = attention_mask.unsqueeze(1).repeat(1, branch_out, 1).view(-1, attention_mask.shape[1])
            # encoder_hidden_states = encoder_hidden_states.unsqueeze(1).repeat(1, branch_out, 1, 1).view(-1, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2])
            # query_vecs = self.bert_model(inputs_embeds=cum_variable_name_embed, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask)[0]

            # q_t = query_vecs[:, -1, :]

            # (total_live_hyp_num, vocab_size)
            hyp_var_name_scores_t = torch.log_softmax(self.state2names(q_t),
                                                      dim=-1)

            cont_cand_hyp_scores = hyp_scores_tm1.unsqueeze(
                -1) + hyp_var_name_scores_t

            new_beams = OrderedDict()
            live_beam_ids = []
            new_hyp_scores = []
            live_prev_hyp_ids = []
            new_hyp_var_name_ids = []
            new_hyp_ast_ids = []
            new_hyp_variable_ptrs = []
            is_same_variable_mask = []
            beam_start_hyp_pos = 0
            for beam_id, (ast_id, beam) in enumerate(beams.items()):
                beam_end_hyp_pos = beam_start_hyp_pos + len(beam)
                # (live_beam_size, vocab_size)
                beam_cont_cand_hyp_scores = cont_cand_hyp_scores[
                    beam_start_hyp_pos:beam_end_hyp_pos]
                cont_beam_size = beam_size - len(completed_hyps[ast_id])
                # Take `len(beam)` more candidates to account for possible duplicate
                k = min(beam_cont_cand_hyp_scores.numel(),
                        cont_beam_size + len(beam))
                beam_new_hyp_scores, beam_new_hyp_positions = torch.topk(
                    beam_cont_cand_hyp_scores.view(-1), k=k, dim=-1)

                # (cont_beam_size)
                beam_prev_hyp_ids = beam_new_hyp_positions / tgt_vocab_size
                beam_hyp_var_name_ids = beam_new_hyp_positions % tgt_vocab_size

                _prev_hyp_ids = beam_prev_hyp_ids.cpu()
                _hyp_var_name_ids = beam_hyp_var_name_ids.cpu()
                _new_hyp_scores = beam_new_hyp_scores.cpu()

                beam_cnt = 0
                for i in range(len(beam_new_hyp_positions)):
                    prev_hyp_id = _prev_hyp_ids[i].item()
                    prev_hyp = beam[prev_hyp_id]
                    hyp_var_name_id = _hyp_var_name_ids[i].item()
                    new_hyp_score = _new_hyp_scores[i].item()

                    variable_ptr = prev_hyp.variable_ptr
                    new_variable_list = list(prev_hyp.variable_list)
                    new_variable_list[-1] = list(new_variable_list[-1] +
                                                 [hyp_var_name_id])

                    if hyp_var_name_id == end_of_variable_id:
                        # remove empty cases
                        if new_variable_list[-1] == [end_of_variable_id]:
                            continue

                        if remove_duplicate:
                            last_pred = new_variable_list[-1]
                            if any(x == last_pred
                                   for x in new_variable_list[:-1] if x !=
                                   [same_variable_id, end_of_variable_id]):
                                # print('found a duplicate!', ', '.join([str(x) for x in last_pred]))
                                continue

                        variable_ptr += 1
                        new_variable_list.append([])

                    beam_cnt += 1
                    new_hyp = self.Hypothesis(variable_list=new_variable_list,
                                              variable_ptr=variable_ptr,
                                              score=new_hyp_score)

                    if variable_ptr == variable_nums[ast_id]:
                        completed_hyps[ast_id].append(new_hyp)
                    else:
                        new_beams.setdefault(ast_id, []).append(new_hyp)
                        live_beam_ids.append(beam_id)
                        new_hyp_scores.append(new_hyp_score)
                        live_prev_hyp_ids.append(beam_start_hyp_pos +
                                                 prev_hyp_id)
                        new_hyp_var_name_ids.append(hyp_var_name_id)
                        new_hyp_ast_ids.append(ast_id)
                        new_hyp_variable_ptrs.append(variable_ptr)
                        is_same_variable_mask.append(
                            1. if prev_hyp.variable_ptr ==
                            variable_ptr else 0.)

                    if beam_cnt >= cont_beam_size:
                        break

                beam_start_hyp_pos = beam_end_hyp_pos

            if live_beam_ids:
                hyp_scores_tm1 = torch.tensor(new_hyp_scores,
                                              device=self.device)
                h_tm1 = (h_t[0][live_prev_hyp_ids], h_t[1][live_prev_hyp_ids])
                att_tm1 = q_t[live_prev_hyp_ids]

                if self.config['tie_embedding']:
                    variable_name_embed_tm1 = self.state2names.weight[
                        new_hyp_var_name_ids]
                else:
                    variable_name_embed_tm1 = self.var_name_embed.weight[
                        new_hyp_var_name_ids]

                # branch_out = int((variable_name_embed_tm1.shape[0] + 1)/cum_variable_name_embed.shape[0])
                # if branch_out == 0: branch_out = 1
                # tiled = cum_variable_name_embed.unsqueeze(1).repeat(1, branch_out, 1, 1).view(-1, cum_variable_name_embed.shape[1], cum_variable_name_embed.shape[2])
                # try:
                #     cum_variable_name_embed = torch.cat([tiled, variable_name_embed_tm1.unsqueeze(1)], axis=1)
                # except:
                #     pass

                hyp_ast_ids_t = new_hyp_ast_ids
                hyp_variable_ptrs_t = new_hyp_variable_ptrs

                beams = new_beams

                # (total_hyp_num, max_tree_size, node_encoding_size)
                context_encoding_t = dict(
                    attention_key=attention_memory['attention_key']
                    [hyp_ast_ids_t],
                    attention_value=attention_memory['attention_value']
                    [hyp_ast_ids_t],
                    attention_value_mask=attention_memory[
                        'attention_value_mask'][hyp_ast_ids_t])

                if self.independent_prediction_for_each_variable:
                    is_same_variable_mask = torch.tensor(
                        is_same_variable_mask,
                        device=self.device,
                        dtype=torch.float).unsqueeze(-1)
                    h_tm1 = (h_tm1[0] * is_same_variable_mask,
                             h_tm1[1] * is_same_variable_mask)
                    att_tm1 = att_tm1 * is_same_variable_mask
                    variable_name_embed_tm1 = variable_name_embed_tm1 * is_same_variable_mask
            else:
                break

        variable_rename_results = []
        for i, hyps in enumerate(completed_hyps):
            ast = examples[i].ast
            hyps = sorted(hyps, key=lambda hyp: -hyp.score)

            if not hyps:
                # return identity renamings
                print(
                    f'Failed to find a hypothesis for function {ast.compilation_unit}',
                    file=sys.stderr)
                variable_rename_result = dict()
                for old_name in ast.variables:
                    variable_rename_result[old_name] = {
                        'new_name': old_name,
                        'prob': 0.
                    }

                example_rename_results = [variable_rename_result]
            else:
                # top_hyp = hyps[0]
                # sub_token_ptr = 0
                # for old_name in ast.variables:
                #     sub_token_begin = sub_token_ptr
                #     while top_hyp.variable_list[sub_token_ptr] != end_of_variable_id:
                #         sub_token_ptr += 1
                #     sub_token_ptr += 1  # point to first sub-token of next variable
                #     sub_token_end = sub_token_ptr
                #
                #     var_name_token_ids = top_hyp.variable_list[sub_token_begin: sub_token_end]  # include ending </s>
                #     if var_name_token_ids == [same_variable_id, end_of_variable_id]:
                #         new_var_name = old_name
                #     else:
                #         new_var_name = self.vocab.target.subtoken_model.decode_ids(var_name_token_ids)
                #
                #     variable_rename_result[old_name] = {'new_name': new_var_name,
                #                                         'prob': top_hyp.score}

                example_rename_results = []

                for hyp in hyps:
                    variable_rename_result = dict()
                    for var_id, old_name in enumerate(ast.variables):
                        var_name_token_ids = hyp.variable_list[var_id]
                        if var_name_token_ids == [
                                same_variable_id, end_of_variable_id
                        ]:
                            new_var_name = old_name
                        else:
                            new_var_name = self.vocab.target.subtoken_model.decode_ids(
                                var_name_token_ids)

                        variable_rename_result[old_name] = {
                            'new_name': new_var_name,
                            'prob': hyp.score
                        }

                    example_rename_results.append(variable_rename_result)

            variable_rename_results.append(example_rename_results)

        return variable_rename_results
示例#6
0
    def predict(self, source_asts: List[AbstractSyntaxTree],
                encoder: Encoder) -> List[Dict]:
        beam_size = self.config['beam_size']
        unk_replace = self.config['unk_replace']

        variable_nums = []
        for ast_id, ast in enumerate(source_asts):
            variable_nums.append(len(ast.variables))

        beams = OrderedDict((ast_id, [self.Hypothesis([], 0.)])
                            for ast_id, ast in enumerate(source_asts))
        hyp_scores_tm1 = torch.zeros(len(beams), 1, device=self.device)
        completed_hyps = [[] for _ in source_asts]
        tgt_vocab_size = len(self.vocab.target)

        tensor_dict = self.batcher.to_tensor_dict(source_asts=source_asts)
        nn_util.to(tensor_dict, self.device)

        context_encoding = encoder(tensor_dict)
        h_tm1 = h_0 = self.get_init_state(context_encoding)

        # (prediction_node_num, encoding_size)
        variable_master_node_encoding = context_encoding[
            'variable_master_node_encoding']
        encoding_size = variable_master_node_encoding.size(1)
        # (batch_size, max_prediction_node_num)
        variable_master_node_restoration_indices = context_encoding[
            'variable_encoding_restoration_indices']

        # (batch_size, max_prediction_node_num, encoding_size)
        variable_master_node_encoding = variable_master_node_encoding[
            variable_master_node_restoration_indices]
        variable_encoding_t = variable_master_node_encoding[:, 0]
        # (batch_size, encoding_size)
        variable_name_embed_tm1 = att_tm1 = torch.zeros(len(source_asts),
                                                        encoding_size,
                                                        device=self.device)

        max_prediction_node_num = variable_master_node_encoding.size(1)

        for t in range(0, max_prediction_node_num):
            live_beam_size = beam_size if t > 0 else 1
            live_tree_ids = [ast_id for ast_id in beams]

            if t > 0:
                variable_encoding_t = variable_master_node_encoding[live_tree_ids][:, t]\
                    .unsqueeze(1).expand(-1, beam_size, -1).contiguous().view(-1, encoding_size)

            if self.config['input_feed']:
                x = torch.cat(
                    [variable_encoding_t, variable_name_embed_tm1, att_tm1],
                    dim=-1)
            else:
                x = torch.cat([variable_encoding_t, variable_name_embed_tm1],
                              dim=-1)

            h_t, q_t, alpha_t = self.rnn_step(x, h_tm1, context_encoding)

            # (live_beam_num, beam_size, encoding_size)
            q_t_by_beam = q_t.view(len(beams), -1, q_t.size(-1))
            # (live_beam_num, beam_size, vocab_size)
            hyp_var_name_scores_t = torch.log_softmax(
                self.state2names(q_t_by_beam), dim=-1)

            if unk_replace:
                hyp_var_name_scores_t[:, :,
                                      self.vocab.target['<unk>']] = float(
                                          '-inf')

            cont_cand_hyp_scores = hyp_scores_tm1.unsqueeze(
                -1) + hyp_var_name_scores_t

            # (live_beam_num, beam_size)
            new_hyp_scores, new_hyp_position_list = torch.topk(
                cont_cand_hyp_scores.view(len(beams), -1), k=beam_size, dim=-1)

            # (live_beam_num, beam_size)
            prev_hyp_ids = (new_hyp_position_list / tgt_vocab_size)
            hyp_var_name_ids = (new_hyp_position_list % tgt_vocab_size)
            new_hyp_scores = new_hyp_scores

            # move this tensor to cpu for fast indexing
            _prev_hyp_ids = prev_hyp_ids.cpu()
            _hyp_var_name_ids = hyp_var_name_ids.cpu()
            _new_hyp_scores = new_hyp_scores.cpu()

            new_beams = OrderedDict()
            live_beam_ids = []
            for beam_id, (ast_id, beam) in enumerate(beams.items()):
                new_hyps = []
                for i in range(beam_size):
                    prev_hyp_id = _prev_hyp_ids[beam_id, i].item()
                    prev_hyp = beam[prev_hyp_id]
                    hyp_var_name_id = _hyp_var_name_ids[beam_id, i].item()
                    new_hyp_score = _new_hyp_scores[beam_id, i].item()

                    new_hyp = self.Hypothesis(
                        variable_list=list(prev_hyp.variable_list) +
                        [hyp_var_name_id],
                        score=new_hyp_score)
                    new_hyps.append(new_hyp)
                if t + 1 == variable_nums[ast_id]:
                    completed_hyps[ast_id] = new_hyps
                else:
                    new_beams[ast_id] = new_hyps
                    live_beam_ids.append(beam_id)

            if t < max_prediction_node_num - 1:
                # (live_beam_num, beam_size, *)
                prev_hyp_ids = (torch.arange(len(beams)).to(self.device) *
                                live_beam_size).unsqueeze(-1) + prev_hyp_ids
                live_prev_hyp_ids = prev_hyp_ids[live_beam_ids].view(-1)
                live_beam_ids = torch.tensor(live_beam_ids, device=self.device)

                hyp_scores_tm1 = new_hyp_scores[live_beam_ids]
                h_tm1 = (h_t[0][live_prev_hyp_ids], h_t[1][live_prev_hyp_ids])
                att_tm1 = q_t[live_prev_hyp_ids]

                # (live_beam_num * beam_size)
                live_hyp_var_name_ids = hyp_var_name_ids[live_beam_ids].view(
                    -1)
                # (live_beam_num * beam_size, embed_size)
                variable_name_embed_tm1 = self.state2names.weight[
                    live_hyp_var_name_ids]

                beams = new_beams

        variable_rename_results = []
        for i, hyps in enumerate(completed_hyps):
            ast = source_asts[i]
            hyps = sorted(hyps, key=lambda hyp: -hyp.score)
            top_hyp = hyps[0]
            variable_rename_result = dict()
            for old_name, var_name_id in zip(ast.variables,
                                             top_hyp.variable_list):
                new_var_name = self.vocab.target.id2word[var_name_id]
                if new_var_name == SAME_VARIABLE_TOKEN:
                    new_var_name = old_name

                variable_rename_result[old_name] = {
                    'new_name': new_var_name,
                    'prob': top_hyp.score
                }

            variable_rename_results.append(variable_rename_result)

        return variable_rename_results
示例#7
0
def train(args):
    work_dir = args['--work-dir']
    config = json.loads(_jsonnet.evaluate_file(args['CONFIG_FILE']))
    config['work_dir'] = work_dir

    if not os.path.exists(work_dir):
        print(f'creating work dir [{work_dir}]', file=sys.stderr)
        os.makedirs(work_dir)

    if args['--extra-config']:
        extra_config = args['--extra-config']
        extra_config = json.loads(extra_config)
        config = util.update(config, extra_config)

    json.dump(config,
              open(os.path.join(work_dir, 'config.json'), 'w'),
              indent=2)

    model = RenamingModel.build(config)
    config = model.config
    model.train()

    if args['--cuda']:
        model = model.cuda()

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=0.001)
    nn_util.glorot_init(params)

    # set the padding index for embedding layers to zeros
    # model.encoder.var_node_name_embedding.weight[0].fill_(0.)

    train_set = Dataset(config['data']['train_file'])
    dev_set = Dataset(config['data']['dev_file'])
    batch_size = config['train']['batch_size']

    print(f'Training set size {len(train_set)}, dev set size {len(dev_set)}',
          file=sys.stderr)

    # training loop
    train_iter = epoch = cum_examples = 0
    log_every = config['train']['log_every']
    evaluate_every_nepoch = config['train']['evaluate_every_nepoch']
    max_epoch = config['train']['max_epoch']
    max_patience = config['train']['patience']
    cum_loss = 0.
    patience = 0.
    t_log = time.time()

    history_accs = []
    while True:
        # load training dataset, which is a collection of ASTs and maps of gold-standard renamings
        train_set_iter = train_set.batch_iterator(
            batch_size=batch_size,
            return_examples=False,
            config=config,
            progress=True,
            train=True,
            num_readers=config['train']['num_readers'],
            num_batchers=config['train']['num_batchers'])
        epoch += 1

        for batch in train_set_iter:
            train_iter += 1
            optimizer.zero_grad()

            # t1 = time.time()
            nn_util.to(batch.tensor_dict, model.device)
            # print(f'[Learner] {time.time() - t1}s took for moving tensors to device', file=sys.stderr)

            # t1 = time.time()
            result = model(batch.tensor_dict,
                           batch.tensor_dict['prediction_target'])
            # print(f'[Learner] batch {train_iter}, {batch.size} examples took {time.time() - t1:4f}s', file=sys.stderr)

            loss = -result['batch_log_prob'].mean()

            cum_loss += loss.item() * batch.size
            cum_examples += batch.size

            loss.backward()

            # clip gradient
            grad_norm = torch.nn.utils.clip_grad_norm_(params, 5.)

            optimizer.step()
            del loss

            if train_iter % log_every == 0:
                print(
                    f'[Learner] train_iter={train_iter} avg. loss={cum_loss / cum_examples}, '
                    f'{cum_examples} examples ({cum_examples / (time.time() - t_log)} examples/s)',
                    file=sys.stderr)

                cum_loss = cum_examples = 0.
                t_log = time.time()

        print(f'[Learner] Epoch {epoch} finished', file=sys.stderr)

        if epoch % evaluate_every_nepoch == 0:
            print(f'[Learner] Perform evaluation', file=sys.stderr)
            t1 = time.time()
            # ppl = Evaluator.evaluate_ppl(model, dev_set, config, predicate=lambda e: not e['function_body_in_train'])
            eval_results = Evaluator.decode_and_evaluate(
                model, dev_set, config)
            # print(f'[Learner] Evaluation result ppl={ppl} (took {time.time() - t1}s)', file=sys.stderr)
            print(
                f'[Learner] Evaluation result {eval_results} (took {time.time() - t1}s)',
                file=sys.stderr)
            dev_metric = eval_results['func_body_not_in_train_acc']['accuracy']
            # dev_metric = -ppl
            if len(history_accs) == 0 or dev_metric > max(history_accs):
                patience = 0
                model_save_path = os.path.join(work_dir, f'model.bin')
                model.save(model_save_path)
                print(
                    f'[Learner] Saved currently the best model to {model_save_path}',
                    file=sys.stderr)
            else:
                patience += 1
                if patience == max_patience:
                    print(
                        f'[Learner] Reached max patience {max_patience}, exiting...',
                        file=sys.stderr)
                    patience = 0
                    exit()

            history_accs.append(dev_metric)

        if epoch == max_epoch:
            print(f'[Learner] Reached max epoch', file=sys.stderr)
            exit()

        t1 = time.time()