Esempio n. 1
0
def test_graph_ecs_entity_update_input_outputs():
    # test use_input_outputs() propagates input and output dict to components
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    car = GraphEntity.from_def(
        entity_def=EntityDef(components=[Position, Speed], entity_id=1),
        component_defs=[Position, Speed],
    )
    car.use_input_outputs(inputs, outputs)

    # get/set should propagate retrieve mutate to inputs and output
    car_pos_x = car[Position].x
    car[Position].x = 1
    car_speed_x = car[Speed].x
    car[Position].x = 2

    pos_attr_ref = AttributeRef(entity_id=entity_id,
                                component=Position.name,
                                attribute="x")
    speed_attr_ref = AttributeRef(entity_id=entity_id,
                                  component=Position.name,
                                  attribute="x")

    pos_expected_input = Node(retrieve_op=Node.Retrieve(
        retrieve_attr=pos_attr_ref))
    pos_expected_output = Node(
        mutate_op=Node.Mutate(mutate_attr=pos_attr_ref, to_node=wrap_const(1)))
    assert inputs[to_str_attr(pos_attr_ref)] == pos_expected_input
    assert outputs[to_str_attr(pos_attr_ref)] == pos_expected_input

    speed_expected_input = Node(retrieve_op=Node.Retrieve(
        retrieve_attr=speed_attr_ref))
    speed_expected_output = Node(mutate_op=Node.Mutate(
        mutate_attr=speed_attr_ref, to_node=wrap_const(2)))
    assert inputs[to_str_attr(speed_attr_ref)] == speed_expected_input
    assert outputs[to_str_attr(speed_attr_ref)] == speed_expected_input
Esempio n. 2
0
def test_graph_plotter_retrieve_mutate_op():
    entity_id = 1
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position], entity_id=1),
        ],
        component_defs=[Position],
    )
    person = g.entity(components=[Position])
    pos_x = person[Position].x
    person[Position].y = pos_x

    # check retrieve and mutate nodes correctly are set as graph inputs and outputs
    assert g.graph() == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=entity_id,
                component=Position.name,
                attribute="x",
            ))
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=entity_id,
                    component=Position.name,
                    attribute="y",
                ),
                # check that mutation node recorded assignment correctly
                to_node=pos_x.node,
            ),
        ],
    )
Esempio n. 3
0
def test_graph_plotter_preserve_code_order():
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position, Velocity], entity_id=1),
        ],
        component_defs=[Position, Velocity],
    )

    car = g.entity(components=[Position, Velocity])
    # since Velocity.x is used first, it should appear in inputs before Position.x
    car_velocity_x = car[Velocity].x
    car_pos_x = car[Position].x
    # since Velocity.x is modified last, it should appear in outputs after Position.x
    car[Position].x = 3
    car[Velocity].x = 2

    position_x = car[Position].x
    velocity_x = car[Velocity].x

    assert (g.graph().yaml == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=car.id,
                component=Velocity.name,
                attribute="x",
            )),
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=car.id,
                component=Position.name,
                attribute="x",
            )),
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Position.name,
                    attribute="x",
                ),
                to_node=position_x.node,
            ),
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Velocity.name,
                    attribute="x",
                ),
                to_node=velocity_x.node,
            ),
        ],
    ).yaml)
Esempio n. 4
0
def test_graph_ecs_component_aug_assign_node():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    # check augment assignment flags the attribute (position.x) as both input and output
    position.y += 30

    attr_ref = AttributeRef(
        entity_id=entity_id,
        component=Position.name,
        attribute="y",
    )
    expected_input = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref))
    expected_output = Node(mutate_op=Node.Mutate(
        mutate_attr=attr_ref,
        to_node=Node(add_op=Node.Add(
            x=expected_input,
            y=wrap_const(30),
        )),
    ))
    assert len(inputs) == 1
    assert inputs[to_str_attr(attr_ref)] == expected_input
    assert len(outputs) == 1
    assert outputs[to_str_attr(attr_ref)] == expected_output
Esempio n. 5
0
def test_client_set_attr(client, sim_def, attr_ref, attr_val):
    client.set_attr(sim_def.name, attr_ref, attr_val)

    # test not found error handling
    has_not_found_error = False
    try:
        client.set_attr(
            sim_name=sim_def.name,
            attr_ref=AttributeRef(
                entity_id=1,
                component="not",
                attribute="found",
            ),
            value=attr_val,
        )
    except LookupError:
        has_not_found_error = True
    assert has_not_found_error

    # test invalid value error handling
    has_invalid_error = True
    invalid_val = Value()
    try:
        response = client.set_attr(
            sim_name=sim_def.name,
            attr_ref=attr_ref,
            value=invalid_val,
        )
    except ValueError:
        has_invalid_error = True
    assert has_invalid_error
Esempio n. 6
0
 def check_client_set_attr(x):
     mock_client.set_attr.assert_called_with(
         sim_name=sim_name,
         attr_ref=AttributeRef(entity_id=1,
                               component="position",
                               attribute="y"),
         value=wrap(x),
     )
Esempio n. 7
0
 def from_attr(cls, entity_id: int, component: str, name: str):
     """Create a GraphNode from the specified ECS attribute"""
     return GraphNode.wrap(
         Node.Retrieve(retrieve_attr=AttributeRef(
             entity_id=entity_id,
             component=component,
             attribute=name,
         )))
Esempio n. 8
0
def test_graph_plotter_conditional_boolean():
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position], entity_id=1),
            EntityDef(components=[Keyboard], entity_id=2),
        ],
        component_defs=[Position, Keyboard],
    )
    env = g.entity(components=[Keyboard])
    car = g.entity(components=[Position])

    key_pressed = env[Keyboard].pressed
    car_pos_x = g.switch(
        condition=key_pressed == "left",
        true=-1.0,
        false=g.switch(
            condition=key_pressed == "right",
            true=1.0,
            false=0.0,
        ),
    )
    car[Position].x = car_pos_x
    assert g.graph() == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=env.id,
                component=Keyboard.name,
                attribute="pressed",
            ))
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Position.name,
                    attribute="x",
                ),
                # check that mutation node recorded assignment correctly
                to_node=car_pos_x.node,
            ),
        ],
    )
Esempio n. 9
0
 def set_attr(self, name: str, value: Any):
     # set attribute to value on the engine
     self._client.set_attr(
         sim_name=self._sim_name,
         attr_ref=AttributeRef(
             entity_id=self._entity_id,
             component=self._name,
             attribute=name,
         ),
         value=wrap(value),
     )
Esempio n. 10
0
 def set_attr(self, name: str, value: Any):
     value = GraphNode.wrap(value)
     attr_ref = AttributeRef(
         entity_id=self._entity_id,
         component=self._name,
         attribute=name,
     )
     # ignore attribute self assignments (ie component.attr = component.attr)
     if (value.node.WhichOneof("op") == "retrieve_op"
             and value.node.retrieve_op.retrieve_attr.SerializeToString()
             == attr_ref.SerializeToString()):
         return
     # record the attribute set/mutate operation as output graph node
     set_op = GraphNode(node=Node(mutate_op=Node.Mutate(
         mutate_attr=attr_ref,
         to_node=value.node,
     )))
     self._outputs[to_str_attr(attr_ref)] = set_op
     # preserve order of execution in _outputs by moving set operation record to end
     self._outputs.move_to_end(to_str_attr(attr_ref))
Esempio n. 11
0
 def get_attr(self, name: str) -> Any:
     # retrieve attribute from engine
     value = self._client.get_attr(
         sim_name=self._sim_name,
         attr_ref=AttributeRef(
             entity_id=self._entity_id,
             component=self._name,
             attribute=name,
         ),
     )
     return unwrap(value)
Esempio n. 12
0
def test_build_proto():
    attr_ref = AttributeRef(
        entity_id=24,
        component="Sprite2D",
    )
    encoded = attr_ref.SerializeToString()
    restored = AttributeRef()
    restored.ParseFromString(encoded)
    assert attr_ref == restored
Esempio n. 13
0
    def get_attr(self, name: str) -> GraphNode:
        attr_ref = AttributeRef(
            entity_id=self._entity_id,
            component=self._name,
            attribute=str(name),
        )
        # check if attribute has been defined in earlier set_attr()
        # if so return that definition to preserve the graph already built for that attribute
        if to_str_attr(attr_ref) in self._outputs:
            built_graph = self._outputs[to_str_attr(
                attr_ref)].node.mutate_op.to_node
            return GraphNode(built_graph)

        # record the attribute retrieve operation as input graph node
        get_op = GraphNode(node=Node(retrieve_op=Node.Retrieve(
            retrieve_attr=attr_ref)))
        self._inputs[to_str_attr(attr_ref)] = get_op
        return get_op
Esempio n. 14
0
def test_graph_ecs_component_set_attr_native_value():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    # check setting attribute to native sets expected output node node.
    position.y = 3

    attr_ref = AttributeRef(
        entity_id=entity_id,
        component=Position.name,
        attribute="y",
    )
    expected_node = Node(mutate_op=Node.Mutate(
        mutate_attr=attr_ref,
        to_node=wrap_const(3),
    ))
    assert outputs[to_str_attr(attr_ref)] == expected_node
Esempio n. 15
0
def test_client_get_attr(client, sim_def, attr_ref, attr_val):
    value = client.get_attr(sim_def.name, attr_ref)
    assert value == attr_val

    # test not found error handling
    # attribute not found
    has_not_found_error = False
    try:
        client.get_attr(
            sim_name=sim_def.name,
            attr_ref=AttributeRef(
                entity_id=1,
                component="not",
                attribute="found",
            ),
        )
    except LookupError:
        has_not_found_error = True
    assert has_not_found_error
Esempio n. 16
0
def test_graph_ecs_component_get_attr():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    # check that getting an attribute from a component returns a GraphNode
    # wrapping a Retrieve node that retrieves the attribute
    pos_x = position.x

    attr_ref = AttributeRef(entity_id=entity_id,
                            component=Position.name,
                            attribute="x")
    expected_node = Node(retrieve_op=Node.Retrieve(retrieve_attr=attr_ref))
    assert pos_x.node == expected_node
    # check that component records the retrieve in
    assert inputs[to_str_attr(attr_ref)].node == expected_node
    # check that retrieving the same attribute only records it once
    pos_y = position.x
    assert len(inputs) == 1
Esempio n. 17
0
def test_graph_ecs_component_set_attr_node():
    entity_id, inputs, outputs = 1, OrderedDict(), OrderedDict()
    position = GraphComponent.from_def(entity_id, Position)
    position.use_input_outputs(inputs, outputs)

    pos_x = position.x
    position.y = 10
    # check setting attribute to node sets expected output node.
    position.y = pos_x

    attr_ref = AttributeRef(
        entity_id=entity_id,
        component=Position.name,
        attribute="y",
    )
    expected_node = Node(mutate_op=Node.Mutate(
        mutate_attr=attr_ref,
        to_node=pos_x.node,
    ))
    assert outputs[to_str_attr(attr_ref)].node == expected_node
    # check that setting attribute only takes the last definition
    # the first definition should be ignored since the attribute is redefined
    assert len(outputs) == 1
Esempio n. 18
0
def attr_ref():
    return AttributeRef(entity_id=1, component="position", attribute="x")