Exemplo n.º 1
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Exemplo n.º 2
0
 def _run_value_model(self, obs):
     """Runs value model."""
     n_devices = fastmath.device_count()
     weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
     state = tl.for_n_devices(self._value_eval_model.state, n_devices)
     rng = self._value_eval_model.rng
     # TODO(henrykm): the line below fails on TPU with the error
     # ValueError: Number of devices (8) does not evenly divide batch size (1).
     obs_batch = obs.shape[0]
     if n_devices > obs_batch:
         obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0)
     values, _ = self._value_eval_jit(obs, weights, state, rng)
     values = values[:obs_batch]
     values *= self._value_network_scale
     return values
Exemplo n.º 3
0
  def __init__(self, loss_layer, optimizer, n_devices=None):
    self._loss_layer = loss_layer
    self._optimizer = optimizer
    self._n_devices = n_devices or fastmath.device_count()

    # optimizer slots and opt_params may need to be replicated
    self._slots, self._opt_params = tl.for_n_devices(
        (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)

    # accelerated version of loss layer to replicate weights and state
    self._accelerated_loss_layer = tl.Accelerate(
        loss_layer, n_devices=n_devices)

    # Signature:
    # (batch, weights, state, rng) -> ((loss, state), gradients)
    self._forward_and_backward_fn = (
        fastmath.value_and_grad(
            loss_layer.pure_fn,
            argnums=1,  # arg1 of pure_fn: weights
            has_aux=True))  # return (loss, state), gradients

    # Signature:
    # (weights, slots), step, opt_params, batch, state, rng ->
    # (weights, slots), state, stats
    self._accelerated_update_fn = (
        _accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
        )
    )
Exemplo n.º 4
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = fastmath.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Exemplo n.º 5
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates loss layer weights/state and optimizer slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are current optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            self._opt_params['learning_rate'] = tl.for_n_devices(
                learning_rate, self._n_devices)

        # batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)

        # separate rng needs to be created for each device
        if self._n_devices > 1:
            rng = jnp.stack(fastmath.random.split(rng, self._n_devices))

        weights = self._accelerated_loss_layer.weights
        state = self._accelerated_loss_layer.state
        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            # Prints every power of two, if debugging is enabled.
            logging.info('step[%d]', step)
            logging.info('opt_params[%s]', self._opt_params)
            logging.info('slots[%s]', self._slots)
            logging.info('weights[%s]', weights)
            logging.info('state[%s]', state)

        # NOTE: stats is a replicated dictionary of key to jnp arrays.
        (new_weights,
         new_slots), new_state, stats = self._accelerated_update_fn(
             (weights, self._slots), step, self._opt_params, batch, state, rng)

        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            logging.info('updated weights[%s]', new_weights)
            logging.info('stats[%s]', stats)

        self._accelerated_loss_layer.weights = new_weights
        self._accelerated_loss_layer.state = new_state
        self._slots = new_slots
        self._optimizer.slots = self._unreplicate(self._slots)
        return stats['loss'], stats
Exemplo n.º 6
0
 def _run_value_model(self, obs, use_eval_model=True):
   """Runs value model."""
   n_devices = fastmath.device_count()
   if use_eval_model:
     weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
     state = tl.for_n_devices(self._value_eval_model.state, n_devices)
     rng = self._value_eval_model.rng
   else:
     # TODO(henrykm): this strangely looking solution address the problem that
     # value_batches_stream calls _run_value_model _once_ before
     # the trainer is initialized.
     try:
       weights = tl.for_n_devices(self._value_trainer.model_weights, n_devices)
       state = tl.for_n_devices(self._value_trainer.model_state, n_devices)
       rng = self._value_trainer._rng  # pylint: disable=protected-access
     except AttributeError:
       weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
       state = tl.for_n_devices(self._value_eval_model.state, n_devices)
       rng = self._value_eval_model.rng
   # TODO(henrykm): the line below fails on TPU with the error
   # ValueError: Number of devices (8) does not evenly divide batch size (1).
   obs_batch = obs.shape[0]
   if n_devices > obs_batch:
     obs = jnp.repeat(obs, int(n_devices / obs_batch), axis=0)
   values, _ = self._value_eval_jit(obs, weights, state, rng)
   values = values[:obs_batch]
   values *= self._value_network_scale
   return values
Exemplo n.º 7
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Runs one training step, to update model and optimizer parameters.

    Args:
      batch: Batch of labeled training data.
      rng: Single-use random number generator (JAX PRNG key).
      step: Training step number.
      learning_rate: Learning rate for the optimizer; if None, use optimizer's
          default learning rate.

    Returns:
      Tuple of (loss, optimizer_stats), with the newly computed loss and
      updated stats as reported by the optimizer.
    """
        if learning_rate is not None:
            self._opt_params['learning_rate'] = tl.for_n_devices(
                learning_rate, self._n_devices)

        # Split the batch across devices (batch_dim --> batch_dim // n_devices)
        # and create new rng's 1-1 with devices.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)
            rng = jnp.stack(fastmath.random.split(rng, self._n_devices))

        weights = self._accelerated_model_with_loss.weights
        state = self._accelerated_model_with_loss.state
        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            # Prints every power of two, if debugging is enabled.
            logging.info('step[%d]', step)
            logging.info('opt_params[%s]', self._opt_params)
            logging.info('slots[%s]', self._slots)
            logging.info('weights[%s]', weights)
            logging.info('state[%s]', state)

        # NOTE: stats is a replicated dictionary of key to jnp arrays.
        (new_weights,
         new_slots), new_state, stats = self._accelerated_update_fn(
             (weights, self._slots), step, self._opt_params, batch, state, rng)

        if logging.vlog_is_on(1) and ((step & step - 1) == 0):
            logging.info('updated weights[%s]', new_weights)
            logging.info('stats[%s]', stats)

        self._accelerated_model_with_loss.weights = new_weights
        self._accelerated_model_with_loss.state = new_state
        self._slots = new_slots
        self._optimizer.slots = self._unreplicate(self._slots)
        return stats['loss'], stats
Exemplo n.º 8
0
    def __init__(self,
                 model_with_loss,
                 optimizer,
                 n_devices=None,
                 adasum=False):
        self._model_with_loss = model_with_loss
        self._optimizer = optimizer
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum

        # optimizer slots and opt_params may need to be replicated
        self._slots, self._opt_params = tl.on_cpu(
            tl.for_n_devices(
                (self._optimizer.slots, self._optimizer.opt_params),
                self._n_devices))

        # accelerated version of model+loss to replicate weights and state
        self._accelerated_model_with_loss = tl.Accelerate(model_with_loss,
                                                          n_devices=n_devices)

        # Signature:
        # (batch, weights, state, rng) -> ((loss, state), gradients)
        self._forward_and_backward_fn = (
            fastmath.value_and_grad(
                model_with_loss.pure_fn,
                argnums=1,  # arg1 of pure_fn: weights
                has_aux=True))  # return (loss, state), gradients

        # Signature:
        # (weights, slots), step, opt_params, batch, state, rng ->
        # (weights, slots), state, stats
        self._accelerated_update_fn = (_accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
            adasum=self._adasum))
Exemplo n.º 9
0
 def _for_n_devices(self, x):
     """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`."""
     return tl.for_n_devices(x, self.n_devices)  # pylint: disable=protected-access
Exemplo n.º 10
0
 def _for_n_devices(self, x):
   """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`."""
   return tl.for_n_devices(x, self.n_devices)
Exemplo n.º 11
0
  def one_step(self, batch, rng, step=0, learning_rate=None):
    """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
    # Update the learning rate if needed.
    if learning_rate is not None:
      self._replicated_loss_opt_params['learning_rate'] = tl.for_n_devices(
          learning_rate, self._n_devices)
      for (std_op, rev_ops) in self._replicated_opt_params:
        std_op['learning_rate'] = tl.for_n_devices(
            learning_rate, self._n_devices)
        for op in rev_ops:
          op['learning_rate'] = tl.for_n_devices(
              learning_rate, self._n_devices)

    # Batch needs to be split across the local devices -- the difference
    # between _for_n_devices and _reshape_by_device is that the latter splits
    # the batch dim to batch // n_devices, vs _for_n_devices
    # broadcasts/replicates to n_devices dimension.
    if self._n_devices > 1:
      batch = tl.reshape_by_device(batch, self._n_devices)
      step = jnp.repeat(step, self._n_devices)

    # Create separate rng for each device and layer.
    if self._n_devices == 1:
      rngs = fastmath.random.split(rng, self._n_layers)
    else:
      # Splitting by device first to be identical with default trainer.
      per_device_rng = fastmath.random.split(rng, self._n_devices)
      per_device_rngs = [
          fastmath.random.split(r, self._n_layers) for r in per_device_rng]
      rngs = [jnp.stack([r[i] for r in per_device_rngs])
              for i in range(self._n_layers)]
    # Group rngs by layer blocks.
    rng_blocks, rng_i = [], 0
    for _, rev_layers in self._blocks:
      l = len(rev_layers)
      rng_blocks.append((rngs[rng_i], rngs[rng_i + 1: rng_i + l + 1]))
      rng_i += l + 1

    # Run the layers forward upto the loss layer.
    stack = batch
    block_inputs_states = []
    for i, (std_layer, rev_layers) in enumerate(self._blocks):
      acc_std_layer_fn, acc_rev_layer_fns = self._accelerated_layer_fns[i]
      std_rng, rev_rngs = rng_blocks[i]
      # Run the standard layer.
      stack, std_inputs, std_state = self._run_forward_standard(
          stack, std_layer, acc_std_layer_fn, std_rng)

      # Run the reversible layers and collect old and new states.
      stack, rev_old_states, rev_new_states = self._run_forward_reversible(
          stack, rev_layers, acc_rev_layer_fns, rev_rngs)
      block_inputs_states.append(
          ((std_inputs, std_state), (rev_old_states, rev_new_states)))

    # Run the loss layer forward and backward with optimizer update.
    loss_state = self._replicate(self._loss_layer.state)
    loss_inputs = cb.inputs_from_stack(stack, self._loss_layer.n_in)
    loss_stats, grad_stack = self._run_backward_standard(
        None, step, self._loss_layer, loss_inputs,
        loss_state, self._loss_fbo, rngs[-1], self._loss_opt,
        self._replicated_loss_opt_params)
    stats = [loss_stats]

    # Run the layers backward and run optimizer updates.
    for i in range(len(self._blocks) - 1, -1, -1):
      std_layer, rev_layers = self._blocks[i]
      (std_inputs, std_state), (rev_old_states,
                                rev_new_states) = block_inputs_states[i]
      std_fbo, rev_fbos = self._fbos[i]
      std_opt, rev_opts = self._optimizers[i]
      std_rng, rev_rngs = rng_blocks[i]
      repl_std_opt_params, repl_rev_opts_params = self._replicated_opt_params[i]

      # Run reversible layers backward with optimizer update.
      stack, grad_stack, new_stats = self._run_backward_reversible(
          stack, grad_stack, step, rev_layers, rev_fbos, rev_old_states,
          rev_new_states, rev_rngs, rev_opts, repl_rev_opts_params)
      stats.extend(new_stats)

      # Run the standard layer forward-and-backward pass and optimizer update.
      std_layer_stats, grad_stack = self._run_backward_standard(
          grad_stack, step, std_layer, std_inputs, std_state, std_fbo, std_rng,
          std_opt, repl_std_opt_params)
      stack = cb.outputs_onto_stack(  # Put layer inputs on the stack.
          std_inputs, stack, std_layer.n_out)
      stats.append(std_layer_stats)

    # Join stats from different optimizers into one.
    joint_stats = {}
    for i, stat in enumerate(reversed(stats)):
      for k, v in stat.items():
        joint_stats[f'layer{i}/' + k] = v
    return stats[0]['loss'], joint_stats
Exemplo n.º 12
0
 def _replicate(self, x):
   if self._n_devices > 1:
     return tl.for_n_devices(x, self._n_devices)
   return tl.on_accelerator(x)
Exemplo n.º 13
0
 def slots(self, slots):
     """Sets the slots of the optimizers and this class (replicated)."""
     self._optimizer.slots = slots
     self._slots = tl.on_cpu(tl.for_n_devices(slots, self._n_devices))
Exemplo n.º 14
0
    def one_step(self, batch, rng, step=0, learning_rate=None):
        """Updates layers weights/state and optimizers slots by running one step.

    Args:
      batch: Batch of data to use for optimization.
      rng: Random number generator to use for running this step.
      step: Which step of the training are we running.
      learning_rate: Learning rate to use instead of the default one.

    Returns:
      Tuple (loss, stats) with new values from one step
      of training, where stats are all optimizer statistics.
    """
        # Update the learning rate if needed.
        if learning_rate is not None:
            for op in self._replicated_opt_params:
                op['learning_rate'] = tl.for_n_devices(learning_rate,
                                                       self._n_devices)

        # Batch needs to be split across the local devices -- the difference
        # between _for_n_devices and _reshape_by_device is that the latter splits
        # the batch dim to batch // n_devices, vs _for_n_devices
        # broadcasts/replicates to n_devices dimension.
        if self._n_devices > 1:
            batch = tl.reshape_by_device(batch, self._n_devices)
            step = jnp.repeat(step, self._n_devices)

        # Separate rng needs to be created for each device.
        if self._n_devices == 1:
            rngs = fastmath.random.split(rng, len(self._reversible_layers) + 2)
        else:
            # Splitting by device first to be identical with default trainer.
            per_device_rng = fastmath.random.split(rng, self._n_devices)
            per_device_rngs = [
                fastmath.random.split(r,
                                      len(self._reversible_layers) + 2)
                for r in per_device_rng
            ]
            rngs = [
                jnp.stack([r[i] for r in per_device_rngs])
                for i in range(len(self._reversible_layers) + 2)
            ]

        # Run the layers forward upto the loss layer.
        stack = batch

        # Run the first layer.
        first_layer_inputs = _inputs_from_stack(self._first_layer, stack)
        first_layer_weights = self._replicate(self._first_layer.weights)
        first_layer_state = self._replicate(self._first_layer.state)
        outputs, first_layer_new_state = self._accelerated_first_layer_fn(
            first_layer_inputs, first_layer_weights, first_layer_state,
            rngs[0])
        stack = _outputs_onto_stack(self._first_layer, outputs, stack)

        # Run the reversible layers and collect old and new states.
        old_states, new_states = [], []
        for i, layer in enumerate(self._reversible_layers):
            weights = self._replicate(
                layer.weights)  # also copies cpu -> accelerator
            state = self._replicate(layer.state)
            old_states.append(state)
            inputs = _inputs_from_stack(layer, stack)
            outputs, new_state = self._accelerated_reversible_layers_fns[i](
                inputs, weights, state, rngs[i + 1])
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_states.append(new_state)

        # Run the loss layer forward and backward with optimizer update.
        loss_weights = self._replicate(self._loss_layer.weights)
        loss_state = self._replicate(self._loss_layer.state)
        loss_inputs = _inputs_from_stack(self._loss_layer, stack)
        loss_slots = self._replicate(self._optimizers[-1].slots)
        new_weights, new_state, new_slots, grad_stack, loss_stats = self._loss_fbo(
            loss_inputs, loss_weights, loss_state, loss_slots,
            self._replicated_opt_params[-1], rngs[-1], step)
        stats = [loss_stats]
        self._loss_layer.weights = self._unreplicate(
            new_weights)  # acceler. -> cpu
        self._loss_layer.state = self._unreplicate(new_state)
        self._optimizers[-1].slots = self._unreplicate(new_slots)

        # Run reversible layers backward with optimizer update.
        counter = -1
        for layer, reverse_and_fbo, old_state, new_state, rng in reversed(
                list(
                    zip(self._reversible_layers, self._reverse_and_fbos,
                        old_states, new_states, rngs[1:-1]))):
            counter -= 1
            # We are running backwards and reversing, so we get *outputs* from stack.
            outputs = _inputs_from_stack(layer, stack, layer.n_out)
            grads = _inputs_from_stack(layer, grad_stack, layer.n_out)
            slots = self._replicate(self._optimizers[counter].slots)
            opt_params = self._replicated_opt_params[counter]
            weights = self._replicate(layer.weights)  # cpu -> accelerator
            new_weights, new_slots, inputs, grads, layer_stats = reverse_and_fbo(
                outputs, weights, old_state, new_state, slots, opt_params, rng,
                step, grads)
            layer.weights = self._unreplicate(
                new_weights)  # accelerator -> cpu
            layer.state = self._unreplicate(new_state)
            self._optimizers[counter].slots = self._unreplicate(new_slots)
            stats.append(layer_stats)
            stack = _outputs_onto_stack(layer, inputs, stack, layer.n_out,
                                        layer.n_in)
            grad_stack = _outputs_onto_stack(layer, grads, grad_stack,
                                             layer.n_out, layer.n_in)

        # Run the first layer forward-and-backward pass and optimizer update.
        grads = _inputs_from_stack(self._first_layer, grad_stack,
                                   self._first_layer.n_out)
        slots = self._replicate(self._optimizers[0].slots)
        new_weights, new_state, new_slots, first_layer_stats = self._first_fbo(
            first_layer_inputs, first_layer_weights, first_layer_new_state,
            slots, self._replicated_opt_params[0], rngs[0], step, grads)
        stats.append(first_layer_stats)
        self._first_layer.weights = self._unreplicate(new_weights)
        self._first_layer.state = self._unreplicate(new_state)
        self._optimizers[0].slots = self._unreplicate(new_slots)

        return stats[0]['loss'], stats