def _list_declared_variables_dfs(node: FormatNode, *, counter: Dict[VarName, _CounterDecl], declared: Dict[VarName, VarDecl]) -> None: """ :raises DeclaredVariablesError: """ if isinstance(node, ItemNode): if node.name in declared: raise DeclaredVariablesError( f"the same variable appears twice in tree: {node.name}") dims = [] bases = [] depending = set() for index in node.indices: dim = index base = index for i, decl in counter.items(): dim = Expr( re.subn(r'\b' + re.escape(i) + r'\b', decl.size, dim)[0]) base = Expr( re.subn(r'\b' + re.escape(i) + r'\b', '0', base)[0]) for n in declared.keys(): if re.search(r'\b' + re.escape(n) + r'\b', dim): depending.add(n) dims.append(simplify(Expr(f"""{dim} - ({base})"""))) bases.append(simplify(base)) declared[node.name] = VarDecl(name=node.name, dims=dims, bases=bases, depending=depending, type=None) elif isinstance(node, NewlineNode): pass elif isinstance(node, SequenceNode): for item in node.items: _list_declared_variables_dfs(item, counter=counter, declared=declared) elif isinstance(node, LoopNode): depending = set() for n in declared.keys(): if re.search(r'\b' + re.escape(n) + r'\b', node.size): depending.add(n) decl = _CounterDecl(name=node.name, size=node.size, depending=depending) _list_declared_variables_dfs(node.body, counter={ node.name: decl, **counter }, declared=declared)
def _get_variable_on_code(*, decl: VarDecl, indices: List[Expr], decls: Dict[VarName, VarDecl]) -> str: var = str(decl.name) for index, base in zip(indices, decl.bases): i = simplify(Expr(f"""{index} - ({base})""")) var = f"""{var}[{i}]""" return var
def test_subscripted(self) -> None: expr = Expr( 'a _ {n + 1 - n} + a _ {n + 1 - (n - 1)} + a _ {n + 1 - (n - 2)} + dots + a _ {n + 1 - 2} + a _ {n + 1 - 1}' ) expected = Expr('a_1 + a_2 + a_3 + a_n + a_{n - 1} + dots') actual = simplify.simplify(expr) self.assertEqual(actual, expected)
def extend_loop_node(a: FormatNode, b: FormatNode, *, loop: LoopNode) -> Optional[FormatNode]: if isinstance(a, ItemNode) and isinstance(b, ItemNode): if a.name != b.name or len(a.indices) != len(b.indices): return None indices = [] for i, j in zip(a.indices, b.indices): decr_j = Expr( re.subn(r'\b' + re.escape(loop.name) + r'\b', '(-1)', j)[0]) if simplify(i) == simplify(decr_j): indices.append(simplify(Expr(f"""{i} + {loop.name}"""))) else: return None return ItemNode(name=a.name, indices=indices) elif isinstance(a, NewlineNode) and isinstance(b, NewlineNode): return NewlineNode() elif isinstance(a, SequenceNode) and isinstance(b, SequenceNode): if len(a.items) != len(b.items): return None items = [] for a_i, b_i in zip(a.items, b.items): c_i = extend_loop_node(a_i, b_i, loop=loop) if c_i is None: return None items.append(c_i) return SequenceNode(items=items) elif isinstance(a, LoopNode) and isinstance(b, LoopNode): if a.size != b.size or a.name != b.name: return None c = extend_loop_node(a.body, b.body, loop=loop) if c is None: return None return LoopNode(size=a.size, name=a.name, body=c) else: return None
def zip_nodes(a: FormatNode, b: FormatNode, *, name: VarName, size: Optional[Expr]) -> Tuple[FormatNode, Optional[Expr]]: """ :raises FormatStringParserError: """ if isinstance(a, ItemNode) and isinstance(b, ItemNode): if a.name != b.name or len(a.indices) != len(b.indices): raise FormatStringParserError( "semantics: unmatched dots pair: {} and {}".format(a, b)) indices = [] for i, j in zip(a.indices, b.indices): if simplify(i) == simplify(j): indices.append(i) else: if size is None: size = simplify(Expr(f"""{j} - {i} + 1""")) else: if simplify(Expr(f"""{j} - {i} + 1""")) != simplify(size): raise FormatStringParserError( "semantics: unmatched dots pair: {} and {}".format( a, b)) indices.append(simplify(Expr(f"{i} + {name}"))) return ItemNode(name=a.name, indices=indices), size elif isinstance(a, NewlineNode) and isinstance(b, NewlineNode): return NewlineNode(), size elif isinstance(a, SequenceNode) and isinstance(b, SequenceNode): if len(a.items) != len(b.items): raise FormatStringParserError( "semantics: unmatched dots pair: {} and {}".format(a, b)) items = [] for a_i, b_i in zip(a.items, b.items): c_i, size = zip_nodes(a_i, b_i, name=name, size=size) items.append(c_i) return SequenceNode(items=items), size elif isinstance(a, LoopNode) and isinstance(b, LoopNode): if a.size != b.size or a.name != b.name: raise FormatStringParserError( "semantics: unmatched dots pair: {} and {}".format(a, b)) c, size = zip_nodes(a.body, b.body, name=name, size=size) return LoopNode(size=a.size, name=a.name, body=c), size else: raise FormatStringParserError( "semantics: unmatched dots pair: {} and {}".format(a, b))
def test_div(self) -> None: expr = Expr('n / 2 + n / 2') expected = Expr('n') actual = simplify.simplify(expr) self.assertEqual(actual, expected)
def test_parens(self) -> None: expr = Expr('(x + 1) * (y + 1) - (2 (x * y) + 1)') expected = Expr('x - x * y + y') actual = simplify.simplify(expr) self.assertEqual(actual, expected)
def test_const(self) -> None: expr = Expr('2 * 3 * x - 5 * x') expected = Expr('x') actual = simplify.simplify(expr) self.assertEqual(actual, expected)
def test_simple(self) -> None: expr = Expr('(n + 1) + (n - 1)') expected = Expr('2 * n') actual = simplify.simplify(expr) self.assertEqual(actual, expected)
def _get_variable(*, decl: VarDecl, indices: Sequence[Expr]) -> str: var = str(decl.name) for index, base in zip(indices, decl.bases): i = simplify(Expr(f"""{index} - ({base})""")) var = f"""{var}[{i}]""" return var
def analyze_parsed_node(node: ParserNode) -> FormatNode: """ translates an internal representation :any:`ParserNode` to a result tree :any:`FormatNode` :raises FormatStringParserError: """ if isinstance(node, ItemParserNode): indices = [simplify(index) for index in node.indices] return ItemNode(name=node.name, indices=indices) elif isinstance(node, NewlineParserNode): return NewlineNode() elif isinstance(node, SequenceParserNode): items: List[FormatNode] = [] que: List[FormatNode] = list(map(analyze_parsed_node, node.items)) while que: item, *que = que if isinstance(item, SequenceNode): # flatten SequenceNode in SequenceNode que = item.items + que elif isinstance(item, LoopNode) and items: # merge FormatNode with LoopNode if possible if isinstance( item.body, SequenceNode) and len(items) >= len(item.body.items): items_init = items[:-len(item.body.items)] items_tail: FormatNode = SequenceNode( items=items[-len(item.body.items):]) else: items_init = items[:-1] items_tail = items[-1] extended_body = extend_loop_node(items_tail, item.body, loop=item) if extended_body is not None: extended_loop: FormatNode = LoopNode(size=simplify( Expr(f"""{item.size} + 1""")), name=item.name, body=extended_body) items = items_init que = [extended_loop] + que else: items.append(item) else: items.append(item) if len(items) == 1: # return the node directly if the length is 1 return items[0] else: return SequenceNode(items=items) elif isinstance(node, DotsParserNode): a = analyze_parsed_node(node.first) b = analyze_parsed_node(node.last) # find the name of the new loop counter used_names = list_used_names(a) | list_used_names(b) name = VarName('i') while name in used_names: assert name != VarName('z') name = VarName(chr(ord(name) + 1)) # zip bodies c, size = zip_nodes(a, b, name=name, size=None) if size is None: raise FormatStringParserError( "semantics: unmatched dots pair: {} and {}".format(a, b)) return LoopNode(size=size, name=name, body=c) else: assert False