Beispiel #1
0
    def test_mcts(self):
        # Create a fake environment to test with.
        num_actions = 5
        environment = fakes.DiscreteEnvironment(num_actions=num_actions,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        network = snt.Sequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50]),
            networks.PolicyValueHead(spec.actions.num_values),
        ])
        model = simulator.Simulator(environment)
        optimizer = snt.optimizers.Adam(1e-3)

        # Construct the agent.
        agent = mcts.MCTS(environment_spec=spec,
                          network=network,
                          model=model,
                          optimizer=optimizer,
                          n_step=1,
                          discount=1.,
                          replay_capacity=100,
                          num_simulations=10,
                          batch_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Beispiel #2
0
    def test_catch(self, policy_type: Text):
        env = catch.Catch(rows=2, seed=1)
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)
        eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.)

        timestep = env.reset()
        model.reset()

        search_policy = search.bfs if policy_type == 'bfs' else search.puct

        root = search.mcts(observation=timestep.observation,
                           model=model,
                           search_policy=search_policy,
                           evaluation=eval_fn,
                           num_simulations=100,
                           num_actions=num_actions)

        values = np.array([c.value for c in root.children.values()])
        best_action = search.argmax(values)

        if env._paddle_x > env._ball_x:
            self.assertEqual(best_action, 0)
        if env._paddle_x == env._ball_x:
            self.assertEqual(best_action, 1)
        if env._paddle_x < env._ball_x:
            self.assertEqual(best_action, 2)
Beispiel #3
0
    def test_checkpointing(self):
        """Tests whether checkpointing restores the state correctly."""
        # Given an environment, and a model based on this environment.
        model = simulator.Simulator(catch.Catch())
        num_actions = model.action_spec().num_values

        model.reset()

        # Now, we save a checkpoint.
        model.save_checkpoint()

        ts = model.step(1)

        # Step the model once and load the checkpoint.
        timestep = model.step(np.random.randint(num_actions))
        model.load_checkpoint()
        self._check_equal(ts, model.step(1))

        while not timestep.last():
            timestep = model.step(np.random.randint(num_actions))

        # The model should require a reset.
        self.assertTrue(model.needs_reset)

        # Once we load checkpoint, the model should no longer require reset.
        model.load_checkpoint()
        self.assertFalse(model.needs_reset)

        # Further steps should agree with the original environment state.
        self._check_equal(ts, model.step(1))
Beispiel #4
0
    def test_simulator_fidelity(self):
        """Tests whether the simulator match the ground truth."""
        env = catch.Catch()
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)

        for _ in range(10):
            true_timestep = env.reset()
            model_timestep = model.reset()
            self._check_equal(true_timestep, model_timestep)

            while not true_timestep.last():
                action = np.random.randint(num_actions)
                true_timestep = env.step(action)
                model_timestep = model.step(action)
                self._check_equal(true_timestep, model_timestep)
Beispiel #5
0
    def test_checkpointing(self):
        env = catch.Catch()
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)

        env.reset()
        model.reset()

        env.step(2)
        timestep = model.step(2)
        model.save_checkpoint()
        while not timestep.last():
            timestep = model.step(np.random.randint(num_actions))
        model.load_checkpoint()

        true_timestep = env.step(2)
        model_timestep = model.step(2)
        self._check_equal(true_timestep, model_timestep)
Beispiel #6
0
def make_env_and_model(
        bsuite_id: str, results_dir: str,
        overwrite: bool) -> Tuple[dm_env.Environment, models.Model]:
    """Create environment and corresponding model (learned or simulator)."""
    raw_env = bsuite.load_from_id(bsuite_id)
    if FLAGS.simulator:
        model = simulator.Simulator(raw_env)  # pytype: disable=attribute-error
    else:
        model = mlp.MLPModel(
            specs.make_environment_spec(raw_env),
            replay_capacity=1000,
            batch_size=16,
            hidden_sizes=(50, ),
        )
    environment = csv_logging.wrap_environment(raw_env, bsuite_id, results_dir,
                                               overwrite)
    environment = wrappers.SinglePrecisionWrapper(environment)

    return environment, model
Beispiel #7
0
    def test_simulator_fidelity(self):
        """Tests whether the simulator match the ground truth."""

        # Given an environment.
        env = catch.Catch()

        # If we instantiate a simulator 'model' of this environment.
        model = simulator.Simulator(env)

        # Then the model and environment should always agree as we step them.
        num_actions = env.action_spec().num_values
        for _ in range(10):
            true_timestep = env.reset()
            self.assertTrue(model.needs_reset)
            model_timestep = model.reset()
            self.assertFalse(model.needs_reset)
            self._check_equal(true_timestep, model_timestep)

            while not true_timestep.last():
                action = np.random.randint(num_actions)
                true_timestep = env.step(action)
                model_timestep = model.step(action)
                self._check_equal(true_timestep, model_timestep)