Exemple #1
0
def train_with_bc(make_demonstrations: Callable[[int],
                                                Iterator[types.Transition]],
                  networks: networks_lib.FeedForwardNetwork,
                  loss: losses.Loss,
                  num_steps: int = 100000) -> networks_lib.Params:
    """Trains the given network with BC and returns the params.

  Args:
    make_demonstrations: A function (batch_size) -> iterator with demonstrations
      to be imitated.
    networks: Network taking (params, obs, is_training, key) as input
    loss: BC loss to use.
    num_steps: number of training steps

  Returns:
    The trained network params.
  """
    demonstration_iterator = make_demonstrations(256)
    prefetching_iterator = utils.sharded_prefetch(
        demonstration_iterator,
        buffer_size=2,
        num_threads=jax.local_device_count())

    learner = learning.BCLearner(network=networks,
                                 random_key=jax.random.PRNGKey(0),
                                 loss_fn=loss,
                                 prefetching_iterator=prefetching_iterator,
                                 optimizer=optax.adam(1e-4),
                                 num_sgd_steps_per_step=1)

    # Train the agent
    for _ in range(num_steps):
        learner.step()

    return learner.get_variables(['policy'])[0]
Exemple #2
0
  def __init__(self,
               networks: mbop_networks.MBOPNetworks,
               losses: mbop_losses.MBOPLosses,
               iterator: Iterator[types.Transition],
               rng_key: jax_types.PRNGKey,
               logger_fn: LoggerFn,
               make_world_model_learner: MakeWorldModelLearner,
               make_policy_prior_learner: MakePolicyPriorLearner,
               make_n_step_return_learner: MakeNStepReturnLearner,
               counter: Optional[counting.Counter] = None):
    """Creates an MBOP learner.

    Args:
      networks: One network per model.
      losses: One loss per model.
      iterator: An iterator of time-batched transitions used to train the
        networks.
      rng_key: Random key.
      logger_fn: Constructs a logger for a label.
      make_world_model_learner: Function to create the world model learner.
      make_policy_prior_learner: Function to create the policy prior learner.
      make_n_step_return_learner: Function to create the n-step return learner.
      counter: Parent counter object.
    """
    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger_fn('', 'steps')

    # Prepare iterators for the learners, to not split the data (preserve sample
    # efficiency).
    sharded_prefetching_dataset = utils.sharded_prefetch(iterator)
    world_model_iterator, policy_prior_iterator, n_step_return_iterator = (
        itertools.tee(sharded_prefetching_dataset, 3))

    world_model_key, policy_prior_key, n_step_return_key = jax.random.split(
        rng_key, 3)

    self._world_model = make_world_model_learner(logger_fn, self._counter,
                                                 world_model_key,
                                                 world_model_iterator,
                                                 networks.world_model_network,
                                                 losses.world_model_loss)
    self._policy_prior = make_policy_prior_learner(
        logger_fn, self._counter, policy_prior_key, policy_prior_iterator,
        networks.policy_prior_network, losses.policy_prior_loss)
    self._n_step_return = make_n_step_return_learner(
        logger_fn, self._counter, n_step_return_key, n_step_return_iterator,
        networks.n_step_return_network, losses.n_step_return_loss)
    # Start recording timestamps after the first learning step to not report
    # "warmup" time.
    self._timestamp = None
    self._learners = {
        'world_model': self._world_model,
        'policy_prior': self._policy_prior,
        'n_step_return': self._n_step_return
    }
Exemple #3
0
def main(_):
  # Create an environment and grab the spec.
  environment = bc_utils.make_environment()
  environment_spec = specs.make_environment_spec(environment)

  # Unwrap the environment to get the demonstrations.
  dataset = bc_utils.make_demonstrations(environment.environment,
                                         FLAGS.batch_size)
  dataset = dataset.as_numpy_iterator()

  # Create the networks to optimize.
  network = bc_utils.make_network(environment_spec)

  key = jax.random.PRNGKey(FLAGS.seed)
  key, key1 = jax.random.split(key, 2)

  def logp_fn(logits, actions):
    logits_actions = jnp.sum(
        jax.nn.one_hot(actions, logits.shape[-1]) * logits, axis=-1)
    logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
    return logits_actions

  loss_fn = bc.logp(logp_fn=logp_fn)

  learner = bc.BCLearner(
      network=network,
      random_key=key1,
      loss_fn=loss_fn,
      optimizer=optax.adam(FLAGS.learning_rate),
      prefetching_iterator=utils.sharded_prefetch(dataset),
      num_sgd_steps_per_step=1)

  def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                        observation: jnp.DeviceArray) -> jnp.DeviceArray:
    dist_params = network.apply(params, observation)
    return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
        key, dist_params)

  actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
      evaluator_network)
  variable_client = variable_utils.VariableClient(
      learner, 'policy', device='cpu')
  evaluator = actors.GenericActor(
      actor_core, key, variable_client, backend='cpu')

  eval_loop = acme.EnvironmentLoop(
      environment=environment,
      actor=evaluator,
      logger=loggers.TerminalLogger('evaluation', time_delta=0.))

  # Run the environment loop.
  while True:
    for _ in range(FLAGS.evaluate_every):
      learner.step()
    eval_loop.run(FLAGS.evaluation_episodes)
Exemple #4
0
  def make_learner(
      self,
      random_key: networks_lib.PRNGKey,
      networks: networks_lib.FeedForwardNetwork,
      dataset: Iterator[types.Transition],
      logger_fn: loggers.LoggerFactory,
      environment_spec: specs.EnvironmentSpec,
      *,
      counter: Optional[counting.Counter] = None,
  ) -> core.Learner:
    del environment_spec

    return learning.BCLearner(
        network=networks,
        random_key=random_key,
        loss_fn=self._loss_fn,
        optimizer=optax.adam(learning_rate=self._config.learning_rate),
        prefetching_iterator=utils.sharded_prefetch(dataset),
        num_sgd_steps_per_step=self._config.num_sgd_steps_per_step,
        loss_has_aux=self._loss_has_aux,
        logger=logger_fn('learner'),
        counter=counter)
Exemple #5
0
  def __init__(
      self,
      obs_spec: specs.Array,
      unroll_fn: networks_lib.PolicyValueRNN,
      initial_state_fn: Callable[[], hk.LSTMState],
      iterator: Iterator[reverb.ReplaySample],
      optimizer: optax.GradientTransformation,
      random_key: networks_lib.PRNGKey,
      discount: float = 0.99,
      entropy_cost: float = 0.,
      baseline_cost: float = 1.,
      max_abs_reward: float = np.inf,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      devices: Optional[Sequence[jax.xla.Device]] = None,
      prefetch_size: int = 2,
      num_prefetch_threads: Optional[int] = None,
  ):

    self._devices = devices or jax.local_devices()

    # Transform into pure functions.
    unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True))
    initial_state_fn = hk.without_apply_rng(
        hk.transform(initial_state_fn, apply_rng=True))

    loss_fn = losses.impala_loss(
        unroll_fn,
        discount=discount,
        max_abs_reward=max_abs_reward,
        baseline_cost=baseline_cost,
        entropy_cost=entropy_cost)

    @jax.jit
    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      """Computes an SGD step, returning new state and metrics for logging."""

      # Compute gradients.
      grad_fn = jax.value_and_grad(loss_fn)
      loss_value, gradients = grad_fn(state.params, sample)

      # Average gradients over pmap replicas before optimizer update.
      gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

      # Apply updates.
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      metrics = {
          'loss': loss_value,
      }

      new_state = TrainingState(params=new_params, opt_state=new_opt_state)

      return new_state, metrics

    def make_initial_state(key: jnp.ndarray) -> TrainingState:
      """Initialises the training state (parameters and optimiser state)."""
      dummy_obs = utils.zeros_like(obs_spec)
      dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
      initial_state = initial_state_fn.apply(None)
      initial_params = unroll_fn.init(key, dummy_obs, initial_state)
      initial_opt_state = optimizer.init(initial_params)
      return TrainingState(params=initial_params, opt_state=initial_opt_state)

    # Initialise training state (parameters and optimiser state).
    state = make_initial_state(random_key)
    self._state = utils.replicate_in_all_devices(state, self._devices)

    if num_prefetch_threads is None:
      num_prefetch_threads = len(self._devices)
    self._prefetched_iterator = utils.sharded_prefetch(
        iterator,
        buffer_size=prefetch_size,
        devices=devices,
        num_threads=num_prefetch_threads,
    )

    self._sgd_step = jax.pmap(
        sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices)

    # Set up logging/counting.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')
Exemple #6
0
    def __init__(self,
                 unroll: networks_lib.FeedForwardNetwork,
                 initial_state: networks_lib.FeedForwardNetwork,
                 batch_size: int,
                 random_key: networks_lib.PRNGKey,
                 burn_in_length: int,
                 discount: float,
                 importance_sampling_exponent: float,
                 max_priority_weight: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optax.GradientTransformation,
                 bootstrap_n: int = 5,
                 tx_pair: rlax.TxPair = rlax.SIGNED_HYPERBOLIC_PAIR,
                 clip_rewards: bool = False,
                 max_abs_reward: float = 1.,
                 use_core_state: bool = True,
                 prefetch_size: int = 2,
                 replay_client: Optional[reverb.Client] = None,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):
        """Initializes the learner."""

        random_key, key_initial_1, key_initial_2 = jax.random.split(
            random_key, 3)
        initial_state_params = initial_state.init(key_initial_1, batch_size)
        initial_state = initial_state.apply(initial_state_params,
                                            key_initial_2, batch_size)

        def loss(
            params: networks_lib.Params, target_params: networks_lib.Params,
            key_grad: networks_lib.PRNGKey, sample: reverb.ReplaySample
        ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
            """Computes mean transformed N-step loss for a batch of sequences."""

            # Convert sample data to sequence-major format [T, B, ...].
            data = utils.batch_to_sequence(sample.data)

            # Get core state & warm it up on observations for a burn-in period.
            if use_core_state:
                # Replay core state.
                online_state = jax.tree_map(lambda x: x[0],
                                            data.extras['core_state'])
            else:
                online_state = initial_state
            target_state = online_state

            # Maybe burn the core state in.
            if burn_in_length:
                burn_obs = jax.tree_map(lambda x: x[:burn_in_length],
                                        data.observation)
                key_grad, key1, key2 = jax.random.split(key_grad, 3)
                _, online_state = unroll.apply(params, key1, burn_obs,
                                               online_state)
                _, target_state = unroll.apply(target_params, key2, burn_obs,
                                               target_state)

            # Only get data to learn on from after the end of the burn in period.
            data = jax.tree_map(lambda seq: seq[burn_in_length:], data)

            # Unroll on sequences to get online and target Q-Values.
            key1, key2 = jax.random.split(key_grad)
            online_q, _ = unroll.apply(params, key1, data.observation,
                                       online_state)
            target_q, _ = unroll.apply(target_params, key2, data.observation,
                                       target_state)

            # Get value-selector actions from online Q-values for double Q-learning.
            selector_actions = jnp.argmax(online_q, axis=-1)
            # Preprocess discounts & rewards.
            discounts = (data.discount * discount).astype(online_q.dtype)
            rewards = data.reward
            if clip_rewards:
                rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward)
            rewards = rewards.astype(online_q.dtype)

            # Get N-step transformed TD error and loss.
            batch_td_error_fn = jax.vmap(functools.partial(
                rlax.transformed_n_step_q_learning,
                n=bootstrap_n,
                tx_pair=tx_pair),
                                         in_axes=1,
                                         out_axes=1)
            # TODO(b/183945808): when this bug is fixed, truncations of actions,
            # rewards, and discounts will no longer be necessary.
            batch_td_error = batch_td_error_fn(online_q[:-1], data.action[:-1],
                                               target_q[1:],
                                               selector_actions[1:],
                                               rewards[:-1], discounts[:-1])
            batch_loss = 0.5 * jnp.square(batch_td_error).sum(axis=0)

            # Importance weighting.
            probs = sample.info.probability
            importance_weights = (1. / (probs + 1e-6)).astype(online_q.dtype)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)
            mean_loss = jnp.mean(importance_weights * batch_loss)

            # Calculate priorities as a mixture of max and mean sequence errors.
            abs_td_error = jnp.abs(batch_td_error).astype(online_q.dtype)
            max_priority = max_priority_weight * jnp.max(abs_td_error, axis=0)
            mean_priority = (1 - max_priority_weight) * jnp.mean(abs_td_error,
                                                                 axis=0)
            priorities = (max_priority + mean_priority)

            return mean_loss, priorities

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Performs an update step, averaging over pmap replicas."""

            # Compute loss and gradients.
            grad_fn = jax.value_and_grad(loss, has_aux=True)
            key, key_grad = jax.random.split(state.random_key)
            (loss_value,
             priorities), gradients = grad_fn(state.params,
                                              state.target_params, key_grad,
                                              samples)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply optimizer updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps,
                                      random_key=key)
            return new_state, priorities, {'loss': loss_value}

        def update_priorities(keys_and_priorities: Tuple[jnp.ndarray,
                                                         jnp.ndarray]):
            keys, priorities = keys_and_priorities
            keys, priorities = tree.map_structure(
                # Fetch array and combine device and batch dimensions.
                lambda x: utils.fetch_devicearray(x).reshape(
                    (-1, ) + x.shape[2:]),
                (keys, priorities))
            replay_client.mutate_priorities(  # pytype: disable=attribute-error
                table=adders.DEFAULT_PRIORITY_TABLE,
                updates=dict(zip(keys, priorities)))

        # Internalise components, hyperparameters, logger, counter, and methods.
        self._replay_client = replay_client
        self._target_update_period = target_update_period
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            time_delta=1.)

        self._sgd_step = jax.pmap(sgd_step, axis_name=_PMAP_AXIS_NAME)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)

        # Initialise and internalise training state (parameters/optimiser state).
        random_key, key_init = jax.random.split(random_key)
        initial_params = unroll.init(key_init, initial_state)
        opt_state = optimizer.init(initial_params)

        state = TrainingState(params=initial_params,
                              target_params=initial_params,
                              opt_state=opt_state,
                              steps=jnp.array(0),
                              random_key=random_key)
        # Replicate parameters.
        self._state = utils.replicate_in_all_devices(state)

        # Shard multiple inputs with on-device prefetching.
        # We split samples in two outputs, the keys which need to be kept on-host
        # since int64 arrays are not supported in TPUs, and the entire sample
        # separately so it can be sent to the sgd_step method.
        def split_sample(
                sample: reverb.ReplaySample) -> utils.PrefetchingSplit:
            return utils.PrefetchingSplit(host=sample.info.key, device=sample)

        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            num_threads=jax.local_device_count(),
            split_fn=split_sample)
Exemple #7
0
    def __init__(
        self,
        networks: impala_networks.IMPALANetworks,
        iterator: Iterator[reverb.ReplaySample],
        optimizer: optax.GradientTransformation,
        random_key: networks_lib.PRNGKey,
        discount: float = 0.99,
        entropy_cost: float = 0.,
        baseline_cost: float = 1.,
        max_abs_reward: float = np.inf,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
        devices: Optional[Sequence[jax.xla.Device]] = None,
        prefetch_size: int = 2,
        num_prefetch_threads: Optional[int] = None,
    ):
        local_devices = jax.local_devices()
        process_id = jax.process_index()
        logging.info('Learner process id: %s. Devices passed: %s', process_id,
                     devices)
        logging.info('Learner process id: %s. Local devices from JAX API: %s',
                     process_id, local_devices)
        self._devices = devices or local_devices
        self._local_devices = [d for d in self._devices if d in local_devices]

        loss_fn = losses.impala_loss(networks.unroll_fn,
                                     discount=discount,
                                     max_abs_reward=max_abs_reward,
                                     baseline_cost=baseline_cost,
                                     entropy_cost=entropy_cost)

        @jax.jit
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Computes an SGD step, returning new state and metrics for logging."""

            # Compute gradients.
            grad_fn = jax.value_and_grad(loss_fn)
            loss_value, gradients = grad_fn(state.params, sample)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
                'param_norm': optax.global_norm(new_params),
                'param_updates_norm': optax.global_norm(updates),
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics

        def make_initial_state(key: jnp.ndarray) -> TrainingState:
            """Initialises the training state (parameters and optimiser state)."""
            key, key_initial_state = jax.random.split(key)
            # Note: parameters do not depend on the batch size, so initial_state below
            # does not need a batch dimension.
            params = networks.initial_state_init_fn(key_initial_state)
            # TODO(jferret): as it stands, we do not yet support
            # training the initial state params.
            initial_state = networks.initial_state_fn(params)

            initial_params = networks.unroll_init_fn(key, initial_state)
            initial_opt_state = optimizer.init(initial_params)
            return TrainingState(params=initial_params,
                                 opt_state=initial_opt_state)

        # Initialise training state (parameters and optimiser state).
        state = make_initial_state(random_key)
        self._state = utils.replicate_in_all_devices(state,
                                                     self._local_devices)

        if num_prefetch_threads is None:
            num_prefetch_threads = len(self._local_devices)
        self._prefetched_iterator = utils.sharded_prefetch(
            iterator,
            buffer_size=prefetch_size,
            devices=self._local_devices,
            num_threads=num_prefetch_threads,
        )

        self._sgd_step = jax.pmap(sgd_step,
                                  axis_name=_PMAP_AXIS_NAME,
                                  devices=self._devices)

        # Set up logging/counting.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')