Esempio n. 1
0
 def make_actor(self,
                random_key,
                policy_network,
                adder=None,
                variable_source=None,
                force_eval_with_q_filter=False):
     assert variable_source is not None
     if self._config.eval_with_q_filter or force_eval_with_q_filter:
         params_to_get = ['policy', 'all_q']
         if self._config.use_img_encoder:
             params_to_get.append('img_encoder')
         return actors.GenericActor(
             actor=policy_network,
             random_key=random_key,
             # Inference happens on CPU, so it's better to move variables there.
             variable_client=variable_utils.VariableClient(variable_source,
                                                           params_to_get,
                                                           device='cpu'),
             adder=adder,
         )
     else:
         params_to_get = ['policy']
         if self._config.use_img_encoder:
             params_to_get.append('img_encoder')
         return actors.GenericActor(
             actor=policy_network,
             random_key=random_key,
             # Inference happens on CPU, so it's better to move variables there.
             variable_client=variable_utils.VariableClient(variable_source,
                                                           params_to_get,
                                                           device='cpu'),
             adder=adder,
         )
Esempio n. 2
0
 def make_actor(
     self,
     random_key,
     policy_network,
     adder = None,
     variable_source = None):
   assert variable_source is not None
   if self._config.use_img_encoder:
     return actors.GenericActor(
         actor=policy_network,
         random_key=random_key,
         # Inference happens on CPU, so it's better to move variables there too.
         variable_client=variable_utils.VariableClient(
             variable_source, ['policy', 'img_encoder'], device='cpu'),
         adder=adder,
     )
   else:
     return actors.GenericActor(
         actor=policy_network,
         random_key=random_key,
         # Inference happens on CPU, so it's better to move variables there too.
         variable_client=variable_utils.VariableClient(
             variable_source, ['policy',], device='cpu'),
         adder=adder,
     )
Esempio n. 3
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset with next_actions extra.
    transitions = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                        FLAGS.num_demonstrations)
    double_transitions = rlds.transformations.batch(transitions,
                                                    size=2,
                                                    shift=1,
                                                    drop_remainder=True)
    transitions = double_transitions.map(_add_next_action_extras)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions, key=key_demonstrations, batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = td3.make_networks(environment_spec)

    # Create the learner.
    learner = td3.TD3Learner(
        networks=networks,
        random_key=key_learner,
        discount=FLAGS.discount,
        iterator=demonstrations,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        twin_critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        use_sarsa_target=FLAGS.use_sarsa_target,
        bc_alpha=FLAGS.bc_alpha,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        del key
        return networks.policy_network.apply(params, observation)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Esempio n. 4
0
    def make_actor(
        self,
        random_key: networks_lib.PRNGKey,
        policy: Tuple[str, networks_lib.FeedForwardNetwork],
        environment_spec: specs.EnvironmentSpec,
        variable_source: Optional[core.VariableSource] = None,
        adder: Optional[adders.Adder] = None,
    ) -> acme.Actor:
        del environment_spec
        assert variable_source is not None

        kname, policy = policy

        normalization_apply_fn = (running_statistics.normalize
                                  if self._config.normalize_observations else
                                  (lambda a, b: a))
        policy_to_run = get_policy(policy, normalization_apply_fn)

        actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
            policy_to_run)
        variable_client = variable_utils.VariableClient(variable_source,
                                                        kname,
                                                        device='cpu')
        return actors.GenericActor(actor_core,
                                   random_key,
                                   variable_client,
                                   adder,
                                   backend='cpu',
                                   per_episode_update=True)
Esempio n. 5
0
    def actor_evaluator(
        random_key: networks_lib.PRNGKey,
        variable_source: core.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""
        # Create the actor loading the weights from variable source.
        actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
            evaluator_network)
        # Inference happens on CPU, so it's better to move variables there too.
        variable_client = variable_utils.VariableClient(variable_source,
                                                        'policy',
                                                        device='cpu')
        actor = actors.GenericActor(actor_core,
                                    random_key,
                                    variable_client,
                                    backend='cpu')

        # Logger.
        logger = loggers.make_default_logger('evaluator',
                                             steps_key='evaluator_steps')

        # Create environment and evaluator networks
        environment = environment_factory(False)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')

        # Create the run loop and return it.
        return environment_loop.EnvironmentLoop(
            environment,
            actor,
            counter,
            logger,
        )
Esempio n. 6
0
def make_actor(actor_core: ActorCore,
               random_key: networks_lib.PRNGKey,
               variable_source: core.VariableSource,
               adder: Optional[adders.Adder] = None) -> core.Actor:
    """Creates an MBOP actor from an actor core.

  Args:
    actor_core: An MBOP actor core.
    random_key: JAX Random key.
    variable_source: The source to get networks parameters from.
    adder: An adder to add experiences to. The `extras` of the adder holds the
      state of the recurrent policy. If `has_extras=True` then the `extras` part
      returned from the recurrent policy is appended to the state before added
      to the adder.

  Returns:
    A recurrent actor.
  """
    variable_client = variable_utils.VariableClient(client=variable_source,
                                                    key=[
                                                        'world_model-policy',
                                                        'policy_prior-policy',
                                                        'n_step_return-policy'
                                                    ])

    return actors.GenericActor(actor_core,
                               random_key,
                               variable_client,
                               adder,
                               backend=None)
Esempio n. 7
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Esempio n. 8
0
def main(_):
    key = jax.random.PRNGKey(FLAGS.seed)
    key_demonstrations, key_learner = jax.random.split(key, 2)

    # Create an environment and grab the spec.
    environment = gym_helpers.make_environment(task=FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Get a demonstrations dataset.
    transitions_iterator = tfds.get_tfds_dataset(FLAGS.dataset_name,
                                                 FLAGS.num_demonstrations)
    demonstrations = tfds.JaxInMemoryRandomSampleIterator(
        transitions_iterator,
        key=key_demonstrations,
        batch_size=FLAGS.batch_size)

    # Create the networks to optimize.
    networks = cql.make_networks(environment_spec)

    # Create the learner.
    learner = cql.CQLLearner(
        batch_size=FLAGS.batch_size,
        networks=networks,
        random_key=key_learner,
        policy_optimizer=optax.adam(FLAGS.policy_learning_rate),
        critic_optimizer=optax.adam(FLAGS.critic_learning_rate),
        fixed_cql_coefficient=FLAGS.fixed_cql_coefficient,
        cql_lagrange_threshold=FLAGS.cql_lagrange_threshold,
        demonstrations=demonstrations,
        num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = networks.policy_network.apply(params, observation)
        return networks.sample_eval(dist_params, key)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Esempio n. 9
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: networks_lib.FeedForwardNetwork,
      config: dqn_config.DQNConfig,
  ):
    """Initialize the agent."""
    # Data is communicated via reverb replay.
    reverb_replay = replay.make_reverb_prioritized_nstep_replay(
        environment_spec=environment_spec,
        n_step=config.n_step,
        batch_size=config.batch_size,
        max_replay_size=config.max_replay_size,
        min_replay_size=config.min_replay_size,
        priority_exponent=config.priority_exponent,
        discount=config.discount,
    )
    self._server = reverb_replay.server

    optimizer = optax.chain(
        optax.clip_by_global_norm(config.max_gradient_norm),
        optax.adam(config.learning_rate),
    )
    key_learner, key_actor = jax.random.split(jax.random.PRNGKey(config.seed))
    # The learner updates the parameters (and initializes them).
    loss_fn = losses.PrioritizedDoubleQLearning(
        discount=config.discount,
        importance_sampling_exponent=config.importance_sampling_exponent,
    )
    learner = learning_lib.SGDLearner(
        network=network,
        loss_fn=loss_fn,
        data_iterator=reverb_replay.data_iterator,
        optimizer=optimizer,
        target_update_period=config.target_update_period,
        random_key=key_learner,
        replay_client=reverb_replay.client,
    )

    # The actor selects actions according to the policy.
    assert config.epsilon is not Sequence
    def policy(params: networks_lib.Params, key: jnp.ndarray,
               observation: jnp.ndarray) -> jnp.ndarray:
      action_values = network.apply(params, observation)
      return rlax.epsilon_greedy(config.epsilon).sample(key, action_values)
    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
    variable_client = variable_utils.VariableClient(learner, '')
    actor = actors.GenericActor(
        actor_core, key_actor, variable_client, reverb_replay.adder)

    super().__init__(
        actor=actor,
        learner=learner,
        min_observations=max(config.batch_size, config.min_replay_size),
        observations_per_step=config.batch_size / config.samples_per_insert,
    )
Esempio n. 10
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None) -> acme.Actor:
   assert variable_source is not None
   actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
       policy_network)
   variable_client = variable_utils.VariableClient(variable_source, 'policy',
                                                   device='cpu')
   return actors.GenericActor(
       actor_core, random_key, variable_client, adder, backend='cpu')
Esempio n. 11
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy: actor_core_lib.FeedForwardPolicy,
     environment_spec: specs.EnvironmentSpec,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   del environment_spec
   assert variable_source is not None
   actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
   variable_client = variable_utils.VariableClient(
       variable_source, 'policy', device='cpu')
   return actors.GenericActor(
       actor_core, random_key, variable_client, backend='cpu')
Esempio n. 12
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
   assert variable_source is not None
   actor = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
       policy_network)
   variable_client = variable_utils.VariableClient(
       variable_source,
       'network',
       device='cpu',
       update_period=self._config.variable_update_period)
   return actors.GenericActor(
       actor, random_key, variable_client, adder, backend='cpu')
Esempio n. 13
0
    def test_recurrent(self):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor_core = actor_core_lib.batched_recurrent_to_actor_core(
            policy, initial_state)
        actor = actors.GenericActor(actor_core, jax.random.PRNGKey(1),
                                    variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Esempio n. 14
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
     assert variable_source is not None
     actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
         policy_network)
     # Inference happens on CPU, so it's better to move variables there too.
     variable_client = variable_utils.VariableClient(variable_source,
                                                     'policy',
                                                     device='cpu')
     return actors.GenericActor(actor_core,
                                random_key,
                                variable_client,
                                adder,
                                backend='cpu')
Esempio n. 15
0
    def make_actor(
        self,
        random_key: networks_lib.PRNGKey,
        policy: r2d2_actor.R2D2Policy,
        environment_spec: specs.EnvironmentSpec,
        variable_source: Optional[core.VariableSource] = None,
        adder: Optional[adders.Adder] = None,
    ) -> acme.Actor:
        del environment_spec
        # Create variable client.
        variable_client = variable_utils.VariableClient(
            variable_source,
            key='actor_variables',
            update_period=self._config.variable_update_period)

        return actors.GenericActor(policy,
                                   random_key,
                                   variable_client,
                                   adder,
                                   backend='cpu')
Esempio n. 16
0
 def make_actor(
     self,
     random_key: networks_lib.PRNGKey,
     policy_network: dqn_actor.EpsilonPolicy,
     adder: Optional[adders.Adder] = None,
     variable_source: Optional[core.VariableSource] = None,
 ) -> core.Actor:
     assert variable_source is not None
     # Inference happens on CPU, so it's better to move variables there too.
     variable_client = variable_utils.VariableClient(variable_source,
                                                     '',
                                                     device='cpu')
     epsilon = self._config.epsilon
     epsilons = epsilon if epsilon is Sequence else (epsilon, )
     actor_core = dqn_actor.alternating_epsilons_actor_core(
         policy_network, epsilons=epsilons)
     return actors.GenericActor(actor=actor_core,
                                random_key=random_key,
                                variable_client=variable_client,
                                adder=adder,
                                backend=self._actor_backend)
Esempio n. 17
0
    def make_actor(
            self,
            random_key: networks_lib.PRNGKey,
            policy_network,
            adder: Optional[adders.Adder] = None,
            variable_source: Optional[core.VariableSource] = None
    ) -> acme.Actor:

        # Create variable client.
        variable_client = variable_utils.VariableClient(
            variable_source,
            key='actor_variables',
            update_period=self._config.variable_update_period)

        # TODO(b/186613827) move this to
        # - the actor __init__ function - this is a good place if it is specific
        #   for R2D2.
        # - the EnvironmentLoop - this is a good place if it potentially applies
        #   for all actors.
        #
        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        variable_client.update_and_wait()

        initial_state_key1, initial_state_key2, random_key = jax.random.split(
            random_key, 3)
        actor_initial_state_params = self._networks.initial_state.init(
            initial_state_key1, 1)
        actor_initial_state = self._networks.initial_state.apply(
            actor_initial_state_params, initial_state_key2, 1)

        actor_core = r2d2_actor.get_actor_core(policy_network,
                                               actor_initial_state,
                                               self._config.num_epsilons)
        return actors.GenericActor(actor_core,
                                   random_key,
                                   variable_client,
                                   adder,
                                   backend='cpu')
Esempio n. 18
0
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)