def test_graph_ecs_entity_update_input_outputs(): # test use_input_outputs() propagates input and output dict to components entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() car = GraphEntity.from_def( entity_def=EntityDef(components=[Position, Speed], entity_id=1), component_defs=[Position, Speed], ) car.use_input_outputs(inputs, outputs) # get/set should propagate retrieve mutate to inputs and output car_pos_x = car[Position].x car[Position].x = 1 car_speed_x = car[Speed].x car[Position].x = 2 pos_attr_ref = AttributeRef(entity_id=entity_id, component=Position.name, attribute="x") speed_attr_ref = AttributeRef(entity_id=entity_id, component=Position.name, attribute="x") pos_expected_input = Node(retrieve_op=Node.Retrieve( retrieve_attr=pos_attr_ref)) pos_expected_output = Node( mutate_op=Node.Mutate(mutate_attr=pos_attr_ref, to_node=wrap_const(1))) assert inputs[to_str_attr(pos_attr_ref)] == pos_expected_input assert outputs[to_str_attr(pos_attr_ref)] == pos_expected_input speed_expected_input = Node(retrieve_op=Node.Retrieve( retrieve_attr=speed_attr_ref)) speed_expected_output = Node(mutate_op=Node.Mutate( mutate_attr=speed_attr_ref, to_node=wrap_const(2))) assert inputs[to_str_attr(speed_attr_ref)] == speed_expected_input assert outputs[to_str_attr(speed_attr_ref)] == speed_expected_input
def test_graph_plotter_retrieve_mutate_op(): entity_id = 1 g = Plotter( entity_defs=[ EntityDef(components=[Position], entity_id=1), ], component_defs=[Position], ) person = g.entity(components=[Position]) pos_x = person[Position].x person[Position].y = pos_x # check retrieve and mutate nodes correctly are set as graph inputs and outputs assert g.graph() == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=entity_id, component=Position.name, attribute="x", )) ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ), # check that mutation node recorded assignment correctly to_node=pos_x.node, ), ], )
def test_graph_plotter_preserve_code_order(): g = Plotter( entity_defs=[ EntityDef(components=[Position, Velocity], entity_id=1), ], component_defs=[Position, Velocity], ) car = g.entity(components=[Position, Velocity]) # since Velocity.x is used first, it should appear in inputs before Position.x car_velocity_x = car[Velocity].x car_pos_x = car[Position].x # since Velocity.x is modified last, it should appear in outputs after Position.x car[Position].x = 3 car[Velocity].x = 2 position_x = car[Position].x velocity_x = car[Velocity].x assert (g.graph().yaml == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=car.id, component=Velocity.name, attribute="x", )), Node.Retrieve(retrieve_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", )), ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", ), to_node=position_x.node, ), Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Velocity.name, attribute="x", ), to_node=velocity_x.node, ), ], ).yaml)
def test_graph_ecs_component_aug_assign_node(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) # check augment assignment flags the attribute (position.x) as both input and output position.y += 30 attr_ref = AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ) expected_input = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref)) expected_output = Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=Node(add_op=Node.Add( x=expected_input, y=wrap_const(30), )), )) assert len(inputs) == 1 assert inputs[to_str_attr(attr_ref)] == expected_input assert len(outputs) == 1 assert outputs[to_str_attr(attr_ref)] == expected_output
def test_client_set_attr(client, sim_def, attr_ref, attr_val): client.set_attr(sim_def.name, attr_ref, attr_val) # test not found error handling has_not_found_error = False try: client.set_attr( sim_name=sim_def.name, attr_ref=AttributeRef( entity_id=1, component="not", attribute="found", ), value=attr_val, ) except LookupError: has_not_found_error = True assert has_not_found_error # test invalid value error handling has_invalid_error = True invalid_val = Value() try: response = client.set_attr( sim_name=sim_def.name, attr_ref=attr_ref, value=invalid_val, ) except ValueError: has_invalid_error = True assert has_invalid_error
def check_client_set_attr(x): mock_client.set_attr.assert_called_with( sim_name=sim_name, attr_ref=AttributeRef(entity_id=1, component="position", attribute="y"), value=wrap(x), )
def from_attr(cls, entity_id: int, component: str, name: str): """Create a GraphNode from the specified ECS attribute""" return GraphNode.wrap( Node.Retrieve(retrieve_attr=AttributeRef( entity_id=entity_id, component=component, attribute=name, )))
def test_graph_plotter_conditional_boolean(): g = Plotter( entity_defs=[ EntityDef(components=[Position], entity_id=1), EntityDef(components=[Keyboard], entity_id=2), ], component_defs=[Position, Keyboard], ) env = g.entity(components=[Keyboard]) car = g.entity(components=[Position]) key_pressed = env[Keyboard].pressed car_pos_x = g.switch( condition=key_pressed == "left", true=-1.0, false=g.switch( condition=key_pressed == "right", true=1.0, false=0.0, ), ) car[Position].x = car_pos_x assert g.graph() == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=env.id, component=Keyboard.name, attribute="pressed", )) ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", ), # check that mutation node recorded assignment correctly to_node=car_pos_x.node, ), ], )
def set_attr(self, name: str, value: Any): # set attribute to value on the engine self._client.set_attr( sim_name=self._sim_name, attr_ref=AttributeRef( entity_id=self._entity_id, component=self._name, attribute=name, ), value=wrap(value), )
def set_attr(self, name: str, value: Any): value = GraphNode.wrap(value) attr_ref = AttributeRef( entity_id=self._entity_id, component=self._name, attribute=name, ) # ignore attribute self assignments (ie component.attr = component.attr) if (value.node.WhichOneof("op") == "retrieve_op" and value.node.retrieve_op.retrieve_attr.SerializeToString() == attr_ref.SerializeToString()): return # record the attribute set/mutate operation as output graph node set_op = GraphNode(node=Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=value.node, ))) self._outputs[to_str_attr(attr_ref)] = set_op # preserve order of execution in _outputs by moving set operation record to end self._outputs.move_to_end(to_str_attr(attr_ref))
def get_attr(self, name: str) -> Any: # retrieve attribute from engine value = self._client.get_attr( sim_name=self._sim_name, attr_ref=AttributeRef( entity_id=self._entity_id, component=self._name, attribute=name, ), ) return unwrap(value)
def test_build_proto(): attr_ref = AttributeRef( entity_id=24, component="Sprite2D", ) encoded = attr_ref.SerializeToString() restored = AttributeRef() restored.ParseFromString(encoded) assert attr_ref == restored
def get_attr(self, name: str) -> GraphNode: attr_ref = AttributeRef( entity_id=self._entity_id, component=self._name, attribute=str(name), ) # check if attribute has been defined in earlier set_attr() # if so return that definition to preserve the graph already built for that attribute if to_str_attr(attr_ref) in self._outputs: built_graph = self._outputs[to_str_attr( attr_ref)].node.mutate_op.to_node return GraphNode(built_graph) # record the attribute retrieve operation as input graph node get_op = GraphNode(node=Node(retrieve_op=Node.Retrieve( retrieve_attr=attr_ref))) self._inputs[to_str_attr(attr_ref)] = get_op return get_op
def test_graph_ecs_component_set_attr_native_value(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) # check setting attribute to native sets expected output node node. position.y = 3 attr_ref = AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ) expected_node = Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=wrap_const(3), )) assert outputs[to_str_attr(attr_ref)] == expected_node
def test_client_get_attr(client, sim_def, attr_ref, attr_val): value = client.get_attr(sim_def.name, attr_ref) assert value == attr_val # test not found error handling # attribute not found has_not_found_error = False try: client.get_attr( sim_name=sim_def.name, attr_ref=AttributeRef( entity_id=1, component="not", attribute="found", ), ) except LookupError: has_not_found_error = True assert has_not_found_error
def test_graph_ecs_component_get_attr(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) # check that getting an attribute from a component returns a GraphNode # wrapping a Retrieve node that retrieves the attribute pos_x = position.x attr_ref = AttributeRef(entity_id=entity_id, component=Position.name, attribute="x") expected_node = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref)) assert pos_x.node == expected_node # check that component records the retrieve in assert inputs[to_str_attr(attr_ref)].node == expected_node # check that retrieving the same attribute only records it once pos_y = position.x assert len(inputs) == 1
def test_graph_ecs_component_set_attr_node(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) pos_x = position.x position.y = 10 # check setting attribute to node sets expected output node. position.y = pos_x attr_ref = AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ) expected_node = Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=pos_x.node, )) assert outputs[to_str_attr(attr_ref)].node == expected_node # check that setting attribute only takes the last definition # the first definition should be ignored since the attribute is redefined assert len(outputs) == 1
def attr_ref(): return AttributeRef(entity_id=1, component="position", attribute="x")