def test_subgraph_components(self): return # TODO fix when we have built selective subgraph fetching correctly. # Create agent. agent_config = config_from_path("configs/ray_apex_for_pong.json") agent_config["execution_spec"].pop("ray_spec") environment = OpenAIGymEnv("Pong-v0", frameskip=4) # Do not build yet. agent = ApexAgent.from_spec(agent_config, state_space=environment.state_space, action_space=environment.action_space, auto_build=False) # Prepare all steps until build device strategy so we can test subgraph fetching. agent.graph_executor.init_execution() agent.graph_executor.setup_graph() # Meta graph must be built for sub-graph tracing. agent.graph_builder.build_meta_graph(agent.input_spaces) sub_graph = agent.graph_builder.get_subgraph( "update_from_external_batch") print("Sub graph components:") print(sub_graph.sub_components) print("Sub graph API: ") print(sub_graph.api_methods)
def test_apex_compilation(self): """ Tests agent compilation without Ray to ease debugging on Windows. """ agent_config = config_from_path("configs/ray_apex_for_pong.json") agent_config["execution_spec"].pop("ray_spec") environment = OpenAIGymEnv("Pong-v0", frameskip=4) agent = ApexAgent.from_spec(agent_config, state_space=environment.state_space, action_space=environment.action_space) print("Compiled {}".format(agent))
def test_multi_gpu_apex_agent_compilation(self): """ Tests if the multi gpu strategy can compile successfully on a multi gpu system, but also runs on a CPU-only system using fake-GPU logic for testing purposes. """ root_logger.setLevel(DEBUG) agent_config = config_from_path("configs/multi_gpu_ray_apex_for_pong.json") agent_config["execution_spec"].pop("ray_spec") environment = OpenAIGymEnv("Pong-v0", frameskip=4) agent = ApexAgent.from_spec( agent_config, state_space=environment.state_space, action_space=environment.action_space ) print("Compiled Apex agent")
def test_multi_gpu_apex_agent_compilation(self): """ Tests if the multi gpu strategy can compile successfully on a multi gpu system. THIS TEST REQUIRES A MULTI GPU SYSTEM. """ root_logger.setLevel(DEBUG) agent_config = config_from_path("configs/multi_gpu_ray_apex_for_pong.json") agent_config["execution_spec"].pop("ray_spec") environment = OpenAIGymEnv("Pong-v0", frameskip=4) agent = ApexAgent.from_spec( agent_config, state_space=environment.state_space, action_space=environment.action_space ) print("Compiled Apex agent")
def test_apex_compilation(self): """ Tests agent compilation without Ray to ease debugging on Windows. """ agent_config = config_from_path("configs/ray_apex_for_pong.json") agent_config["execution_spec"].pop("ray_spec") # TODO remove after unified. if get_backend() == "pytorch": agent_config["memory_spec"]["type"] = "mem_prioritized_replay" environment = OpenAIGymEnv("Pong-v0", frameskip=4) agent = ApexAgent.from_spec(agent_config, state_space=environment.state_space, action_space=environment.action_space) print('Compiled apex agent')
def test_post_processing(self): env = OpenAIGymEnv("Pong-v0", frameskip=4, max_num_noops=30, episodic_life=True) agent_config = config_from_path("configs/ray_apex_for_pong.json") # Test cpu settings for batching here. agent_config["memory_spec"]["type"] = "mem_prioritized_replay" agent_config["execution_spec"]["torch_num_threads"] = 1 agent_config["execution_spec"]["OMP_NUM_THREADS"] = 1 agent = ApexAgent.from_spec( # Uses 2015 DQN parameters as closely as possible. agent_config, state_space=env.state_space, # Try with "reduced" action space (actually only 3 actions, up, down, no-op) action_space=env.action_space ) samples = 200 rewards = np.random.random(size=samples) states = list(agent.preprocessed_state_space.sample(samples)) actions = agent.action_space.sample(samples) terminals = np.zeros(samples, dtype=np.uint8) next_states = states[1:] next_states.extend([agent.preprocessed_state_space.sample(1)]) next_states = np.asarray(next_states) states = np.asarray(states) weights = np.ones_like(rewards) for _ in range(1): start = time.perf_counter() _, loss_per_item = agent.post_process( dict( states=states, actions=actions, rewards=rewards, terminals=terminals, next_states=next_states, importance_weights=weights ) ) print("post process time = {}".format(time.perf_counter() - start)) profile = Component.call_times print_call_chain(profile, False, 0.003)