Exemplo n.º 1
0
def main(_):
  env = helpers.make_environment(FLAGS.level)
  env_spec = acme.make_environment_spec(env)
  network = networks.DQNAtariNetwork(env_spec.actions.num_values)

  agent = dqn.DQN(env_spec, network)

  loop = acme.EnvironmentLoop(env, agent)
  loop.run(FLAGS.num_episodes)
Exemplo n.º 2
0
def get_env_agent():
    """Create env and agent.

    Returns:
        env_acme (acme.wrappers.observation_action_reward.
            ObservationActionRewardWrappe).

        agent (acme.agents.tf.dqn.agent.DQN).
    """
    # Get environment
    env_acme = make_environmment()
    env_spec = acme.make_environment_spec(env_acme)

    # Create agent and network
    network = networks.DQNAtariNetwork(env_spec.actions.num_values)
    agent = dqn.DQN(env_spec, network, checkpoint_subpath="./acme")

    return env_acme, agent
Exemplo n.º 3
0
    def test_atari(self):
        """Tests that the agent can run for some steps without crashing."""
        env_factory = lambda x: fakes.fake_atari_wrapped()
        net_factory = lambda spec: networks.DQNAtariNetwork(spec.num_values)

        agent = dqn.DistributedDQN(
            environment_factory=env_factory,
            network_factory=net_factory,
            num_actors=2,
            batch_size=32,
            min_replay_size=32,
            max_replay_size=1000,
        )
        program = agent.build()

        (learner_node, ) = program.groups['learner']
        learner_node.disable_run()

        lp.launch(program, launch_type='test_mt')

        learner: acme.Learner = learner_node.create_handle().dereference()

        for _ in range(5):
            learner.step()
Exemplo n.º 4
0
def make_network(action_spec: specs.DiscreteArray) -> snt.Module:
    return snt.Sequential([
        lambda x: tf.image.convert_image_dtype(x, tf.float32),
        networks.DQNAtariNetwork(action_spec.num_values)
    ])
from acme_dist_toolkit.remote_actors import RemoteFeedForwardActor
from acme_dist_toolkit.remote_environments import create_env_fns
from acme_dist_toolkit.remote_variable_client import RemoteVariableClient

ray.init()

# Set gpu config
gpus = tf.config.list_physical_devices(device_type='GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(device=gpu, enable=True)
tf.config.set_visible_devices(devices=gpus[0], device_type='GPU')

# Create network
env = create_env_fns['atari']('PongNoFrameskip-v4')
env_spec = acme.make_environment_spec(env)
network = networks.DQNAtariNetwork(env_spec.actions.num_values)
tf2_utils.create_variables(network, [env_spec.observations])

# Create a variable replay buffer for sharing parameters
# between the learner and the actor
variable_server = reverb.Server(tables=[
  reverb.Table(
    name='variable_server',
    sampler=reverb.selectors.Fifo(),
    remover=reverb.selectors.Fifo(),
    max_size=20,
    rate_limiter=reverb.rate_limiters.MinSize(1)),
])
variable_server_address = f'localhost:{variable_server.port}'
variable_client = RemoteVariableClient.remote('variable_server',
                                              variable_server_address)
Exemplo n.º 6
0
def DQNAtariActorNetwork(num_actions: int, epsilon: tf.Variable):
    network = networks.DQNAtariNetwork(num_actions)
    return snt.Sequential([
        network,
        lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(),
    ])