Ejemplo n.º 1
0
def sim_fn(g: Plotter):
    env = g.entity(components=[Action, State])
    car = g.entity(components=[Velocity, Position])

    # process car physics
    # compute velocity based on acceleration action & decceleration due to gravity
    acceleration, gravity, max_speed = 0.001, 0.0025, 0.07
    # apply acceleration based on accelerate action:
    # 0: Accelerate to the Left
    # 1: Don't accelerate
    # 2: Accelerate to the Right
    car[Velocity].x += (env[Action].accelerate - 1) * acceleration
    # apply gravity inverse to the mountain path used by the car
    # the mountain is defined by y = sin(3*x)
    # as such we apply gravity inversely using y = cos(3*x)
    # apply negative gravity as gravity works in the opposite direction of movement
    car[Velocity].x += g.cos(3 * car[Position].x) * (-gravity)
    car[Velocity].x = g.clip(car[Velocity].x,
                             min_x=-max_speed,
                             max_x=max_speed)

    # compute new position from current velocity
    min_position, max_position = -1.2, 0.6
    car[Position].x += car[Velocity].x
    car[Position].x = g.clip(car[Position].x, min_position, max_position)

    # collision: stop car when colliding with min_position
    if car[Position].x <= min_position:
        car[Velocity].x = 0.0

    # resolve simulation state: reward and simulation completition
    env[State].reward = 0 if car[Position].x >= 0.5 else -1
    env[State].ended = True if car[Position].x > 0.5 else False
Ejemplo n.º 2
0
def test_graph_plotter_retrieve_mutate_op():
    entity_id = 1
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position], entity_id=1),
        ],
        component_defs=[Position],
    )
    person = g.entity(components=[Position])
    pos_x = person[Position].x
    person[Position].y = pos_x

    # check retrieve and mutate nodes correctly are set as graph inputs and outputs
    assert g.graph() == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=entity_id,
                component=Position.name,
                attribute="x",
            ))
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=entity_id,
                    component=Position.name,
                    attribute="y",
                ),
                # check that mutation node recorded assignment correctly
                to_node=pos_x.node,
            ),
        ],
    )
Ejemplo n.º 3
0
 def ternary_fn(g: Plotter):
     car = g.entity(
         components=[
             Position,
         ]
     )
     env = g.entity(components=[Clock])
     x_delta = 20 if env[Clock].tick_ms > 2000 else 10
     car[Position].x = x_delta
Ejemplo n.º 4
0
def init_fn(g: Plotter):
    car = g.entity(components=[Velocity, Position])
    car[Velocity].x = 0.0
    car[Position].x = g.random(-0.6, -0.4)

    env = g.entity(components=[Action, State])
    env[State].reward = 0
    env[State].ended = False
    env[Action].accelerate = 1
Ejemplo n.º 5
0
    def physics_sys(g: Plotter):
        # compute velocity from car's rotation and speed
        car = g.entity(components=[Movement, Velocity, Position, Meta])
        # rotation
        heading_x, heading_y = g.cos(
            car[Movement].rotation), -g.sin(car[Movement].rotation)
        # speed
        car[Velocity].x = car[Movement].speed * heading_x
        car[Velocity].y = car[Movement].speed * heading_y

        # update car position based on current velocity
        car[Position].x += car[Velocity].x
        car[Position].y += car[Velocity].y
Ejemplo n.º 6
0
 def arithmetic_multiple_fn(g: Plotter):
     ms_in_sec = int(1e3)
     env = g.entity(components=[Clock])
     car = g.entity(
         components=[
             Position,
             Velocity,
         ]
     )
     tick_ms = env[Clock].tick_ms
     neg_xps = -car[Velocity].x
     x_delta = neg_xps * (tick_ms * ms_in_sec)
     car[Position].x = x_delta + car[Position].x
Ejemplo n.º 7
0
def test_graph_plotter_preserve_code_order():
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position, Velocity], entity_id=1),
        ],
        component_defs=[Position, Velocity],
    )

    car = g.entity(components=[Position, Velocity])
    # since Velocity.x is used first, it should appear in inputs before Position.x
    car_velocity_x = car[Velocity].x
    car_pos_x = car[Position].x
    # since Velocity.x is modified last, it should appear in outputs after Position.x
    car[Position].x = 3
    car[Velocity].x = 2

    position_x = car[Position].x
    velocity_x = car[Velocity].x

    assert (g.graph().yaml == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=car.id,
                component=Velocity.name,
                attribute="x",
            )),
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=car.id,
                component=Position.name,
                attribute="x",
            )),
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Position.name,
                    attribute="x",
                ),
                to_node=position_x.node,
            ),
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Velocity.name,
                    attribute="x",
                ),
                to_node=velocity_x.node,
            ),
        ],
    ).yaml)
Ejemplo n.º 8
0
def test_transform_build_graph():
    def convert_fn(g: Plotter):
        pass

    def convert_fn_with_long_name(g: Plotter):
        pass

    identity = lambda fn: fn

    @identity
    def convert_fn_with_annotation(g: Plotter):
        pass

    # convert fn test cases
    convert_fns = [
        convert_fn,
        convert_fn_with_long_name,
        convert_fn_with_annotation,
    ]

    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
    ]

    for convert_fn in convert_fns:
        ast = parse_ast(convert_fn)
        for analyzer in req_analyzers:
            ast = analyzer(ast)
        trans_ast = transform_build_graph(ast)

        # try running the transformed function renamed to 'build_graph'
        mod = load_ast_module(trans_ast)
        mod.build_graph(Plotter(()))
Ejemplo n.º 9
0
    def init_sim(g: Plotter):
        controls = g.entity(components=[Keyboard])
        controls[Keyboard].left = False
        controls[Keyboard].right = False
        controls[Keyboard].up = False
        controls[Keyboard].down = False

        car = g.entity(components=[Movement, Velocity, Position, Meta])
        car[Meta].name = "beetle"
        car[Meta].id = 512
        car[Meta].version = 2
        car[Movement].speed = 0.0
        car[Movement].rotation = 90.0
        car[Velocity].x = 0.0
        car[Velocity].y = 0.0
        car[Position].x = 0.0
        car[Position].y = 0.0
Ejemplo n.º 10
0
 def ifelse_fn(g: Plotter):
     car = g.entity(
         components=[
             Position,
             Speed,
             Velocity,
         ]
     )
     env = g.entity(components=[Clock])
     if env[Clock].tick_ms > 2000:
         car[Position].x += g.min(car[Speed].max_x, 2 * car[Velocity].x)
         car[Position].y = car[Position].x + 2
     elif env[Clock].tick_ms > 5000:
         car[Position].x += g.min(car[Speed].max_x, 5 * car[Velocity].x)
         car[Position].y = car[Position].x + 10
     else:
         car[Position].x = g.min(car[Speed].max_x, 1 * car[Velocity].x)
         car[Position].y = car[Position].x - 5
Ejemplo n.º 11
0
    def control_sys(g: Plotter):
        controls = g.entity(components=[Keyboard])
        car = g.entity(components=[Movement, Velocity, Position, Meta])
        acceleration, max_speed, steer_rate = 5.0, 18.0, 10.0

        # steer car
        if controls[Keyboard].left:
            car[Movement].rotation -= steer_rate
            controls[Keyboard].left = False
        elif controls[Keyboard].right:
            car[Movement].rotation += steer_rate
            controls[Keyboard].right = False

        # accelerate/slow down car
        if controls[Keyboard].up:
            car[Movement].speed = g.min(car[Movement].speed + acceleration,
                                        max_speed)
            controls[Keyboard].up = False
        elif controls[Keyboard].down:
            car[Movement].speed = g.max(car[Movement].speed - acceleration,
                                        0.0)
            controls[Keyboard].down = False
Ejemplo n.º 12
0
def test_transform_ternary():
    def ternary_fn(g: Plotter):
        int_ternary = 1 if True else 2

    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
    ]
    ast = parse_ast(ternary_fn)
    for analyzer in req_analyzers:
        ast = analyzer(ast)
    trans_ast = transform_build_graph(transform_ternary(ast))

    mod = load_ast_module(trans_ast)
    mock_g = Mock(wraps=Plotter())
    mod.build_graph(mock_g)
    mock_g.switch.assert_called_once_with(
        condition=True,
        true=1,
        false=2,
    )
Ejemplo n.º 13
0
def test_graph_plotter_conditional_boolean():
    g = Plotter(
        entity_defs=[
            EntityDef(components=[Position], entity_id=1),
            EntityDef(components=[Keyboard], entity_id=2),
        ],
        component_defs=[Position, Keyboard],
    )
    env = g.entity(components=[Keyboard])
    car = g.entity(components=[Position])

    key_pressed = env[Keyboard].pressed
    car_pos_x = g.switch(
        condition=key_pressed == "left",
        true=-1.0,
        false=g.switch(
            condition=key_pressed == "right",
            true=1.0,
            false=0.0,
        ),
    )
    car[Position].x = car_pos_x
    assert g.graph() == Graph(
        inputs=[
            Node.Retrieve(retrieve_attr=AttributeRef(
                entity_id=env.id,
                component=Keyboard.name,
                attribute="pressed",
            ))
        ],
        outputs=[
            Node.Mutate(
                mutate_attr=AttributeRef(
                    entity_id=car.id,
                    component=Position.name,
                    attribute="x",
                ),
                # check that mutation node recorded assignment correctly
                to_node=car_pos_x.node,
            ),
        ],
    )
Ejemplo n.º 14
0
 def arithmetic_fn(g: Plotter):
     car = g.entity(components=[Position])
     x_delta = 20
     car[Position].x += x_delta
Ejemplo n.º 15
0
def compile_graph(
    convert_fn: ConvertFn,
    entity_defs=List[EntityDef],
    component_defs=List[ComponentDef],
    preprocessors: List[Transform] = [
        preprocess_augassign,
    ],
    analyzers: List[Analyzer] = [
        analyze_parent,
        analyze_func,
        analyze_convert_fn,
        analyze_symbol,
        analyze_assign,
        resolve_symbol,
        analyze_block,
        analyze_activity,
    ],
    linters: List[Linter] = [],
    transforms: List[Transform] = [
        transform_build_graph,
        transform_ternary,
        transform_ifelse,
    ],
) -> Graph:
    """Compiles the given `convert_fn` into a computation Graph running the given sim.

    Globals can be used read only in the `convert_fn`. Writing to globals is not supported.

    Compiles by converting the given `convert_fn` function to AST
    applying the given `preprocessors` transforms to perform preprocessing on the AST,
    applying given `analyzers` on the AST to perform static analysis,
    linting the AST with the given `linters` to perform static checks,
    applying the given `transforms` to transform the AST to a function that
    plots the computational graph when run.

    Note:
        Even though both `preprocessors` and `transforms` are comprised of  a list of `Transform`s
        `preprocessors` transforms are applied before any static analysis is done while
        `transforms` are applied after static analysis. This allows `preprocessors` to focus
        on transforming the AST to make static analysis easier while `transforms` to focus on
        transforming the AST to plot a computation graph.

    The transformed AST is converted back to source where it can be imported
    to provide a compiled function that builds the graph using the given `Plotter` on call.
    The graph obtained from the `Plotter` is finally returned.

        Example:
        def car_pos_fn(g: Plotter):
            car = g.entity(
                components=[
                    "position",
                ]
            )
            env = g.entity(components=["clock"])
            x_delta = 20 if env["clock"].tick_ms > 2000 else 10
            car["position"].x = x_delta

        car_pos_graph = compile_graph(car_pos_fn, entity_defs, component_defs)
        # use compiled graph 'car_pos_graph' in code ...

    Args:
        convert_fn: Function containing the code that should be compiled into a computational graph.
            The target `convert_fn` should take in one parameter: a `Plotter` instance which
            allows users to access graphing specific operations. Must be a plain Python
            Function, not a Callable class, method, classmethod or staticmethod.

        entity_defs: List of EntityDef representing the ECS entities available for
            use in `convert_fn` via the Plotter instance.

        component_defs: List of ComponentDef representing the ECS component types
            available for use in `convert_fn` via the Plotter instance.

        preprocessors: List of `Transform`s that are run sequentially to apply
            preprocesssing transforms to the AST before any static analysis is done.
            Typically these AST transforms make static analysis easier by simplifying the AST.

        analyzers: List of `Analyzer`s that are run sequentially on the AST perform
            static analysis.  Analyzers can add attributes to AST nodes but not
            modify the AST tree.

        linters: List of `Linter`s that are run sequentially on the AST to perform
            static checks on the convertability of the AST. `Linter`s are expected
            to throw exception when failing a check.

        transforms: List of `Transform`s that are run sequentially to transform the
            AST to a compiled function (in AST form) that builds the computation
            graph when called.

    Returns:
        The converted computational Graph as a `Graph`.
    """

    # parse ast from function source
    ast = parse_ast(convert_fn)

    # apply preprocessors to apply preprocesssing transforms on the AST
    for preprocessor in preprocessors:
        ast = preprocessor(ast)
    # apply analyzers to conduct static analysis
    for analyzer in analyzers:
        ast = analyzer(ast)
    # check that AST can be coverted by applying linters to check convertability
    for linter in linters:
        linter(ast)
    # convert AST to computation graph by applying transforms
    for transform in transforms:
        ast = transform(ast)

    # load AST back as a module
    compiled, src_path = load_ast_module(ast, remove_src=False)
    # allow the use of globals symbols with respect to convert_fn function
    # to be used during graph plotting
    compiled.build_graph.__globals__.update(
        convert_fn.__globals__)  # type: ignore

    # run build graph function with plotter to build final computation graph
    g = Plotter(entity_defs, component_defs)
    try:
        compiled.build_graph(g)
    except Exception as e:
        print(f"Compilation generated source code with errors: {src_path}")
        raise e

    # remove the intermediate source file generated by load_ast_module()
    os.remove(src_path)
    return g.graph()
Ejemplo n.º 16
0
def init_fn(g: Plotter):
    car = g.entity(components=[Position, Speed])
    car[Position].x = 50
    car[Position].y = 25
    car[Speed].x = 1
    car[Speed].y = 2
Ejemplo n.º 17
0
def test_graph_plotter_empty():
    g = Plotter(entity_defs=[], component_defs=[])
    assert g.graph() == Graph()
Ejemplo n.º 18
0
def sys_fn(g: Plotter):
    car = g.entity(components=[Position, Speed])
    car[Position].x = 2 * car[Speed].x_neg
Ejemplo n.º 19
0
def test_transform_ifelse():
    def if_fn(g: Plotter):
        x, w = "str1", "str2"
        if True:
            x = w

    def ifelse_fn(g: Plotter):
        w, y = "str1", "str2"
        if True:
            x = w
            z = 1
        else:
            x = y
            z = 2

    def ifelse_elif_else_fn(g: Plotter):
        y, m, n = "str1", "str2", "str3"
        if True:
            x = y
            z = 1
        elif False:
            x = m
            z = 2
        else:
            x = n
            z = 3

    def ifelse_augassign_fn(g: Plotter):
        x = 1
        if True:
            x = x + 1
        else:
            x = x + 2

    def if_assign_condition_fn(g: Plotter):
        class A:
            b = True

        # test that the condition is evaluated immediately => True
        if A.b:
            A.b = True
        else:
            A.b = False

    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
        analyze_assign,
        analyze_symbol,
        resolve_symbol,
        analyze_block,
        analyze_activity,
    ]

    # test case plotter => expected g.switch() call args
    g = Plotter()
    ifelse_fns = [
        (
            if_fn,
            [
                {
                    "condition": True,
                    "true": "str2",
                    "false": "str1"
                },
            ],
        ),
        (
            ifelse_fn,
            [
                {
                    "condition": True,
                    "true": "str1",
                    "false": "str2"
                },
                {
                    "condition": True,
                    "true": 1,
                    "false": 2
                },
            ],
        ),
        (
            ifelse_elif_else_fn,
            [
                {
                    "condition": True,
                    "true": "str1",
                    "false": g.switch(False, "str2", "str3"),
                },
                {
                    "condition": True,
                    "true": 1,
                    "false": g.switch(False, 2, 3)
                },
                {
                    "condition": False,
                    "true": "str2",
                    "false": "str3"
                },
                {
                    "condition": False,
                    "true": 2,
                    "false": 3
                },
            ],
        ),
        (
            ifelse_augassign_fn,
            [
                {
                    "condition": True,
                    "true": 2,
                    "false": 3
                },
            ],
        ),
        (
            if_assign_condition_fn,
            [
                {
                    "condition": True,
                    "true": True,
                    "false": False
                },
            ],
        ),
    ]

    for fn, expected_switch_args in ifelse_fns:
        ast = parse_ast(fn)
        for analyzer in req_analyzers:
            ast = analyzer(ast)
        trans_ast = transform_build_graph(transform_ifelse(ast))

        mod = load_ast_module(trans_ast)
        mock_g = Mock(wraps=Plotter())
        mod.build_graph(mock_g)

        for expected_arg in expected_switch_args:
            mock_g.switch.assert_any_call(**expected_arg)