def test_unexpand(self, expander): assert expander.unexpand([Circle(1)]) == Circle(1) assert expander.unexpand([Rotation(1, Circle(1))]) == \ Rotation(1, Circle(1)) assert expander.unexpand( [Circle(1), Difference(Circle(1), Reference(0))] ) == Difference(Circle(1), Circle(1)) assert expander.unexpand( [Circle(1), Rotation(1, Reference(0)), Difference(Circle(1), Reference(1))] ) == Difference(Circle(1), Rotation(1, Circle(1)))
def test_execute_with_multiple_inputs(self): ref0 = Rectangle(1, 1) interpreter = Interpreter(3, 3, 1, False) state = interpreter.create_state([None, None]) state = interpreter.execute(ref0, state) assert len(state.environment[Reference(0)]) == 2
def execute(self, code: AST, state: BatchedState[AST, np.ndarray, str, None]) \ -> BatchedState[AST, np.ndarray, str, None]: next = cast(BatchedState[AST, np.ndarray, str, None], state.clone()) next.history.append(code) ref = Reference(len(next.history) - 1) next.type_environment[ref] = code.type_name() v = self._render(self._cached_eval(self._expander.unexpand(next.history))) value = [v for _ in state.context] next.environment[ref] = value if self.delete_used_reference: deleted = set() def _visit(code: AST): if isinstance(code, Circle) or isinstance(code, Rectangle): return if isinstance(code, Rotation) or isinstance(code, Translation): _visit(code.child) return if isinstance(code, Union) or isinstance(code, Difference): _visit(code.a) _visit(code.b) return if isinstance(code, Reference): deleted.add(code) if code not in next.environment: logger.warning(f"reference {code} is not found in environment") _visit(code) for code in deleted: del next.environment[code] del next.type_environment[code] return next
def get_samples(dataset: Dataset, parser: Parser[csgAST], reference: bool = False) -> Samples: rules: List[Rule] = [] node_types = [] srule = set() sntype = set() tokens = [("size", x) for x in dataset.size_candidates] tokens.extend([("length", x) for x in dataset.length_candidates]) tokens.extend([("degree", x) for x in dataset.degree_candidates]) if reference: # TODO use expander xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Reference(0)), Rotation(45, Reference(1)), Union(Reference(0), Reference(1)), Difference(Reference(0), Reference(1)) ] else: xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Circle(1)), Rotation(45, Circle(1)), Union(Circle(1), Circle(1)), Difference(Circle(1), Circle(1)) ] for x in xs: ast = parser.parse(x) if ast is None: continue action_sequence = ActionSequence.create(ast) for action in action_sequence.action_sequence: if isinstance(action, ApplyRule): rule = action.rule if not isinstance(rule, CloseVariadicFieldRule): if rule not in srule: rules.append(rule) srule.add(rule) if rule.parent not in sntype: node_types.append(rule.parent) sntype.add(rule.parent) for _, child in rule.children: if child not in sntype: node_types.append(child) sntype.add(child) tokens = list(set(tokens)) tokens.sort() return Samples(rules, node_types, tokens)
def test_expand(self, expander): assert expander.expand(Circle(1)) == [Circle(1)] assert expander.expand(Rotation(1, Circle(1))) == \ [Circle(1), Rotation(1, Reference(0))] assert expander.expand(Difference(Circle(1), Circle(1))) == \ [Circle(1), Circle(1), Difference(Reference(0), Reference(1))] assert expander.expand(Difference(Circle(1), Rotation(1, Circle(1)))) == \ [Circle(1), Circle(1), Rotation(1, Reference(1)), Difference(Reference(0), Reference(2))]
def test_delete_used_variable(self): ref0 = Rectangle(1, 1) ref1 = Rectangle(3, 1) ref2 = Difference(Reference(0), Reference(1)) ref3 = Union(Rectangle(1, 1), Reference(2)) interpreter = Interpreter(3, 3, 1, True) state = interpreter.create_state([None]) state = interpreter.execute(ref0, state) assert set(state.environment.keys()) == set([Reference(0)]) state = interpreter.execute(ref1, state) assert set(state.environment.keys()) == set( [Reference(0), Reference(1)]) state = interpreter.execute(ref2, state) assert set(state.environment.keys()) == set([Reference(2)]) state = interpreter.execute(ref3, state) assert set(state.environment.keys()) == set([Reference(3)])
def test_draw_same_objects(self): ref0 = Rectangle(1, 1) ref1 = Rectangle(1, 1) ref2 = Rotation(180, Reference(0)) interpreter = Interpreter(3, 3, 1, True) state = interpreter.create_state([None]) state = interpreter.execute(ref0, state) assert set(state.environment.keys()) == set([Reference(0)]) state = interpreter.execute(ref1, state) assert set(state.environment.keys()) == set( [Reference(0), Reference(1)]) state = interpreter.execute(ref2, state) assert set(state.environment.keys()) == set( [Reference(1), Reference(2)])
def test_execute(self): ref0 = Rectangle(1, 1) ref1 = Rectangle(3, 1) ref2 = Difference(Reference(0), Reference(1)) ref3 = Union(Rectangle(1, 1), Reference(2)) interpreter = Interpreter(3, 3, 1, False) state = interpreter.create_state([None]) state = interpreter.execute(ref0, state) assert state.history == [ref0] assert set(state.environment.keys()) == set([Reference(0)]) assert state.type_environment[Reference(0)] == "Rectangle" assert show(state.environment[Reference(0)][0]) == " \n # \n \n" assert state.context == [None] state = interpreter.execute(ref1, state) assert state.history == [ref0, ref1] assert set(state.environment.keys()) == set( [Reference(0), Reference(1)]) assert show(state.environment[Reference(1)][0]) == " \n###\n \n" assert state.context == [None] state = interpreter.execute(ref2, state) assert state.history == [ref0, ref1, ref2] assert set(state.environment.keys()) == \ set([Reference(0), Reference(1), Reference(2)]) assert show(state.environment[Reference(2)][0]) == " \n# #\n \n" assert state.context == [None] state = interpreter.execute(ref3, state) assert state.history == [ref0, ref1, ref2, ref3] assert set(state.environment.keys()) == \ set([Reference(0), Reference(1), Reference(2), Reference(3)]) assert show(state.environment[Reference(3)][0]) == " \n###\n \n" assert state.context == [None]
def test_unparse_difference(self, parser): assert Difference(Reference(0), Reference(1)) == \ parser.unparse( parser.parse(Difference(Reference(0), Reference(1))) )
def test_unparse_union(self, parser): assert Union(Reference(0), Reference(1)) == \ parser.unparse( parser.parse(Union(Reference(0), Reference(1))) )
def test_unparse_rotation(self, parser): assert Rotation(45, Reference(0)) == parser.unparse( parser.parse(Rotation(45, Reference(0))))
def test_unparse_translation(self, parser): assert Translation(1, 2, Reference(0)) == parser.unparse( parser.parse(Translation(1, 2, Reference(0))))
def test_parse_difference(self, parser): assert Node("Difference", [ Field("a", "CSG", Leaf("CSG", Reference(0))), Field("b", "CSG", Leaf("CSG", Reference(1))) ]) == parser.parse(Difference(Reference(0), Reference(1)))
def test_parse_union(self, parser): assert Node("Union", [ Field("a", "CSG", Leaf("CSG", Reference(0))), Field("b", "CSG", Leaf("CSG", Reference(1))) ]) == parser.parse(Union(Reference(0), Reference(1)))
def test_parse_rotation(self, parser): assert Node("Rotation", [ Field("theta", "degree", Leaf("degree", 45)), Field("child", "CSG", Leaf("CSG", Reference(0))) ]) == parser.parse(Rotation(45, Reference(0)))
def test_parse_translation(self, parser): assert Node("Translation", [ Field("x", "length", Leaf("length", 1)), Field("y", "length", Leaf("length", 2)), Field("child", "CSG", Leaf("CSG", Reference(0))) ]) == parser.parse(Translation(1, 2, Reference(0)))