Esempio n. 1
0
    def __init__(self,
                 direct_rl_learner_factory: Callable[
                     [Any, Iterator[reverb.ReplaySample]], acme.Learner],
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 rnd_network: rnd_networks.RNDNetworks,
                 rng_key: jnp.ndarray,
                 grad_updates_per_batch: int,
                 is_sequence_based: bool,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        self._is_sequence_based = is_sequence_based

        target_key, predictor_key = jax.random.split(rng_key)
        target_params = rnd_network.target.init(target_key)
        predictor_params = rnd_network.predictor.init(predictor_key)
        optimizer_state = optimizer.init(predictor_params)

        self._state = RNDTrainingState(optimizer_state=optimizer_state,
                                       params=predictor_params,
                                       target_params=target_params,
                                       steps=0)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        loss = functools.partial(rnd_loss, networks=rnd_network)
        self._update = functools.partial(rnd_update_step,
                                         loss_fn=loss,
                                         optimizer=optimizer)
        self._update = utils.process_multiple_batches(self._update,
                                                      grad_updates_per_batch)
        self._update = jax.jit(self._update)

        self._get_reward = jax.jit(
            functools.partial(rnd_networks.compute_rnd_reward,
                              networks=rnd_network))

        # Generator expression that works the same as an iterator.
        # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions
        updated_iterator = (self._process_sample(sample)
                            for sample in iterator)

        self._direct_rl_learner = direct_rl_learner_factory(
            rnd_network.direct_rl_networks, updated_iterator)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Esempio n. 2
0
            #         jax.tree_map(lambda x: jnp.std(x, axis=0),
            #                      transitions.next_observation)))

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

        # Iterator on demonstration transitions.
        self._iterator = iterator

        # update_step = utils.process_multiple_batches(update_step,
        #                                              num_sgd_steps_per_step)
        self._update_step_in_initial_bc_iters = utils.process_multiple_batches(
            lambda x, y: update_step(x, y, True), num_sgd_steps_per_step)
        self._update_step_rest = utils.process_multiple_batches(
            lambda x, y: update_step(x, y, False), num_sgd_steps_per_step)

        # Use the JIT compiler.
        self._update_step = jax.jit(update_step)

        def make_initial_state(key):
            """Initialises the training state (parameters and optimiser state)."""
            key_policy, key_q, key = jax.random.split(key, 3)
            devices = jax.local_devices()

            policy_params = networks.policy_network.init(key_policy)
            policy_optimizer_state = policy_optimizer.init(policy_params)
            policy_params = jax.device_put_replicated(policy_params, devices)
            policy_optimizer_state = jax.device_put_replicated(
Esempio n. 3
0
  def __init__(self,
               networks: value_dice_networks.ValueDiceNetworks,
               policy_optimizer: optax.GradientTransformation,
               nu_optimizer: optax.GradientTransformation,
               discount: float,
               rng: jnp.ndarray,
               iterator_replay: Iterator[reverb.ReplaySample],
               iterator_demonstrations: Iterator[types.Transition],
               alpha: float = 0.05,
               policy_reg_scale: float = 1e-4,
               nu_reg_scale: float = 10.0,
               num_sgd_steps_per_step: int = 1,
               counter: Optional[counting.Counter] = None,
               logger: Optional[loggers.Logger] = None):

    rng, policy_key, nu_key = jax.random.split(rng, 3)
    policy_init_params = networks.policy_network.init(policy_key)
    policy_optimizer_state = policy_optimizer.init(policy_init_params)

    nu_init_params = networks.nu_network.init(nu_key)
    nu_optimizer_state = nu_optimizer.init(nu_init_params)

    def compute_losses(
        policy_params: networks_lib.Params,
        nu_params: networks_lib.Params,
        key: jnp.ndarray,
        replay_o_tm1: types.NestedArray,
        replay_a_tm1: types.NestedArray,
        replay_o_t: types.NestedArray,
        demo_o_tm1: types.NestedArray,
        demo_a_tm1: types.NestedArray,
        demo_o_t: types.NestedArray,
    ) -> jnp.ndarray:
      # TODO(damienv, hussenot): what to do with the discounts ?

      def policy(obs, key):
        dist_params = networks.policy_network.apply(policy_params, obs)
        return networks.sample(dist_params, key)

      key1, key2, key3, key4 = jax.random.split(key, 4)

      # Predicted actions.
      demo_o_t0 = demo_o_tm1
      policy_demo_a_t0 = policy(demo_o_t0, key1)
      policy_demo_a_t = policy(demo_o_t, key2)
      policy_replay_a_t = policy(replay_o_t, key3)

      replay_a_tm1 = networks.encode_action(replay_a_tm1)
      demo_a_tm1 = networks.encode_action(demo_a_tm1)
      policy_demo_a_t0 = networks.encode_action(policy_demo_a_t0)
      policy_demo_a_t = networks.encode_action(policy_demo_a_t)
      policy_replay_a_t = networks.encode_action(policy_replay_a_t)

      # "Value function" nu over the expert states.
      nu_demo_t0 = networks.nu_network.apply(nu_params, demo_o_t0,
                                             policy_demo_a_t0)
      nu_demo_tm1 = networks.nu_network.apply(nu_params, demo_o_tm1, demo_a_tm1)
      nu_demo_t = networks.nu_network.apply(nu_params, demo_o_t,
                                            policy_demo_a_t)
      nu_demo_diff = nu_demo_tm1 - discount * nu_demo_t

      # "Value function" nu over the replay buffer states.
      nu_replay_tm1 = networks.nu_network.apply(nu_params, replay_o_tm1,
                                                replay_a_tm1)
      nu_replay_t = networks.nu_network.apply(nu_params, replay_o_t,
                                              policy_replay_a_t)
      nu_replay_diff = nu_replay_tm1 - discount * nu_replay_t

      # Linear part of the loss.
      linear_loss_demo = jnp.mean(nu_demo_t0 * (1.0 - discount))
      linear_loss_rb = jnp.mean(nu_replay_diff)
      linear_loss = (linear_loss_demo * (1 - alpha) + linear_loss_rb * alpha)

      # Non linear part of the loss.
      nu_replay_demo_diff = jnp.concatenate([nu_demo_diff, nu_replay_diff],
                                            axis=0)
      replay_demo_weights = jnp.concatenate([
          jnp.ones_like(nu_demo_diff) * (1 - alpha),
          jnp.ones_like(nu_replay_diff) * alpha
      ],
                                            axis=0)
      replay_demo_weights /= jnp.mean(replay_demo_weights)
      non_linear_loss = jnp.sum(
          jax.lax.stop_gradient(
              utils.weighted_softmax(nu_replay_demo_diff, replay_demo_weights,
                                     axis=0)) *
          nu_replay_demo_diff)

      # Final loss.
      loss = (non_linear_loss - linear_loss)

      # Regularized policy loss.
      if policy_reg_scale > 0.:
        policy_reg = _orthogonal_regularization_loss(policy_params)
      else:
        policy_reg = 0.

      # Gradient penality on nu
      if nu_reg_scale > 0.0:
        batch_size = demo_o_tm1.shape[0]
        c = jax.random.uniform(key4, shape=(batch_size,))
        shape_o = [
            dim if i == 0 else 1 for i, dim in enumerate(replay_o_tm1.shape)
        ]
        shape_a = [
            dim if i == 0 else 1 for i, dim in enumerate(replay_a_tm1.shape)
        ]
        c_o = jnp.reshape(c, shape_o)
        c_a = jnp.reshape(c, shape_a)
        mixed_o_tm1 = c_o * demo_o_tm1 + (1 - c_o) * replay_o_tm1
        mixed_a_tm1 = c_a * demo_a_tm1 + (1 - c_a) * replay_a_tm1
        mixed_o_t = c_o * demo_o_t + (1 - c_o) * replay_o_t
        mixed_policy_a_t = c_a * policy_demo_a_t + (1 - c_a) * policy_replay_a_t
        mixed_o = jnp.concatenate([mixed_o_tm1, mixed_o_t], axis=0)
        mixed_a = jnp.concatenate([mixed_a_tm1, mixed_policy_a_t], axis=0)

        def sum_nu(o, a):
          return jnp.sum(networks.nu_network.apply(nu_params, o, a))

        nu_grad_o_fn = jax.grad(sum_nu, argnums=0)
        nu_grad_a_fn = jax.grad(sum_nu, argnums=1)
        nu_grad_o = nu_grad_o_fn(mixed_o, mixed_a)
        nu_grad_a = nu_grad_a_fn(mixed_o, mixed_a)
        nu_grad = jnp.concatenate([
            jnp.reshape(nu_grad_o, [batch_size, -1]),
            jnp.reshape(nu_grad_a, [batch_size, -1])], axis=-1)
        # TODO(damienv, hussenot): check for the need of eps
        # (like in the original value dice code).
        nu_grad_penalty = jnp.mean(
            jnp.square(
                jnp.linalg.norm(nu_grad + 1e-8, axis=-1, keepdims=True) - 1))
      else:
        nu_grad_penalty = 0.0

      policy_loss = -loss + policy_reg_scale * policy_reg
      nu_loss = loss + nu_reg_scale * nu_grad_penalty

      return policy_loss, nu_loss

    def sgd_step(
        state: TrainingState,
        data: Tuple[types.Transition, types.Transition]
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      replay_transitions, demo_transitions = data
      key, key_loss = jax.random.split(state.key)
      compute_losses_with_input = functools.partial(
          compute_losses,
          replay_o_tm1=replay_transitions.observation,
          replay_a_tm1=replay_transitions.action,
          replay_o_t=replay_transitions.next_observation,
          demo_o_tm1=demo_transitions.observation,
          demo_a_tm1=demo_transitions.action,
          demo_o_t=demo_transitions.next_observation,
          key=key_loss)
      (policy_loss_value, nu_loss_value), vjpfun = jax.vjp(
          compute_losses_with_input,
          state.policy_params, state.nu_params)
      policy_gradients, _ = vjpfun((1.0, 0.0))
      _, nu_gradients = vjpfun((0.0, 1.0))

      # Update optimizers.
      policy_update, policy_optimizer_state = policy_optimizer.update(
          policy_gradients, state.policy_optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      nu_update, nu_optimizer_state = nu_optimizer.update(
          nu_gradients, state.nu_optimizer_state)
      nu_params = optax.apply_updates(state.nu_params, nu_update)

      new_state = TrainingState(
          policy_optimizer_state=policy_optimizer_state,
          policy_params=policy_params,
          nu_optimizer_state=nu_optimizer_state,
          nu_params=nu_params,
          key=key,
          steps=state.steps + 1,
      )

      metrics = {
          'policy_loss': policy_loss_value,
          'nu_loss': nu_loss_value,
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        'learner',
        asynchronous=True,
        serialize_fn=utils.fetch_devicearray,
        steps_key=self._counter.get_steps_key())

    # Iterator on demonstration transitions.
    self._iterator_demonstrations = iterator_demonstrations
    self._iterator_replay = iterator_replay

    self._sgd_step = jax.jit(utils.process_multiple_batches(
        sgd_step, num_sgd_steps_per_step))

    # Create initial state.
    self._state = TrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_init_params,
        nu_optimizer_state=nu_optimizer_state,
        nu_params=nu_init_params,
        key=rng,
        steps=0,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Esempio n. 4
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 loss_fn: LossFn,
                 optimizer: optax.GradientTransformation,
                 data_iterator: Iterator[reverb.ReplaySample],
                 target_update_period: int,
                 random_key: networks_lib.PRNGKey,
                 replay_client: Optional[reverb.Client] = None,
                 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initialize the SGD learner."""
        self.network = network

        # Internalize the loss_fn with network.
        self._loss = jax.jit(functools.partial(loss_fn, self.network))

        # SGD performs the loss, optimizer update and periodic target net update.
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra

        def postprocess_aux(extra: LossExtra) -> LossExtra:
            reverb_update = jax.tree_map(
                lambda a: jnp.reshape(a, (-1, *a.shape[2:])),
                extra.reverb_update)
            return extra._replace(metrics=jax.tree_map(jnp.mean,
                                                       extra.metrics),
                                  reverb_update=reverb_update)

        self._num_sgd_steps_per_step = num_sgd_steps_per_step
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step,
                                                  postprocess_aux)
        self._sgd_step = jax.jit(sgd_step)

        # Internalise agent components
        self._data_iterator = utils.prefetch(data_iterator)
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None

        # Initialize the network parameters
        key_params, key_target, key_state = jax.random.split(random_key, 3)
        initial_params = self.network.init(key_params)
        initial_target_params = self.network.init(key_target)
        self._state = TrainingState(
            params=initial_params,
            target_params=initial_target_params,
            opt_state=optimizer.init(initial_params),
            steps=0,
            rng_key=key_state,
        )

        # Update replay priorities
        def update_priorities(reverb_update: ReverbUpdate) -> None:
            if replay_client is None:
                return
            keys, priorities = tree.map_structure(
                utils.fetch_devicearray,
                (reverb_update.keys, reverb_update.priorities))
            replay_client.mutate_priorities(table=replay_table_name,
                                            updates=dict(zip(keys,
                                                             priorities)))

        self._replay_client = replay_client
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
Esempio n. 5
0
    def __init__(self,
                 batch_size: int,
                 networks: CQLNetworks,
                 random_key: networks_lib.PRNGKey,
                 demonstrations: Iterator[types.Transition],
                 policy_optimizer: optax.GradientTransformation,
                 critic_optimizer: optax.GradientTransformation,
                 tau: float = 0.005,
                 fixed_cql_coefficient: Optional[float] = None,
                 cql_lagrange_threshold: Optional[float] = None,
                 cql_num_samples: int = 10,
                 num_sgd_steps_per_step: int = 1,
                 reward_scale: float = 1.0,
                 discount: float = 0.99,
                 fixed_entropy_coefficient: Optional[float] = None,
                 target_entropy: Optional[float] = 0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the CQL learner.

    Args:
      batch_size: bath size.
      networks: CQL networks.
      random_key: a key for random number generation.
      demonstrations: an iterator over training data.
      policy_optimizer: the policy optimizer.
      critic_optimizer: the Q-function optimizer.
      tau: target smoothing coefficient.
      fixed_cql_coefficient: the value for cql coefficient. If None, an adaptive
        coefficient will be used.
      cql_lagrange_threshold: a threshold that controls the adaptive loss for
        the cql coefficient.
      cql_num_samples: number of samples used to compute logsumexp(Q) via
        importance sampling.
      num_sgd_steps_per_step: how many gradient updated to perform per batch.
        batch is split into this many smaller batches, thus should be a multiple
        of num_sgd_steps_per_step
      reward_scale: reward scale.
      discount: discount to use for TD updates.
      fixed_entropy_coefficient: coefficient applied to the entropy bonus. If
        None, an adaptative coefficient will be used.
      target_entropy: Target entropy when using adapdative entropy bonus.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
    """
        adaptive_entropy_coefficient = fixed_entropy_coefficient is None
        action_spec = networks.environment_specs.actions
        if adaptive_entropy_coefficient:
            # sac_alpha is the temperature parameter that determines the relative
            # importance of the entropy term versus the reward.
            log_sac_alpha = jnp.asarray(0., dtype=jnp.float32)
            alpha_optimizer = optax.adam(learning_rate=3e-4)
            alpha_optimizer_state = alpha_optimizer.init(log_sac_alpha)
        else:
            if target_entropy:
                raise ValueError('target_entropy should not be set when '
                                 'fixed_entropy_coefficient is provided')

        adaptive_cql_coefficient = fixed_cql_coefficient is None
        if adaptive_cql_coefficient:
            log_cql_alpha = jnp.asarray(0., dtype=jnp.float32)
            cql_optimizer = optax.adam(learning_rate=3e-4)
            cql_optimizer_state = cql_optimizer.init(log_cql_alpha)
        else:
            if cql_lagrange_threshold:
                raise ValueError(
                    'cql_lagrange_threshold should not be set when '
                    'fixed_cql_coefficient is provided')

        def alpha_loss(log_sac_alpha: jnp.ndarray,
                       policy_params: networks_lib.Params,
                       transitions: types.Transition,
                       key: jnp.ndarray) -> jnp.ndarray:
            """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            sac_alpha = jnp.exp(log_sac_alpha)
            sac_alpha_loss = sac_alpha * jax.lax.stop_gradient(-log_prob -
                                                               target_entropy)
            return jnp.mean(sac_alpha_loss)

        def sac_critic_loss(q_old_action: jnp.ndarray,
                            policy_params: networks_lib.Params,
                            target_critic_params: networks_lib.Params,
                            sac_alpha: jnp.ndarray,
                            transitions: types.Transition,
                            key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Computes the SAC part of the loss."""
            next_dist_params = networks.policy_network.apply(
                policy_params, transitions.next_observation)
            next_action = networks.sample(next_dist_params, key)
            next_log_prob = networks.log_prob(next_dist_params, next_action)
            next_q = networks.critic_network.apply(
                target_critic_params, transitions.next_observation,
                next_action)
            next_v = jnp.min(next_q, axis=-1) - sac_alpha * next_log_prob
            target_q = jax.lax.stop_gradient(
                transitions.reward * reward_scale +
                transitions.discount * discount * next_v)
            return jnp.mean(
                jnp.square(q_old_action - jnp.expand_dims(target_q, -1)))

        def batched_critic(actions: jnp.ndarray,
                           critic_params: networks_lib.Params,
                           observation: jnp.ndarray) -> jnp.ndarray:
            """Applies the critic network to a batch of sampled actions."""
            actions = jax.lax.stop_gradient(actions)
            tiled_actions = jnp.reshape(actions,
                                        (batch_size * cql_num_samples, -1))
            tiled_states = jnp.tile(observation, [cql_num_samples, 1])
            tiled_q = networks.critic_network.apply(critic_params,
                                                    tiled_states,
                                                    tiled_actions)
            return jnp.reshape(tiled_q, (cql_num_samples, batch_size, -1))

        def cql_critic_loss(q_old_action: jnp.ndarray,
                            critic_params: networks_lib.Params,
                            policy_params: networks_lib.Params,
                            transitions: types.Transition,
                            key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Computes the CQL part of the loss."""
            # The CQL part of the loss is
            #     logsumexp(Q(s,·)) - Q(s,a),
            # where s is the currrent state, and a the action in the dataset (so
            # Q(s,a) is simply q_old_action.
            # We need to estimate logsumexp(Q). This is done with importance sampling
            # (IS). This function implements the unlabeled equation page 29, Appx. F,
            # in https://arxiv.org/abs/2006.04779.
            # Here, IS is done with the uniform distribution and the policy in the
            # current state s. In their implementation, the authors also add the
            # policy in the transiting state s':
            # https://github.com/aviralkumar2907/CQL/blob/master/d4rl/rlkit/torch/sac/cql.py,
            # (l. 233-236).

            key_policy, key_policy_next, key_uniform = jax.random.split(key, 3)

            def sampled_q(obs, key):
                actions, log_probs = apply_and_sample_n(
                    key, networks, policy_params, obs, cql_num_samples)
                return batched_critic(
                    actions, critic_params,
                    transitions.observation) - jax.lax.stop_gradient(
                        jnp.expand_dims(log_probs, -1))

            # Sample wrt policy in s
            sampled_q_from_policy = sampled_q(transitions.observation,
                                              key_policy)

            # Sample wrt policy in s'
            sampled_q_from_policy_next = sampled_q(
                transitions.next_observation, key_policy_next)

            # Sample wrt uniform
            actions_uniform = jax.random.uniform(
                key_uniform, (cql_num_samples, batch_size) + action_spec.shape,
                minval=action_spec.minimum,
                maxval=action_spec.maximum)
            log_prob_uniform = -jnp.sum(
                jnp.log(action_spec.maximum - action_spec.minimum))
            sampled_q_from_uniform = (batched_critic(
                actions_uniform, critic_params, transitions.observation) -
                                      log_prob_uniform)

            # Combine the samplings
            combined = jnp.concatenate(
                (sampled_q_from_uniform, sampled_q_from_policy,
                 sampled_q_from_policy_next),
                axis=0)
            lse_q = jax.nn.logsumexp(combined,
                                     axis=0,
                                     b=1. / (3 * cql_num_samples))

            return jnp.mean(lse_q - q_old_action)

        def critic_loss(critic_params: networks_lib.Params,
                        policy_params: networks_lib.Params,
                        target_critic_params: networks_lib.Params,
                        sac_alpha: jnp.ndarray, cql_alpha: jnp.ndarray,
                        transitions: types.Transition,
                        key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Computes the full critic loss."""
            key_cql, key_sac = jax.random.split(key, 2)
            q_old_action = networks.critic_network.apply(
                critic_params, transitions.observation, transitions.action)
            cql_loss = cql_critic_loss(q_old_action, critic_params,
                                       policy_params, transitions, key_cql)
            sac_loss = sac_critic_loss(q_old_action, policy_params,
                                       target_critic_params, sac_alpha,
                                       transitions, key_sac)
            return cql_alpha * cql_loss + sac_loss

        def cql_lagrange_loss(log_cql_alpha: jnp.ndarray,
                              critic_params: networks_lib.Params,
                              policy_params: networks_lib.Params,
                              transitions: types.Transition,
                              key: jnp.ndarray) -> jnp.ndarray:
            """Computes the loss that optimizes the cql coefficient."""
            cql_alpha = jnp.exp(log_cql_alpha)
            q_old_action = networks.critic_network.apply(
                critic_params, transitions.observation, transitions.action)
            return -cql_alpha * (cql_critic_loss(
                q_old_action, critic_params, policy_params, transitions, key) -
                                 cql_lagrange_threshold)

        def actor_loss(policy_params: networks_lib.Params,
                       critic_params: networks_lib.Params,
                       sac_alpha: jnp.ndarray, transitions: types.Transition,
                       key: jnp.ndarray) -> jnp.ndarray:
            """Computes the loss for the policy."""
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            q_action = networks.critic_network.apply(critic_params,
                                                     transitions.observation,
                                                     action)
            min_q = jnp.min(q_action, axis=-1)
            return jnp.mean(sac_alpha * log_prob - min_q)

        alpha_grad = jax.value_and_grad(alpha_loss)
        cql_lagrange_grad = jax.value_and_grad(cql_lagrange_loss)
        critic_grad = jax.value_and_grad(critic_loss)
        actor_grad = jax.value_and_grad(actor_loss)

        def update_step(
            state: TrainingState,
            rb_transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)

            if adaptive_entropy_coefficient:
                alpha_loss, alpha_grads = alpha_grad(state.log_sac_alpha,
                                                     state.policy_params,
                                                     rb_transitions, key_alpha)
                sac_alpha = jnp.exp(state.log_sac_alpha)
            else:
                sac_alpha = fixed_entropy_coefficient

            if adaptive_cql_coefficient:
                cql_lagrange_loss, cql_lagrange_grads = cql_lagrange_grad(
                    state.log_cql_alpha, state.critic_params,
                    state.policy_params, rb_transitions, key_critic)
                cql_lagrange_grads = jnp.clip(cql_lagrange_grads,
                                              -_CQL_GRAD_CLIPPING_VALUE,
                                              _CQL_GRAD_CLIPPING_VALUE)
                cql_alpha = jnp.exp(state.log_cql_alpha)
                cql_alpha = jnp.clip(cql_alpha,
                                     a_min=0.,
                                     a_max=_CQL_COEFFICIENT_MAX_VALUE)
            else:
                cql_alpha = fixed_cql_coefficient

            critic_loss, critic_grads = critic_grad(state.critic_params,
                                                    state.policy_params,
                                                    state.target_critic_params,
                                                    sac_alpha, cql_alpha,
                                                    rb_transitions, key_critic)
            actor_loss, actor_grads = actor_grad(state.policy_params,
                                                 state.critic_params,
                                                 sac_alpha, rb_transitions,
                                                 key_actor)

            # Apply policy gradients
            actor_update, policy_optimizer_state = policy_optimizer.update(
                actor_grads, state.policy_optimizer_state)
            policy_params = optax.apply_updates(state.policy_params,
                                                actor_update)

            # Apply critic gradients
            critic_update, critic_optimizer_state = critic_optimizer.update(
                critic_grads, state.critic_optimizer_state)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_update)

            new_target_critic_params = jax.tree_multimap(
                lambda x, y: x * (1 - tau) + y * tau,
                state.target_critic_params, critic_params)

            metrics = {
                'critic_loss': critic_loss,
                'actor_loss': actor_loss,
            }

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                critic_optimizer_state=critic_optimizer_state,
                policy_params=policy_params,
                critic_params=critic_params,
                target_critic_params=new_target_critic_params,
                key=key,
                steps=state.steps + 1,
            )
            if adaptive_entropy_coefficient:
                # Apply sac_alpha gradients
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                log_sac_alpha = optax.apply_updates(state.log_sac_alpha,
                                                    alpha_update)
                metrics.update({
                    'alpha_loss': alpha_loss,
                    'sac_alpha': jnp.exp(log_sac_alpha),
                })
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    log_sac_alpha=log_sac_alpha)

            if adaptive_cql_coefficient:
                # Apply cql coeff gradients
                cql_update, cql_optimizer_state = cql_optimizer.update(
                    cql_lagrange_grads, state.cql_optimizer_state)
                log_cql_alpha = optax.apply_updates(state.log_cql_alpha,
                                                    cql_update)
                metrics.update({
                    'cql_lagrange_loss': cql_lagrange_loss,
                    'cql_alpha': jnp.exp(log_cql_alpha),
                })
                new_state = new_state._replace(
                    cql_optimizer_state=cql_optimizer_state,
                    log_cql_alpha=log_cql_alpha)

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

        # Iterator on demonstration transitions.
        self._demonstrations = demonstrations

        # Use the JIT compiler.
        self._update_step = utils.process_multiple_batches(
            update_step, num_sgd_steps_per_step)
        self._update_step = jax.jit(self._update_step)

        # Create initial state.
        key_policy, key_q, training_state_key = jax.random.split(random_key, 3)
        del random_key
        policy_params = networks.policy_network.init(key_policy)
        policy_optimizer_state = policy_optimizer.init(policy_params)
        critic_params = networks.critic_network.init(key_q)
        critic_optimizer_state = critic_optimizer.init(critic_params)

        self._state = TrainingState(
            policy_optimizer_state=policy_optimizer_state,
            critic_optimizer_state=critic_optimizer_state,
            policy_params=policy_params,
            critic_params=critic_params,
            target_critic_params=critic_params,
            key=training_state_key,
            steps=0)

        if adaptive_entropy_coefficient:
            self._state = self._state._replace(
                alpha_optimizer_state=alpha_optimizer_state,
                log_sac_alpha=log_sac_alpha)
        if adaptive_cql_coefficient:
            self._state = self._state._replace(
                cql_optimizer_state=cql_optimizer_state,
                log_cql_alpha=log_cql_alpha)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
    def __init__(self, networks, rng, policy_optimizer, q_optimizer, iterator,
                 counter, logger, obs_to_goal, config):
        """Initialize the Contrastive RL learner.

    Args:
      networks: Contrastive RL networks.
      rng: a key for random number generation.
      policy_optimizer: the policy optimizer.
      q_optimizer: the Q-function optimizer.
      iterator: an iterator over training data.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      obs_to_goal: a function for extracting the goal coordinates.
      config: the experiment config file.
    """
        if config.add_mc_to_td:
            assert config.use_td
        adaptive_entropy_coefficient = config.entropy_coefficient is None
        self._num_sgd_steps_per_step = config.num_sgd_steps_per_step
        self._obs_dim = config.obs_dim
        self._use_td = config.use_td
        if adaptive_entropy_coefficient:
            # alpha is the temperature parameter that determines the relative
            # importance of the entropy term versus the reward.
            log_alpha = jnp.asarray(0., dtype=jnp.float32)
            alpha_optimizer = optax.adam(learning_rate=3e-4)
            alpha_optimizer_state = alpha_optimizer.init(log_alpha)
        else:
            if config.target_entropy:
                raise ValueError('target_entropy should not be set when '
                                 'entropy_coefficient is provided')

        def alpha_loss(log_alpha, policy_params, transitions, key):
            """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            alpha = jnp.exp(log_alpha)
            alpha_loss = alpha * jax.lax.stop_gradient(-log_prob -
                                                       config.target_entropy)
            return jnp.mean(alpha_loss)

        def critic_loss(q_params, policy_params, target_q_params, transitions,
                        key):
            batch_size = transitions.observation.shape[0]
            # Note: We might be able to speed up the computation for some of the
            # baselines to making a single network that returns all the values. This
            # avoids computing some of the underlying representations multiple times.
            if config.use_td:
                # For TD learning, the diagonal elements are the immediate next state.
                s, g = jnp.split(transitions.observation, [config.obs_dim],
                                 axis=1)
                next_s, _ = jnp.split(transitions.next_observation,
                                      [config.obs_dim],
                                      axis=1)
                if config.add_mc_to_td:
                    next_fraction = (1 - config.discount) / (
                        (1 - config.discount) + 1)
                    num_next = int(batch_size * next_fraction)
                    new_g = jnp.concatenate([
                        obs_to_goal(next_s[:num_next]),
                        g[num_next:],
                    ],
                                            axis=0)
                else:
                    new_g = obs_to_goal(next_s)
                obs = jnp.concatenate([s, new_g], axis=1)
                transitions = transitions._replace(observation=obs)
            I = jnp.eye(batch_size)  # pylint: disable=invalid-name
            logits = networks.q_network.apply(q_params,
                                              transitions.observation,
                                              transitions.action)

            if config.use_td:
                # Make sure to use the twin Q trick.
                assert len(logits.shape) == 3

                # We evaluate the next-state Q function using random goals
                s, g = jnp.split(transitions.observation, [config.obs_dim],
                                 axis=1)
                del s
                next_s = transitions.next_observation[:, :config.obs_dim]
                goal_indices = jnp.roll(
                    jnp.arange(batch_size, dtype=jnp.int32), -1)
                g = g[goal_indices]
                transitions = transitions._replace(
                    next_observation=jnp.concatenate([next_s, g], axis=1))
                next_dist_params = networks.policy_network.apply(
                    policy_params, transitions.next_observation)
                next_action = networks.sample(next_dist_params, key)
                next_q = networks.q_network.apply(
                    target_q_params, transitions.next_observation,
                    next_action)  # This outputs logits.
                next_q = jax.nn.sigmoid(next_q)
                next_v = jnp.min(next_q, axis=-1)
                next_v = jax.lax.stop_gradient(next_v)
                next_v = jnp.diag(next_v)
                # diag(logits) are predictions for future states.
                # diag(next_q) are predictions for random states, which correspond to
                # the predictions logits[range(B), goal_indices].
                # So, the only thing that's meaningful for next_q is the diagonal. Off
                # diagonal entries are meaningless and shouldn't be used.
                w = next_v / (1 - next_v)
                w_clipping = 20.0
                w = jnp.clip(w, 0, w_clipping)
                # (B, B, 2) --> (B, 2), computes diagonal of each twin Q.
                pos_logits = jax.vmap(jnp.diag, -1, -1)(logits)
                loss_pos = optax.sigmoid_binary_cross_entropy(
                    logits=pos_logits, labels=1)  # [B, 2]

                neg_logits = logits[jnp.arange(batch_size), goal_indices]
                loss_neg1 = w[:, None] * optax.sigmoid_binary_cross_entropy(
                    logits=neg_logits, labels=1)  # [B, 2]
                loss_neg2 = optax.sigmoid_binary_cross_entropy(
                    logits=neg_logits, labels=0)  # [B, 2]

                if config.add_mc_to_td:
                    loss = ((1 + (1 - config.discount)) * loss_pos +
                            config.discount * loss_neg1 + 2 * loss_neg2)
                else:
                    loss = ((1 - config.discount) * loss_pos +
                            config.discount * loss_neg1 + loss_neg2)
                # Take the mean here so that we can compute the accuracy.
                logits = jnp.mean(logits, axis=-1)

            else:  # For the MC losses.

                def loss_fn(_logits):  # pylint: disable=invalid-name
                    if config.use_cpc:
                        return (optax.softmax_cross_entropy(logits=_logits,
                                                            labels=I) +
                                0.01 * jax.nn.logsumexp(_logits, axis=1)**2)
                    else:
                        return optax.sigmoid_binary_cross_entropy(
                            logits=_logits, labels=I)

                if len(logits.shape) == 3:  # twin q
                    # loss.shape = [.., num_q]
                    loss = jax.vmap(loss_fn, in_axes=2, out_axes=-1)(logits)
                    loss = jnp.mean(loss, axis=-1)
                    # Take the mean here so that we can compute the accuracy.
                    logits = jnp.mean(logits, axis=-1)
                else:
                    loss = loss_fn(logits)

            loss = jnp.mean(loss)
            correct = (jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1))
            logits_pos = jnp.sum(logits * I) / jnp.sum(I)
            logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)
            if len(logits.shape) == 3:
                logsumexp = jax.nn.logsumexp(logits[:, :, 0], axis=1)**2
            else:
                logsumexp = jax.nn.logsumexp(logits, axis=1)**2
            metrics = {
                'binary_accuracy': jnp.mean((logits > 0) == I),
                'categorical_accuracy': jnp.mean(correct),
                'logits_pos': logits_pos,
                'logits_neg': logits_neg,
                'logsumexp': logsumexp.mean(),
            }

            return loss, metrics

        def actor_loss(
            policy_params,
            q_params,
            alpha,
            transitions,
            key,
        ):
            obs = transitions.observation
            if config.use_gcbc:
                dist_params = networks.policy_network.apply(policy_params, obs)
                log_prob = networks.log_prob(dist_params, transitions.action)
                actor_loss = -1.0 * jnp.mean(log_prob)
            else:
                state = obs[:, :config.obs_dim]
                goal = obs[:, config.obs_dim:]

                if config.random_goals == 0.0:
                    new_state = state
                    new_goal = goal
                elif config.random_goals == 0.5:
                    new_state = jnp.concatenate([state, state], axis=0)
                    new_goal = jnp.concatenate(
                        [goal, jnp.roll(goal, 1, axis=0)], axis=0)
                else:
                    assert config.random_goals == 1.0
                    new_state = state
                    new_goal = jnp.roll(goal, 1, axis=0)

                new_obs = jnp.concatenate([new_state, new_goal], axis=1)
                dist_params = networks.policy_network.apply(
                    policy_params, new_obs)
                action = networks.sample(dist_params, key)
                log_prob = networks.log_prob(dist_params, action)
                q_action = networks.q_network.apply(q_params, new_obs, action)
                if len(q_action.shape) == 3:  # twin q trick
                    assert q_action.shape[2] == 2
                    q_action = jnp.mean(q_action, axis=-1)
                actor_loss = alpha * log_prob - jnp.diag(q_action)

            return jnp.mean(actor_loss)

        alpha_grad = jax.value_and_grad(alpha_loss)
        critic_grad = jax.value_and_grad(critic_loss, has_aux=True)
        actor_grad = jax.value_and_grad(actor_loss)

        def update_step(
            state,
            transitions,
        ):

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)
            if adaptive_entropy_coefficient:
                alpha_loss, alpha_grads = alpha_grad(state.alpha_params,
                                                     state.policy_params,
                                                     transitions, key_alpha)
                alpha = jnp.exp(state.alpha_params)
            else:
                alpha = config.entropy_coefficient

            if not config.use_gcbc:
                (critic_loss, critic_metrics), critic_grads = critic_grad(
                    state.q_params, state.policy_params, state.target_q_params,
                    transitions, key_critic)

            actor_loss, actor_grads = actor_grad(state.policy_params,
                                                 state.q_params, alpha,
                                                 transitions, key_actor)

            # Apply policy gradients
            actor_update, policy_optimizer_state = policy_optimizer.update(
                actor_grads, state.policy_optimizer_state)
            policy_params = optax.apply_updates(state.policy_params,
                                                actor_update)

            # Apply critic gradients
            if config.use_gcbc:
                metrics = {}
                critic_loss = 0.0
                q_params = state.q_params
                q_optimizer_state = state.q_optimizer_state
                new_target_q_params = state.target_q_params
            else:
                critic_update, q_optimizer_state = q_optimizer.update(
                    critic_grads, state.q_optimizer_state)

                q_params = optax.apply_updates(state.q_params, critic_update)

                new_target_q_params = jax.tree_multimap(
                    lambda x, y: x * (1 - config.tau) + y * config.tau,
                    state.target_q_params, q_params)
                metrics = critic_metrics

            metrics.update({
                'critic_loss': critic_loss,
                'actor_loss': actor_loss,
            })

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=new_target_q_params,
                key=key,
            )
            if adaptive_entropy_coefficient:
                # Apply alpha gradients
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                alpha_params = optax.apply_updates(state.alpha_params,
                                                   alpha_update)
                metrics.update({
                    'alpha_loss': alpha_loss,
                    'alpha': jnp.exp(alpha_params),
                })
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=alpha_params)

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            time_delta=10.0)

        # Iterator on demonstration transitions.
        self._iterator = iterator

        update_step = utils.process_multiple_batches(
            update_step, config.num_sgd_steps_per_step)
        # Use the JIT compiler.
        if config.jit:
            self._update_step = jax.jit(update_step)
        else:
            self._update_step = update_step

        def make_initial_state(key):
            """Initialises the training state (parameters and optimiser state)."""
            key_policy, key_q, key = jax.random.split(key, 3)

            policy_params = networks.policy_network.init(key_policy)
            policy_optimizer_state = policy_optimizer.init(policy_params)

            q_params = networks.q_network.init(key_q)
            q_optimizer_state = q_optimizer.init(q_params)

            state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=q_params,
                key=key)

            if adaptive_entropy_coefficient:
                state = state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=log_alpha)
            return state

        # Create initial state.
        self._state = make_initial_state(rng)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online
        # and fill the replay buffer.
        self._timestamp = None
Esempio n. 7
0
    def __init__(self,
                 counter: counting.Counter,
                 direct_rl_learner_factory: Callable[
                     [Iterator[reverb.ReplaySample]], acme.Learner],
                 loss_fn: losses.Loss,
                 iterator: Iterator[AILSample],
                 discriminator_optimizer: optax.GradientTransformation,
                 ail_network: ail_networks.AILNetworks,
                 discriminator_key: networks_lib.PRNGKey,
                 is_sequence_based: bool,
                 num_sgd_steps_per_step: int = 1,
                 policy_variable_name: Optional[str] = None,
                 logger: Optional[loggers.Logger] = None):
        """AIL Learner.

    Args:
      counter: Counter.
      direct_rl_learner_factory: Function that creates the direct RL learner
        when passed a replay sample iterator.
      loss_fn: Discriminator loss.
      iterator: Iterator that returns AILSamples.
      discriminator_optimizer: Discriminator optax optimizer.
      ail_network: AIL networks.
      discriminator_key: RNG key.
      is_sequence_based: If True, a direct rl algorithm is using SequenceAdder
        data format. Otherwise the learner assumes that the direct rl algorithm
        is using NStepTransitionAdder.
      num_sgd_steps_per_step: Number of discriminator gradient updates per step.
      policy_variable_name: The name of the policy variable to retrieve
        direct_rl policy parameters.
      logger: Logger.
    """
        self._is_sequence_based = is_sequence_based

        state_key, networks_key = jax.random.split(discriminator_key)

        # Generator expression that works the same as an iterator.
        # https://pymbook.readthedocs.io/en/latest/igd.html#generator-expressions
        iterator, direct_rl_iterator = itertools.tee(iterator)
        direct_rl_iterator = (self._process_sample(sample.direct_sample)
                              for sample in direct_rl_iterator)
        self._direct_rl_learner = direct_rl_learner_factory(direct_rl_iterator)

        self._iterator = iterator

        if policy_variable_name is not None:

            def get_policy_params():
                return self._direct_rl_learner.get_variables(
                    [policy_variable_name])[0]

            self._get_policy_params = get_policy_params

        else:
            self._get_policy_params = lambda: None

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

        # Use the JIT compiler.
        self._update_step = functools.partial(
            ail_update_step,
            optimizer=discriminator_optimizer,
            ail_network=ail_network,
            loss_fn=loss_fn)
        self._update_step = utils.process_multiple_batches(
            self._update_step, num_sgd_steps_per_step)
        self._update_step = jax.jit(self._update_step)

        discriminator_params, discriminator_state = (
            ail_network.discriminator_network.init(networks_key))
        self._state = DiscriminatorTrainingState(
            optimizer_state=discriminator_optimizer.init(discriminator_params),
            discriminator_params=discriminator_params,
            discriminator_state=discriminator_state,
            policy_params=self._get_policy_params(),
            key=state_key,
            steps=0,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None

        self._get_reward = jax.jit(
            functools.partial(ail_networks.compute_ail_reward,
                              networks=ail_network))
Esempio n. 8
0
  def __init__(self,
               policy_network: networks_lib.FeedForwardNetwork,
               critic_network: networks_lib.FeedForwardNetwork,
               random_key: networks_lib.PRNGKey,
               discount: float,
               target_update_period: int,
               iterator: Iterator[reverb.ReplaySample],
               policy_optimizer: Optional[optax.GradientTransformation] = None,
               critic_optimizer: Optional[optax.GradientTransformation] = None,
               clipping: bool = True,
               counter: Optional[counting.Counter] = None,
               logger: Optional[loggers.Logger] = None,
               jit: bool = True,
               num_sgd_steps_per_step: int = 1):

    def critic_mean(
        critic_params: networks_lib.Params,
        observation: types.NestedArray,
        action: types.NestedArray,
    ) -> jnp.ndarray:
      # We add batch dimension to make sure batch concat in critic_network
      # works correctly.
      observation = utils.add_batch_dim(observation)
      action = utils.add_batch_dim(action)
      # Computes the mean action-value estimate.
      logits, atoms = critic_network.apply(critic_params, observation, action)
      logits = utils.squeeze_batch_dim(logits)
      probabilities = jax.nn.softmax(logits)
      return jnp.sum(probabilities * atoms, axis=-1)

    def policy_loss(
        policy_params: networks_lib.Params,
        critic_params: networks_lib.Params,
        o_t: types.NestedArray,
    ) -> jnp.ndarray:
      # Computes the discrete policy gradient loss.
      dpg_a_t = policy_network.apply(policy_params, o_t)
      grad_critic = jax.vmap(
          jax.grad(critic_mean, argnums=2), in_axes=(None, 0, 0))
      dq_da = grad_critic(critic_params, o_t, dpg_a_t)
      dqda_clipping = 1. if clipping else None
      batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0, None))
      loss = batch_dpg_learning(dpg_a_t, dq_da, dqda_clipping)
      return jnp.mean(loss)

    def critic_loss(
        critic_params: networks_lib.Params,
        state: TrainingState,
        transition: types.Transition,
    ):
      # Computes the distributional critic loss.
      q_tm1, atoms_tm1 = critic_network.apply(critic_params,
                                              transition.observation,
                                              transition.action)
      a = policy_network.apply(state.target_policy_params,
                               transition.next_observation)
      q_t, atoms_t = critic_network.apply(state.target_critic_params,
                                          transition.next_observation, a)
      batch_td_learning = jax.vmap(
          rlax.categorical_td_learning, in_axes=(None, 0, 0, 0, None, 0))
      loss = batch_td_learning(atoms_tm1, q_tm1, transition.reward,
                               discount * transition.discount, atoms_t, q_t)
      return jnp.mean(loss)

    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      # TODO(jaslanides): Use a shared forward pass for efficiency.
      policy_loss_and_grad = jax.value_and_grad(policy_loss)
      critic_loss_and_grad = jax.value_and_grad(critic_loss)

      # Compute losses and their gradients.
      policy_loss_value, policy_gradients = policy_loss_and_grad(
          state.policy_params, state.critic_params,
          transitions.next_observation)
      critic_loss_value, critic_gradients = critic_loss_and_grad(
          state.critic_params, state, transitions)

      # Get optimizer updates and state.
      policy_updates, policy_opt_state = policy_optimizer.update(  # pytype: disable=attribute-error
          policy_gradients, state.policy_opt_state)
      critic_updates, critic_opt_state = critic_optimizer.update(  # pytype: disable=attribute-error
          critic_gradients, state.critic_opt_state)

      # Apply optimizer updates to parameters.
      policy_params = optax.apply_updates(state.policy_params, policy_updates)
      critic_params = optax.apply_updates(state.critic_params, critic_updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_policy_params, target_critic_params = optax.periodic_update(
          (policy_params, critic_params),
          (state.target_policy_params, state.target_critic_params), steps,
          self._target_update_period)

      new_state = TrainingState(
          policy_params=policy_params,
          critic_params=critic_params,
          target_policy_params=target_policy_params,
          target_critic_params=target_critic_params,
          policy_opt_state=policy_opt_state,
          critic_opt_state=critic_opt_state,
          steps=steps,
      )

      metrics = {
          'policy_loss': policy_loss_value,
          'critic_loss': critic_loss_value,
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        'learner',
        asynchronous=True,
        serialize_fn=utils.fetch_devicearray,
        steps_key=self._counter.get_steps_key())

    # Necessary to track when to update target networks.
    self._target_update_period = target_update_period

    # Create prefetching dataset iterator.
    self._iterator = iterator

    # Maybe use the JIT compiler.
    sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step)
    self._sgd_step = jax.jit(sgd_step) if jit else sgd_step

    # Create the network parameters and copy into the target network parameters.
    key_policy, key_critic = jax.random.split(random_key)
    initial_policy_params = policy_network.init(key_policy)
    initial_critic_params = critic_network.init(key_critic)
    initial_target_policy_params = initial_policy_params
    initial_target_critic_params = initial_critic_params

    # Create optimizers if they aren't given.
    critic_optimizer = critic_optimizer or optax.adam(1e-4)
    policy_optimizer = policy_optimizer or optax.adam(1e-4)

    # Initialize optimizers.
    initial_policy_opt_state = policy_optimizer.init(initial_policy_params)  # pytype: disable=attribute-error
    initial_critic_opt_state = critic_optimizer.init(initial_critic_params)  # pytype: disable=attribute-error

    # Create initial state.
    self._state = TrainingState(
        policy_params=initial_policy_params,
        target_policy_params=initial_target_policy_params,
        critic_params=initial_critic_params,
        target_critic_params=initial_target_critic_params,
        policy_opt_state=initial_policy_opt_state,
        critic_opt_state=initial_critic_opt_state,
        steps=0,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Esempio n. 9
0
    def __init__(self,
                 network: networks_lib.FeedForwardNetwork,
                 random_key: networks_lib.PRNGKey,
                 loss_fn: losses.Loss,
                 optimizer: optax.GradientTransformation,
                 prefetching_iterator: Iterator[types.Transition],
                 num_sgd_steps_per_step: int,
                 loss_has_aux: bool = False,
                 logger: Optional[loggers.Logger] = None,
                 counter: Optional[counting.Counter] = None):
        """Behavior Cloning Learner.

    Args:
      network: Networks with signature for apply: (params, obs, is_training,
        key) -> jnp.ndarray and for init: (rng, is_training) -> params
      random_key: RNG key.
      loss_fn: BC loss to use.
      optimizer: Optax optimizer.
      prefetching_iterator: A sharded prefetching iterator as outputted from
        `acme.jax.utils.sharded_prefetch`. Please see the documentation for
        `sharded_prefetch` for more details.
      num_sgd_steps_per_step: Number of gradient updates per step.
      loss_has_aux: Whether the loss function returns auxiliary metrics as a
        second argument.
      logger: Logger.
      counter: Counter.
    """
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            loss_and_grad = jax.value_and_grad(loss_fn,
                                               argnums=1,
                                               has_aux=loss_has_aux)

            # Compute losses and their gradients.
            key, key_input = jax.random.split(state.key)
            loss_result, gradients = loss_and_grad(network.apply,
                                                   state.policy_params,
                                                   key_input, transitions)

            # Combine the gradient across all devices (by taking their mean).
            gradients = jax.lax.pmean(gradients, axis_name=_PMAP_AXIS_NAME)

            # Compute and combine metrics across all devices.
            metrics = _create_loss_metrics(loss_has_aux, loss_result,
                                           gradients)
            metrics = jax.lax.pmean(metrics, axis_name=_PMAP_AXIS_NAME)

            policy_update, optimizer_state = optimizer.update(
                gradients, state.optimizer_state, state.policy_params)
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_update)

            new_state = TrainingState(
                optimizer_state=optimizer_state,
                policy_params=policy_params,
                key=key,
                steps=state.steps + 1,
            )

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter(prefix='learner')
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Split the input batch to `num_sgd_steps_per_step` minibatches in order
        # to achieve better performance on accelerators.
        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  num_sgd_steps_per_step)
        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)

        random_key, init_key = jax.random.split(random_key)
        policy_params = network.init(init_key)
        optimizer_state = optimizer.init(policy_params)

        # Create initial state.
        state = TrainingState(
            optimizer_state=optimizer_state,
            policy_params=policy_params,
            key=random_key,
            steps=0,
        )
        self._state = utils.replicate_in_all_devices(state)

        self._timestamp = None

        self._prefetching_iterator = prefetching_iterator
Esempio n. 10
0
        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

        # Iterator on demonstration transitions.
        self._iterator = iterator

        # update_step = utils.process_multiple_batches(update_step,
        #                                              num_sgd_steps_per_step)
        # # Use the JIT compiler.
        # self._update_step = jax.jit(update_step)

        self._update_step_in_initial_bc_iters = jax.jit(
            utils.process_multiple_batches(
                lambda x, y: _full_update_step(x, y, True),
                num_sgd_steps_per_step))
        self._update_step_rest = jax.jit(
            utils.process_multiple_batches(
                lambda x, y: _full_update_step(x, y, False),
                num_sgd_steps_per_step))

        def make_initial_state(key):
            """Initialises the training state (parameters and optimiser state)."""
            key_policy, key_q, key = jax.random.split(key, 3)

            policy_params = networks.policy_network.init(key_policy)
            policy_optimizer_state = policy_optimizer.init(policy_params)

            q_params = networks.q_network.init(key_q)
            q_optimizer_state = q_optimizer.init(q_params)
Esempio n. 11
0
    def __init__(self,
                 networks: sac_networks.SACNetworks,
                 rng: jnp.ndarray,
                 iterator: Iterator[reverb.ReplaySample],
                 policy_optimizer: optax.GradientTransformation,
                 q_optimizer: optax.GradientTransformation,
                 tau: float = 0.005,
                 reward_scale: float = 1.0,
                 discount: float = 0.99,
                 entropy_coefficient: Optional[float] = None,
                 target_entropy: float = 0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initialize the SAC learner.

    Args:
      networks: SAC networks
      rng: a key for random number generation.
      iterator: an iterator over training data.
      policy_optimizer: the policy optimizer.
      q_optimizer: the Q-function optimizer.
      tau: target smoothing coefficient.
      reward_scale: reward scale.
      discount: discount to use for TD updates.
      entropy_coefficient: coefficient applied to the entropy bonus. If None, an
        adaptative coefficient will be used.
      target_entropy: Used to normalize entropy. Only used when
        entropy_coefficient is None.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'.
    """
        adaptive_entropy_coefficient = entropy_coefficient is None
        if adaptive_entropy_coefficient:
            # alpha is the temperature parameter that determines the relative
            # importance of the entropy term versus the reward.
            log_alpha = jnp.asarray(0., dtype=jnp.float32)
            alpha_optimizer = optax.adam(learning_rate=3e-4)
            alpha_optimizer_state = alpha_optimizer.init(log_alpha)
        else:
            if target_entropy:
                raise ValueError('target_entropy should not be set when '
                                 'entropy_coefficient is provided')

        def alpha_loss(log_alpha: jnp.ndarray,
                       policy_params: networks_lib.Params,
                       transitions: types.Transition,
                       key: networks_lib.PRNGKey) -> jnp.ndarray:
            """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            alpha = jnp.exp(log_alpha)
            alpha_loss = alpha * jax.lax.stop_gradient(-log_prob -
                                                       target_entropy)
            return jnp.mean(alpha_loss)

        def critic_loss(q_params: networks_lib.Params,
                        policy_params: networks_lib.Params,
                        target_q_params: networks_lib.Params,
                        alpha: jnp.ndarray, transitions: types.Transition,
                        key: networks_lib.PRNGKey) -> jnp.ndarray:
            q_old_action = networks.q_network.apply(q_params,
                                                    transitions.observation,
                                                    transitions.action)
            next_dist_params = networks.policy_network.apply(
                policy_params, transitions.next_observation)
            next_action = networks.sample(next_dist_params, key)
            next_log_prob = networks.log_prob(next_dist_params, next_action)
            next_q = networks.q_network.apply(target_q_params,
                                              transitions.next_observation,
                                              next_action)
            next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob
            target_q = jax.lax.stop_gradient(
                transitions.reward * reward_scale +
                transitions.discount * discount * next_v)
            q_error = q_old_action - jnp.expand_dims(target_q, -1)
            q_loss = 0.5 * jnp.mean(jnp.square(q_error))
            return q_loss

        def actor_loss(policy_params: networks_lib.Params,
                       q_params: networks_lib.Params, alpha: jnp.ndarray,
                       transitions: types.Transition,
                       key: networks_lib.PRNGKey) -> jnp.ndarray:
            dist_params = networks.policy_network.apply(
                policy_params, transitions.observation)
            action = networks.sample(dist_params, key)
            log_prob = networks.log_prob(dist_params, action)
            q_action = networks.q_network.apply(q_params,
                                                transitions.observation,
                                                action)
            min_q = jnp.min(q_action, axis=-1)
            actor_loss = alpha * log_prob - min_q
            return jnp.mean(actor_loss)

        alpha_grad = jax.value_and_grad(alpha_loss)
        critic_grad = jax.value_and_grad(critic_loss)
        actor_grad = jax.value_and_grad(actor_loss)

        def update_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_alpha, key_critic, key_actor = jax.random.split(
                state.key, 4)
            if adaptive_entropy_coefficient:
                alpha_loss, alpha_grads = alpha_grad(state.alpha_params,
                                                     state.policy_params,
                                                     transitions, key_alpha)
                alpha = jnp.exp(state.alpha_params)
            else:
                alpha = entropy_coefficient
            critic_loss, critic_grads = critic_grad(state.q_params,
                                                    state.policy_params,
                                                    state.target_q_params,
                                                    alpha, transitions,
                                                    key_critic)
            actor_loss, actor_grads = actor_grad(state.policy_params,
                                                 state.q_params, alpha,
                                                 transitions, key_actor)

            # Apply policy gradients
            actor_update, policy_optimizer_state = policy_optimizer.update(
                actor_grads, state.policy_optimizer_state)
            policy_params = optax.apply_updates(state.policy_params,
                                                actor_update)

            # Apply critic gradients
            critic_update, q_optimizer_state = q_optimizer.update(
                critic_grads, state.q_optimizer_state)
            q_params = optax.apply_updates(state.q_params, critic_update)

            new_target_q_params = jax.tree_map(
                lambda x, y: x * (1 - tau) + y * tau, state.target_q_params,
                q_params)

            metrics = {
                'critic_loss': critic_loss,
                'actor_loss': actor_loss,
            }

            new_state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=new_target_q_params,
                key=key,
            )
            if adaptive_entropy_coefficient:
                # Apply alpha gradients
                alpha_update, alpha_optimizer_state = alpha_optimizer.update(
                    alpha_grads, state.alpha_optimizer_state)
                alpha_params = optax.apply_updates(state.alpha_params,
                                                   alpha_update)
                metrics.update({
                    'alpha_loss': alpha_loss,
                    'alpha': jnp.exp(alpha_params),
                })
                new_state = new_state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=alpha_params)

            metrics['rewards_mean'] = jnp.mean(
                jnp.abs(jnp.mean(transitions.reward, axis=0)))
            metrics['rewards_std'] = jnp.std(transitions.reward, axis=0)

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Iterator on demonstration transitions.
        self._iterator = iterator

        update_step = utils.process_multiple_batches(update_step,
                                                     num_sgd_steps_per_step)
        # Use the JIT compiler.
        self._update_step = jax.jit(update_step)

        def make_initial_state(key: networks_lib.PRNGKey) -> TrainingState:
            """Initialises the training state (parameters and optimiser state)."""
            key_policy, key_q, key = jax.random.split(key, 3)

            policy_params = networks.policy_network.init(key_policy)
            policy_optimizer_state = policy_optimizer.init(policy_params)

            q_params = networks.q_network.init(key_q)
            q_optimizer_state = q_optimizer.init(q_params)

            state = TrainingState(
                policy_optimizer_state=policy_optimizer_state,
                q_optimizer_state=q_optimizer_state,
                policy_params=policy_params,
                q_params=q_params,
                target_q_params=q_params,
                key=key)

            if adaptive_entropy_coefficient:
                state = state._replace(
                    alpha_optimizer_state=alpha_optimizer_state,
                    alpha_params=log_alpha)
            return state

        # Create initial state.
        self._state = make_initial_state(rng)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Esempio n. 12
0
  def __init__(
      self,
      networks,
      rng,
      iterator,
      policy_lr = 1e-4,
      loss_type = 'MLE', # or MSE
      regularize_entropy = False,
      entropy_regularization_weight = 1.0,
      use_img_encoder = False,
      img_encoder_params_ckpt_path = '',
      counter = None,
      logger = None,
      num_sgd_steps_per_step = 1):
    """Initialize the BC learner.

    Args:
      networks: BC networks
      rng: a key for random number generation.
      iterator: an iterator over training data.
      policy_lr: learning rate for the policy
      regularize_entropy: whether to regularize the entropy of the policy.
      entropy_regularization_weight: weight for entropy regularization.
      use_img_encoder: whether to preprocess the image part of the observation
        using a pretrained encoder.
      img_encoder_params_ckpt_path: path to checkpoint for image encoder params
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'.
    """
    assert loss_type in ['MLE', 'MSE'], 'Invalid BC loss type!'
    num_devices = len(jax.devices())
    self._num_sgd_steps_per_step = num_sgd_steps_per_step
    self._use_img_encoder = use_img_encoder
    policy_optimizer = optax.adam(learning_rate=policy_lr)

    def actor_loss(
        policy_params,
        transitions,
        key,
        img_encoder_params):
      obs = transitions.observation
      acts = transitions.action

      if use_img_encoder:
        img = obs['state_image']
        dense = obs['state_dense']
        obs = dict(
            state_image=networks.img_encoder.apply(img_encoder_params, img),
            state_dense=dense,)

      dist = networks.policy_network.apply(policy_params, obs)
      if loss_type == 'MLE':
        log_probs = networks.log_prob(dist, acts)
        loss = -1. * jnp.mean(log_probs)
      else:
        acts_mode = dist.mode()
        mse = jnp.sum((acts_mode - acts)**2, axis=-1)
        loss = 0.5 * jnp.mean(mse)

      total_loss = loss
      entropy_term = 0.
      if regularize_entropy:
        sample_acts = networks.sample(dist, key)
        sample_log_probs = networks.log_prob(dist, sample_acts)
        entropy_term = jnp.mean(sample_log_probs)
        total_loss = total_loss + entropy_regularization_weight * entropy_term

      return total_loss, (loss, entropy_term)

    actor_loss_and_grad = jax.value_and_grad(actor_loss, has_aux=True)

    def actor_update_step(
        policy_params,
        optim_state,
        transitions,
        key,
        img_encoder_params):
      (total_loss, (bc_loss_term, entropy_term)), actor_grad = actor_loss_and_grad(
          policy_params,
          transitions,
          key,
          img_encoder_params)
      actor_grad = jax.lax.pmean(actor_grad, 'across_devices')

      policy_update, optim_state = policy_optimizer.update(actor_grad, optim_state)
      policy_params = optax.apply_updates(policy_params, policy_update)

      return policy_params, optim_state, total_loss, bc_loss_term, entropy_term

    pmapped_actor_update_step = jax.pmap(
        actor_update_step,
        axis_name='across_devices',
        in_axes=0,
        out_axes=0)


    def _full_update_step(
        state,
        transitions,
    ):
      """The unjitted version of the full update step."""

      metrics = OrderedDict()

      key = state.key

      # actor update step
      def reshape_for_devices(t):
        rest_t_shape = list(t.shape[1:])
        new_shape = [num_devices, t.shape[0]//num_devices,] + rest_t_shape
        return jnp.reshape(t, new_shape)
      transitions = jax.tree_map(reshape_for_devices, transitions)
      sub_keys = jax.random.split(key, num_devices + 1)
      key = sub_keys[0]
      sub_keys = sub_keys[1:]

      new_policy_params, new_policy_optimizer_state, total_loss, bc_loss_term, entropy_term = pmapped_actor_update_step(
          state.policy_params,
          state.policy_optimizer_state,
          transitions,
          sub_keys,
          state.img_encoder_params)
      metrics['total_actor_loss'] = jnp.mean(total_loss)
      metrics['BC_loss'] = jnp.mean(bc_loss_term)
      metrics['entropy_loss'] = jnp.mean(entropy_term)


      # create new state
      new_state = TrainingState(
          policy_optimizer_state=new_policy_optimizer_state,
          policy_params=new_policy_params,
          key=key,
          img_encoder_params=state.img_encoder_params)

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

    # Iterator on demonstration transitions.
    self._iterator = iterator

    self._update_step = utils.process_multiple_batches(
        _full_update_step,
        num_sgd_steps_per_step)


    def make_initial_state(key):
      """"""
      # policy stuff
      key, sub_key = jax.random.split(key)
      policy_params = networks.policy_network.init(sub_key)
      policy_optimizer_state = policy_optimizer.init(policy_params)

      devices = jax.local_devices()
      replicated_policy_params = jax.device_put_replicated(
          policy_params, devices)
      replicated_optim_state = jax.device_put_replicated(
          policy_optimizer_state, devices)

      if use_img_encoder:
        """
        Load pretrained img_encoder_params and do:
        replicated_img_encoder_params = jax.device_put_replicated(
            img_encoder_params, devices)
        """
        class EncoderTrainingState(NamedTuple):
          encoder_params: hk.Params
        img_encoder_params = {}
        replicated_img_encoder_params = img_encoder_params
        raise NotImplementedError('Need to load a checkpoint.')
      else:
        img_encoder_params = {}
        replicated_img_encoder_params = img_encoder_params

      state = TrainingState(
          policy_optimizer_state=replicated_optim_state,
          policy_params=replicated_policy_params,
          key=key,
          img_encoder_params=replicated_img_encoder_params)
      return state

    # Create initial state.
    self._state = make_initial_state(rng)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Esempio n. 13
0
  def __init__(self,
               network,
               random_key,
               temperature,
               num_actions,
               optimizer,
               demonstrations,
               num_sgd_steps_per_step,
               logger = None,
               counter = None):

    def aqualoss(params, transitions, key):
      predicted_actions = network.apply(
          params,
          transitions.observation,
          is_training=True,
          rngs={'dropout': key})
      predicted_actions = jnp.squeeze(predicted_actions)

      action_distances = jnp.sum(
          jnp.square(predicted_actions -
                     jnp.expand_dims(transitions.action, axis=-1)),
          axis=0)

      # softmin
      softmin_action_distances = temperature * (
          jax.nn.logsumexp(-action_distances / temperature)
          - jnp.log(num_actions)
      )
      loss = - softmin_action_distances
      return loss

    def batch_aqualoss(params, transitions, key):
      batched_aqualoss = jax.vmap(aqualoss, in_axes=(None, 0, None), out_axes=0)
      return jnp.mean(batched_aqualoss(params, transitions, key))

    def sgd_step(
        state,
        transitions,
    ):

      loss_and_grad = jax.value_and_grad(batch_aqualoss, argnums=0)

      # Compute losses and their gradients.
      loss_key, random_key = jax.random.split(state.random_key)
      loss_value, gradients = loss_and_grad(state.encoder_params, transitions,
                                            loss_key)

      update, optimizer_state = optimizer.update(
          gradients, state.optimizer_state, params=state.encoder_params)
      encoder_params = optax.apply_updates(state.encoder_params, update)

      new_state = PretrainingState(
          optimizer_state=optimizer_state,
          encoder_params=encoder_params,
          random_key=random_key,
          steps=state.steps + 1,
      )

      metrics = {
          'encoder_loss': loss_value,
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter(prefix='encoder')
    self._logger = logger or loggers.make_default_logger(
        'encoder', asynchronous=True, serialize_fn=utils.fetch_devicearray)

    # Iterator on demonstration transitions.
    self._demonstrations = demonstrations

    # Use the JIT compiler.
    self._sgd_step = utils.process_multiple_batches(
        sgd_step, num_sgd_steps_per_step)
    self._sgd_step = jax.jit(self._sgd_step)

    self._num_actions = num_actions

    encoder_params = network.init(random_key)
    optimizer_state = optimizer.init(encoder_params)

    # Create initial state.
    self._state = PretrainingState(
        optimizer_state=optimizer_state,
        encoder_params=encoder_params,
        random_key=random_key,
        steps=0,
    )

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Esempio n. 14
0
    def __init__(self,
                 networks: CRRNetworks,
                 random_key: networks_lib.PRNGKey,
                 discount: float,
                 target_update_period: int,
                 policy_loss_coeff_fn: PolicyLossCoeff,
                 iterator: Iterator[types.Transition],
                 policy_optimizer: optax.GradientTransformation,
                 critic_optimizer: optax.GradientTransformation,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 grad_updates_per_batch: int = 1,
                 use_sarsa_target: bool = False):
        """Initializes the CRR learner.

    Args:
      networks: CRR networks.
      random_key: a key for random number generation.
      discount: discount to use for TD updates.
      target_update_period: period to update target's parameters.
      policy_loss_coeff_fn: set the loss function for the policy.
      iterator: an iterator over training data.
      policy_optimizer: the policy optimizer.
      critic_optimizer: the Q-function optimizer.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      grad_updates_per_batch: how many gradient updates given a sampled batch.
      use_sarsa_target: compute on-policy target using iterator's actions rather
        than sampled actions.
        Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf).
        When set to `True`, `target_policy_params` are unused.
    """

        critic_network = networks.critic_network
        policy_network = networks.policy_network

        def policy_loss(
            policy_params: networks_lib.Params,
            critic_params: networks_lib.Params,
            transition: types.Transition,
            key: networks_lib.PRNGKey,
        ) -> jnp.ndarray:
            # Compute the loss coefficients.
            coeff = policy_loss_coeff_fn(networks, policy_params,
                                         critic_params, transition, key)
            coeff = jax.lax.stop_gradient(coeff)
            # Return the weighted loss.
            dist_params = policy_network.apply(policy_params,
                                               transition.observation)
            logp_action = networks.log_prob(dist_params, transition.action)
            return -jnp.mean(logp_action * coeff)

        def critic_loss(
            critic_params: networks_lib.Params,
            target_policy_params: networks_lib.Params,
            target_critic_params: networks_lib.Params,
            transition: types.Transition,
            key: networks_lib.PRNGKey,
        ):
            # Sample the next action.
            if use_sarsa_target:
                # TODO(b/222674779): use N-steps Trajectories to get the next actions.
                assert 'next_action' in transition.extras, (
                    'next actions should be given as extras for one step RL.')
                next_action = transition.extras['next_action']
            else:
                next_dist_params = policy_network.apply(
                    target_policy_params, transition.next_observation)
                next_action = networks.sample(next_dist_params, key)
            # Calculate the value of the next state and action.
            next_q = critic_network.apply(target_critic_params,
                                          transition.next_observation,
                                          next_action)
            target_q = transition.reward + transition.discount * discount * next_q
            target_q = jax.lax.stop_gradient(target_q)

            q = critic_network.apply(critic_params, transition.observation,
                                     transition.action)
            q_error = q - target_q
            # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
            # TODO(sertan): Replace with a distributional critic. CRR paper states
            # that this may perform better.
            return 0.5 * jnp.mean(jnp.square(q_error))

        policy_loss_and_grad = jax.value_and_grad(policy_loss)
        critic_loss_and_grad = jax.value_and_grad(critic_loss)

        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_policy, key_critic = jax.random.split(state.key, 3)

            # Compute losses and their gradients.
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params, transitions,
                key_policy)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state.target_policy_params,
                state.target_critic_params, transitions, key_critic)

            # Get optimizer updates and state.
            policy_updates, policy_opt_state = policy_optimizer.update(
                policy_gradients, state.policy_opt_state)
            critic_updates, critic_opt_state = critic_optimizer.update(
                critic_gradients, state.critic_opt_state)

            # Apply optimizer updates to parameters.
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_updates)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)

            steps = state.steps + 1

            # Periodically update target networks.
            target_policy_params, target_critic_params = optax.periodic_update(
                (policy_params, critic_params),
                (state.target_policy_params, state.target_critic_params),
                steps, target_update_period)

            new_state = TrainingState(
                policy_params=policy_params,
                target_policy_params=target_policy_params,
                critic_params=critic_params,
                target_critic_params=target_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                steps=steps,
                key=key,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
            }

            return new_state, metrics

        sgd_step = utils.process_multiple_batches(sgd_step,
                                                  grad_updates_per_batch)
        self._sgd_step = jax.jit(sgd_step)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Create prefetching dataset iterator.
        self._iterator = iterator

        # Create the network parameters and copy into the target network parameters.
        key, key_policy, key_critic = jax.random.split(random_key, 3)
        initial_policy_params = policy_network.init(key_policy)
        initial_critic_params = critic_network.init(key_critic)
        initial_target_policy_params = initial_policy_params
        initial_target_critic_params = initial_critic_params

        # Initialize optimizers.
        initial_policy_opt_state = policy_optimizer.init(initial_policy_params)
        initial_critic_opt_state = critic_optimizer.init(initial_critic_params)

        # Create initial state.
        self._state = TrainingState(
            policy_params=initial_policy_params,
            target_policy_params=initial_target_policy_params,
            critic_params=initial_critic_params,
            target_critic_params=initial_target_critic_params,
            policy_opt_state=initial_policy_opt_state,
            critic_opt_state=initial_critic_opt_state,
            steps=0,
            key=key,
        )

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Esempio n. 15
0
  def __init__(self,
               network: networks_lib.FeedForwardNetwork,
               random_key: networks_lib.PRNGKey,
               loss_fn: losses.Loss,
               optimizer: optax.GradientTransformation,
               demonstrations: Iterator[types.Transition],
               num_sgd_steps_per_step: int,
               logger: Optional[loggers.Logger] = None,
               counter: Optional[counting.Counter] = None):
    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      loss_and_grad = jax.value_and_grad(loss_fn, argnums=1)

      # Compute losses and their gradients.
      key, key_input = jax.random.split(state.key)
      loss_value, gradients = loss_and_grad(network.apply, state.policy_params,
                                            key_input, transitions)

      policy_update, optimizer_state = optimizer.update(gradients,
                                                        state.optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      new_state = TrainingState(
          optimizer_state=optimizer_state,
          policy_params=policy_params,
          key=key,
          steps=state.steps + 1,
      )
      metrics = {
          'loss': loss_value,
          'gradient_norm': optax.global_norm(gradients)
      }

      return new_state, metrics

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter(prefix='learner')
    self._logger = logger or loggers.make_default_logger(
        'learner', asynchronous=True, serialize_fn=utils.fetch_devicearray)

    # Iterator on demonstration transitions.
    self._demonstrations = demonstrations

    # Split the input batch to `num_sgd_steps_per_step` minibatches in order
    # to achieve better performance on accelerators.
    self._sgd_step = jax.jit(utils.process_multiple_batches(
        sgd_step, num_sgd_steps_per_step))

    random_key, init_key = jax.random.split(random_key)
    policy_params = network.init(init_key)
    optimizer_state = optimizer.init(policy_params)

    # Create initial state.
    self._state = TrainingState(
        optimizer_state=optimizer_state,
        policy_params=policy_params,
        key=random_key,
        steps=0,
    )

    self._timestamp = None
Esempio n. 16
0
    def __init__(self,
                 networks: td3_networks.TD3Networks,
                 random_key: networks_lib.PRNGKey,
                 discount: float,
                 iterator: Iterator[reverb.ReplaySample],
                 policy_optimizer: optax.GradientTransformation,
                 critic_optimizer: optax.GradientTransformation,
                 twin_critic_optimizer: optax.GradientTransformation,
                 delay: int = 2,
                 target_sigma: float = 0.2,
                 noise_clip: float = 0.5,
                 tau: float = 0.005,
                 use_sarsa_target: bool = False,
                 bc_alpha: Optional[float] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 num_sgd_steps_per_step: int = 1):
        """Initializes the TD3 learner.

    Args:
      networks: TD3 networks.
      random_key: a key for random number generation.
      discount: discount to use for TD updates
      iterator: an iterator over training data.
      policy_optimizer: the policy optimizer.
      critic_optimizer: the Q-function optimizer.
      twin_critic_optimizer: the twin Q-function optimizer.
      delay: ratio of policy updates for critic updates (see TD3),
        delay=2 means 2 updates of the critic for 1 policy update.
      target_sigma: std of zero mean Gaussian added to the action of
        the next_state, for critic evaluation (reducing overestimation bias).
      noise_clip: hard constraint on target noise.
      tau: target parameters smoothing coefficient.
      use_sarsa_target: compute on-policy target using iterator's actions rather
        than sampled actions.
        Useful for 1-step offline RL (https://arxiv.org/pdf/2106.08909.pdf).
        When set to `True`, `target_policy_params` are unused.
        This is only working when the learner is used as an offline algorithm.
        I.e. TD3Builder does not support adding the SARSA target to the replay
        buffer.
      bc_alpha: bc_alpha: Implements TD3+BC.
        See comments in TD3Config.bc_alpha for details.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      num_sgd_steps_per_step: number of sgd steps to perform per learner 'step'.
    """
        def policy_loss(
            policy_params: networks_lib.Params,
            critic_params: networks_lib.Params,
            transition: types.NestedArray,
        ) -> jnp.ndarray:
            # Computes the discrete policy gradient loss.
            action = networks.policy_network.apply(policy_params,
                                                   transition.observation)
            grad_critic = jax.vmap(jax.grad(networks.critic_network.apply,
                                            argnums=2),
                                   in_axes=(None, 0, 0))
            dq_da = grad_critic(critic_params, transition.observation, action)
            batch_dpg_learning = jax.vmap(rlax.dpg_loss, in_axes=(0, 0))
            loss = jnp.mean(batch_dpg_learning(action, dq_da))
            if bc_alpha is not None:
                # BC regularization for offline RL
                q_sa = networks.critic_network.apply(critic_params,
                                                     transition.observation,
                                                     action)
                bc_factor = jax.lax.stop_gradient(bc_alpha /
                                                  jnp.mean(jnp.abs(q_sa)))
                loss += jnp.mean(
                    jnp.square(action - transition.action)) / bc_factor
            return loss

        def critic_loss(
            critic_params: networks_lib.Params,
            state: TrainingState,
            transition: types.Transition,
            random_key: jnp.ndarray,
        ):
            # Computes the critic loss.
            q_tm1 = networks.critic_network.apply(critic_params,
                                                  transition.observation,
                                                  transition.action)

            if use_sarsa_target:
                # TODO(b/222674779): use N-steps Trajectories to get the next actions.
                assert 'next_action' in transition.extras, (
                    'next actions should be given as extras for one step RL.')
                action = transition.extras['next_action']
            else:
                action = networks.policy_network.apply(
                    state.target_policy_params, transition.next_observation)
                action = networks.add_policy_noise(action, random_key,
                                                   target_sigma, noise_clip)

            q_t = networks.critic_network.apply(state.target_critic_params,
                                                transition.next_observation,
                                                action)
            twin_q_t = networks.twin_critic_network.apply(
                state.target_twin_critic_params, transition.next_observation,
                action)

            q_t = jnp.minimum(q_t, twin_q_t)

            target_q_tm1 = transition.reward + discount * transition.discount * q_t
            td_error = jax.lax.stop_gradient(target_q_tm1) - q_tm1

            return jnp.mean(jnp.square(td_error))

        def update_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            random_key, key_critic, key_twin = jax.random.split(
                state.random_key, 3)

            # Updates on the critic: compute the gradients, and update using
            # Polyak averaging.
            critic_loss_and_grad = jax.value_and_grad(critic_loss)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state, transitions, key_critic)
            critic_updates, critic_opt_state = critic_optimizer.update(
                critic_gradients, state.critic_opt_state)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)
            # In the original authors' implementation the critic target update is
            # delayed similarly to the policy update which we found empirically to
            # perform slightly worse.
            target_critic_params = optax.incremental_update(
                new_tensors=critic_params,
                old_tensors=state.target_critic_params,
                step_size=tau)

            # Updates on the twin critic: compute the gradients, and update using
            # Polyak averaging.
            twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad(
                state.twin_critic_params, state, transitions, key_twin)
            twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update(
                twin_critic_gradients, state.twin_critic_opt_state)
            twin_critic_params = optax.apply_updates(state.twin_critic_params,
                                                     twin_critic_updates)
            # In the original authors' implementation the twin critic target update is
            # delayed similarly to the policy update which we found empirically to
            # perform slightly worse.
            target_twin_critic_params = optax.incremental_update(
                new_tensors=twin_critic_params,
                old_tensors=state.target_twin_critic_params,
                step_size=tau)

            # Updates on the policy: compute the gradients, and update using
            # Polyak averaging (if delay enabled, the update might not be applied).
            policy_loss_and_grad = jax.value_and_grad(policy_loss)
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params, transitions)

            def update_policy_step():
                policy_updates, policy_opt_state = policy_optimizer.update(
                    policy_gradients, state.policy_opt_state)
                policy_params = optax.apply_updates(state.policy_params,
                                                    policy_updates)
                target_policy_params = optax.incremental_update(
                    new_tensors=policy_params,
                    old_tensors=state.target_policy_params,
                    step_size=tau)
                return policy_params, target_policy_params, policy_opt_state

            # The update on the policy is applied every `delay` steps.
            current_policy_state = (state.policy_params,
                                    state.target_policy_params,
                                    state.policy_opt_state)
            policy_params, target_policy_params, policy_opt_state = jax.lax.cond(
                state.steps % delay == 0,
                lambda _: update_policy_step(),
                lambda _: current_policy_state,
                operand=None)

            steps = state.steps + 1

            new_state = TrainingState(
                policy_params=policy_params,
                critic_params=critic_params,
                twin_critic_params=twin_critic_params,
                target_policy_params=target_policy_params,
                target_critic_params=target_critic_params,
                target_twin_critic_params=target_twin_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                twin_critic_opt_state=twin_critic_opt_state,
                steps=steps,
                random_key=random_key,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
                'twin_critic_loss': twin_critic_loss_value,
            }

            return new_state, metrics

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Create prefetching dataset iterator.
        self._iterator = iterator

        # Faster sgd step
        update_step = utils.process_multiple_batches(update_step,
                                                     num_sgd_steps_per_step)
        # Use the JIT compiler.
        self._update_step = jax.jit(update_step)

        (key_init_policy, key_init_twin, key_init_target,
         key_state) = jax.random.split(random_key, 4)
        # Create the network parameters and copy into the target network parameters.
        initial_policy_params = networks.policy_network.init(key_init_policy)
        initial_critic_params = networks.critic_network.init(key_init_twin)
        initial_twin_critic_params = networks.twin_critic_network.init(
            key_init_target)

        initial_target_policy_params = initial_policy_params
        initial_target_critic_params = initial_critic_params
        initial_target_twin_critic_params = initial_twin_critic_params

        # Initialize optimizers.
        initial_policy_opt_state = policy_optimizer.init(initial_policy_params)
        initial_critic_opt_state = critic_optimizer.init(initial_critic_params)
        initial_twin_critic_opt_state = twin_critic_optimizer.init(
            initial_twin_critic_params)

        # Create initial state.
        self._state = TrainingState(
            policy_params=initial_policy_params,
            target_policy_params=initial_target_policy_params,
            critic_params=initial_critic_params,
            twin_critic_params=initial_twin_critic_params,
            target_critic_params=initial_target_critic_params,
            target_twin_critic_params=initial_target_twin_critic_params,
            policy_opt_state=initial_policy_opt_state,
            critic_opt_state=initial_critic_opt_state,
            twin_critic_opt_state=initial_twin_critic_opt_state,
            steps=0,
            random_key=key_state)

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None