def test_graph_compile_arithmetic_multiple(): def arithmetic_multiple_fn(g: Plotter): ms_in_sec = int(1e3) env = g.entity(components=[Clock]) car = g.entity( components=[ Position, Velocity, ] ) tick_ms = env[Clock].tick_ms neg_xps = -car[Velocity].x x_delta = neg_xps * (tick_ms * ms_in_sec) car[Position].x = x_delta + car[Position].x actual_graph = compile_graph( convert_fn=arithmetic_multiple_fn, component_defs=[Position, Velocity, Clock], entity_defs=[ EntityDef(components=[Position, Velocity], entity_id=1), EntityDef(components=[Clock], entity_id=2), ], ) assert_graph(actual_graph, "expected_graph_arithmetic_multiple.yaml")
def test_graph_compile_ifelse(): def ifelse_fn(g: Plotter): car = g.entity( components=[ Position, Speed, Velocity, ] ) env = g.entity(components=[Clock]) if env[Clock].tick_ms > 2000: car[Position].x += g.min(car[Speed].max_x, 2 * car[Velocity].x) car[Position].y = car[Position].x + 2 elif env[Clock].tick_ms > 5000: car[Position].x += g.min(car[Speed].max_x, 5 * car[Velocity].x) car[Position].y = car[Position].x + 10 else: car[Position].x = g.min(car[Speed].max_x, 1 * car[Velocity].x) car[Position].y = car[Position].x - 5 actual_graph = compile_graph( convert_fn=ifelse_fn, component_defs=[Position, Clock, Velocity, Speed], entity_defs=[ EntityDef(components=[Position, Velocity, Speed], entity_id=1), EntityDef(components=[Clock], entity_id=2), ], ) assert_graph(actual_graph, "expected_graph_ifelse.yaml")
def build(self, include_graphs: bool = True) -> sim_pb2.SimulationDef: """ Build a `bento.eachproto.sim_pb2.SimulationDef` Proto from this Simulation. Args: include_graphs: Whether to compile & include graphs in the returned Proto. This requires that id to be set for each entity as entity ids are required for graph compilation to work. Returns: The `bento.proto.sim_pb2.SimulationDef` Proto equivalent of this Simulation. """ # compile graphs if requested to be included system_defs, init_graph = [], Graph() if include_graphs: compile_fn = lambda fn: compile_graph(fn, self.entity_defs, self. component_defs) # compile systems graphs system_defs = [ SystemDef( graph=compile_fn(fn), system_id=system_id, ) for fn, system_id in self.system_fns ] # compile init graph init_graph = (compile_fn(self.init_fn) if self.init_fn is not None else Graph()) return sim_pb2.SimulationDef( name=self.name, entities=[e.proto for e in self.entity_defs], components=[c.proto for c in self.component_defs], systems=[s.proto for s in system_defs], init_graph=init_graph.proto, )
def test_graph_compile_empty(): def empty_fn(g: Plotter): pass actual_graph = compile_graph( convert_fn=empty_fn, entity_defs=[], component_defs=[], ) assert actual_graph == Graph()
def test_graph_compile_arithmetic(): def arithmetic_fn(g: Plotter): car = g.entity(components=[Position]) x_delta = 20 car[Position].x += x_delta actual_graph = compile_graph( convert_fn=arithmetic_fn, component_defs=[Position], entity_defs=[EntityDef(components=[Position], entity_id=1)], ) assert_graph(actual_graph, "expected_graph_arithmetic.yaml")
def test_graph_compile_ternary(): def ternary_fn(g: Plotter): car = g.entity( components=[ Position, ] ) env = g.entity(components=[Clock]) x_delta = 20 if env[Clock].tick_ms > 2000 else 10 car[Position].x = x_delta actual_graph = compile_graph( convert_fn=ternary_fn, component_defs=[Position, Clock], entity_defs=[ EntityDef(components=[Position], entity_id=1), EntityDef(components=[Clock], entity_id=2), ], ) assert_graph(actual_graph, "expected_graph_ternary.yaml")
def init_graph(component_defs, entity_defs): return compile_graph(init_fn, entity_defs, component_defs)
def system_defs(component_defs, entity_defs): return [ SystemDef(graph=compile_graph(sys_fn, entity_defs, component_defs), system_id=1) ]