def sim_fn(g: Plotter): env = g.entity(components=[Action, State]) car = g.entity(components=[Velocity, Position]) # process car physics # compute velocity based on acceleration action & decceleration due to gravity acceleration, gravity, max_speed = 0.001, 0.0025, 0.07 # apply acceleration based on accelerate action: # 0: Accelerate to the Left # 1: Don't accelerate # 2: Accelerate to the Right car[Velocity].x += (env[Action].accelerate - 1) * acceleration # apply gravity inverse to the mountain path used by the car # the mountain is defined by y = sin(3*x) # as such we apply gravity inversely using y = cos(3*x) # apply negative gravity as gravity works in the opposite direction of movement car[Velocity].x += g.cos(3 * car[Position].x) * (-gravity) car[Velocity].x = g.clip(car[Velocity].x, min_x=-max_speed, max_x=max_speed) # compute new position from current velocity min_position, max_position = -1.2, 0.6 car[Position].x += car[Velocity].x car[Position].x = g.clip(car[Position].x, min_position, max_position) # collision: stop car when colliding with min_position if car[Position].x <= min_position: car[Velocity].x = 0.0 # resolve simulation state: reward and simulation completition env[State].reward = 0 if car[Position].x >= 0.5 else -1 env[State].ended = True if car[Position].x > 0.5 else False
def test_graph_plotter_retrieve_mutate_op(): entity_id = 1 g = Plotter( entity_defs=[ EntityDef(components=[Position], entity_id=1), ], component_defs=[Position], ) person = g.entity(components=[Position]) pos_x = person[Position].x person[Position].y = pos_x # check retrieve and mutate nodes correctly are set as graph inputs and outputs assert g.graph() == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=entity_id, component=Position.name, attribute="x", )) ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=entity_id, component=Position.name, attribute="y", ), # check that mutation node recorded assignment correctly to_node=pos_x.node, ), ], )
def ternary_fn(g: Plotter): car = g.entity( components=[ Position, ] ) env = g.entity(components=[Clock]) x_delta = 20 if env[Clock].tick_ms > 2000 else 10 car[Position].x = x_delta
def init_fn(g: Plotter): car = g.entity(components=[Velocity, Position]) car[Velocity].x = 0.0 car[Position].x = g.random(-0.6, -0.4) env = g.entity(components=[Action, State]) env[State].reward = 0 env[State].ended = False env[Action].accelerate = 1
def physics_sys(g: Plotter): # compute velocity from car's rotation and speed car = g.entity(components=[Movement, Velocity, Position, Meta]) # rotation heading_x, heading_y = g.cos( car[Movement].rotation), -g.sin(car[Movement].rotation) # speed car[Velocity].x = car[Movement].speed * heading_x car[Velocity].y = car[Movement].speed * heading_y # update car position based on current velocity car[Position].x += car[Velocity].x car[Position].y += car[Velocity].y
def arithmetic_multiple_fn(g: Plotter): ms_in_sec = int(1e3) env = g.entity(components=[Clock]) car = g.entity( components=[ Position, Velocity, ] ) tick_ms = env[Clock].tick_ms neg_xps = -car[Velocity].x x_delta = neg_xps * (tick_ms * ms_in_sec) car[Position].x = x_delta + car[Position].x
def test_graph_plotter_preserve_code_order(): g = Plotter( entity_defs=[ EntityDef(components=[Position, Velocity], entity_id=1), ], component_defs=[Position, Velocity], ) car = g.entity(components=[Position, Velocity]) # since Velocity.x is used first, it should appear in inputs before Position.x car_velocity_x = car[Velocity].x car_pos_x = car[Position].x # since Velocity.x is modified last, it should appear in outputs after Position.x car[Position].x = 3 car[Velocity].x = 2 position_x = car[Position].x velocity_x = car[Velocity].x assert (g.graph().yaml == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=car.id, component=Velocity.name, attribute="x", )), Node.Retrieve(retrieve_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", )), ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", ), to_node=position_x.node, ), Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Velocity.name, attribute="x", ), to_node=velocity_x.node, ), ], ).yaml)
def test_transform_build_graph(): def convert_fn(g: Plotter): pass def convert_fn_with_long_name(g: Plotter): pass identity = lambda fn: fn @identity def convert_fn_with_annotation(g: Plotter): pass # convert fn test cases convert_fns = [ convert_fn, convert_fn_with_long_name, convert_fn_with_annotation, ] req_analyzers = [ analyze_func, analyze_convert_fn, ] for convert_fn in convert_fns: ast = parse_ast(convert_fn) for analyzer in req_analyzers: ast = analyzer(ast) trans_ast = transform_build_graph(ast) # try running the transformed function renamed to 'build_graph' mod = load_ast_module(trans_ast) mod.build_graph(Plotter(()))
def init_sim(g: Plotter): controls = g.entity(components=[Keyboard]) controls[Keyboard].left = False controls[Keyboard].right = False controls[Keyboard].up = False controls[Keyboard].down = False car = g.entity(components=[Movement, Velocity, Position, Meta]) car[Meta].name = "beetle" car[Meta].id = 512 car[Meta].version = 2 car[Movement].speed = 0.0 car[Movement].rotation = 90.0 car[Velocity].x = 0.0 car[Velocity].y = 0.0 car[Position].x = 0.0 car[Position].y = 0.0
def ifelse_fn(g: Plotter): car = g.entity( components=[ Position, Speed, Velocity, ] ) env = g.entity(components=[Clock]) if env[Clock].tick_ms > 2000: car[Position].x += g.min(car[Speed].max_x, 2 * car[Velocity].x) car[Position].y = car[Position].x + 2 elif env[Clock].tick_ms > 5000: car[Position].x += g.min(car[Speed].max_x, 5 * car[Velocity].x) car[Position].y = car[Position].x + 10 else: car[Position].x = g.min(car[Speed].max_x, 1 * car[Velocity].x) car[Position].y = car[Position].x - 5
def control_sys(g: Plotter): controls = g.entity(components=[Keyboard]) car = g.entity(components=[Movement, Velocity, Position, Meta]) acceleration, max_speed, steer_rate = 5.0, 18.0, 10.0 # steer car if controls[Keyboard].left: car[Movement].rotation -= steer_rate controls[Keyboard].left = False elif controls[Keyboard].right: car[Movement].rotation += steer_rate controls[Keyboard].right = False # accelerate/slow down car if controls[Keyboard].up: car[Movement].speed = g.min(car[Movement].speed + acceleration, max_speed) controls[Keyboard].up = False elif controls[Keyboard].down: car[Movement].speed = g.max(car[Movement].speed - acceleration, 0.0) controls[Keyboard].down = False
def test_transform_ternary(): def ternary_fn(g: Plotter): int_ternary = 1 if True else 2 req_analyzers = [ analyze_func, analyze_convert_fn, ] ast = parse_ast(ternary_fn) for analyzer in req_analyzers: ast = analyzer(ast) trans_ast = transform_build_graph(transform_ternary(ast)) mod = load_ast_module(trans_ast) mock_g = Mock(wraps=Plotter()) mod.build_graph(mock_g) mock_g.switch.assert_called_once_with( condition=True, true=1, false=2, )
def test_graph_plotter_conditional_boolean(): g = Plotter( entity_defs=[ EntityDef(components=[Position], entity_id=1), EntityDef(components=[Keyboard], entity_id=2), ], component_defs=[Position, Keyboard], ) env = g.entity(components=[Keyboard]) car = g.entity(components=[Position]) key_pressed = env[Keyboard].pressed car_pos_x = g.switch( condition=key_pressed == "left", true=-1.0, false=g.switch( condition=key_pressed == "right", true=1.0, false=0.0, ), ) car[Position].x = car_pos_x assert g.graph() == Graph( inputs=[ Node.Retrieve(retrieve_attr=AttributeRef( entity_id=env.id, component=Keyboard.name, attribute="pressed", )) ], outputs=[ Node.Mutate( mutate_attr=AttributeRef( entity_id=car.id, component=Position.name, attribute="x", ), # check that mutation node recorded assignment correctly to_node=car_pos_x.node, ), ], )
def arithmetic_fn(g: Plotter): car = g.entity(components=[Position]) x_delta = 20 car[Position].x += x_delta
def compile_graph( convert_fn: ConvertFn, entity_defs=List[EntityDef], component_defs=List[ComponentDef], preprocessors: List[Transform] = [ preprocess_augassign, ], analyzers: List[Analyzer] = [ analyze_parent, analyze_func, analyze_convert_fn, analyze_symbol, analyze_assign, resolve_symbol, analyze_block, analyze_activity, ], linters: List[Linter] = [], transforms: List[Transform] = [ transform_build_graph, transform_ternary, transform_ifelse, ], ) -> Graph: """Compiles the given `convert_fn` into a computation Graph running the given sim. Globals can be used read only in the `convert_fn`. Writing to globals is not supported. Compiles by converting the given `convert_fn` function to AST applying the given `preprocessors` transforms to perform preprocessing on the AST, applying given `analyzers` on the AST to perform static analysis, linting the AST with the given `linters` to perform static checks, applying the given `transforms` to transform the AST to a function that plots the computational graph when run. Note: Even though both `preprocessors` and `transforms` are comprised of a list of `Transform`s `preprocessors` transforms are applied before any static analysis is done while `transforms` are applied after static analysis. This allows `preprocessors` to focus on transforming the AST to make static analysis easier while `transforms` to focus on transforming the AST to plot a computation graph. The transformed AST is converted back to source where it can be imported to provide a compiled function that builds the graph using the given `Plotter` on call. The graph obtained from the `Plotter` is finally returned. Example: def car_pos_fn(g: Plotter): car = g.entity( components=[ "position", ] ) env = g.entity(components=["clock"]) x_delta = 20 if env["clock"].tick_ms > 2000 else 10 car["position"].x = x_delta car_pos_graph = compile_graph(car_pos_fn, entity_defs, component_defs) # use compiled graph 'car_pos_graph' in code ... Args: convert_fn: Function containing the code that should be compiled into a computational graph. The target `convert_fn` should take in one parameter: a `Plotter` instance which allows users to access graphing specific operations. Must be a plain Python Function, not a Callable class, method, classmethod or staticmethod. entity_defs: List of EntityDef representing the ECS entities available for use in `convert_fn` via the Plotter instance. component_defs: List of ComponentDef representing the ECS component types available for use in `convert_fn` via the Plotter instance. preprocessors: List of `Transform`s that are run sequentially to apply preprocesssing transforms to the AST before any static analysis is done. Typically these AST transforms make static analysis easier by simplifying the AST. analyzers: List of `Analyzer`s that are run sequentially on the AST perform static analysis. Analyzers can add attributes to AST nodes but not modify the AST tree. linters: List of `Linter`s that are run sequentially on the AST to perform static checks on the convertability of the AST. `Linter`s are expected to throw exception when failing a check. transforms: List of `Transform`s that are run sequentially to transform the AST to a compiled function (in AST form) that builds the computation graph when called. Returns: The converted computational Graph as a `Graph`. """ # parse ast from function source ast = parse_ast(convert_fn) # apply preprocessors to apply preprocesssing transforms on the AST for preprocessor in preprocessors: ast = preprocessor(ast) # apply analyzers to conduct static analysis for analyzer in analyzers: ast = analyzer(ast) # check that AST can be coverted by applying linters to check convertability for linter in linters: linter(ast) # convert AST to computation graph by applying transforms for transform in transforms: ast = transform(ast) # load AST back as a module compiled, src_path = load_ast_module(ast, remove_src=False) # allow the use of globals symbols with respect to convert_fn function # to be used during graph plotting compiled.build_graph.__globals__.update( convert_fn.__globals__) # type: ignore # run build graph function with plotter to build final computation graph g = Plotter(entity_defs, component_defs) try: compiled.build_graph(g) except Exception as e: print(f"Compilation generated source code with errors: {src_path}") raise e # remove the intermediate source file generated by load_ast_module() os.remove(src_path) return g.graph()
def init_fn(g: Plotter): car = g.entity(components=[Position, Speed]) car[Position].x = 50 car[Position].y = 25 car[Speed].x = 1 car[Speed].y = 2
def test_graph_plotter_empty(): g = Plotter(entity_defs=[], component_defs=[]) assert g.graph() == Graph()
def sys_fn(g: Plotter): car = g.entity(components=[Position, Speed]) car[Position].x = 2 * car[Speed].x_neg
def test_transform_ifelse(): def if_fn(g: Plotter): x, w = "str1", "str2" if True: x = w def ifelse_fn(g: Plotter): w, y = "str1", "str2" if True: x = w z = 1 else: x = y z = 2 def ifelse_elif_else_fn(g: Plotter): y, m, n = "str1", "str2", "str3" if True: x = y z = 1 elif False: x = m z = 2 else: x = n z = 3 def ifelse_augassign_fn(g: Plotter): x = 1 if True: x = x + 1 else: x = x + 2 def if_assign_condition_fn(g: Plotter): class A: b = True # test that the condition is evaluated immediately => True if A.b: A.b = True else: A.b = False req_analyzers = [ analyze_func, analyze_convert_fn, analyze_assign, analyze_symbol, resolve_symbol, analyze_block, analyze_activity, ] # test case plotter => expected g.switch() call args g = Plotter() ifelse_fns = [ ( if_fn, [ { "condition": True, "true": "str2", "false": "str1" }, ], ), ( ifelse_fn, [ { "condition": True, "true": "str1", "false": "str2" }, { "condition": True, "true": 1, "false": 2 }, ], ), ( ifelse_elif_else_fn, [ { "condition": True, "true": "str1", "false": g.switch(False, "str2", "str3"), }, { "condition": True, "true": 1, "false": g.switch(False, 2, 3) }, { "condition": False, "true": "str2", "false": "str3" }, { "condition": False, "true": 2, "false": 3 }, ], ), ( ifelse_augassign_fn, [ { "condition": True, "true": 2, "false": 3 }, ], ), ( if_assign_condition_fn, [ { "condition": True, "true": True, "false": False }, ], ), ] for fn, expected_switch_args in ifelse_fns: ast = parse_ast(fn) for analyzer in req_analyzers: ast = analyzer(ast) trans_ast = transform_build_graph(transform_ifelse(ast)) mod = load_ast_module(trans_ast) mock_g = Mock(wraps=Plotter()) mod.build_graph(mock_g) for expected_arg in expected_switch_args: mock_g.switch.assert_any_call(**expected_arg)