def test_r2d2(self): # Create a fake environment to test with. # TODO(b/152596848): Allow R2D2 to deal with integer observations. environment = fakes.DiscreteEnvironment(num_actions=5, num_observations=10, obs_shape=(10, 4), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) # Construct the agent. agent = r2d2.R2D2( environment_spec=spec, network=SimpleNetwork(spec.actions), batch_size=10, samples_per_insert=2, min_replay_size=10, burn_in_length=2, trace_length=6, replay_period=4, checkpoint=False, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=5)
def get_env_agent(): """Creates env and agent. Returns: env_acme (acme.wrappers.observation_action_reward. ObservationActionRewardWrappe). agent (acme.agents.tf.r2d2.agent.R2D2). """ # Get environment env_acme = make_environmment() env_spec = acme.make_environment_spec(env_acme) # Create agent and network network = networks.R2D2AtariNetwork(env_spec.actions.num_values) agent = r2d2.R2D2( environment_spec=env_spec, network=network, burn_in_length=2, trace_length=6, replay_period=4, ) return env_acme, agent
# between the learner and the actor variable_server = reverb.Server(tables=[ reverb.Table(name='variable_server', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), max_size=5, rate_limiter=reverb.rate_limiters.MinSize(1)), ]) variable_server_address = f'localhost:{variable_server.port}' variable_client = RemoteVariableClient.remote('variable_server', variable_server_address) agent = r2d2.R2D2(env_spec, network, burn_in_length=40, trace_length=39, replay_period=40, batch_size=64, target_update_period=2500) replay_server_address = agent._actor._adder._client.server_address variable_client.add.remote(agent._learner._variables) remote_processes = [] epsilon = 0.4 alpha = 7 num_actors = 8 for i in range(num_actors): actor_epsilon = pow(epsilon, 1 + i / (num_actors - 1) * alpha) remote_actor = RemoteRecurrentActor.remote( actor_id=i,