Exemple #1
0
    def leave_Return(self, original_node: cst.Return,
                     updated_node: cst.Return) -> cst.RemovalSentinel:
        assert self.return_format is not None
        assert self.attr_format is not None

        assignments = []
        cond = self.get_metadata(IncrementalConditionProvider, original_node)

        for name, attr in self.names_to_attr.items():
            assert isinstance(attr.value, cst.Name)
            state = self.attr_states.setdefault(name, [])
            attr_name = cst.Name(value=self.attr_format.format(
                attr.value.value, attr.attr.value, len(state)))
            self.added_names.add(attr_name.value)
            state.append((cond, attr_name))
            assignments.append(make_assign(attr_name, cst.Name(name)))

        r_name = cst.Name(value=self.return_format.format(len(self.returns)))
        self.added_names.add(r_name.value)
        self.returns.append((cond, r_name))

        if updated_node.value is None:
            r_val = cst.Name(value='None')
        else:
            r_val = updated_node.value

        assignments.append(make_assign(r_name, r_val))

        return cst.FlattenSentinel(assignments)
Exemple #2
0
        def _mux_name(name, t_name, f_name):
            new_name = self._make_name(name)
            assign = make_assign(
                cst.Name(new_name),
                cst.IfExp(
                    test=new_test,
                    body=cst.Name(t_name),
                    orelse=cst.Name(f_name),
                ),
            )
            self.name_assignments[new_name] = assign

            stmt = to_stmt(assign)

            assert isinstance(original_node, cst.If)
            assert isinstance(self.name_assignments[t_name],
                              (cst.Assign, cst.Param))
            assert isinstance(self.name_assignments[f_name],
                              (cst.Assign, cst.Param))
            self.track_with_children((
                self.name_assignments[t_name],
                self.name_assignments[f_name],
                original_node,
            ), stmt)
            assert assign in self.node_tracking_table.i
            return stmt
Exemple #3
0
    def leave_FunctionDef(self, original_node: cst.FunctionDef,
                          updated_node: cst.FunctionDef) -> cst.FunctionDef:
        final_node = updated_node
        if original_node is self.scope:
            suite = updated_node.body
            tail = self.tail
            for name, attr in self.names_to_attr.items():
                state = self.attr_states.get(name, [])
                # default writeback initial value
                state.append(([], cst.Name(name)))
                attr_val = _fold_conditions(_simplify_gaurds(state),
                                            self.strict)
                write = to_stmt(make_assign(attr, attr_val))
                tail.append(write)

            if self.returns:
                strict = self.strict

                try:
                    return_val = _fold_conditions(
                        _simplify_gaurds(self.returns), strict)
                except IncompleteGaurdError:
                    raise SyntaxError(
                        'Cannot prove function always returns') from None
                return_stmt = cst.SimpleStatementLine(
                    [cst.Return(value=return_val)])
                tail.append(return_stmt)

        return final_node
Exemple #4
0
 def leave_If(
     self,
     original_node: cst.If,
     updated_node: cst.If,
 ) -> cst.If:
     c_name = cst.Name(value=self.format.format(len(self.added_names)))
     self.added_names.add(c_name.value)
     assign = to_stmt(make_assign(c_name, updated_node.test))
     final_node = updated_node.with_changes(test=c_name)
     return cst.FlattenSentinel([assign, final_node])
Exemple #5
0
    def rewrite(self, original_tree: cst.FunctionDef, env: SymbolTable,
                metadata: tp.MutableMapping) -> PASS_ARGS_T:
        if not isinstance(original_tree, cst.FunctionDef):
            raise TypeError('ssa must be run on a FunctionDef')

        # resolve position information necessary for generating symbol table
        wrapper = _wrap(to_module(original_tree))
        pos_info = wrapper.resolve(PositionProvider)

        # convert `elif cond:` to `else: if cond:`
        # (simplifies ssa logic)
        transformer = with_tracking(ElifToElse)()
        tree = original_tree.visit(transformer)

        # original node -> generated nodes
        node_tracking_table = transformer.node_tracking_table
        # node_tracking_table.i
        # generated node -> original nodes

        wrapper = _wrap(to_module(tree))
        writter_attr_visitor = WrittenAttrs()
        wrapper.visit(writter_attr_visitor)

        replacer = with_tracking(AttrReplacer)()
        attr_format = gen_free_prefix(tree, env, '_attr') + '_{}_{}'
        init_reads = []
        names_to_attr = {}
        seen = set()

        for written_attr in writter_attr_visitor.written_attrs:
            d_attr = DeepNode(written_attr)
            if d_attr in seen:
                continue
            if not isinstance(written_attr.value, cst.Name):
                raise NotImplementedError(
                    'writing non name nodes is not supported')

            seen.add(d_attr)

            attr_name = attr_format.format(
                written_attr.value.value,
                written_attr.attr.value,
            )

            # using normal node instead of original node
            # is safe as parenthesis don't matter:
            #  (name).attr == (name.attr) == name.attr
            norm = d_attr.normal_node
            names_to_attr[attr_name] = norm
            name = cst.Name(attr_name)
            replacer.add_replacement(written_attr, name)
            read = to_stmt(make_assign(name, norm))
            init_reads.append(read)

        # Replace references to attr with the name generated above
        tree = tree.visit(replacer)

        node_tracking_table = replacer.trace_origins(node_tracking_table)

        # Rewrite conditions to be ssa
        cond_prefix = gen_free_prefix(tree, env, '_cond')
        wrapper = _wrap(tree)
        name_tests = NameTests(cond_prefix)
        tree = wrapper.visit(name_tests)

        node_tracking_table = name_tests.trace_origins(node_tracking_table)

        # Transform to single return format
        wrapper = _wrap(tree)
        single_return = SingleReturn(env, names_to_attr, self.strict)
        tree = wrapper.visit(single_return)

        node_tracking_table = single_return.trace_origins(node_tracking_table)

        # insert the initial reads / final writes / return
        body = tree.body
        body = body.with_changes(body=(*init_reads, *body.body,
                                       *single_return.tail))
        tree = tree.with_changes(body=body)

        # perform ssa
        wrapper = _wrap(to_module(tree))
        ctxs = wrapper.resolve(ExpressionContextProvider)
        # These names were constructed in such a way that they are
        # guaranteed to be ssa and shouldn't be touched by the
        # transformer
        final_names = single_return.added_names | name_tests.added_names
        ssa_transformer = SSATransformer(env,
                                         ctxs,
                                         final_names,
                                         single_return.returning_blocks,
                                         strict=self.strict)
        tree = tree.visit(ssa_transformer)

        node_tracking_table = ssa_transformer.trace_origins(
            node_tracking_table)

        tree.validate_types_deep()
        # generate symbol table
        start_ln = pos_info[original_tree].start.line
        end_ln = pos_info[original_tree].end.line
        visitor = GenerateSymbolTable(
            node_tracking_table,
            ssa_transformer.original_names,
            pos_info,
            start_ln,
            end_ln,
        )

        tree.visit(visitor)
        metadata.setdefault('SYMBOL-TABLE', list()).append(
            (type(self), visitor.symbol_table))
        return tree, env, metadata