Пример #1
0
def devectorize(vector, dictionary, reverse=False):
    result = []
    reversed_dictionary = get_rev_dict(dictionary)

    for each in vector:
        if reversed_dictionary[each] != '_pad_' and (not reverse or reversed_dictionary[each] != '_eos_'):
            result.append(reversed_dictionary[each])

    if reverse:
        result = result[::-1]

    if len(result) == 0:
        raise EmptyFixException('Empty vector: {} = {} passed in devectorize'.format(
            vector, [reversed_dictionary[v] for v in vector]))

    return ' '.join(filter_minus_one(result))
def save_dictionaries(destination, tldict):
    all_dicts = (tldict, get_rev_dict(tldict))
    np.save(os.path.join(destination, 'all_dicts.npy'), all_dicts)
Пример #3
0
    def __init__(self,
                 tl_dict,
                 seed,
                 correct_fix_reward=1.0,
                 step_penalty=-0.01,
                 top_down_movement=False,
                 reject_spurious_edits=False,
                 compilation_error_store=None,
                 single_delete=True,
                 actions=None,
                 sparse_rewards=True):
        self.rng = np.random.RandomState(seed)
        self.java = javac_parser.Java()

        # self.action_tokens = list('''{}();,.''')
        # 把这里的action_tokens修改为:
        self.action_tokens = [
            'LPAREN', 'RPAREN', 'COMMA', 'SEMI', 'LBRACE', 'RBRACE', 'DOT'
        ]

        # 4个移动的action
        self.actions = ([] if top_down_movement else ['move_up', 'move_left']
                        ) + ['move_down', 'move_right']

        # 初始化操作token的action
        '''
        原来RLAssist定义的action: 
        ['move_down', 'move_right', 'insert{', 'insert}', 'insert(', 'insert)', 'insert;', 
        'insert,', 'delete', 'replace;with,', 'replace,with;', 'replace.with;', 'replace;)with);']
        '''

        for each in self.action_tokens:
            # self.actions += ['insert' + each] if each != '.' else []
            # 修改为 'insert token'的格式
            self.actions += ['insert' + ' ' + each] if each != 'DOT' else []
            # 为什么这里不能为'.'?
            # 应该都会变成'delete'
            self.actions += ['delete' + ' ' +
                             each] if not single_delete else []

        self.actions += ['delete'] if single_delete else []
        self.actions += [
            'replace;with,', 'replace,with;', 'replace.with;',
            'replace;)with);'
        ]

        self.replacement_action_to_action_sequence_map = {
            'replace;)with);':
            ['delete SEMI RPARE', 'move_right', 'insert SEMI'],
            'replace;with,': ['delete SEMI', 'insert COMMA'],
            'replace,with;': ['delete COMMA', 'insert SEMI'],
            'replace.with;': ['delete DOT', 'insert SEMI']
        }
        self.action_sequence_to_replacement_action_map = {}
        for replacement_action, action_seq in self.replacement_action_to_action_sequence_map.items(
        ):
            self.action_sequence_to_replacement_action_map[''.join(
                action_seq)] = replacement_action

        if actions is not None:
            raise NotImplementedError()

        self.tl_dict = tl_dict
        self.rev_tl_dict = get_rev_dict(tl_dict)
        self.new_line = tl_dict['-new-line-']
        self.pad = tl_dict['_pad_']

        self.cursor = tl_dict['EOF']
        self.normalized_ids_tl_dict, self.id_token_vecs = self.get_normalized_ids_dict(
            tl_dict)
        assert self.cursor == self.normalized_ids_tl_dict['EOF']
        assert self.new_line == self.normalized_ids_tl_dict['-new-line-']
        assert self.pad == self.normalized_ids_tl_dict['_pad_']

        # 限定了能够修改的token
        self.mutables = [self.tl_dict[each] for each in self.action_tokens]

        self.top_down_movement = top_down_movement
        self.reject_spurious_edits = reject_spurious_edits
        self.compilation_error_store = compilation_error_store
        self.single_delete = single_delete

        self.correct_fix_reward = correct_fix_reward
        if sparse_rewards:
            self.step_penalty = step_penalty
            self.edit_penalty = 0.0
            self.error_resolution_reward = 2 * abs(step_penalty) - (
                abs(step_penalty) / 10)
        else:
            self.step_penalty = step_penalty / 2
            self.edit_penalty = self.step_penalty * 5
            self.error_resolution_reward = 2 * (abs(self.edit_penalty) -
                                                abs(self.step_penalty))
def undeclare_variable(rng,
                       old_program,
                       program_string,
                       deleted_ids,
                       name_dict=None,
                       print_debug_messages=False):
    if name_dict is not None:
        rev_name_dict = get_rev_dict(name_dict)

    # Lines
    orig_lines = get_lines(program_string)
    old_lines = get_lines(old_program)

    # Lines to ignore
    struct_lines = []
    structs_deep = 0

    for i, line in enumerate(orig_lines):
        if len(re.findall('_<keyword>_struct _<id>_\d@ _<op>_\{', line)) > 0 or \
           len(re.findall('_<keyword>_union _<id>_\d@ _<op>_\{', line)) > 0 or \
           len(re.findall('_<keyword>_enum _<id>_\d@ _<op>_\{', line)) > 0:
            structs_deep += len(re.findall('_<op>_\{', line))
        elif structs_deep > 0:
            structs_deep += len(re.findall('_<op>_\{', line))
            structs_deep -= len(re.findall('_<op>_\}', line))
            assert structs_deep >= 0, str(structs_deep) + " " + line
            struct_lines.append(i)

    global_lines = []
    brackets_deep = 0

    for i, line in enumerate(orig_lines):
        if len(re.findall('_<op>_\{', line)) > 0 or len(
                re.findall('_<op>_\}', line)) > 0:
            brackets_deep += len(re.findall('_<op>_\{', line))
            brackets_deep -= len(re.findall('_<op>_\}', line))
            assert brackets_deep >= 0, str(brackets_deep) + " " + line
        elif brackets_deep == 0:
            global_lines.append(i)

    if print_debug_messages:
        print 'Ignoring lines:', struct_lines
        print 'Ignoring lines:', global_lines

        for line in sorted(set(struct_lines + global_lines)):
            print "-", orig_lines[line]

    # Variables
    variables = []

    for token in program_string.split():
        if '_<id>_' in token:
            if token not in variables:
                variables.append(token)

    assert len(orig_lines) == len(old_lines)

    # Look for a declaration
    done = False

    rng.shuffle(variables)

    for to_undeclare in variables:
        if print_debug_messages:
            print 'Looking for:', rev_name_dict[to_undeclare], '...'

        # Find a location (scope) to undeclare it from
        shuffled_lines = list(
            set(range(len(orig_lines))) - set(struct_lines + global_lines))
        rng.shuffle(shuffled_lines)

        # NEW
        regex_alone_use = '(_<keyword>_(?:struct|enum|union) _<id>_\d+@|_<type>_\w+)((?: _<op>_\*)* %s(?: _<op>_\[(?: [^\]]+)? _<op>_\])*)(?: _<op>_= [^,;]+)(?: _<op>_;)' % to_undeclare
        regex_alone = '((?:_<keyword>_(?:struct|enum|union) _<id>_\d+@|_<type>_\w+)(?: _<op>_\*)* %s(?: _<op>_\[(?: [^\]]+)? _<op>_\])* _<op>_;)' % to_undeclare
        regex_group_leader = '((?:_<keyword>_(?:struct|enum|union) _<id>_\d+@|_<type>_\w+)(?: _<op>_\*)*)( %s(?: _<op>_\[(?: [^\]]+)? _<op>_\])*)(?: _<op>_= [^,;]+)?( _<op>_,)(?:(?: _<op>_\*)* _<id>_\d+@(?: _<op>_\[(?: [^\]]+)? _<op>_\])*(?: _<op>_= [^,;]+)? _<op>_,)*(?:(?: _<op>_\*)* _<id>_\d+@(?: _<op>_\[(?: [^\]]+)? _<op>_\])*(?: _<op>_= [^,;]+)? _<op>_;)' % to_undeclare
        regex_group = '(_<keyword>_(?:struct|enum|union) _<id>_\d+@|_<type>_\w+)(?: _<op>_\*)* _<id>_\d+@(?: _<op>_\[(?: [^\]]+)? _<op>_\])*(?: _<op>_= [^,;]+)?(?: _<op>_,(?: _<op>_\*)* _<id>_\d+@(?: _<op>_\[(?: [^\]]+)? _<op>_\])*(?: _<op>_= [^,;]+)?)*( _<op>_,(?: _<op>_\*)* %s(?: _<op>_\[(?: [^\]]+)? _<op>_\])*(?: _<op>_= [^,;]+)?)(?: _<op>_,(?: _<op>_\*)* _<id>_\d+@(?: _<op>_\[(?: [^\]]+)? _<op>_\])*(?: _<op>_= [^,;]+)?)*(?: _<op>_;)' % to_undeclare

        fix_line = None
        declaration = None
        declaration_pos = None

        # Start our search upwards
        for i in shuffled_lines:
            if len(re.findall(regex_alone_use, orig_lines[i])) == 1:
                if print_debug_messages:
                    print("On line %d:" % i), tokens_to_source(
                        orig_lines[i], name_dict, clang_format=True)
                    print "Found Alone use", re.findall(
                        regex_alone_use, orig_lines[i])
                m = re.search(regex_alone_use, orig_lines[i])
                declaration = orig_lines[i][m.start(1):m.end(2)] + ' _<op>_;'
                declaration_pos = i

                # Mutate
                orig_lines[i] = orig_lines[i][:m.start(1)] + orig_lines[i][
                    m.end(1) + 1:]
                done = True
                break

            if len(re.findall(regex_alone, orig_lines[i])) == 1:
                if print_debug_messages:
                    print("On line %d:" % i), tokens_to_source(
                        orig_lines[i], name_dict, clang_format=True)
                    print "Found Alone", re.findall(regex_alone, orig_lines[i])
                m = re.search(regex_alone, orig_lines[i])
                declaration = orig_lines[i][m.start(1):m.end(1)]
                declaration_pos = i

                # Mutate
                orig_lines[i] = orig_lines[i][:m.start(1)] + orig_lines[i][
                    m.end(1) + 1:]
                done = True
                break

            elif len(re.findall(regex_group, orig_lines[i])) == 1:
                if print_debug_messages:
                    print("On line %d:" % i), tokens_to_source(
                        orig_lines[i], name_dict, clang_format=True)
                    print "Found Group", re.findall(regex_group, orig_lines[i])
                m = re.search(regex_group, orig_lines[i])
                declaration = orig_lines[i][m.start(1):m.end(1)] + orig_lines[
                    i][m.start(2):m.end(2)][8:] + ' _<op>_;'
                declaration_pos = i

                try:
                    end_of_declr = declaration.index('_<op>_=')
                    declaration = declaration[:end_of_declr]
                except ValueError:
                    pass

                # Mutate
                orig_lines[i] = orig_lines[i][:m.start(2) +
                                              1] + orig_lines[i][m.end(2) + 1:]
                done = True
                break

            elif len(re.findall(regex_group_leader, orig_lines[i])) == 1:
                if print_debug_messages:
                    print("On line %d:" % i), tokens_to_source(
                        orig_lines[i], name_dict, clang_format=True)
                    print "Found Group Leader", re.findall(
                        regex_group_leader, orig_lines[i])
                m = re.search(regex_group_leader, orig_lines[i])
                declaration = orig_lines[i][m.start(1):m.end(2)] + ' _<op>_;'
                declaration_pos = i

                # Mutate
                orig_lines[i] = orig_lines[i][:m.start(2) +
                                              1] + orig_lines[i][m.end(3) + 1:]
                done = True
                break

        if done:
            break

    if not done:
        # Failed to find something to undeclare
        raise NothingToMutateException

    # Find the function signature
    fn_regex = '(?:_<keyword>_(?:struct|union|enum) _<id>_\d+@|_<type>_\w+|_<keyword>_void)(?: _<op>_\*)* (?:_<id>_\d+@|_<APIcall>_main) _<op>_\('
    fn_start_regex = '_<op>_\{'
    inserted = False

    assert declaration_pos != None
    for i in range(declaration_pos, 0, -1):
        if len(re.findall(fn_regex, old_lines[i])) == 1:
            for j in range(i, len(old_lines)):
                if len(re.findall(fn_start_regex, old_lines[i])) >= 1:
                    fix_line = j
                    break
            inserted = True

        if inserted:
            break

    if not inserted:
        # print Failed to insert fix
        raise FailedToMutateException
    if fix_line is None:
        # Couldn't find { after function definition
        raise FailedToMutateException

    fix = '_<insertion>_ '

    assert fix_line is not None

    for digit in str(fix_line):
        fix += str(digit) + ' '

    fix += '~ ' + declaration

    to_delete = False

    if orig_lines[declaration_pos].strip() == '':
        to_delete = declaration_pos
        del orig_lines[to_delete]

    recomposed_program = ''

    for i, line in enumerate(orig_lines):
        for digit in str(i):
            recomposed_program += digit + ' '

        recomposed_program += '~ '
        recomposed_program += line + ' '

    return recomposed_program, fix, fix_line
Пример #5
0
    def __init__(self,
                 tl_dict,
                 seed,
                 correct_fix_reward=1.0,
                 step_penalty=-0.01,
                 top_down_movement=False,
                 reject_spurious_edits=False,
                 compilation_error_store=None,
                 single_delete=True,
                 actions=None,
                 sparse_rewards=True):
        self.rng = np.random.RandomState(seed)
        self.action_tokens = list('''{}();,.''')
        self.actions = ([] if top_down_movement else ['move_up', 'move_left']
                        ) + ['move_down', 'move_right']
        for each in self.action_tokens:
            self.actions += ['insert' + each] if each != '.' else []
            self.actions += ['delete' + each] if not single_delete else []
        self.actions += ['delete'] if single_delete else []
        self.actions += [
            'replace;with,', 'replace,with;', 'replace.with;',
            'replace;)with);'
        ]

        self.replacement_action_to_action_sequence_map = {
            'replace;)with);': ['delete;)', 'move_right', 'insert;'],
            'replace;with,': ['delete;', 'insert,'],
            'replace,with;': ['delete,', 'insert;'],
            'replace.with;': ['delete.', 'insert;']
        }
        self.action_sequence_to_replacement_action_map = {}
        for replacement_action, action_seq in self.replacement_action_to_action_sequence_map.items(
        ):
            self.action_sequence_to_replacement_action_map[''.join(
                action_seq)] = replacement_action

        if actions is not None:
            raise NotImplementedError()

        self.tl_dict = tl_dict
        self.rev_tl_dict = get_rev_dict(tl_dict)
        self.new_line = tl_dict['-new-line-']
        self.pad = tl_dict['_pad_']
        self.cursor = tl_dict['_eos_']
        self.normalized_ids_tl_dict, self.id_token_vecs = self.get_normalized_ids_dict(
            tl_dict)
        assert self.cursor == self.normalized_ids_tl_dict['_eos_']
        assert self.new_line == self.normalized_ids_tl_dict['-new-line-']
        assert self.pad == self.normalized_ids_tl_dict['_pad_']

        self.mutables = [
            self.tl_dict['_<op>_' + each] for each in self.action_tokens
        ]

        self.top_down_movement = top_down_movement
        self.reject_spurious_edits = reject_spurious_edits
        self.compilation_error_store = compilation_error_store
        self.single_delete = single_delete

        self.correct_fix_reward = correct_fix_reward
        if sparse_rewards:
            self.step_penalty = step_penalty
            self.edit_penalty = 0.0
            self.error_resolution_reward = 2 * abs(step_penalty) - (
                abs(step_penalty) / 10)
        else:
            self.step_penalty = step_penalty / 2
            self.edit_penalty = self.step_penalty * 5
            self.error_resolution_reward = 2 * (abs(self.edit_penalty) -
                                                abs(self.step_penalty))