예제 #1
0
def token_to_source_dict(rng, data_dir):
    token_to_source_dict = {}
    java = javac_parser.Java()
    token_strings = {}
    dir_list = os.listdir(data_dir)

    rng.shuffle(dir_list)

    count = 0
    for file_event in dir_list:
        if count >= 5000:
            break
        if count % 1000 == 0:
            print count

        fail_file_path = os.path.join(data_dir, file_event, '0')
        fail = open(fail_file_path)
        fail_code = fail.read()
        tokens = java.lex(fail_code)
        for token in tokens:
            if token[0] == 'IDENTIFIER' or token[0] == 'INTLITERAL' or token[0] == 'DOUBLELITERAL' \
                                        or token[0] == 'CHARLITERAL' or token[0] == 'FLOATLITERAL' \
                                        or token[0] == 'STRINGLITERAL' or token[0] == 'LONGLITERAL'\
                                        or token[0] == 'ERROR':
                continue
            else:
                token_to_source_dict[token[0]] = token[1]
        count += 1

    print token_to_source_dict
예제 #2
0
    def run(self):
        java = javac_parser.Java()
        i = 0
        while True:
            l = q.get()
            i = done.add(1)
            if i % 100 == 0:
                sys.stdout.write(".")
                sys.stdout.flush()
            file_id = l[0]
            fail_event_id = l[1]
            success_event_id = l[2]
            os.system("mkdir ./java_data/%s_%s" % (fail_event_id, success_event_id))
            os.system("/tools/nccb/bin/print-source-state %s %s > ./java_data/%s_%s/0" % (file_id, fail_event_id, fail_event_id, success_event_id))
            os.system("/tools/nccb/bin/print-source-state %s %s > ./java_data/%s_%s/1" % (file_id, success_event_id, fail_event_id, success_event_id))

            q.task_done()
예제 #3
0
    def java(self):
        """
        Lazily start up the Java server. This decreases the chances of things
        going horribly wrong when two seperate process initialize
        the Java language instance around the same time.
        """
        if not hasattr(self, '_java_server'):
            self._java_server = javac_parser.Java()

            # Py4j usually crashes as Python is cleaning up after exit() so
            # decrement the servers' reference count to lessen the chance of
            # that happening.
            @atexit.register
            def remove_reference():
                del self._java_server

        return self._java_server
def classify_error(file_path_list):
    java = javac_parser.Java()
    error_dict = {}
    first_error_dict = {}
    count_dict = {}
    count = 0

    for i in file_path_list:
        count += 1
        with open(i, 'r') as f:
            source_code = f.read()
            try:
                dignostic_messages = java.check_syntax(source_code)
            except:
                continue
            if (len(dignostic_messages) == 0):
                continue
            # count the distribution of number of error messages
            message_count = 0
            for message in dignostic_messages:
                if (message[0] == 'ERROR'):
                    message_count += 1
            # count all the error messages
            if message_count in count_dict:
                count_dict[message_count] += 1
            else:
                count_dict[message_count] = 1

            # TODO: refactoring code

            # only consider the first error message.
            if (dignostic_messages[0][0] == 'ERROR'):
                if dignostic_messages[0][1] in first_error_dict:
                    first_error_dict[dignostic_messages[0][1]] += 1
                else:
                    first_error_dict[dignostic_messages[0][1]] = 1

            for message in dignostic_messages:
                if (message[0] == 'ERROR'):
                    # count all the error messages
                    if message[1] in error_dict:
                        error_dict[message[1]] += 1
                    else:
                        error_dict[message[1]] = 1

    return error_dict, first_error_dict, count_dict
예제 #5
0
def send_syntax_check(filepath):
    java = javac_parser.Java()
    code = open(filepath, 'r').read()
    errors = java.check_syntax(code)

    for err in errors:
        if err[0].upper() == "ERROR":
            found = True

            err_msg = err[2]
            line = err[3]
            col = err[4]
            char_ix = err[5]

            s = f'L{line} "{err_msg}" (col {col})'
            send_message(s)
            return

    send_message(text="OK")
    def run(self):
        java = javac_parser.Java()
        tff, tfp = tempfile.mkstemp()
        os.close(tff)
        while True:
            l = q.get()
            i = done.add(1)
            if i % 100 == 0:
                sys.stdout.write(".")
                sys.stdout.flush()
            file_id = l[0]
            fail_event_id = l[1]
            success_event_id = l[2]
            rtn = os.system("/tools/nccb/bin/print-source-state %s %s > %s" %
                            (file_id, fail_event_id, tfp))
            if rtn != 0:
                q.task_done()
                continue
            rtn = os.system(
                "../error_recovery_experiment/blackbox/grmtools/target/release/lrlex ../error_recovery_experiment/blackbox/grammars/java7/java.l %s > /dev/null 2> /dev/null"
                % tfp)
            if rtn != 0:
                q.task_done()
                continue

            out = subprocess.check_output(
                ["../error_recovery_experiment/runner/java_parser_none", tfp])
            if "Parsed successfully" in out:
                with open(tfp, mode='r') as code:
                    content = code.read()
                    errors = java.check_syntax(content)
                    if len(errors) > 0:
                        with open('a.txt', mode='a') as f:
                            f.write(file_id)
                            f.write('\n')

                            if errors[0][2] == "= expected":
                                f.write(errors[0][3])
                                f.write(content)
                            f.write(errors[0][2])
                            f.write('\n')
            q.task_done()
def find_code_by_error(error_type):
    java = javac_parser.Java()
    src_files_dir = './src_files/'
    for i in os.listdir(src_files_dir):
        with open(src_files_dir + i, 'r') as f:
            source_code = f.read()
            try:
                dignostic_messages = java.check_syntax(source_code)
            except:
                print(i)
            else:
                if (len(dignostic_messages) == 0):
                    continue
                # print(dignostic_messages)
                if dignostic_messages[0][1] == error_type:
                    # print(source_code)
                    # print('----------------------------------------------')
                    os.system(
                        '/Users/zhouyang/Downloads/error_recovery_experiment/runner/java_parser_none %s'
                        % ())
예제 #8
0
    def __init__(self,
                 tl_dict,
                 seed,
                 correct_fix_reward=1.0,
                 step_penalty=-0.01,
                 top_down_movement=False,
                 reject_spurious_edits=False,
                 compilation_error_store=None,
                 single_delete=True,
                 actions=None,
                 sparse_rewards=True):
        self.rng = np.random.RandomState(seed)
        self.java = javac_parser.Java()

        # self.action_tokens = list('''{}();,.''')
        # 把这里的action_tokens修改为:
        self.action_tokens = [
            'LPAREN', 'RPAREN', 'COMMA', 'SEMI', 'LBRACE', 'RBRACE', 'DOT'
        ]

        # 4个移动的action
        self.actions = ([] if top_down_movement else ['move_up', 'move_left']
                        ) + ['move_down', 'move_right']

        # 初始化操作token的action
        '''
        原来RLAssist定义的action: 
        ['move_down', 'move_right', 'insert{', 'insert}', 'insert(', 'insert)', 'insert;', 
        'insert,', 'delete', 'replace;with,', 'replace,with;', 'replace.with;', 'replace;)with);']
        '''

        for each in self.action_tokens:
            # self.actions += ['insert' + each] if each != '.' else []
            # 修改为 'insert token'的格式
            self.actions += ['insert' + ' ' + each] if each != 'DOT' else []
            # 为什么这里不能为'.'?
            # 应该都会变成'delete'
            self.actions += ['delete' + ' ' +
                             each] if not single_delete else []

        self.actions += ['delete'] if single_delete else []
        self.actions += [
            'replace;with,', 'replace,with;', 'replace.with;',
            'replace;)with);'
        ]

        self.replacement_action_to_action_sequence_map = {
            'replace;)with);':
            ['delete SEMI RPARE', 'move_right', 'insert SEMI'],
            'replace;with,': ['delete SEMI', 'insert COMMA'],
            'replace,with;': ['delete COMMA', 'insert SEMI'],
            'replace.with;': ['delete DOT', 'insert SEMI']
        }
        self.action_sequence_to_replacement_action_map = {}
        for replacement_action, action_seq in self.replacement_action_to_action_sequence_map.items(
        ):
            self.action_sequence_to_replacement_action_map[''.join(
                action_seq)] = replacement_action

        if actions is not None:
            raise NotImplementedError()

        self.tl_dict = tl_dict
        self.rev_tl_dict = get_rev_dict(tl_dict)
        self.new_line = tl_dict['-new-line-']
        self.pad = tl_dict['_pad_']

        self.cursor = tl_dict['EOF']
        self.normalized_ids_tl_dict, self.id_token_vecs = self.get_normalized_ids_dict(
            tl_dict)
        assert self.cursor == self.normalized_ids_tl_dict['EOF']
        assert self.new_line == self.normalized_ids_tl_dict['-new-line-']
        assert self.pad == self.normalized_ids_tl_dict['_pad_']

        # 限定了能够修改的token
        self.mutables = [self.tl_dict[each] for each in self.action_tokens]

        self.top_down_movement = top_down_movement
        self.reject_spurious_edits = reject_spurious_edits
        self.compilation_error_store = compilation_error_store
        self.single_delete = single_delete

        self.correct_fix_reward = correct_fix_reward
        if sparse_rewards:
            self.step_penalty = step_penalty
            self.edit_penalty = 0.0
            self.error_resolution_reward = 2 * abs(step_penalty) - (
                abs(step_penalty) / 10)
        else:
            self.step_penalty = step_penalty / 2
            self.edit_penalty = self.step_penalty * 5
            self.error_resolution_reward = 2 * (abs(self.edit_penalty) -
                                                abs(self.step_penalty))
예제 #9
0
def post_analysis(before_after_program):
    java = javac_parser.Java()
    repair_rate_data = {
        '0~50': [0, 0, 0, 0, 0],
        '50~100': [0, 0, 0, 0, 0],
        '100~200': [0, 0, 0, 0, 0],
        '200~450': [0, 0, 0, 0, 0],
        '450~1000': [0, 0, 0, 0, 0]
    }

    complete_rate_data = {'< 4': [0, 0, 0, 0, 0], \
                            '4~15': [0, 0, 0, 0, 0], \
                            '15~50': [0, 0, 0, 0, 0], \
                            '50~100': [0, 0, 0, 0, 0], \
                            '100~200': [0, 0, 0, 0, 0], \
                            '> 200': [0, 0, 0, 0, 0]}

    original_total_error_num = 0
    repaired_total_error_num = 0

    original_err_msg_list = []
    repaired_err_msg_list = []

    for original_program, repaired_program, token_num, line_num in before_after_program:
        for key in repair_rate_data.keys():
            left = int(key.split('~')[0])
            right = int(key.split('~')[1])
            if token_num >= left and token_num <= right:
                repair_rate_data[key][0] += 1

                # use javac_parser.check_syntax() to get all the parse errors
                original_err = java.check_syntax(original_program)

                # get the number of parse errors
                original_error_num = len(original_err)

                # store the eoors
                original_err_msg_list.append(original_err)

                # add errors to total
                original_total_error_num += original_error_num

                # do the same thing for repaired programs
                after_err = java.check_syntax(repaired_program)
                repaired_error_num = len(after_err)
                repaired_err_msg_list.append(after_err)
                repaired_total_error_num += repaired_error_num

                # Judge repair type: complete, partial, cascading or dumy (no change)
                is_complete_repair = False
                if repaired_error_num == 0:
                    is_complete_repair = True

                if is_complete_repair:
                    repair_rate_data[key][1] += 1

                is_partial_repair = False
                if original_error_num > repaired_error_num and repaired_error_num > 0:
                    is_partial_repair = True

                if is_partial_repair:
                    repair_rate_data[key][2] += 1

                introduce_cascading_error = False
                if original_error_num < repaired_error_num:
                    introduce_cascading_error = True

                if introduce_cascading_error:
                    repair_rate_data[key][3] += 1

                dumy_repair = False
                if original_error_num == repaired_error_num:
                    dumy_repair = True
                if dumy_repair:
                    repair_rate_data[key][4] += 1

        # store data by lines of code

    tb = pt.PrettyTable()
    tb.field_names = [
        'Token Length Range', 'Number of Files', 'Complete Repair',
        'Partial Repair', 'Cascading Repair', 'Dummy Repair'
    ]
    for key in repair_rate_data.keys():
        tb.add_row([key] + repair_rate_data[key])
    # tb.add_row(['Other', other])
    # tb.add_row(['Total', len(dataset)])

    print tb

    # print complete_rate_data

    # print 'original_total_error_num:  %d' % (original_total_error_num)
    # print 'repaired_total_error_num  :  %d ' % (repaired_total_error_num)
    original_err_distribution = count_err_message(original_err_msg_list)
    repaired_err_distribution = count_err_message(repaired_err_msg_list)
    analyze_perf_by_err_type(original_err_distribution,
                             repaired_err_distribution)
예제 #10
0
if args.task == 'typo':
    normalize_names = True
    fix_kind = 'replace'
else:
    assert args.task == 'ids'
    normalize_names = False
    fix_kind = 'insert'

times = []
counts = []
token_lens = []

print 'test data length:', len(test_dataset)

java = javac_parser.Java()

before_after_program = []

count = 0

for problem_id, test_programs in test_dataset.iteritems():
    if count % 10000 == 0:
        print '%d programs evaluated' % (count)
    count += 1
    sequences_of_programs[problem_id] = {}
    fixes_suggested_by_network[problem_id] = {}
    start = time.time()

    token_seq = test_programs[0][0]
    token_len = len(token_seq.split())
def extract_var_names(nodes, bodies, lang):

    if lang == "python":
        pass
    elif lang == "java":
        pass
        import javac_parser
        java = javac_parser.Java()
    else:
        raise ValueError("Valid languages: python, java")

    id_offset = nodes["id"].max() + 1
    bodies = bodies[['id', 'body']].dropna(axis=0)

    if lang == "java":
        nodes = read_nodes(working_directory)
        names = nodes['serialized_name'].apply(
            lambda x: x.split("___")[0].split("."))
        not_local = set()
        for name in names:
            for n in name:
                not_local.add(n)

    variable_names = dict()
    func_var_pairs = []

    for body_ind, (ind,
                   row) in custom_tqdm(enumerate(bodies.iterrows()),
                                       message="Extracting variable names",
                                       total=len(bodies)):
        variables = []
        try:
            if lang == "python":
                tree = ast.parse(row['body'].strip())
                variables.extend([
                    n.id for n in ast.walk(tree) if type(n).__name__ == "Name"
                ])
            elif lang == "java":
                lines = row['body'].strip()  #.split("\n")
                tokens = java.lex(lines)
                variables = [
                    name for type, name, _, _, _ in tokens
                    if type == "IDENTIFIER" and name not in not_local
                ]
            else:
                continue
        except SyntaxError:  # thrown by ast
            continue

        for v in set(variables):
            if v not in variable_names:
                variable_names[v] = id_offset
                id_offset += 1

            func_var_pairs.append((row['id'], v))

        # print(f"\r{body_ind}/{len(bodies)}", end="")
    # print(" " * 30, end ="\r")

    if func_var_pairs:
        counter = Counter(map(lambda x: x[1], func_var_pairs))
        pp = []
        for func, var in func_var_pairs:
            if counter[var] > 1:
                pp.append({'src': func, 'dst': var})
        pairs = pd.DataFrame(pp)
        return pairs
    else:
        return None
예제 #12
0
def repair_rate_analysis(result):
    java = javac_parser.Java()
    repair_rate_data = {
        '0~50': [0, 0, 0, 0, 0],
        '50~100': [0, 0, 0, 0, 0],
        '100~200': [0, 0, 0, 0, 0],
        '200~450': [0, 0, 0, 0, 0],
        '450~1000': [0, 0, 0, 0, 0]
    }

    original_total_error_num = 0
    repaired_total_error_num = 0

    original_err_msg_list = []
    repaired_err_msg_list = []

    too_large_count = 0
    zero_count = 0

    for token_num, org_error_list, repair_error_list in result:
        if token_num > 1000:
            too_large_count += 1
        if token_num == 0:
            zero_count += 1
        # 现在数量应该是对的了
        for key in repair_rate_data.keys():
            left = int(key.split('~')[0])
            right = int(key.split('~')[1])
            if token_num > left and token_num <= right:
                repair_rate_data[key][0] += 1

                # get the number of parse errors
                original_error_num = len(org_error_list)

                # store the eoors
                original_err_msg_list.append(org_error_list)

                # add errors to total
                original_total_error_num += original_error_num

                # do the same thing for repaired programs
                repaired_error_num = len(repair_error_list)
                repaired_err_msg_list.append(repair_error_list)
                repaired_total_error_num += repaired_error_num

                # Judge repair type: complete, partial, cascading or dumy (no change)
                is_complete_repair = False
                if original_error_num > 0 and repaired_error_num == 0:
                    # 如果修复后的为0,
                    is_complete_repair = True

                if is_complete_repair:
                    repair_rate_data[key][1] += 1

                is_partial_repair = False
                if original_error_num > repaired_error_num and repaired_error_num > 0:
                    is_partial_repair = True

                if is_partial_repair:
                    repair_rate_data[key][2] += 1

                introduce_cascading_error = False
                if original_error_num < repaired_error_num:
                    introduce_cascading_error = True

                if introduce_cascading_error:
                    repair_rate_data[key][3] += 1

                dumy_repair = False
                if original_error_num == repaired_error_num:
                    ## 如果前后相同,且不为0
                    dumy_repair = True
                if dumy_repair:
                    repair_rate_data[key][4] += 1

        # store data by lines of code
    print('> 1000: ', too_large_count)
    print('Zero Count: ', zero_count)

    tb = pt.PrettyTable()
    tb.field_names = [
        'Token Length Range', 'Number of Files', 'Complete Repair',
        'Partial Repair', 'Cascading Repair', 'Dummy Repair'
    ]
    for key in repair_rate_data.keys():
        tb.add_row([key] + repair_rate_data[key])

    print(tb)