def test_graph_plotter_retrieve_mutate_op(): entity_id = 1 g = Plotter( entity_defs=[ EntityDef(components=[Position], entity_id=1), ], component_defs=[Position], ) person = g.entity(components=[Position]) pos_x = person[Position].x person[Position].y = pos_x # check retrieve and mutate nodes correctly are set as graph inputs and outputs assert g.graph() == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=entity_id, component=Position.name, attribute="x", )) ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ), # check that mutation node recorded assignment correctly to_node=pos_x.node, ), ], )
def wrap_const(val: Any): """Wrap the given native value as a Constant graph node. If val is a Constant node, returns value as is. Args: val: Native value to wrap. Returns: The given value wrapped as a constant graph node. """ # check if already constant node, return as is if true. if isinstance(val, Node) and val.WhichOneof("op") == "const_op": return val return Node(const_op=Node.Const(held_value=wrap(val)))
def random(self, low: Any, high: Any) -> GraphNode: """Creates a Random Node that evaluates to a random float in range [`low`,`high`] Args: low: Expression that evaluates to the lower bound of the random number generated (inclusive). high: Expression that evaluates to the upper bound of the random number generated (inclusive). Returns: Random Graph Node that evaluates to a random float in range [`low`,`high`] """ low, high = GraphNode.wrap(low), GraphNode.wrap(high) return GraphNode(node=Node( random_op=Node.Random(low=low.node, high=high.node)))
def test_graph_plotter_preserve_code_order(): g = Plotter( entity_defs=[ EntityDef(components=[Position, Velocity], entity_id=1), ], component_defs=[Position, Velocity], ) car = g.entity(components=[Position, Velocity]) # since Velocity.x is used first, it should appear in inputs before Position.x car_velocity_x = car[Velocity].x car_pos_x = car[Position].x # since Velocity.x is modified last, it should appear in outputs after Position.x car[Position].x = 3 car[Velocity].x = 2 position_x = car[Position].x velocity_x = car[Velocity].x assert (g.graph().yaml == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=car.id, component=Velocity.name, attribute="x", )), Node.Retrieve(retrieve_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", )), ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", ), to_node=position_x.node, ), Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Velocity.name, attribute="x", ), to_node=velocity_x.node, ), ], ).yaml)
def test_graph_ecs_entity_update_input_outputs(): # test use_input_outputs() propagates input and output dict to components entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() car = GraphEntity.from_def( entity_def=EntityDef(components=[Position, Speed], entity_id=1), component_defs=[Position, Speed], ) car.use_input_outputs(inputs, outputs) # get/set should propagate retrieve mutate to inputs and output car_pos_x = car[Position].x car[Position].x = 1 car_speed_x = car[Speed].x car[Position].x = 2 pos_attr_ref = AttributeRef(entity_id=entity_id, component=Position.name, attribute="x") speed_attr_ref = AttributeRef(entity_id=entity_id, component=Position.name, attribute="x") pos_expected_input = Node(retrieve_op=Node.Retrieve( retrieve_attr=pos_attr_ref)) pos_expected_output = Node( mutate_op=Node.Mutate(mutate_attr=pos_attr_ref, to_node=wrap_const(1))) assert inputs[to_str_attr(pos_attr_ref)] == pos_expected_input assert outputs[to_str_attr(pos_attr_ref)] == pos_expected_input speed_expected_input = Node(retrieve_op=Node.Retrieve( retrieve_attr=speed_attr_ref)) speed_expected_output = Node(mutate_op=Node.Mutate( mutate_attr=speed_attr_ref, to_node=wrap_const(2))) assert inputs[to_str_attr(speed_attr_ref)] == speed_expected_input assert outputs[to_str_attr(speed_attr_ref)] == speed_expected_input
def from_attr(cls, entity_id: int, component: str, name: str): """Create a GraphNode from the specified ECS attribute""" return GraphNode.wrap( Node.Retrieve(retrieve_attr=AttributeRef( entity_id=entity_id, component=component, attribute=name, )))
def test_graph_ecs_component_set_attr_native_value(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) # check setting attribute to native sets expected output node node. position.y = 3 attr_ref = AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ) expected_node = Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=wrap_const(3), )) assert outputs[to_str_attr(attr_ref)] == expected_node
def get_attr(self, name: str) -> GraphNode: attr_ref = AttributeRef( entity_id=self._entity_id, component=self._name, attribute=str(name), ) # check if attribute has been defined in earlier set_attr() # if so return that definition to preserve the graph already built for that attribute if to_str_attr(attr_ref) in self._outputs: built_graph = self._outputs[to_str_attr( attr_ref)].node.mutate_op.to_node return GraphNode(built_graph) # record the attribute retrieve operation as input graph node get_op = GraphNode(node=Node(retrieve_op=Node.Retrieve( retrieve_attr=attr_ref))) self._inputs[to_str_attr(attr_ref)] = get_op return get_op
def test_graph_ecs_component_get_attr(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) # check that getting an attribute from a component returns a GraphNode # wrapping a Retrieve node that retrieves the attribute pos_x = position.x attr_ref = AttributeRef(entity_id=entity_id, component=Position.name, attribute="x") expected_node = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref)) assert pos_x.node == expected_node # check that component records the retrieve in assert inputs[to_str_attr(attr_ref)].node == expected_node # check that retrieving the same attribute only records it once pos_y = position.x assert len(inputs) == 1
def test_graph_plotter_conditional_boolean(): g = Plotter( entity_defs=[ EntityDef(components=[Position], entity_id=1), EntityDef(components=[Keyboard], entity_id=2), ], component_defs=[Position, Keyboard], ) env = g.entity(components=[Keyboard]) car = g.entity(components=[Position]) key_pressed = env[Keyboard].pressed car_pos_x = g.switch( condition=key_pressed == "left", true=-1.0, false=g.switch( condition=key_pressed == "right", true=1.0, false=0.0, ), ) car[Position].x = car_pos_x assert g.graph() == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=env.id, component=Keyboard.name, attribute="pressed", )) ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", ), # check that mutation node recorded assignment correctly to_node=car_pos_x.node, ), ], )
def set_attr(self, name: str, value: Any): value = GraphNode.wrap(value) attr_ref = AttributeRef( entity_id=self._entity_id, component=self._name, attribute=name, ) # ignore attribute self assignments (ie component.attr = component.attr) if (value.node.WhichOneof("op") == "retrieve_op" and value.node.retrieve_op.retrieve_attr.SerializeToString() == attr_ref.SerializeToString()): return # record the attribute set/mutate operation as output graph node set_op = GraphNode(node=Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=value.node, ))) self._outputs[to_str_attr(attr_ref)] = set_op # preserve order of execution in _outputs by moving set operation record to end self._outputs.move_to_end(to_str_attr(attr_ref))
def test_graph_ecs_component_aug_assign_node(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) # check augment assignment flags the attribute (position.x) as both input and output position.y += 30 attr_ref = AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ) expected_input = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref)) expected_output = Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=Node(add_op=Node.Add( x=expected_input, y=wrap_const(30), )), )) assert len(inputs) == 1 assert inputs[to_str_attr(attr_ref)] == expected_input assert len(outputs) == 1 assert outputs[to_str_attr(attr_ref)] == expected_output
def test_graph_ecs_component_set_attr_node(): entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict() position = GraphComponent.from_def(entity_id, Position) position.use_input_outputs(inputs, outputs) pos_x = position.x position.y = 10 # check setting attribute to node sets expected output node. position.y = pos_x attr_ref = AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ) expected_node = Node(mutate_op=Node.Mutate( mutate_attr=attr_ref, to_node=pos_x.node, )) assert outputs[to_str_attr(attr_ref)].node == expected_node # check that setting attribute only takes the last definition # the first definition should be ignored since the attribute is redefined assert len(outputs) == 1
def switch(self, condition: Any, true: Any, false: Any) -> GraphNode: """Creates a conditional Switch Node that evaluates based on condition. The switch evalutes to `true` if the `condition` is true, `false` otherwise. Args: condition: Defines the condition. Should evaluate to true or false. true: Switch Node evaluates to this expression if `condition` evaluates to true. false: Switch Node evaluates to this expression if `condition` evaluates to false. Return: Switch Node Graph Node that evaluates based on the condition Node. """ condition, true, false = ( GraphNode.wrap(condition), GraphNode.wrap(true), GraphNode.wrap(false), ) return GraphNode.wrap( Node(switch_op=Node.Switch( condition_node=condition.node, true_node=true.node, false_node=false.node, )))
def cos(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(cos_op=Node.Cos(x=x.node)))
def mod(self, x: Any, y: Any) -> GraphNode: x, y = GraphNode.wrap(x), GraphNode.wrap(y) return GraphNode(node=Node(mod_op=Node.Mod(x=x.node, y=y.node)))
def sin(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(sin_op=Node.Sin(x=x.node)))
def ceil(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(ceil_op=Node.Ceil(x=x.node)))
def pow(self, x: Any, y: Any) -> GraphNode: x, y = GraphNode.wrap(x), GraphNode.wrap(y) return GraphNode(node=Node(pow_op=Node.Pow(x=x.node, y=y.node)))
def abs(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(abs_op=Node.Abs(x=x)))
def floor(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(floor_op=Node.Floor(x=x.node)))
def min(self, x: Any, y: Any) -> GraphNode: x, y = GraphNode.wrap(x), GraphNode.wrap(y) return GraphNode(node=Node(min_op=Node.Min(x=x.node, y=y.node)))
def __rmul__(self, other: Any): other = type(self).wrap(other) return type(self).wrap( Node(mul_op=Node.Mul(x=other.node, y=self.node)))
def arctan(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(arctan_op=Node.ArcTan(x=x.node)))
def arccos(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(arccos_op=Node.ArcCos(x=x.node)))
def __radd__(self, other: Any): other = type(self).wrap(other) return type(self).wrap( Node(add_op=Node.Add(x=other.node, y=self.node)))
def tan(self, x: Any) -> GraphNode: x = GraphNode.wrap(x) return GraphNode(node=Node(tan_op=Node.Tan(x=x.node)))
def max(self, x: Any, y: Any) -> GraphNode: x, y = GraphNode.wrap(x), GraphNode.wrap(y) return GraphNode(node=Node(max_op=Node.Max(x=x.node, y=y.node)))
def __ge__(self, other: Any): other = type(self).wrap(other) return type(self).wrap( Node(or_op=Node.Or(x=self.__gt__(other.node).node, y=self.__eq__(other.node).node), ))
def __rsub__(self, other: Any): other = type(self).wrap(other) return type(self).wrap( Node(sub_op=Node.Sub(x=other.node, y=self.node)))