Ejemplo n.º 1
0
    def test_catch(self, policy_type: Text):
        env = catch.Catch(rows=2)
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)
        eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.)

        timestep = env.reset()
        model.reset()

        search_policy = search.bfs if policy_type == 'bfs' else search.puct

        root = search.mcts(observation=timestep.observation,
                           model=model,
                           search_policy=search_policy,
                           evaluation=eval_fn,
                           num_simulations=100,
                           num_actions=num_actions)

        values = np.array([c.value for c in root.children.values()])
        best_action = search.argmax(values)

        if env._paddle_x > env._ball_x:
            self.assertEqual(best_action, 0)
        if env._paddle_x == env._ball_x:
            self.assertEqual(best_action, 1)
        if env._paddle_x < env._ball_x:
            self.assertEqual(best_action, 2)
Ejemplo n.º 2
0
  def test_mcts(self):
    # Create a fake environment to test with.
    num_actions = 5
    environment = fakes.DiscreteEnvironment(
        num_actions=num_actions,
        num_observations=10,
        obs_dtype=np.float32,
        episode_length=10)
    spec = specs.make_environment_spec(environment)

    network = snt.Sequential([
        snt.Flatten(),
        snt.nets.MLP([50, 50]),
        networks.PolicyValueHead(spec.actions.num_values),
    ])
    model = simulator.Simulator(environment)
    optimizer = snt.optimizers.Adam(1e-3)

    # Construct the agent.
    agent = mcts.MCTS(
        environment_spec=spec,
        network=network,
        model=model,
        optimizer=optimizer,
        n_step=1,
        discount=1.,
        replay_capacity=100,
        num_simulations=10,
        batch_size=10)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent)
    loop.run(num_episodes=2)
Ejemplo n.º 3
0
def make_env_and_model() -> Tuple[dm_env.Environment, models.Model]:
  """Create environment and corresponding model (learned or simulator)."""
  environment = bsuite.load('catch', kwargs={})
  if FLAGS.simulator:
    model = simulator.Simulator(environment)  # pytype: disable=attribute-error
  else:
    model = mlp.MLPModel(
        specs.make_environment_spec(environment),
        replay_capacity=1000,
        batch_size=16,
        hidden_sizes=(50,),
    )
  environment = wrappers.SinglePrecisionWrapper(environment)

  return environment, model
Ejemplo n.º 4
0
    def test_simulator_fidelity(self):
        """Tests whether the simulator match the ground truth."""
        env = catch.Catch()
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)

        for _ in range(10):
            true_timestep = env.reset()
            model_timestep = model.reset()
            self._check_equal(true_timestep, model_timestep)

            while not true_timestep.last():
                action = np.random.randint(num_actions)
                true_timestep = env.step(action)
                model_timestep = model.step(action)
                self._check_equal(true_timestep, model_timestep)
Ejemplo n.º 5
0
    def test_checkpointing(self):
        env = catch.Catch()
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)

        env.reset()
        model.reset()

        env.step(2)
        timestep = model.step(2)
        model.save_checkpoint()
        while not timestep.last():
            timestep = model.step(np.random.randint(num_actions))
        model.load_checkpoint()

        true_timestep = env.step(2)
        model_timestep = model.step(2)
        self._check_equal(true_timestep, model_timestep)
Ejemplo n.º 6
0
def make_env_and_model(
        bsuite_id: str, results_dir: str,
        overwrite: bool) -> Tuple[dm_env.Environment, models.Model]:
    """Create environment and corresponding model (learned or simulator)."""
    raw_env = bsuite.load_from_id(bsuite_id)
    if FLAGS.simulator:
        model = simulator.Simulator(raw_env)  # pytype: disable=attribute-error
    else:
        model = mlp.MLPModel(
            specs.make_environment_spec(raw_env),
            replay_capacity=1000,
            batch_size=16,
            hidden_sizes=(50, ),
        )
    environment = csv_logging.wrap_environment(raw_env, bsuite_id, results_dir,
                                               overwrite)
    environment = wrappers.SinglePrecisionWrapper(environment)

    return environment, model