Beispiel #1
0
            def visit_If(self_t, if_node):
                for node in u.ifs_in_elif_block(if_node):
                    node.body[:] = self_t.visit_list(node.body)

                new_body = []
                for node in u.ifs_in_elif_block(if_node):
                    new_body.extend(node.body)

                return new_body
Beispiel #2
0
            def visit_If(self_t, if_node):
                result_nodes = self.rename_if_else_block(if_node)

                # recursively visit all children in bodies, but don't visit the orelses
                for node in u.ifs_in_elif_block(if_node):
                    node.body[:] = self_t.visit_list(node.body)

                return result_nodes
Beispiel #3
0
    def vars_defined_in_all_cases(self, node):
        if isinstance(node, ast.If):
            vars_defined = None
            for if_node in u.ifs_in_elif_block(node):
                cur_vars_defined = self.vars_defined_in_all_cases(if_node.body)
                if vars_defined is None:
                    vars_defined = cur_vars_defined
                else:
                    vars_defined = vars_defined & cur_vars_defined

        elif isinstance(node, list):
            vars_defined = set()
            for stmt_node in node:
                vars_defined = vars_defined | self.vars_defined_in_all_cases(
                    stmt_node)

        elif isinstance(node, ast.AST):
            vars_defined = set()
            if u.is_set_to(node):
                vars_defined.add(node.value.func.value.id)

        return vars_defined
Beispiel #4
0
    def rename_if_else_block(self, if_node):
        vars_to_rename = self.vars_defined_in_all_cases(if_node)
        all_cases = []

        for node in u.ifs_in_elif_block(if_node):
            condition_lhs = u.get_condition_lhs(node)
            n = u.get_condition_rhs_num(node)
            all_cases.append(n)
            rename_dict = {}
            for var in vars_to_rename:
                rename_dict[var] = "%s_case%s" % (var, n)
            node.body = self.subs(node.body, **rename_dict)

        all_cases = sorted(all_cases)

        result_nodes = [if_node]
        for var in vars_to_rename:
            renamed_vars = ["%s_case%s" % (var, n) for n in all_cases]
            stmt = u.stmt_from_str(
                "%s.set_to(tpt.weighted_sum([%s], %s, scope='%s_weighted_sum'))"
                % (var, ",".join(renamed_vars), condition_lhs, var))
            result_nodes.append(stmt)

        return result_nodes