def test_execute(self): ref0 = Rectangle(1, 1) ref1 = Rectangle(3, 1) ref2 = Difference(Reference(0), Reference(1)) ref3 = Union(Rectangle(1, 1), Reference(2)) interpreter = Interpreter(3, 3, 1, False) state = interpreter.create_state([None]) state = interpreter.execute(ref0, state) assert state.history == [ref0] assert set(state.environment.keys()) == set([Reference(0)]) assert state.type_environment[Reference(0)] == "Rectangle" assert show(state.environment[Reference(0)][0]) == " \n # \n \n" assert state.context == [None] state = interpreter.execute(ref1, state) assert state.history == [ref0, ref1] assert set(state.environment.keys()) == set( [Reference(0), Reference(1)]) assert show(state.environment[Reference(1)][0]) == " \n###\n \n" assert state.context == [None] state = interpreter.execute(ref2, state) assert state.history == [ref0, ref1, ref2] assert set(state.environment.keys()) == \ set([Reference(0), Reference(1), Reference(2)]) assert show(state.environment[Reference(2)][0]) == " \n# #\n \n" assert state.context == [None] state = interpreter.execute(ref3, state) assert state.history == [ref0, ref1, ref2, ref3] assert set(state.environment.keys()) == \ set([Reference(0), Reference(1), Reference(2), Reference(3)]) assert show(state.environment[Reference(3)][0]) == " \n###\n \n" assert state.context == [None]
def get_samples(dataset: Dataset, parser: Parser[csgAST], reference: bool = False) -> Samples: rules: List[Rule] = [] node_types = [] srule = set() sntype = set() tokens = [("size", x) for x in dataset.size_candidates] tokens.extend([("length", x) for x in dataset.length_candidates]) tokens.extend([("degree", x) for x in dataset.degree_candidates]) if reference: # TODO use expander xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Reference(0)), Rotation(45, Reference(1)), Union(Reference(0), Reference(1)), Difference(Reference(0), Reference(1)) ] else: xs = [ Circle(1), Rectangle(1, 2), Translation(1, 1, Circle(1)), Rotation(45, Circle(1)), Union(Circle(1), Circle(1)), Difference(Circle(1), Circle(1)) ] for x in xs: ast = parser.parse(x) if ast is None: continue action_sequence = ActionSequence.create(ast) for action in action_sequence.action_sequence: if isinstance(action, ApplyRule): rule = action.rule if not isinstance(rule, CloseVariadicFieldRule): if rule not in srule: rules.append(rule) srule.add(rule) if rule.parent not in sntype: node_types.append(rule.parent) sntype.add(rule.parent) for _, child in rule.children: if child not in sntype: node_types.append(child) sntype.add(child) tokens = list(set(tokens)) tokens.sort() return Samples(rules, node_types, tokens)
def test_execute_with_multiple_inputs(self): ref0 = Rectangle(1, 1) interpreter = Interpreter(3, 3, 1, False) state = interpreter.create_state([None, None]) state = interpreter.execute(ref0, state) assert len(state.environment[Reference(0)]) == 2
def test_draw_same_objects(self): ref0 = Rectangle(1, 1) ref1 = Rectangle(1, 1) ref2 = Rotation(180, Reference(0)) interpreter = Interpreter(3, 3, 1, True) state = interpreter.create_state([None]) state = interpreter.execute(ref0, state) assert set(state.environment.keys()) == set([Reference(0)]) state = interpreter.execute(ref1, state) assert set(state.environment.keys()) == set( [Reference(0), Reference(1)]) state = interpreter.execute(ref2, state) assert set(state.environment.keys()) == set( [Reference(1), Reference(2)])
def test_delete_used_variable(self): ref0 = Rectangle(1, 1) ref1 = Rectangle(3, 1) ref2 = Difference(Reference(0), Reference(1)) ref3 = Union(Rectangle(1, 1), Reference(2)) interpreter = Interpreter(3, 3, 1, True) state = interpreter.create_state([None]) state = interpreter.execute(ref0, state) assert set(state.environment.keys()) == set([Reference(0)]) state = interpreter.execute(ref1, state) assert set(state.environment.keys()) == set( [Reference(0), Reference(1)]) state = interpreter.execute(ref2, state) assert set(state.environment.keys()) == set([Reference(2)]) state = interpreter.execute(ref3, state) assert set(state.environment.keys()) == set([Reference(3)])
def sample_ast(self, rng: np.random.RandomState, n_object: int) -> AST: objects: Dict[int, AST] = {} for i, t in enumerate(rng.choice(self.leaf_candidates, n_object)): if t == "Circle": objects[i] = Circle(rng.choice(self.size_candidates)) elif t == "Rectangle": objects[i] = Rectangle(rng.choice(self.size_candidates), rng.choice(self.size_candidates)) else: raise Exception(f"Invalid type: {t}") ops = {} for i, t in enumerate(rng.choice(self.branch_candidates, n_object - 1)): ops[i] = t n_node = rng.randint(0, len(ops) + 2) n = len(ops) for i, t in enumerate(rng.choice(self.node_candidates, n_node)): ops[i + n] = t while len(objects) > 1 and len(ops) != 0: op_key = rng.choice(list(ops.keys())) op = ops.pop(op_key) obj0_key = rng.choice(list(objects.keys())) obj0 = objects.pop(obj0_key) if op == "Translation": objects[obj0_key] = Translation( rng.choice(self.length_candidates), rng.choice(self.length_candidates), obj0) elif op == "Rotation": objects[obj0_key] = Rotation( rng.choice(self.degree_candidates), obj0) else: obj1_key = rng.choice(list(objects.keys())) obj1 = objects.pop(obj1_key) if op == "Union": objects[obj0_key] = Union(obj0, obj1) elif op == "Difference": objects[obj0_key] = Difference(obj0, obj1) else: raise Exception(f"Invalid type: {t}") return list(objects.values())[0]
def unparse(self, code: AST) -> Optional[csgAST]: if isinstance(code, Node): fields = {field.name: field.value for field in code.fields} if code.get_type_name() == "Circle": if isinstance(fields["r"], Leaf): return Circle(fields["r"].value) else: return None elif code.get_type_name() == "Rectangle": if isinstance(fields["w"], Leaf) and \ isinstance(fields["h"], Leaf): return Rectangle(fields["w"].value, fields["h"].value) else: return None elif code.get_type_name() == "Translation": if not isinstance(fields["child"], AST): return None child = self.unparse(fields["child"]) if child is None: return None else: if isinstance(fields["x"], Leaf) and \ isinstance(fields["y"], Leaf): return Translation(fields["x"].value, fields["y"].value, child) else: return None elif code.get_type_name() == "Rotation": if not isinstance(fields["child"], AST): return None child = self.unparse(fields["child"]) if child is None: return None else: if isinstance(fields["theta"], Leaf): return Rotation( fields["theta"].value, child, ) else: return None elif code.get_type_name() == "Union": if not isinstance(fields["a"], AST): return None if not isinstance(fields["b"], AST): return None a = self.unparse(fields["a"]) if a is None: return None b = self.unparse(fields["b"]) if b is None: return None return Union(a, b) elif code.get_type_name() == "Difference": if not isinstance(fields["a"], AST): return None if not isinstance(fields["b"], AST): return None a = self.unparse(fields["a"]) if a is None: return None b = self.unparse(fields["b"]) if b is None: return None return Difference(a, b) return None elif isinstance(code, Leaf): return cast(csgAST, code.value) return None
def pretrain(self, output_dir): dataset = Dataset(4, 1, 2, 1, 45, seed=0) """ """ train_dataset = ListDataset([ Environment( {"ground_truth": Circle(1)}, set(["ground_truth"]), ), Environment( {"ground_truth": Rectangle(1, 2)}, set(["ground_truth"]), ), Environment( {"ground_truth": Rectangle(1, 1)}, set(["ground_truth"]), ), Environment( {"ground_truth": Rotation(45, Rectangle(1, 1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Translation(1, 1, Rectangle(1, 1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Difference(Circle(1), Circle(1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Union(Rectangle(1, 2), Circle(1))}, set(["ground_truth"]), ), Environment( {"ground_truth": Difference(Rectangle(1, 1), Circle(1))}, set(["ground_truth"]), ), ]) with tempfile.TemporaryDirectory() as tmpdir: interpreter = self.interpreter() train_dataset = data_transform( train_dataset, Apply( module=AddTestCases(interpreter), in_keys=["ground_truth"], out_key="test_cases", is_out_supervision=False, )) encoder = self.prepare_encoder(dataset, Parser()) collate = Collate( test_case_tensor=CollateOptions(False, 0, 0), variables_tensor=CollateOptions(True, 0, 0), previous_actions=CollateOptions(True, 0, -1), hidden_state=CollateOptions(False, 0, 0), state=CollateOptions(False, 0, 0), ground_truth_actions=CollateOptions(True, 0, -1) ) collate_fn = Sequence(OrderedDict([ ("to_episode", Map(self.to_episode(encoder, interpreter))), ("flatten", Flatten()), ("transform", Map(self.transform( encoder, interpreter, Parser()))), ("collate", collate.collate) ])) model = self.prepare_model(encoder) optimizer = self.prepare_optimizer(model) train_supervised( tmpdir, output_dir, train_dataset, model, optimizer, torch.nn.Sequential(OrderedDict([ ("loss", Apply( module=Loss( reduction="sum", ), in_keys=[ "rule_probs", "token_probs", "reference_probs", "ground_truth_actions", ], out_key="action_sequence_loss", )), ("normalize", # divided by batch_size Apply( [("action_sequence_loss", "lhs")], "loss", mlprogram.nn.Function(Div()), constants={"rhs": 1})), ("pick", mlprogram.nn.Function( Pick("loss"))) ])), None, "score", collate_fn, 1, Epoch(100), evaluation_interval=Epoch(10), snapshot_interval=Epoch(100) ) return encoder, train_dataset
def test_difference(self): code = Difference(Rectangle(1, 1), Rectangle(3, 1)) interpreter = Interpreter(3, 3, 1, False) assert " \n# #\n \n" == show(interpreter.eval(code, [None])[0])
def test_union(self): code = Union(Rectangle(3, 1), Rectangle(1, 3)) interpreter = Interpreter(3, 3, 1, False) assert " # \n###\n # \n" == show(interpreter.eval(code, [None])[0])
def test_rotation(self): code = Rotation(45, Rectangle(4, 1)) interpreter = Interpreter(3, 3, 1, False) assert " #\n # \n# \n" == show(interpreter.eval(code, [None])[0])
def test_translation(self): code = Translation(2, 1, Rectangle(1, 3)) interpreter = Interpreter(5, 5, 1, False) assert " #\n #\n #\n \n \n" == \ show(interpreter.eval(code, [None])[0])
def test_rectangle(self): code = Rectangle(1, 3) interpreter = Interpreter(3, 3, 1, False) assert " # \n # \n # \n" == show(interpreter.eval(code, [None])[0])
def test_unparse_rectangle(self, parser): assert Rectangle(1, 2) == \ parser.unparse(parser.parse(Rectangle(1, 2)))
def test_parse_rectangle(self, parser): assert Node("Rectangle", [ Field("w", "size", Leaf("size", 1)), Field("h", "size", Leaf("size", 2)) ]) == parser.parse(Rectangle(1, 2))