def create_one_code_res(input_seq_name, vocabulary):
    from c_parser.ast_parser import parse_ast_code_graph
    code_graph = parse_ast_code_graph(input_seq_name)
    input_length = code_graph.graph_length + 2
    in_seq, graph = code_graph.graph
    begin_id = vocabulary.word_to_id(vocabulary.begin_tokens[0])
    end_id = vocabulary.word_to_id(vocabulary.end_tokens[0])
    input_seq = [begin_id] + [vocabulary.word_to_id(t)
                              for t in in_seq] + [end_id]
    adj = [[a + 1, b + 1] for a, b, _ in graph] + [[b + 1, a + 1]
                                                   for a, b, _ in graph]
    return input_seq, input_length, adj
 def parse_ast_node(input_seq_name):
     # input_seq_name = [vocabulary.id_to_word(token_id) for token_id in one_final_output]
     # print(' '.join(input_seq_name))
     from c_parser.ast_parser import parse_ast_code_graph
     code_graph = parse_ast_code_graph(input_seq_name)
     input_length = code_graph.graph_length + 2
     in_seq, graph = code_graph.graph
     # begin_id = vocabulary.word_to_id(vocabulary.begin_tokens[0])
     # end_id = vocabulary.word_to_id(vocabulary.end_tokens[0])
     input_seq = [begin_id] + [vocabulary.word_to_id(t)
                               for t in in_seq] + [end_id]
     adj = [[a + 1, b + 1] for a, b, _ in graph] + [[b + 1, a + 1]
                                                    for a, b, _ in graph]
     return input_seq_name, input_seq, adj, input_length
def generate_one_graph_input(input_token_names, vocabulary, tokenize_fn):
    begin_token = vocabulary.begin_tokens[0]
    end_token = vocabulary.end_tokens[0]
    if tokenize_fn(' '.join(input_token_names), print_info=False) is None:
        input_seq = [begin_token
                     ] + input_token_names + ['<Delimiter>', end_token]
        input_length = len(input_seq)
        adj = [[0, 1]]
        is_effect = False
    else:
        code_graph, is_effect = parse_ast_code_graph(input_token_names)
        input_length = code_graph.graph_length + 2
        in_seq, graph = code_graph.graph
        input_seq = [begin_token] + in_seq + [end_token]
        adj = [[a + 1, b + 1] for a, b, _ in graph] + [[b + 1, a + 1]
                                                       for a, b, _ in graph]
    return input_seq, input_length, adj, is_effect
    def _get_raw_sample(self, row):
        # sample = dict(row)
        row.select_random_i(only_first=self.only_first)
        sample = {}
        sample['id'] = row['id']
        sample['includes'] = row['includes']
        # if not self.is_flatten and self.do_multi_step_sample:
        #     sample['input_seq'] = row['error_token_id_list'][0]
        #     sample['input_seq_name'] = row['error_token_name_list'][0][1:-1]
        #     sample['input_length'] = len(sample['input_seq'])
        # elif not self.is_flatten and not self.do_multi_step_sample:
        #     sample['input_seq'] = row['error_token_id_list']
        #     sample['input_seq_name'] = [r[1:-1] for r in row['error_token_name_list']]
        #     sample['input_length'] = [len(ids) for ids in sample['input_seq']]
        # else:
        sample['input_seq'] = row['error_token_id_list']
        sample['input_seq_name'] = row['error_token_name_list'][1:-1]
        sample['input_length'] = len(sample['input_seq'])
        sample['copy_length'] = sample['input_length']
        sample['adj'] = 0

        inner_begin_id = self.vocabulary.word_to_id(
            self.vocabulary.begin_tokens[1])
        inner_end_id = self.vocabulary.word_to_id(
            self.vocabulary.end_tokens[1])
        if not self.do_multi_step_sample:
            sample['target'] = [inner_begin_id
                                ] + row['sample_ac_id_list'] + [inner_end_id]

            sample['is_copy_target'] = row['is_copy_list'] + [0]
            sample['copy_target'] = row['copy_pos_list'] + [-1]

            sample_mask = sorted(row['sample_mask_list'] + [inner_end_id])
            sample_mask_dict = {v: i for i, v in enumerate(sample_mask)}
            sample['compatible_tokens'] = [
                sample_mask for i in range(len(sample['is_copy_target']))
            ]
            sample['compatible_tokens_length'] = [
                len(one) for one in sample['compatible_tokens']
            ]

            sample['sample_target'] = row['sample_ac_id_list'] + [inner_end_id]
            sample['sample_target'] = [
                t if c == 0 else -1 for c, t in zip(sample['is_copy_target'],
                                                    sample['sample_target'])
            ]
            sample['sample_small_target'] = [
                sample_mask_dict[t] if c == 0 else -1 for c, t in zip(
                    sample['is_copy_target'], sample['sample_target'])
            ]
            sample['sample_outputs_length'] = len(sample['sample_target'])

            sample['full_output_target'] = row['target_ac_token_id_list'][1:-1]

            sample['final_output'] = row['ac_code_ids']
            sample['final_output_name'] = row['ac_code_name_with_labels'][1:-1]
            sample['p1_target'] = row['error_pos_list'][0]
            sample['p2_target'] = row['error_pos_list'][1]
            sample['error_pos_list'] = row['error_pos_list']

            sample['distance'] = row['distance']
            sample['includes'] = row['includes']
        else:
            pass

        if self.use_ast:
            code_graph = parse_ast_code_graph(sample['input_seq_name'])
            sample['input_length'] = code_graph.graph_length + 2
            in_seq, graph = code_graph.graph
            begin_id = self.vocabulary.word_to_id(
                self.vocabulary.begin_tokens[0])
            end_id = self.vocabulary.word_to_id(self.vocabulary.end_tokens[0])
            sample['input_seq'] = [begin_id] + [
                self.vocabulary.word_to_id(t) for t in in_seq
            ] + [end_id]
            sample['adj'] = [[a + 1, b + 1]
                             for a, b, _ in graph] + [[b + 1, a + 1]
                                                      for a, b, _ in graph]

        return sample
        def test_parse_ast_code_graph(seq_name):
            try:
                return parse_ast_code_graph(seq_name[1:-1]).graph_length

            except Exception as e:
                return 710
    def _get_raw_sample(self, row):
        # error_tokens = self.vocabulary.parse_text_without_pad([[k.value for k in self.data_df.iloc[index]["tokens"]]],
        #                                                       use_position_label=True)[0]
        # ac_tokens = self.vocabulary.parse_text_without_pad([[k.value for k in self.data_df.iloc[index]["ac_tokens"]]],
        #                                                       use_position_label=True)[0]
        # sample = dict(row)
        sample = {}
        sample['id'] = row['id']
        sample['includes'] = row['includes']
        if not self.is_flatten and self.do_multi_step_sample:
            sample['input_seq'] = row['error_token_id_list'][0]
            sample['input_seq_name'] = row['error_token_name_list'][0][1:-1]
            sample['input_length'] = len(sample['input_seq'])
        elif not self.is_flatten and not self.do_multi_step_sample:
            sample['input_seq'] = row['error_token_id_list']
            sample['input_seq_name'] = [
                r[1:-1] for r in row['error_token_name_list']
            ]
            sample['input_length'] = [len(ids) for ids in sample['input_seq']]
        else:
            sample['input_seq'] = row['error_token_id_list']
            sample['input_seq_name'] = row['error_token_name_list'][1:-1]
            sample['input_length'] = len(sample['input_seq'])
        sample['copy_length'] = sample['input_length']
        sample['last_input_seq_name'] = sample['input_seq_name']

        inner_begin_id = self.vocabulary.word_to_id(
            self.vocabulary.begin_tokens[1])
        inner_end_id = self.vocabulary.word_to_id(
            self.vocabulary.end_tokens[1])
        if not self.do_multi_step_sample:

            # sample['sample_inputs'] = [[inner_begin_id]+one for one in row['sample_ac_id_list']]
            # sample['sample_inputs_length'] = [len(ids) for ids in sample['sample_inputs']]
            if not self.is_flatten:
                sample['is_copy_target'] = [
                    one + [0] for one in row['is_copy_list']
                ]
                sample['copy_target'] = [
                    one + [-1] for one in row['copy_pos_list']
                ]

                sample['sample_target'] = [
                    one + [inner_end_id] for one in row['sample_ac_id_list']
                ]
                sample['sample_outputs_length'] = [
                    len(ids) for ids in sample['sample_target']
                ]
                sample['target'] = [
                    one + [inner_end_id] for one in row['sample_ac_id_list']
                ]

                error_start_pos_list, error_end_pos_list = list(
                    *row['error_pos_list'])
                sample['p1_target'] = error_start_pos_list
                sample['p2_target'] = error_end_pos_list
                sample['error_pos_list'] = row['error_pos_list']

                sample['compatible_tokens'] = [
                    row['sample_mask_list'] + [inner_end_id]
                    for i in range(len(sample['sample_target']))
                ]
                sample['compatible_tokens_length'] = [
                    len(one) for one in sample['compatible_tokens']
                ]

                sample['distance'] = row['distance']
                sample['adj'] = 0
            else:
                sample['target'] = [
                    inner_begin_id
                ] + row['sample_ac_id_list'] + [inner_end_id]

                sample['is_copy_target'] = row['is_copy_list'] + [0]
                sample['copy_target'] = row['copy_pos_list'] + [-1]

                sample_mask = sorted(row['sample_mask_list'] + [inner_end_id])
                sample_mask_dict = {v: i for i, v in enumerate(sample_mask)}
                sample['compatible_tokens'] = [
                    sample_mask for i in range(len(sample['is_copy_target']))
                ]
                sample['compatible_tokens_length'] = [
                    len(one) for one in sample['compatible_tokens']
                ]

                sample['sample_target'] = row['sample_ac_id_list'] + [
                    inner_end_id
                ]
                sample['sample_target'] = [
                    t if c == 0 else -1 for c, t in zip(
                        sample['is_copy_target'], sample['sample_target'])
                ]
                if self.only_sample:
                    sample['sample_target'] = [
                        st if c == 0 else sample['input_seq'][t] for c, t, st
                        in zip(sample['is_copy_target'], sample['copy_target'],
                               sample['sample_target'])
                    ]
                    sample['is_copy_target'] = [0] * len(
                        sample['is_copy_target'])
                sample['sample_small_target'] = [
                    sample_mask_dict[t] if c == 0 else -1 for c, t in zip(
                        sample['is_copy_target'], sample['sample_target'])
                ]
                sample['sample_outputs_length'] = len(sample['sample_target'])

                sample['full_output_target'] = row['target_ac_token_id_list'][
                    1:-1]

                sample['final_output'] = row['ac_code_ids']
                sample['p1_target'] = row['error_pos_list'][0]
                sample['p2_target'] = row['error_pos_list'][1]
                sample['error_pos_list'] = row['error_pos_list']

                sample['distance'] = row['distance']
                sample['includes'] = row['includes']
                sample['adj'] = 0
        else:
            sample['copy_length'] = sample['input_length']
            sample['error_count'] = row['distance']
            sample['adj'] = 0

        if self.use_ast:
            code_graph = parse_ast_code_graph(sample['input_seq_name'])
            sample['input_length'] = code_graph.graph_length + 2
            in_seq, graph = code_graph.graph
            begin_id = self.vocabulary.word_to_id(
                self.vocabulary.begin_tokens[0])
            end_id = self.vocabulary.word_to_id(self.vocabulary.end_tokens[0])
            sample['input_seq'] = [begin_id] + [
                self.vocabulary.word_to_id(t) for t in in_seq
            ] + [end_id]
            sample['adj'] = [[a + 1, b + 1]
                             for a, b, _ in graph] + [[b + 1, a + 1]
                                                      for a, b, _ in graph]

        return sample