コード例 #1
0
    def test_r2d2(self):
        # Create a fake environment to test with.
        # TODO(b/152596848): Allow R2D2 to deal with integer observations.
        environment = fakes.DiscreteEnvironment(num_actions=5,
                                                num_observations=10,
                                                obs_shape=(10, 4),
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        agent = r2d2.R2D2(
            environment_spec=spec,
            network=SimpleNetwork(spec.actions),
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            burn_in_length=2,
            trace_length=6,
            replay_period=4,
            checkpoint=False,
        )

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=5)
コード例 #2
0
def get_env_agent():
    """Creates env and agent.

    Returns:
        env_acme (acme.wrappers.observation_action_reward.
            ObservationActionRewardWrappe).

        agent (acme.agents.tf.r2d2.agent.R2D2).
    """
    # Get environment
    env_acme = make_environmment()
    env_spec = acme.make_environment_spec(env_acme)

    # Create agent and network
    network = networks.R2D2AtariNetwork(env_spec.actions.num_values)
    agent = r2d2.R2D2(
        environment_spec=env_spec,
        network=network,
        burn_in_length=2,
        trace_length=6,
        replay_period=4,
    )

    return env_acme, agent
コード例 #3
0
# between the learner and the actor
variable_server = reverb.Server(tables=[
    reverb.Table(name='variable_server',
                 sampler=reverb.selectors.Fifo(),
                 remover=reverb.selectors.Fifo(),
                 max_size=5,
                 rate_limiter=reverb.rate_limiters.MinSize(1)),
])
variable_server_address = f'localhost:{variable_server.port}'
variable_client = RemoteVariableClient.remote('variable_server',
                                              variable_server_address)

agent = r2d2.R2D2(env_spec,
                  network,
                  burn_in_length=40,
                  trace_length=39,
                  replay_period=40,
                  batch_size=64,
                  target_update_period=2500)

replay_server_address = agent._actor._adder._client.server_address
variable_client.add.remote(agent._learner._variables)

remote_processes = []
epsilon = 0.4
alpha = 7
num_actors = 8
for i in range(num_actors):
    actor_epsilon = pow(epsilon, 1 + i / (num_actors - 1) * alpha)
    remote_actor = RemoteRecurrentActor.remote(
        actor_id=i,