Ejemplo n.º 1
0
def test_graph_plotter_retrieve_mutate_op():
    entity_id = 1
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position], entity_id=1),
        ],
        component_defs=[Position],
    )
    person = g.entity(components=[Position])
    pos_x = person[Position].x
    person[Position].y = pos_x

    # check retrieve and mutate nodes correctly are set as graph inputs and outputs
    assert g.graph() == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=entity_id,
                component=Position.name,
                attribute="x",
            ))
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=entity_id,
                    component=Position.name,
                    attribute="y",
                ),
                # check that mutation node recorded assignment correctly
                to_node=pos_x.node,
            ),
        ],
    )
Ejemplo n.º 2
0
def wrap_const(val: Any):
    """Wrap the given native value as a Constant graph node.
    If val is a Constant node, returns value as is.
    Args:
        val: Native value to wrap.
    Returns:
        The given value wrapped as a constant graph node.
    """
    # check if already constant node, return as is if true.
    if isinstance(val, Node) and val.WhichOneof("op") == "const_op":
        return val
    return Node(const_op=Node.Const(held_value=wrap(val)))
Ejemplo n.º 3
0
    def random(self, low: Any, high: Any) -> GraphNode:
        """Creates a Random Node that evaluates to a random float in range [`low`,`high`]

        Args:
            low: Expression that evaluates to the lower bound of the random number generated (inclusive).
            high: Expression that evaluates to the upper bound of the random number generated (inclusive).

        Returns:
            Random Graph Node that evaluates to a random float in range [`low`,`high`]
        """
        low, high = GraphNode.wrap(low), GraphNode.wrap(high)
        return GraphNode(node=Node(
            random_op=Node.Random(low=low.node, high=high.node)))
Ejemplo n.º 4
0
def test_graph_plotter_preserve_code_order():
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position, Velocity], entity_id=1),
        ],
        component_defs=[Position, Velocity],
    )

    car = g.entity(components=[Position, Velocity])
    # since Velocity.x is used first, it should appear in inputs before Position.x
    car_velocity_x = car[Velocity].x
    car_pos_x = car[Position].x
    # since Velocity.x is modified last, it should appear in outputs after Position.x
    car[Position].x = 3
    car[Velocity].x = 2

    position_x = car[Position].x
    velocity_x = car[Velocity].x

    assert (g.graph().yaml == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=car.id,
                component=Velocity.name,
                attribute="x",
            )),
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=car.id,
                component=Position.name,
                attribute="x",
            )),
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Position.name,
                    attribute="x",
                ),
                to_node=position_x.node,
            ),
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Velocity.name,
                    attribute="x",
                ),
                to_node=velocity_x.node,
            ),
        ],
    ).yaml)
Ejemplo n.º 5
0
def test_graph_ecs_entity_update_input_outputs():
    # test use_input_outputs() propagates input and output dict to components
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    car = GraphEntity.from_def(
        entity_def=EntityDef(components=[Position, Speed], entity_id=1),
        component_defs=[Position, Speed],
    )
    car.use_input_outputs(inputs, outputs)

    # get/set should propagate retrieve mutate to inputs and output
    car_pos_x = car[Position].x
    car[Position].x = 1
    car_speed_x = car[Speed].x
    car[Position].x = 2

    pos_attr_ref = AttributeRef(entity_id=entity_id,
                                component=Position.name,
                                attribute="x")
    speed_attr_ref = AttributeRef(entity_id=entity_id,
                                  component=Position.name,
                                  attribute="x")

    pos_expected_input = Node(retrieve_op=Node.Retrieve(
        retrieve_attr=pos_attr_ref))
    pos_expected_output = Node(
        mutate_op=Node.Mutate(mutate_attr=pos_attr_ref, to_node=wrap_const(1)))
    assert inputs[to_str_attr(pos_attr_ref)] == pos_expected_input
    assert outputs[to_str_attr(pos_attr_ref)] == pos_expected_input

    speed_expected_input = Node(retrieve_op=Node.Retrieve(
        retrieve_attr=speed_attr_ref))
    speed_expected_output = Node(mutate_op=Node.Mutate(
        mutate_attr=speed_attr_ref, to_node=wrap_const(2)))
    assert inputs[to_str_attr(speed_attr_ref)] == speed_expected_input
    assert outputs[to_str_attr(speed_attr_ref)] == speed_expected_input
Ejemplo n.º 6
0
 def from_attr(cls, entity_id: int, component: str, name: str):
     """Create a GraphNode from the specified ECS attribute"""
     return GraphNode.wrap(
         Node.Retrieve(retrieve_attr=AttributeRef(
             entity_id=entity_id,
             component=component,
             attribute=name,
         )))
Ejemplo n.º 7
0
def test_graph_ecs_component_set_attr_native_value():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    # check setting attribute to native sets expected output node node.
    position.y = 3

    attr_ref = AttributeRef(
        entity_id=entity_id,
        component=Position.name,
        attribute="y",
    )
    expected_node = Node(mutate_op=Node.Mutate(
        mutate_attr=attr_ref,
        to_node=wrap_const(3),
    ))
    assert outputs[to_str_attr(attr_ref)] == expected_node
Ejemplo n.º 8
0
    def get_attr(self, name: str) -> GraphNode:
        attr_ref = AttributeRef(
            entity_id=self._entity_id,
            component=self._name,
            attribute=str(name),
        )
        # check if attribute has been defined in earlier set_attr()
        # if so return that definition to preserve the graph already built for that attribute
        if to_str_attr(attr_ref) in self._outputs:
            built_graph = self._outputs[to_str_attr(
                attr_ref)].node.mutate_op.to_node
            return GraphNode(built_graph)

        # record the attribute retrieve operation as input graph node
        get_op = GraphNode(node=Node(retrieve_op=Node.Retrieve(
            retrieve_attr=attr_ref)))
        self._inputs[to_str_attr(attr_ref)] = get_op
        return get_op
Ejemplo n.º 9
0
def test_graph_ecs_component_get_attr():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    # check that getting an attribute from a component returns a GraphNode
    # wrapping a Retrieve node that retrieves the attribute
    pos_x = position.x

    attr_ref = AttributeRef(entity_id=entity_id,
                            component=Position.name,
                            attribute="x")
    expected_node = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref))
    assert pos_x.node == expected_node
    # check that component records the retrieve in
    assert inputs[to_str_attr(attr_ref)].node == expected_node
    # check that retrieving the same attribute only records it once
    pos_y = position.x
    assert len(inputs) == 1
Ejemplo n.º 10
0
def test_graph_plotter_conditional_boolean():
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position], entity_id=1),
            EntityDef(components=[Keyboard], entity_id=2),
        ],
        component_defs=[Position, Keyboard],
    )
    env = g.entity(components=[Keyboard])
    car = g.entity(components=[Position])

    key_pressed = env[Keyboard].pressed
    car_pos_x = g.switch(
        condition=key_pressed == "left",
        true=-1.0,
        false=g.switch(
            condition=key_pressed == "right",
            true=1.0,
            false=0.0,
        ),
    )
    car[Position].x = car_pos_x
    assert g.graph() == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=env.id,
                component=Keyboard.name,
                attribute="pressed",
            ))
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Position.name,
                    attribute="x",
                ),
                # check that mutation node recorded assignment correctly
                to_node=car_pos_x.node,
            ),
        ],
    )
Ejemplo n.º 11
0
 def set_attr(self, name: str, value: Any):
     value = GraphNode.wrap(value)
     attr_ref = AttributeRef(
         entity_id=self._entity_id,
         component=self._name,
         attribute=name,
     )
     # ignore attribute self assignments (ie component.attr = component.attr)
     if (value.node.WhichOneof("op") == "retrieve_op"
             and value.node.retrieve_op.retrieve_attr.SerializeToString()
             == attr_ref.SerializeToString()):
         return
     # record the attribute set/mutate operation as output graph node
     set_op = GraphNode(node=Node(mutate_op=Node.Mutate(
         mutate_attr=attr_ref,
         to_node=value.node,
     )))
     self._outputs[to_str_attr(attr_ref)] = set_op
     # preserve order of execution in _outputs by moving set operation record to end
     self._outputs.move_to_end(to_str_attr(attr_ref))
Ejemplo n.º 12
0
def test_graph_ecs_component_aug_assign_node():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    # check augment assignment flags the attribute (position.x) as both input and output
    position.y += 30

    attr_ref = AttributeRef(
        entity_id=entity_id,
        component=Position.name,
        attribute="y",
    )
    expected_input = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref))
    expected_output = Node(mutate_op=Node.Mutate(
        mutate_attr=attr_ref,
        to_node=Node(add_op=Node.Add(
            x=expected_input,
            y=wrap_const(30),
        )),
    ))
    assert len(inputs) == 1
    assert inputs[to_str_attr(attr_ref)] == expected_input
    assert len(outputs) == 1
    assert outputs[to_str_attr(attr_ref)] == expected_output
Ejemplo n.º 13
0
def test_graph_ecs_component_set_attr_node():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    pos_x = position.x
    position.y = 10
    # check setting attribute to node sets expected output node.
    position.y = pos_x

    attr_ref = AttributeRef(
        entity_id=entity_id,
        component=Position.name,
        attribute="y",
    )
    expected_node = Node(mutate_op=Node.Mutate(
        mutate_attr=attr_ref,
        to_node=pos_x.node,
    ))
    assert outputs[to_str_attr(attr_ref)].node == expected_node
    # check that setting attribute only takes the last definition
    # the first definition should be ignored since the attribute is redefined
    assert len(outputs) == 1
Ejemplo n.º 14
0
    def switch(self, condition: Any, true: Any, false: Any) -> GraphNode:
        """Creates a conditional Switch Node that evaluates based on condition.

        The switch evalutes to `true` if the `condition` is true, `false` otherwise.

        Args:
            condition: Defines the condition. Should evaluate to true or false.
            true: Switch Node evaluates to this expression if `condition` evaluates to true.
            false: Switch Node evaluates to this expression if `condition` evaluates to false.
        Return:
            Switch Node Graph Node that evaluates based on the condition Node.
        """
        condition, true, false = (
            GraphNode.wrap(condition),
            GraphNode.wrap(true),
            GraphNode.wrap(false),
        )
        return GraphNode.wrap(
            Node(switch_op=Node.Switch(
                condition_node=condition.node,
                true_node=true.node,
                false_node=false.node,
            )))
Ejemplo n.º 15
0
 def cos(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(cos_op=Node.Cos(x=x.node)))
Ejemplo n.º 16
0
 def mod(self, x: Any, y: Any) -> GraphNode:
     x, y = GraphNode.wrap(x), GraphNode.wrap(y)
     return GraphNode(node=Node(mod_op=Node.Mod(x=x.node, y=y.node)))
Ejemplo n.º 17
0
 def sin(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(sin_op=Node.Sin(x=x.node)))
Ejemplo n.º 18
0
 def ceil(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(ceil_op=Node.Ceil(x=x.node)))
Ejemplo n.º 19
0
 def pow(self, x: Any, y: Any) -> GraphNode:
     x, y = GraphNode.wrap(x), GraphNode.wrap(y)
     return GraphNode(node=Node(pow_op=Node.Pow(x=x.node, y=y.node)))
Ejemplo n.º 20
0
 def abs(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(abs_op=Node.Abs(x=x)))
Ejemplo n.º 21
0
 def floor(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(floor_op=Node.Floor(x=x.node)))
Ejemplo n.º 22
0
 def min(self, x: Any, y: Any) -> GraphNode:
     x, y = GraphNode.wrap(x), GraphNode.wrap(y)
     return GraphNode(node=Node(min_op=Node.Min(x=x.node, y=y.node)))
Ejemplo n.º 23
0
 def __rmul__(self, other: Any):
     other = type(self).wrap(other)
     return type(self).wrap(
         Node(mul_op=Node.Mul(x=other.node, y=self.node)))
Ejemplo n.º 24
0
 def arctan(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(arctan_op=Node.ArcTan(x=x.node)))
Ejemplo n.º 25
0
 def arccos(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(arccos_op=Node.ArcCos(x=x.node)))
Ejemplo n.º 26
0
 def __radd__(self, other: Any):
     other = type(self).wrap(other)
     return type(self).wrap(
         Node(add_op=Node.Add(x=other.node, y=self.node)))
Ejemplo n.º 27
0
 def tan(self, x: Any) -> GraphNode:
     x = GraphNode.wrap(x)
     return GraphNode(node=Node(tan_op=Node.Tan(x=x.node)))
Ejemplo n.º 28
0
 def max(self, x: Any, y: Any) -> GraphNode:
     x, y = GraphNode.wrap(x), GraphNode.wrap(y)
     return GraphNode(node=Node(max_op=Node.Max(x=x.node, y=y.node)))
Ejemplo n.º 29
0
 def __ge__(self, other: Any):
     other = type(self).wrap(other)
     return type(self).wrap(
         Node(or_op=Node.Or(x=self.__gt__(other.node).node,
                            y=self.__eq__(other.node).node), ))
Ejemplo n.º 30
0
 def __rsub__(self, other: Any):
     other = type(self).wrap(other)
     return type(self).wrap(
         Node(sub_op=Node.Sub(x=other.node, y=self.node)))