def test_integration_with_cartpole(): env = envs.CartPole() agent = agents.DeterministicMCTSAgent( action_space=env.action_space, n_passes=2, ) episode = testing.run_with_dummy_network(agent.solve(env)) assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_integration_with_cartpole(): env = envs.CartPole() agent = agents.DeterministicMCTSAgent(n_passes=2) network_sig = agent.network_signature(env.observation_space, env.action_space) episode = testing.run_with_dummy_network_prediction( agent.solve(env), network_sig) assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_integration_with_cartpole(): # Set up env = envs.CartPole() agent = agents.ShootingAgent(n_rollouts=1) # Run episode = testing.run_with_dummy_network_response(agent.solve(env)) # Test assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_integration_with_cartpole(): # Set up env = envs.CartPole() agent = agents.ShootingAgent(action_space=env.action_space, n_rollouts=1) # Run episode = testing.run_without_suspensions(agent.solve(env)) # Test assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_integration_with_cartpole(): env = envs.CartPole() agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=functools.partial( agents.stochastic_mcts.RolloutNewLeafRater, rollout_time_limit=2, ), ) episode = testing.run_without_suspensions(agent.solve(env)) assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_integration_with_cartpole(graph_mode): env = envs.CartPole() agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=functools.partial( agents.stochastic_mcts.rate_new_leaves_with_rollouts, rollout_time_limit=2, ), graph_mode=graph_mode, ) episode = testing.run_without_suspensions(agent.solve(env)) assert episode.transition_batch.observation.shape[0] # pylint: disable=no-member
def test_act_doesnt_change_env_state(): # Set up env = envs.CartPole() agent = agents.ShootingAgent(n_rollouts=10) observation = env.reset() testing.run_with_dummy_network_response(agent.reset(env, observation)) # Run state_before = env.clone_state() testing.run_without_suspensions(agent.act(observation)) state_after = env.clone_state() # Test np.testing.assert_equal(state_before, state_after)
def test_batch_steppers_different_requests_single_batch(batch_stepper_cls): xparams = 'params' n_envs = 9 n_requests = 100 agent_network_fn = functools.partial( _TestPredictNetwork, request_type=RequestType.AGENT_PREDICTION) model_network_fn = functools.partial( _TestPredictNetwork, request_type=RequestType.MODEL_PREDICTION) bs = batch_stepper_cls( env_class=envs.CartPole, agent_class=RandomAgent, # This is only placeholder. network_fn=agent_network_fn, # This is only placeholder. model_network_fn=model_network_fn, # This is only placeholder. n_envs=n_envs, output_dir=None, ) envs_and_agents = [ (envs.CartPole(), _TestAgentYieldingNetworkRequest(agent_network_fn, xparams, n_requests)) for _ in range(n_envs // 3) ] envs_and_agents += [(envs.CartPole(), _TestAgentYieldingAgentPredictRequest(n_requests)) for _ in range(n_envs // 3)] envs_and_agents += [(envs.CartPole(), _TestAgentYieldingModelPredictRequest(n_requests)) for _ in range(n_envs // 3)] # Envs and agents have to be replaced here manually, because # LocalBatchStepper does not support different types of agents to be run # simultaneously. bs._envs_and_agents = envs_and_agents # pylint: disable=protected-access bs.run_episode_batch(xparams)
def test_act_doesnt_change_env_state(graph_mode, rate_new_leaves_fn): env = envs.CartPole() agent = agents.StochasticMCTSAgent( action_space=env.action_space, n_passes=2, rate_new_leaves_fn=rate_new_leaves_fn, graph_mode=graph_mode, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) state_before = env.clone_state() testing.run_with_dummy_network(agent.act(observation)) state_after = env.clone_state() np.testing.assert_equal(state_before, state_after)
def test_act_doesnt_change_env_state(new_leaf_rater_class): env = envs.CartPole() agent = agents.StochasticMCTSAgent( n_passes=2, new_leaf_rater_class=new_leaf_rater_class, ) observation = env.reset() testing.run_without_suspensions(agent.reset(env, observation)) state_before = env.clone_state() network_sig = agent.network_signature(env.observation_space, env.action_space) testing.run_with_dummy_network_prediction(agent.act(observation), network_sig) state_after = env.clone_state() np.testing.assert_equal(state_before, state_after)