Exemple #1
0
 def make_actor(
     self,
     random_key,
     policy_network,
     adder = None,
     variable_source = None):
   assert variable_source is not None
   if self._config.use_img_encoder:
     return actors.GenericActor(
         actor=policy_network,
         random_key=random_key,
         # Inference happens on CPU, so it's better to move variables there too.
         variable_client=variable_utils.VariableClient(
             variable_source, ['policy', 'img_encoder'], device='cpu'),
         adder=adder,
     )
   else:
     return actors.GenericActor(
         actor=policy_network,
         random_key=random_key,
         # Inference happens on CPU, so it's better to move variables there too.
         variable_client=variable_utils.VariableClient(
             variable_source, ['policy',], device='cpu'),
         adder=adder,
     )
Exemple #2
0
 def make_actor(self,
                random_key,
                policy_network,
                adder=None,
                variable_source=None,
                force_eval_with_q_filter=False):
     assert variable_source is not None
     if self._config.eval_with_q_filter or force_eval_with_q_filter:
         params_to_get = ['policy', 'all_q']
         if self._config.use_img_encoder:
             params_to_get.append('img_encoder')
         return actors.GenericActor(
             actor=policy_network,
             random_key=random_key,
             # Inference happens on CPU, so it's better to move variables there.
             variable_client=variable_utils.VariableClient(variable_source,
                                                           params_to_get,
                                                           device='cpu'),
             adder=adder,
         )
     else:
         params_to_get = ['policy']
         if self._config.use_img_encoder:
             params_to_get.append('img_encoder')
         return actors.GenericActor(
             actor=policy_network,
             random_key=random_key,
             # Inference happens on CPU, so it's better to move variables there.
             variable_client=variable_utils.VariableClient(variable_source,
                                                           params_to_get,
                                                           device='cpu'),
             adder=adder,
         )
Exemple #3
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Exemple #4
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #5
0
    def actor_evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
            evaluator_network)
        # Inference happens on CPU, so it's better to move variables there too.
        variable_client = variable_utils.VariableClient(variable_source,
                                                        'policy',
                                                        device='cpu')
        actor = actors.GenericActor(actor_core,
                                    random_key,
                                    variable_client,
                                    backend='cpu')

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Exemple #6
0
def make_actor(actor_core: ActorCore,
               random_key: networks_lib.PRNGKey,
               variable_source: core.VariableSource,
               adder: Optional[adders.Adder] = None) -> core.Actor:
    """Creates an MBOP actor from an actor core.

  Args:
    actor_core: An MBOP actor core.
    random_key: JAX Random key.
    variable_source: The source to get networks parameters from.
    adder: An adder to add experiences to. The `extras` of the adder holds the
      state of the recurrent policy. If `has_extras=True` then the `extras` part
      returned from the recurrent policy is appended to the state before added
      to the adder.

  Returns:
    A recurrent actor.
  """
    variable_client = variable_utils.VariableClient(client=variable_source,
                                                    key=[
                                                        'world_model-policy',
                                                        'policy_prior-policy',
                                                        'n_step_return-policy'
                                                    ])

    return actors.GenericActor(actor_core,
                               random_key,
                               variable_client,
                               adder,
                               backend=None)
Exemple #7
0
    def make_actor(
        self,
        random_key: networks_lib.PRNGKey,
        policy: Tuple[str, networks_lib.FeedForwardNetwork],
        environment_spec: specs.EnvironmentSpec,
        variable_source: Optional[core.VariableSource] = None,
        adder: Optional[adders.Adder] = None,
    ) -> acme.Actor:
        del environment_spec
        assert variable_source is not None

        kname, policy = policy

        normalization_apply_fn = (running_statistics.normalize
                                  if self._config.normalize_observations else
                                  (lambda a, b: a))
        policy_to_run = get_policy(policy, normalization_apply_fn)

        actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
            policy_to_run)
        variable_client = variable_utils.VariableClient(variable_source,
                                                        kname,
                                                        device='cpu')
        return actors.GenericActor(actor_core,
                                   random_key,
                                   variable_client,
                                   adder,
                                   backend='cpu',
                                   per_episode_update=True)
    def make_actor(
        self,
        random_key,
        policy,
        environment_spec,
        variable_source=None,
        adder=None,
    ):
        assert variable_source is not None

        wrapped_actor = self._rl_agent.make_actor(random_key,
                                                  policy.discrete_policy,
                                                  environment_spec, adder,
                                                  variable_source)
        return actor.AquademActor(
            wrapped_actor=wrapped_actor,
            policy=policy.aquadem_policy,
            # Inference happens on CPU, so it's better to move variables there too.
            variable_client=variable_utils.VariableClient(
                variable_source,
                'aquadem_encoder',
                device='cpu',
                update_period=1000000000),  # never update what does not change
            adder=adder,
        )
Exemple #9
0
    def actor_evaluator(
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor = actors.FeedForwardActor(
            policy=evaluator_network,
            random_key=random_key,
            # Inference happens on CPU, so it's better to move variables there too.
            variable_client=variable_utils.VariableClient(variable_source,
                                                          'policy',
                                                          device='cpu'))

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Exemple #10
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks_lib.FeedForwardNetwork,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(
            jax.random.PRNGKey(config.seed))
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            random_key=key_learner,
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: networks_lib.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(key_actor),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Exemple #11
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #12
0
 def test_update(self):
     init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network))
     params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32)))
     variable_source = fakes.VariableSource(params)
     variable_client = variable_utils.VariableClient(variable_source,
                                                     key='policy')
     variable_client.update_and_wait()
     tree.map_structure(np.testing.assert_array_equal,
                        variable_client.params, params)
Exemple #13
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset.
    transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                                 FLAGS.num_demonstrations)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions_iterator,
        key=key_demonstrations,
        batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = cql.make_networks(environment_spec)

    # Create the learner.
    learner = cql.CQLLearner(
        batch_size=FLAGS.batch_size,
        networks=networks,
        random_key=key_learner,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        fixed_cql_coefficient=FLAGS.fixed_cql_coefficient,
        cql_lagrange_threshold=FLAGS.cql_lagrange_threshold,
        demonstrations=demonstrations,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Exemple #14
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None) -> acme.Actor:
   assert variable_source is not None
   actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
       policy_network)
   variable_client = variable_utils.VariableClient(variable_source, 'policy',
                                                   device='cpu')
   return actors.GenericActor(
       actor_core, random_key, variable_client, adder, backend='cpu')
Exemple #15
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy: actor_core_lib.FeedForwardPolicy,
     environment_spec: specs.EnvironmentSpec,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   del environment_spec
   assert variable_source is not None
   actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
   variable_client = variable_utils.VariableClient(
       variable_source, 'policy', device='cpu')
   return actors.GenericActor(
       actor_core, random_key, variable_client, backend='cpu')
Exemple #16
0
    def test_multiple_keys(self):
        init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network))
        params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32)))
        steps = jnp.zeros(shape=1)
        variables = {'network': params, 'steps': steps}
        variable_source = fakes.VariableSource(variables,
                                               use_default_key=False)
        variable_client = variable_utils.VariableClient(
            variable_source, key=['network', 'steps'])
        variable_client.update_and_wait()

        tree.map_structure(np.testing.assert_array_equal,
                           variable_client.params[0], params)
        tree.map_structure(np.testing.assert_array_equal,
                           variable_client.params[1], steps)
Exemple #17
0
 def __init__(self,
              wrapped_actor: core.Actor,
              variable_source: core.VariableSource,
              max_abs_observation: Optional[float],
              update_period: int = 1,
              backend: Optional[str] = None):
     self._wrapped_actor = wrapped_actor
     self._variable_client = variable_utils.VariableClient(
         variable_source,
         key=_NORMALIZATION_VARIABLES,
         update_period=update_period,
         device=backend)
     self._apply_normalization = jax.jit(functools.partial(
         running_statistics.normalize, max_abs_value=max_abs_observation),
                                         backend=backend)
Exemple #18
0
    def test_recurrent(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return (actions, (action_values, )), core_state
            else:
                return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor = actors.RecurrentActor(policy,
                                      jax.random.PRNGKey(1),
                                      initial_state,
                                      variable_client,
                                      has_extras=has_extras)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Exemple #19
0
 def make_actor(
     self,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   assert variable_source is not None
   key, self._random_key = jax.random.split(self._random_key)
   return actors.FeedForwardActor(
       policy=policy_network,
       random_key=key,
       # Inference happens on CPU, so it's better to move variables there too.
       variable_client=variable_utils.VariableClient(variable_source, '',
                                                     device='cpu'),
       adder=adder,
   )
Exemple #20
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   assert variable_source is not None
   actor = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
       policy_network)
   variable_client = variable_utils.VariableClient(
       variable_source,
       'network',
       device='cpu',
       update_period=self._config.variable_update_period)
   return actors.GenericActor(
       actor, random_key, variable_client, adder, backend='cpu')
Exemple #21
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
     assert variable_source is not None
     actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
         policy_network)
     # Inference happens on CPU, so it's better to move variables there too.
     variable_client = variable_utils.VariableClient(variable_source,
                                                     'policy',
                                                     device='cpu')
     return actors.GenericActor(actor_core,
                                random_key,
                                variable_client,
                                adder,
                                backend='cpu')
Exemple #22
0
 def make_actor(
         self,
         random_key: networks_lib.PRNGKey,
         policy_network,
         adder: Optional[adders.Adder] = None,
         variable_source: Optional[core.VariableSource] = None
 ) -> acme.Actor:
     variable_client = variable_utils.VariableClient(client=variable_source,
                                                     key='network',
                                                     update_period=1000,
                                                     device='cpu')
     return acting.IMPALAActor(
         forward_fn=policy_network.forward_fn,
         initial_state_init_fn=policy_network.initial_state_init_fn,
         initial_state_fn=policy_network.initial_state_fn,
         variable_client=variable_client,
         adder=adder,
         rng=hk.PRNGSequence(random_key),
     )
Exemple #23
0
 def make_actor(self,
                random_key,
                policy_network,
                adder=None,
                variable_source=None):
     assert variable_source is not None
     actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
         policy_network)
     variable_client = variable_utils.VariableClient(variable_source,
                                                     'policy',
                                                     device='cpu')
     if self._config.use_random_actor:
         ACTOR = contrastive_utils.InitiallyRandomActor  # pylint: disable=invalid-name
     else:
         ACTOR = actors.GenericActor  # pylint: disable=invalid-name
     return ACTOR(actor_core,
                  random_key,
                  variable_client,
                  adder,
                  backend='cpu')
Exemple #24
0
    def make_actor(
        self,
        random_key: networks_lib.PRNGKey,
        policy: r2d2_actor.R2D2Policy,
        environment_spec: specs.EnvironmentSpec,
        variable_source: Optional[core.VariableSource] = None,
        adder: Optional[adders.Adder] = None,
    ) -> acme.Actor:
        del environment_spec
        # Create variable client.
        variable_client = variable_utils.VariableClient(
            variable_source,
            key='actor_variables',
            update_period=self._config.variable_update_period)

        return actors.GenericActor(policy,
                                   random_key,
                                   variable_client,
                                   adder,
                                   backend='cpu')
Exemple #25
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network: dqn_actor.EpsilonPolicy,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
     assert variable_source is not None
     # Inference happens on CPU, so it's better to move variables there too.
     variable_client = variable_utils.VariableClient(variable_source,
                                                     '',
                                                     device='cpu')
     epsilon = self._config.epsilon
     epsilons = epsilon if epsilon is Sequence else (epsilon, )
     actor_core = dqn_actor.alternating_epsilons_actor_core(
         policy_network, epsilons=epsilons)
     return actors.GenericActor(actor=actor_core,
                                random_key=random_key,
                                variable_client=variable_client,
                                adder=adder,
                                backend=self._actor_backend)
Exemple #26
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy: impala_networks.IMPALANetworks,
     environment_spec: specs.EnvironmentSpec,
     variable_source: Optional[core.VariableSource] = None,
     adder: Optional[adders.Adder] = None,
 ) -> acme.Actor:
   del environment_spec
   variable_client = variable_utils.VariableClient(
       client=variable_source,
       key='network',
       update_period=self._config.variable_update_period,
       device='cpu')
   return acting.IMPALAActor(
       forward_fn=policy.forward_fn,
       initial_state_fn=policy.initial_state_fn,
       variable_client=variable_client,
       adder=adder,
       rng=hk.PRNGSequence(random_key),
   )
Exemple #27
0
    def make_actor(
            self,
            random_key: networks_lib.PRNGKey,
            policy_network,
            adder: Optional[adders.Adder] = None,
            variable_source: Optional[core.VariableSource] = None
    ) -> acme.Actor:

        # Create variable client.
        variable_client = variable_utils.VariableClient(
            variable_source,
            key='actor_variables',
            update_period=self._config.variable_update_period)

        # TODO(b/186613827) move this to
        # - the actor __init__ function - this is a good place if it is specific
        #   for R2D2.
        # - the EnvironmentLoop - this is a good place if it potentially applies
        #   for all actors.
        #
        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        variable_client.update_and_wait()

        initial_state_key1, initial_state_key2, random_key = jax.random.split(
            random_key, 3)
        actor_initial_state_params = self._networks.initial_state.init(
            initial_state_key1, 1)
        actor_initial_state = self._networks.initial_state.apply(
            actor_initial_state_params, initial_state_key2, 1)

        actor_core = r2d2_actor.get_actor_core(policy_network,
                                               actor_initial_state,
                                               self._config.num_epsilons)
        return actors.GenericActor(actor_core,
                                   random_key,
                                   variable_client,
                                   adder,
                                   backend='cpu')
Exemple #28
0
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Exemple #29
0
  def test_recurrent(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)
    output_size = env_spec.actions.num_values
    obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    rng = hk.PRNGSequence(1)

    @hk.transform
    def network(inputs: jnp.ndarray, state: hk.LSTMState):
      return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state)

    @hk.transform
    def initial_state(batch_size: int):
      network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
      return network.initial_state(batch_size)

    initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1)
    params = network.init(next(rng), obs, initial_state)

    def policy(
        params: jnp.ndarray,
        key: jnp.ndarray,
        observation: jnp.ndarray,
        core_state: hk.LSTMState
    ) -> Tuple[jnp.ndarray, hk.LSTMState]:
      del key  # Unused for test-case deterministic policy.
      action_values, core_state = network.apply(params, observation, core_state)
      return jnp.argmax(action_values, axis=-1), core_state

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.RecurrentActor(
        policy, hk.PRNGSequence(1), initial_state, variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Exemple #30
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
        seed: int = 1,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=adders.NStepTransitionAdder.signature(
                environment_spec=environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            server_address=address,
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(seed),
            optimizer=optax.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, '')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(seed),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)