Exemplo n.º 1
0
 def f(x):
     if n_devices > 1 and math.backend_name() == 'jax':
         return _multi_device_put(x)
     elif n_devices > 1:
         return jnp.broadcast_to(x, (n_devices, ) + x.shape)
     else:
         return x
Exemplo n.º 2
0
def RepresentationMask(mask, serializer, **unused_kwargs):
    """Upsamples a mask to cover the serialized representation."""
    # Trax enforces the mask to be of the same size as the target. Get rid of the
    # extra dimensions.
    mask = np.amax(mask, axis=tuple(range(2, mask.ndim)))
    return np.broadcast_to(mask[:, :, None],
                           mask.shape + (serializer.representation_length, ))
Exemplo n.º 3
0
    def policy_batches_stream(self):
        """Use the RLTask self._task to create inputs to the policy model."""
        # For now TD-0 estimation of the value. TODO(pkozakowski): Support others?
        for np_trajectory in self._task.trajectory_batch_stream(
                self._policy_batch_size,
                epochs=self._replay_epochs,
                max_slice_length=self._max_slice_length,
                include_final_state=False,
        ):
            (q_values, actions,
             act_log_probs) = self._run_value_model(np_trajectory.observations,
                                                    np_trajectory.dist_inputs)
            shapes.assert_same_shape(q_values, act_log_probs)

            # q_values shape: (batch_size, n_samples, length)
            if len(q_values.shape) != 3:
                raise ValueError(
                    'Q-values are expected to have shape [batch_size, ' +
                    'n_samples, length], got: %s' % str(q_values.shape))
            if q_values.shape[1] != self._q_value_n_samples:
                raise ValueError(
                    'Q-values dimension 1 should = n_samples, %d != %d' %
                    (q_values.shape[1], self._q_value_n_samples))
            if q_values.shape[0] != self._policy_batch_size:
                raise ValueError(
                    'Q-values dimension 0 should = policy batch size, ' +
                    '%d!=%d' % (q_values.shape[1], self._policy_batch_size))

            mask = np_trajectory.mask
            mask = np.reshape(mask, [mask.shape[0], 1] + list(mask.shape[1:]))
            mask = jnp.broadcast_to(mask, q_values.shape)
            shapes.assert_same_shape(mask, q_values)

            yield (np_trajectory.observations, actions, q_values,
                   act_log_probs, mask)
Exemplo n.º 4
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Exemplo n.º 5
0
    def forward_with_state(self,
                           inputs,
                           weights=layer_base.EMPTY_WEIGHTS,
                           state=layer_base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, np.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(t=np.arange(input_len, dtype=np.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[np.newaxis, :, :]
            emb = np.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = np.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        return inputs, state
Exemplo n.º 6
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)
        emb = np.concatenate(embs, -1)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            return inputs + emb[:, state, :][:, None, :], state + 1
        elif self._dropout == 0:
            return inputs + np.reshape(emb, inputs.shape), state
        else:
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if math.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Exemplo n.º 7
0
    def forward(self, inputs, weights):
        state = self.state
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, jnp.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(
                t=jnp.arange(input_len, dtype=jnp.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[jnp.newaxis, :, :]
            emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        self.state = state
        return inputs
Exemplo n.º 8
0
 def predict(x, weights, state):
     """Predict function jited and parallelized as requested."""
     res, state = trax.layers.base._combine_devices(
         model_predict(
             trax.layers.base.reshape_by_device(x, n_devices), weights,
             state,
             np.broadcast_to(jax.random.PRNGKey(0)[None, :], (8, 2))))
     return res
Exemplo n.º 9
0
 def LossInput(dist_inputs, actions, advantages, old_dist_inputs, mask):  # pylint: disable=invalid-name
   """Calculates action log probabilities and normalizes advantages."""
   del old_dist_inputs
   advantages = self._preprocess_advantages(advantages)
   dist_inputs = jnp.broadcast_to(
       dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape
   )
   log_probs = self._policy_dist.log_prob(dist_inputs, actions)
   # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
   advantages = jnp.swapaxes(advantages, 0, 1)
   mask = jnp.swapaxes(mask, 0, 1)
   return (log_probs, advantages, log_probs, mask)
Exemplo n.º 10
0
        def LossInput(dist_inputs, actions, q_values, act_log_probs, mask):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
            q_values = jnp.swapaxes(q_values, 0, 1)
            mask = jnp.swapaxes(mask, 0, 1)
            actions = jnp.swapaxes(actions, 0, 1)
            act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)

            # TODO(pkozakowski,lukaszkaiser): Try max here, or reweighting?
            # Reweight: values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
            values = jnp.mean(q_values, axis=0)
            advantages = q_values - values  # Broadcasting values over n_samples
            advantages = self._preprocess_advantages(advantages)

            # Broadcast inputs and calculate log-probs
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            return (log_probs, advantages, act_log_probs, mask)
Exemplo n.º 11
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Exemplo n.º 12
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.concatenate(embs, -1)
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            emb = jax.lax.dynamic_slice_in_dim(emb,
                                               state,
                                               inputs.shape[1],
                                               axis=1)
            return inputs + emb, state + inputs.shape[1]
        elif self._dropout == 0:
            # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
            # leads to memory blow-up on TPU.
            # emb = np.concatenate(embs, -1)
            # return inputs + np.reshape(emb, inputs.shape), state
            return inputs + np.concatenate([
                np.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], ))
                for emb in embs
            ], -1), state
        else:
            emb = np.concatenate(embs, -1)
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if math.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Exemplo n.º 13
0
  def policy_batches_stream(self):
    """Use the RLTask self._task to create inputs to the policy model."""
    # For now TD-0 estimation of the value. TODO(pkozakowski): Support others?
    for np_trajectory in self._task.trajectory_batch_stream(
        self._policy_batch_size,
        epochs=self._replay_epochs,
        max_slice_length=self._max_slice_length,
        include_final_state=False,
    ):
      (q_values, actions) = self._run_value_model(
          np_trajectory.observations, np_trajectory.dist_inputs
      )
      # TODO(pkozakowski): Try max here.
      values = jnp.mean(q_values, axis=0)

      if len(values.shape) != 2:
        raise ValueError('Values are expected to have shape ' +
                         '[batch_size, length], got: %s' % str(values.shape))
      if values.shape[0] != self._policy_batch_size:
        raise ValueError('Values first dimension should = policy batch size, ' +
                         '%d != %d' %(values.shape[0], self._policy_batch_size))

      # q_values shape: (n_samples, batch_size, length)
      # values shape: (batch_size, length)
      # Computing advantages by broadcasting over n_samples.
      advantages = q_values - values
      mask = jnp.broadcast_to(np_trajectory.mask, advantages.shape)

      shapes.assert_shape_equals(
          advantages, (self._q_value_n_samples,) + values.shape
      )
      shapes.assert_same_shape(mask, advantages)

      # Swapping the n_samples and batch_size axes, so the input is split
      # between accelerators along the batch_size axis.
      advantages = jnp.swapaxes(advantages, 0, 1)
      mask = jnp.swapaxes(mask, 0, 1)

      yield (np_trajectory.observations, actions, advantages, mask, mask)
Exemplo n.º 14
0
def SignificanceWeights(mask, serializer, decay, **unused_kwargs):
    """Multiplies a binary mask with a symbol significance mask."""
    # (repr,) -> (batch, length, repr)
    significance = serializer.significance_map[None, None]
    return mask * decay**np.broadcast_to(significance, mask.shape)
Exemplo n.º 15
0
 def significance_weights(mask):
     # (repr,) -> (batch, length, repr)
     significance = serializer.significance_map[None, None]
     return mask * decay**jnp.broadcast_to(significance, mask.shape)
Exemplo n.º 16
0
 def representation_mask(mask):
     mask = jnp.amax(mask, axis=tuple(range(2, mask.ndim)))
     return jnp.broadcast_to(
         mask[:, :, None],
         mask.shape + (serializer.representation_length, ))