Example #1
0
    def rewrite(self, tree: ast.AST, env: SymbolTable,
                metadata: tp.MutableMapping) -> PASS_ARGS_T:

        prefix = gen_free_prefix(tree, env, self.cse_prefix)
        c = 0
        while True:
            # Count all the expressions in the tree
            counter = ExprCounter(self.elim_calls)
            counter.visit(tree)

            # If there are no expression in the tree
            if not counter.cses:
                break

            # get the most common expression
            expr, freq = counter.cses.most_common()[0]
            if freq < self.min_freq:
                break

            expr = mutable(expr)

            # Find the first occurrence of the expression
            # and save it to a variable then replace
            # future occurrences of that expression with
            # references to that variable
            saver = ExprSaver(expr, prefix + repr(c))
            c += 1
            tree = saver.visit(tree)

        return tree, env, metadata
Example #2
0
def test_gen_free_prefix():
    src = '''
class P:
    P5 = 1
    def __init__(self): self.y = 0
def P0():
    return P.P5
P1 = P0()
'''
    tree = ast.parse(src)
    env = SymbolTable({}, {})

    free_prefix = gen_free_prefix(tree, env)
    assert free_prefix == '__auto_prefix_0'

    free_prefix = gen_free_prefix(tree, env, 'P')
    assert free_prefix == 'P2'
Example #3
0
    def _make_name(self, name):
        if name not in self.name_formats:
            prefix = gen_free_prefix(self.scope, self.env, f'{name}_')
            self.name_formats[name] = prefix + '{}'

        ssa_name = self.name_formats[name].format(self.name_idx[name])
        self.name_idx[name] += 1
        self.name_table[name] = ssa_name
        self.original_names[ssa_name] = name
        return ssa_name
Example #4
0
    def visit_FunctionDef(self, node: cst.FunctionDef) -> tp.Optional[bool]:
        # prevent recursion into inner functions
        super().visit_FunctionDef(node)
        if self.scope is None:
            self.scope = node
            prefix = gen_free_prefix(node, self.env, '__')
            self.attr_format = prefix + '_final_{}_{}_{}'
            self.return_format = prefix + '_return_{}'

            return True
        return False
Example #5
0
    def rewrite(self,
            tree: ast.AST,
            env: SymbolTable,
            metadata: tp.MutableMapping) -> PASS_ARGS_T:
        if not isinstance(tree, ast.FunctionDef):
            raise TypeError('ssa should only be applied to functions')

        # Going to use this in an assert later but need to get the info
        # before any transformation happens
        NR = _never_returns(tree.body)

        # Find all attributes that are written
        targets = collect_targets(tree, ast.Attribute)
        replacer = AttrReplacer({})
        init_reads = []
        id_to_attr = {}
        attr_to_name = {}

        for t in targets:
            i_t = immutable(t)
            if not isinstance(t.value, ast.Name):
                raise NotImplementedError(f'Only supports writing attributes '
                                          f'of Name not {type(t.value)}')
            elif i_t not in attr_to_name:
                name = ast.Name(
                        id=gen_free_name(tree, env, '_'.join((t.value.id, t.attr))),
                        ctx=ast.Store())
                # store the maping of names to attrs
                id_to_attr[name.id] = t
                attr_to_name[i_t] = name
                #replace writes to the attr with writes to the name
                replacer.add_replacement(t, name)
                #replace reads to the attr with reads to the name
                replacer.add_replacement(_flip_ctx(t), _flip_ctx(name))

                # read the init value
                if sys.version_info < (3, 8):
                    assign = ast.Assign(targets=[deepcopy(name)], value=_flip_ctx(t))
                else:
                    assign = ast.Assign(
                            targets=[deepcopy(name)], value=_flip_ctx(t),
                            type_comment=None)
                init_reads.append(immutable(assign))
            else:
                name = attr_to_name[i_t]
                replacer.add_replacement(t, name)
                replacer.add_replacement(_flip_ctx(t), _flip_ctx(name))



        # Replace references to the attr with the name generated above
        tree = replacer.visit(tree)

        # insert initial reads
        tree.body = [mutable(r) for r in init_reads] + tree.body

        # Perform ssa
        r_name = gen_free_prefix(tree, env, self.return_prefix)
        visitor = SSATransformer(env, r_name, id_to_attr.keys(), self.strict)
        tree = visitor.visit(tree)

        #insert the write backs to the attrs
        for name, conditons in visitor.attr_states.items():
            if conditons:
                tree.body.append(
                    ast.Assign(
                        targets=[deepcopy(id_to_attr[name])],
                        value=_fold_conditions(conditons)
                    )
                )
            else:
                tree.body.append(
                    ast.Assign(
                        targets=[deepcopy(id_to_attr[name])],
                        value=ast.Name(visitor.name_table[name], ast.Load())
                    )
                )

        # insert the return
        if visitor.returns:
            tree.body.append(
                ast.Return(
                    value=_fold_conditions(visitor.returns)
                )
            )
        else:
            assert NR
        return tree, env, metadata
Example #6
0
    def rewrite(self, original_tree: cst.FunctionDef, env: SymbolTable,
                metadata: tp.MutableMapping) -> PASS_ARGS_T:
        if not isinstance(original_tree, cst.FunctionDef):
            raise TypeError('ssa must be run on a FunctionDef')

        # resolve position information necessary for generating symbol table
        wrapper = _wrap(to_module(original_tree))
        pos_info = wrapper.resolve(PositionProvider)

        # convert `elif cond:` to `else: if cond:`
        # (simplifies ssa logic)
        transformer = with_tracking(ElifToElse)()
        tree = original_tree.visit(transformer)

        # original node -> generated nodes
        node_tracking_table = transformer.node_tracking_table
        # node_tracking_table.i
        # generated node -> original nodes

        wrapper = _wrap(to_module(tree))
        writter_attr_visitor = WrittenAttrs()
        wrapper.visit(writter_attr_visitor)

        replacer = with_tracking(AttrReplacer)()
        attr_format = gen_free_prefix(tree, env, '_attr') + '_{}_{}'
        init_reads = []
        names_to_attr = {}
        seen = set()

        for written_attr in writter_attr_visitor.written_attrs:
            d_attr = DeepNode(written_attr)
            if d_attr in seen:
                continue
            if not isinstance(written_attr.value, cst.Name):
                raise NotImplementedError(
                    'writing non name nodes is not supported')

            seen.add(d_attr)

            attr_name = attr_format.format(
                written_attr.value.value,
                written_attr.attr.value,
            )

            # using normal node instead of original node
            # is safe as parenthesis don't matter:
            #  (name).attr == (name.attr) == name.attr
            norm = d_attr.normal_node
            names_to_attr[attr_name] = norm
            name = cst.Name(attr_name)
            replacer.add_replacement(written_attr, name)
            read = to_stmt(make_assign(name, norm))
            init_reads.append(read)

        # Replace references to attr with the name generated above
        tree = tree.visit(replacer)

        node_tracking_table = replacer.trace_origins(node_tracking_table)

        # Rewrite conditions to be ssa
        cond_prefix = gen_free_prefix(tree, env, '_cond')
        wrapper = _wrap(tree)
        name_tests = NameTests(cond_prefix)
        tree = wrapper.visit(name_tests)

        node_tracking_table = name_tests.trace_origins(node_tracking_table)

        # Transform to single return format
        wrapper = _wrap(tree)
        single_return = SingleReturn(env, names_to_attr, self.strict)
        tree = wrapper.visit(single_return)

        node_tracking_table = single_return.trace_origins(node_tracking_table)

        # insert the initial reads / final writes / return
        body = tree.body
        body = body.with_changes(body=(*init_reads, *body.body,
                                       *single_return.tail))
        tree = tree.with_changes(body=body)

        # perform ssa
        wrapper = _wrap(to_module(tree))
        ctxs = wrapper.resolve(ExpressionContextProvider)
        # These names were constructed in such a way that they are
        # guaranteed to be ssa and shouldn't be touched by the
        # transformer
        final_names = single_return.added_names | name_tests.added_names
        ssa_transformer = SSATransformer(env,
                                         ctxs,
                                         final_names,
                                         single_return.returning_blocks,
                                         strict=self.strict)
        tree = tree.visit(ssa_transformer)

        node_tracking_table = ssa_transformer.trace_origins(
            node_tracking_table)

        tree.validate_types_deep()
        # generate symbol table
        start_ln = pos_info[original_tree].start.line
        end_ln = pos_info[original_tree].end.line
        visitor = GenerateSymbolTable(
            node_tracking_table,
            ssa_transformer.original_names,
            pos_info,
            start_ln,
            end_ln,
        )

        tree.visit(visitor)
        metadata.setdefault('SYMBOL-TABLE', list()).append(
            (type(self), visitor.symbol_table))
        return tree, env, metadata