Пример #1
0
def _list_declared_variables_dfs(node: FormatNode, *,
                                 counter: Dict[VarName, _CounterDecl],
                                 declared: Dict[VarName, VarDecl]) -> None:
    """
    :raises DeclaredVariablesError:
    """

    if isinstance(node, ItemNode):
        if node.name in declared:
            raise DeclaredVariablesError(
                f"the same variable appears twice in tree: {node.name}")
        dims = []
        bases = []
        depending = set()
        for index in node.indices:
            dim = index
            base = index
            for i, decl in counter.items():
                dim = Expr(
                    re.subn(r'\b' + re.escape(i) + r'\b', decl.size, dim)[0])
                base = Expr(
                    re.subn(r'\b' + re.escape(i) + r'\b', '0', base)[0])
            for n in declared.keys():
                if re.search(r'\b' + re.escape(n) + r'\b', dim):
                    depending.add(n)
            dims.append(simplify(Expr(f"""{dim} - ({base})""")))
            bases.append(simplify(base))
        declared[node.name] = VarDecl(name=node.name,
                                      dims=dims,
                                      bases=bases,
                                      depending=depending,
                                      type=None)

    elif isinstance(node, NewlineNode):
        pass

    elif isinstance(node, SequenceNode):
        for item in node.items:
            _list_declared_variables_dfs(item,
                                         counter=counter,
                                         declared=declared)

    elif isinstance(node, LoopNode):
        depending = set()
        for n in declared.keys():
            if re.search(r'\b' + re.escape(n) + r'\b', node.size):
                depending.add(n)
        decl = _CounterDecl(name=node.name,
                            size=node.size,
                            depending=depending)
        _list_declared_variables_dfs(node.body,
                                     counter={
                                         node.name: decl,
                                         **counter
                                     },
                                     declared=declared)
Пример #2
0
def _get_variable_on_code(*, decl: VarDecl, indices: List[Expr],
                          decls: Dict[VarName, VarDecl]) -> str:
    var = str(decl.name)
    for index, base in zip(indices, decl.bases):
        i = simplify(Expr(f"""{index} - ({base})"""))
        var = f"""{var}[{i}]"""
    return var
Пример #3
0
    def test_subscripted(self) -> None:
        expr = Expr(
            'a _ {n + 1 - n} + a _ {n + 1 - (n - 1)} + a _ {n + 1 - (n - 2)} + dots + a _ {n + 1 - 2} + a _ {n + 1 - 1}'
        )
        expected = Expr('a_1 + a_2 + a_3 + a_n + a_{n - 1} + dots')

        actual = simplify.simplify(expr)
        self.assertEqual(actual, expected)
Пример #4
0
def extend_loop_node(a: FormatNode, b: FormatNode, *,
                     loop: LoopNode) -> Optional[FormatNode]:
    if isinstance(a, ItemNode) and isinstance(b, ItemNode):
        if a.name != b.name or len(a.indices) != len(b.indices):
            return None
        indices = []
        for i, j in zip(a.indices, b.indices):
            decr_j = Expr(
                re.subn(r'\b' + re.escape(loop.name) + r'\b', '(-1)', j)[0])
            if simplify(i) == simplify(decr_j):
                indices.append(simplify(Expr(f"""{i} + {loop.name}""")))
            else:
                return None
        return ItemNode(name=a.name, indices=indices)

    elif isinstance(a, NewlineNode) and isinstance(b, NewlineNode):
        return NewlineNode()

    elif isinstance(a, SequenceNode) and isinstance(b, SequenceNode):
        if len(a.items) != len(b.items):
            return None
        items = []
        for a_i, b_i in zip(a.items, b.items):
            c_i = extend_loop_node(a_i, b_i, loop=loop)
            if c_i is None:
                return None
            items.append(c_i)
        return SequenceNode(items=items)

    elif isinstance(a, LoopNode) and isinstance(b, LoopNode):
        if a.size != b.size or a.name != b.name:
            return None
        c = extend_loop_node(a.body, b.body, loop=loop)
        if c is None:
            return None
        return LoopNode(size=a.size, name=a.name, body=c)

    else:
        return None
Пример #5
0
def zip_nodes(a: FormatNode, b: FormatNode, *, name: VarName,
              size: Optional[Expr]) -> Tuple[FormatNode, Optional[Expr]]:
    """
    :raises FormatStringParserError:
    """

    if isinstance(a, ItemNode) and isinstance(b, ItemNode):
        if a.name != b.name or len(a.indices) != len(b.indices):
            raise FormatStringParserError(
                "semantics: unmatched dots pair: {} and {}".format(a, b))
        indices = []
        for i, j in zip(a.indices, b.indices):
            if simplify(i) == simplify(j):
                indices.append(i)
            else:
                if size is None:
                    size = simplify(Expr(f"""{j} - {i} + 1"""))
                else:
                    if simplify(Expr(f"""{j} - {i} + 1""")) != simplify(size):
                        raise FormatStringParserError(
                            "semantics: unmatched dots pair: {} and {}".format(
                                a, b))
                indices.append(simplify(Expr(f"{i} + {name}")))
        return ItemNode(name=a.name, indices=indices), size

    elif isinstance(a, NewlineNode) and isinstance(b, NewlineNode):
        return NewlineNode(), size

    elif isinstance(a, SequenceNode) and isinstance(b, SequenceNode):
        if len(a.items) != len(b.items):
            raise FormatStringParserError(
                "semantics: unmatched dots pair: {} and {}".format(a, b))
        items = []
        for a_i, b_i in zip(a.items, b.items):
            c_i, size = zip_nodes(a_i, b_i, name=name, size=size)
            items.append(c_i)
        return SequenceNode(items=items), size

    elif isinstance(a, LoopNode) and isinstance(b, LoopNode):
        if a.size != b.size or a.name != b.name:
            raise FormatStringParserError(
                "semantics: unmatched dots pair: {} and {}".format(a, b))
        c, size = zip_nodes(a.body, b.body, name=name, size=size)
        return LoopNode(size=a.size, name=a.name, body=c), size

    else:
        raise FormatStringParserError(
            "semantics: unmatched dots pair: {} and {}".format(a, b))
Пример #6
0
    def test_div(self) -> None:
        expr = Expr('n / 2 + n / 2')
        expected = Expr('n')

        actual = simplify.simplify(expr)
        self.assertEqual(actual, expected)
Пример #7
0
    def test_parens(self) -> None:
        expr = Expr('(x + 1) * (y + 1) - (2 (x * y) + 1)')
        expected = Expr('x - x * y + y')

        actual = simplify.simplify(expr)
        self.assertEqual(actual, expected)
Пример #8
0
    def test_const(self) -> None:
        expr = Expr('2 * 3 * x - 5 * x')
        expected = Expr('x')

        actual = simplify.simplify(expr)
        self.assertEqual(actual, expected)
Пример #9
0
    def test_simple(self) -> None:
        expr = Expr('(n + 1) + (n - 1)')
        expected = Expr('2 * n')

        actual = simplify.simplify(expr)
        self.assertEqual(actual, expected)
Пример #10
0
def _get_variable(*, decl: VarDecl, indices: Sequence[Expr]) -> str:
    var = str(decl.name)
    for index, base in zip(indices, decl.bases):
        i = simplify(Expr(f"""{index} - ({base})"""))
        var = f"""{var}[{i}]"""
    return var
Пример #11
0
def analyze_parsed_node(node: ParserNode) -> FormatNode:
    """
    translates an internal representation :any:`ParserNode` to a result tree :any:`FormatNode`

    :raises FormatStringParserError:
    """

    if isinstance(node, ItemParserNode):
        indices = [simplify(index) for index in node.indices]
        return ItemNode(name=node.name, indices=indices)

    elif isinstance(node, NewlineParserNode):
        return NewlineNode()

    elif isinstance(node, SequenceParserNode):
        items: List[FormatNode] = []
        que: List[FormatNode] = list(map(analyze_parsed_node, node.items))
        while que:
            item, *que = que
            if isinstance(item, SequenceNode):
                # flatten SequenceNode in SequenceNode
                que = item.items + que
            elif isinstance(item, LoopNode) and items:
                # merge FormatNode with LoopNode if possible
                if isinstance(
                        item.body,
                        SequenceNode) and len(items) >= len(item.body.items):
                    items_init = items[:-len(item.body.items)]
                    items_tail: FormatNode = SequenceNode(
                        items=items[-len(item.body.items):])
                else:
                    items_init = items[:-1]
                    items_tail = items[-1]
                extended_body = extend_loop_node(items_tail,
                                                 item.body,
                                                 loop=item)
                if extended_body is not None:
                    extended_loop: FormatNode = LoopNode(size=simplify(
                        Expr(f"""{item.size} + 1""")),
                                                         name=item.name,
                                                         body=extended_body)
                    items = items_init
                    que = [extended_loop] + que
                else:
                    items.append(item)
            else:
                items.append(item)
        if len(items) == 1:
            # return the node directly if the length is 1
            return items[0]
        else:
            return SequenceNode(items=items)

    elif isinstance(node, DotsParserNode):
        a = analyze_parsed_node(node.first)
        b = analyze_parsed_node(node.last)

        # find the name of the new loop counter
        used_names = list_used_names(a) | list_used_names(b)
        name = VarName('i')
        while name in used_names:
            assert name != VarName('z')
            name = VarName(chr(ord(name) + 1))

        # zip bodies
        c, size = zip_nodes(a, b, name=name, size=None)
        if size is None:
            raise FormatStringParserError(
                "semantics: unmatched dots pair: {} and {}".format(a, b))
        return LoopNode(size=size, name=name, body=c)

    else:
        assert False