def test_mcts(self): # Create a fake environment to test with. num_actions = 5 environment = fakes.DiscreteEnvironment(num_actions=num_actions, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) network = snt.Sequential([ snt.Flatten(), snt.nets.MLP([50, 50]), networks.PolicyValueHead(spec.actions.num_values), ]) model = simulator.Simulator(environment) optimizer = snt.optimizers.Adam(1e-3) # Construct the agent. agent = mcts.MCTS(environment_spec=spec, network=network, model=model, optimizer=optimizer, n_step=1, discount=1., replay_capacity=100, num_simulations=10, batch_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def test_catch(self, policy_type: Text): env = catch.Catch(rows=2, seed=1) num_actions = env.action_spec().num_values model = simulator.Simulator(env) eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.) timestep = env.reset() model.reset() search_policy = search.bfs if policy_type == 'bfs' else search.puct root = search.mcts(observation=timestep.observation, model=model, search_policy=search_policy, evaluation=eval_fn, num_simulations=100, num_actions=num_actions) values = np.array([c.value for c in root.children.values()]) best_action = search.argmax(values) if env._paddle_x > env._ball_x: self.assertEqual(best_action, 0) if env._paddle_x == env._ball_x: self.assertEqual(best_action, 1) if env._paddle_x < env._ball_x: self.assertEqual(best_action, 2)
def test_checkpointing(self): """Tests whether checkpointing restores the state correctly.""" # Given an environment, and a model based on this environment. model = simulator.Simulator(catch.Catch()) num_actions = model.action_spec().num_values model.reset() # Now, we save a checkpoint. model.save_checkpoint() ts = model.step(1) # Step the model once and load the checkpoint. timestep = model.step(np.random.randint(num_actions)) model.load_checkpoint() self._check_equal(ts, model.step(1)) while not timestep.last(): timestep = model.step(np.random.randint(num_actions)) # The model should require a reset. self.assertTrue(model.needs_reset) # Once we load checkpoint, the model should no longer require reset. model.load_checkpoint() self.assertFalse(model.needs_reset) # Further steps should agree with the original environment state. self._check_equal(ts, model.step(1))
def test_simulator_fidelity(self): """Tests whether the simulator match the ground truth.""" env = catch.Catch() num_actions = env.action_spec().num_values model = simulator.Simulator(env) for _ in range(10): true_timestep = env.reset() model_timestep = model.reset() self._check_equal(true_timestep, model_timestep) while not true_timestep.last(): action = np.random.randint(num_actions) true_timestep = env.step(action) model_timestep = model.step(action) self._check_equal(true_timestep, model_timestep)
def test_checkpointing(self): env = catch.Catch() num_actions = env.action_spec().num_values model = simulator.Simulator(env) env.reset() model.reset() env.step(2) timestep = model.step(2) model.save_checkpoint() while not timestep.last(): timestep = model.step(np.random.randint(num_actions)) model.load_checkpoint() true_timestep = env.step(2) model_timestep = model.step(2) self._check_equal(true_timestep, model_timestep)
def make_env_and_model( bsuite_id: str, results_dir: str, overwrite: bool) -> Tuple[dm_env.Environment, models.Model]: """Create environment and corresponding model (learned or simulator).""" raw_env = bsuite.load_from_id(bsuite_id) if FLAGS.simulator: model = simulator.Simulator(raw_env) # pytype: disable=attribute-error else: model = mlp.MLPModel( specs.make_environment_spec(raw_env), replay_capacity=1000, batch_size=16, hidden_sizes=(50, ), ) environment = csv_logging.wrap_environment(raw_env, bsuite_id, results_dir, overwrite) environment = wrappers.SinglePrecisionWrapper(environment) return environment, model
def test_simulator_fidelity(self): """Tests whether the simulator match the ground truth.""" # Given an environment. env = catch.Catch() # If we instantiate a simulator 'model' of this environment. model = simulator.Simulator(env) # Then the model and environment should always agree as we step them. num_actions = env.action_spec().num_values for _ in range(10): true_timestep = env.reset() self.assertTrue(model.needs_reset) model_timestep = model.reset() self.assertFalse(model.needs_reset) self._check_equal(true_timestep, model_timestep) while not true_timestep.last(): action = np.random.randint(num_actions) true_timestep = env.step(action) model_timestep = model.step(action) self._check_equal(true_timestep, model_timestep)