예제 #1
0
    def clipped_objective_mean(self):
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs):
            """Clipped objective from the PPO algorithm."""
            del dones, rewards
            advantages = returns - values
            probs_ratio = rl_layers.ProbsRatio(
                dist_inputs,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob)
            # advantages are of the shape [128,1,1]
            # and probs_ratio are of the shape [128,1]
            advantages = advantages.squeeze(axis=2)
            clipped_objective = rl_layers.ClippedObjective(
                probs_ratio, advantages, epsilon=self._epsilon)
            return jnp.mean(clipped_objective)

        return tl.Fn('ClippedObjectiveMean', f)
예제 #2
0
def PickLastTokenInPredict(mode='train'):
  """Picks the last token logits.

  Self-descriptive layer for picking the last token logits in predict mode
  for fast inference.

  Args:
    mode: the model mode (train, predict, ...)

  Returns:
    The last token logits.
  """

  def last_token(x):  # pylint: disable=invalid-name
    if mode == 'predict':
      return x[:, -1:, :]
    return x

  return tl.Fn('Pick last token in predict', last_token)
예제 #3
0
    def ppo_objective_mean(self):
        """PPO objective mean."""
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs):
            """Clipped objective from the PPO algorithm."""
            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                values,
                returns,
                dones,
                rewards,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)
            return jnp.mean(ppo_objective)

        return tl.Fn('PPOObjectiveMean', f)
예제 #4
0
def SignificanceWeights(serializer, decay):
  """Multiplies a binary mask with a symbol significance mask."""
  def significance_weights(mask):
    # (repr,) -> (batch, length, repr)
    # significance = [0, 1, 2]
    significance = serializer.significance_map
    assert significance.shape[0] == mask.shape[2]
    # significance = batch_size * [0, 1, 2]
    significance = jnp.repeat(
        significance[np.newaxis, ...], repeats=mask.shape[0], axis=0)
    # significance = batch_size * [0, 1, 2] * mask.shape[1]
    significance = jnp.repeat(
        significance[..., jnp.newaxis], repeats=mask.shape[1], axis=2)
    # significance = batch_size *  mask.shape[1] * [0, 1, 2]
    significance = jnp.swapaxes(significance, 1, 2)
    assert significance.shape == mask.shape
    sig_weights = mask * decay ** significance
    return sig_weights
  return tl.Fn('SignificanceWeights', significance_weights)
예제 #5
0
파일: training.py 프로젝트: ixxxxu/trax
    def value_loss(self):
        """Value loss computed using smooth L1 loss or L2 loss."""
        def f(values, actions, returns, mask):
            ind_0, ind_1 = np.indices(actions.shape)
            # We calculate length using the shape of returns
            # and adequatly remove a superflous slice of values.
            # An analogous operation is done in value_batches_stream.
            length = returns.shape[1]
            values = values[:, :length, :]
            selected_values = values[ind_0, ind_1, actions]
            shapes.assert_same_shape(selected_values, returns)
            shapes.assert_same_shape(selected_values, mask)
            if self._smoothl1loss:
                return tl.SmoothL1Loss().forward(
                    (selected_values, returns, mask))
            else:
                return tl.L2Loss().forward((selected_values, returns, mask))

        return tl.Fn('ValueLoss', f)
예제 #6
0
def _MaskOfRightShiftedArray(n_shifts=1, mode='train'):
    """Gives us the mask of a right shifted by n_shifts array."""
    def F(x):
        # TODO(afrozm): What to do in this case?
        if mode == 'predict':
            raise ValueError(
                'MaskOfRightShiftedArray not implemented for predict.')

        mask = x != 0

        if n_shifts == 0:
            return mask

        # Need to set (B, n_shifts, ...) section to True.
        trues_shape = (x.shape[0], n_shifts) + mask.shape[2:]
        trues = jnp.full(trues_shape, True)
        return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1)

    return tl.Fn(f'MaskOfRightShiftedArray({n_shifts})', F)
예제 #7
0
def Deinterleave(x_size, y_size):
    """Layer that does the inverse of Interleave."""
    def deinterleave(inputs):
        reprs = inputs
        (batch_size, length) = reprs.shape[:2]
        shape_suffix = reprs.shape[2:]
        remainder_length = length % (x_size + y_size)
        if remainder_length > 0:
            remainder = reprs[:, None, -remainder_length:]
            reprs = reprs[:, :-remainder_length]
        reprs = jnp.reshape(reprs,
                            (batch_size, -1, x_size + y_size) + shape_suffix)
        x_reprs = reprs[:, :, :x_size]
        y_reprs = reprs[:, :, x_size:]
        if remainder_length > 0:
            x_reprs = jnp.concatenate((x_reprs, remainder), axis=1)
        return (x_reprs, y_reprs)

    return tl.Fn('Deinterleave', deinterleave, n_out=2)
예제 #8
0
파일: base_test.py 프로젝트: yunhaj47/trax
  def test_forward(self):
    layer = tl.Fn(
        'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)

    x0 = np.array([1, 2, 3, 4, 5])
    x1 = np.array([10, 20, 30, 40, 50])

    y0, y1 = layer((x0, x1))
    self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55])
    self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50])

    y2, y3 = layer.forward((x0, x1))
    self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55])
    self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50])

    (y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE,
                                    None)
    self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55])
    self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50])
    self.assertEqual(state, tl.EMPTY_STATE)
예제 #9
0
    def a2c_objective_mean(self):
        """A2C objective mean."""
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs, mask):
            """A2C objective mean."""
            # TODO(henrykm): include dones, rewards
            del old_log_probs
            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                values,
                returns,
                dones,
                rewards,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)
            return jnp.mean(a2c_objective)

        return tl.Fn('A2CObjectiveMean', f, n_out=1)
예제 #10
0
  def policy_loss(self, **unused_kwargs):
    """Policy loss."""
    def LossInput(dist_inputs, actions, advantages, old_dist_inputs, mask):  # pylint: disable=invalid-name
      """Calculates action log probabilities and normalizes advantages."""
      del old_dist_inputs
      advantages = self._preprocess_advantages(advantages)
      dist_inputs = jnp.broadcast_to(
          dist_inputs, (self._q_value_n_samples,) + dist_inputs.shape
      )
      log_probs = self._policy_dist.log_prob(dist_inputs, actions)
      # (batch_size, n_samples, ...) -> (n_samples, batch_size, ...)
      advantages = jnp.swapaxes(advantages, 0, 1)
      mask = jnp.swapaxes(mask, 0, 1)
      return (log_probs, advantages, log_probs, mask)

    return tl.Serial(
        tl.Fn('LossInput', LossInput, n_out=4),
        # Policy loss is expected to consume
        # (log_probs, advantages, old_log_probs, mask).
        AWRLoss(beta=self._beta, w_max=self._w_max),  # pylint: disable=no-value-for-parameter
    )
예제 #11
0
    def policy_loss(self, **unused_kwargs):
        """Policy loss."""
        def normalize(adv):
            return ((adv - jnp.mean(adv)) /
                    (jnp.std(adv) + self._advantage_normalization_epsilon))

        def LossInput(dist_inputs, actions, advantages, old_dist_inputs):  # pylint: disable=invalid-name
            """Calculates action log probabilities and normalizes advantages."""
            if self._advantage_normalization:
                advantages = normalize(advantages)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            old_log_probs = self._policy_dist.log_prob(old_dist_inputs,
                                                       actions)
            return (log_probs, advantages, old_log_probs)

        return tl.Serial(
            tl.Fn('LossInput', LossInput, n_out=3),
            # Policy loss is expected to consume
            # (log_probs, advantages, old_log_probs, mask).
            self.policy_loss_given_log_probs,
        )
예제 #12
0
파일: sparsity.py 프로젝트: piotrekp1/trax
def MultiplicativeModularSparseDense(sparsity, d_feature):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It is a combination of
  multiplicative dense and locally connected dense layers.

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_feature: Dimensionality of input and output tensor.
  """

    assert d_feature % sparsity == 0
    d_module = d_feature // sparsity

    return tl.Serial(
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_feature]),
        # Weight below is a kernel of multiplicative dense, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]),
        # Weight below is a kernel of modular dense.
        tl.Weights(
            functools.partial(init.GlorotUniformInitializer(),
                              nonreceptive_dims=[0]),
            [sparsity, d_module, d_module]),
        # To save memory the per-head preprocessing and multiplying by
        # kernels is done in a single einsum.
        tl.Fn(
            'SparseDenseEinsum',
            (
                lambda kmod, kmult, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier,
                           embeds))),
        MergeLastTwoAxes(),
        # Weight below is bias after dense, per-head.
        tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]),
        tl.Add(),
    )
예제 #13
0
def MultiplicativeSparseDense(sparsity, d_input, d_output=None):  # pylint: disable=invalid-name
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It multiplies each
  dimension of the input tensor by a scalar specific to each dimension and each
  module separately; then it applies Dense(d_output/sparsity) to each module.
  Compared to standard dense layer, MultiplicativeSparseDense uses less
  parameters while still being able to express many interesting functions (for
  example a permutation).

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_input: Dimensionality of input tensor.
    d_output: Dimensionality of output tensor; by default equal to d_input.
  """

    assert d_output % sparsity == 0
    d_module = d_output // sparsity

    return tl.Serial(
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_input]),
        # Weight below is dense kernel, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module]),
        # To save memory the per-head preprocessing and multiplying by the
        # kernel is done in the same einsum.
        tl.Fn(
            'AttentionEinsum',
            (
                lambda kernel, multiplier, embeds:  # pylint: disable=g-long-lambda
                np.einsum('dx,hd,bld->blhx', kernel, multiplier, embeds))),
        MergeLastTwoAxes(),
        # Weight below is bias after dense, per-head.
        tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]),
        tl.Add(),
    )
def _StripFromConcatenateWithPadding():
    """Strips out the leading encoder tokens from the concatenated array."""
    def _StripEncToks(vec_ed, tok_e, tok_d):
        # pylint: disable=invalid-name
        B, L, H = vec_ed.shape
        L1 = tok_e.shape[1]
        L2 = tok_d.shape[1]
        # pylint: enable=invalid-name
        if L != L1 + L2:
            raise ValueError(
                f'Length from encoder-decoder vectors ({L}) does not'
                f' equal sum of lengths from encoder ({L1}) and decoder'
                f' ({L2}).')
        if tok_e.shape != (B, L1):
            raise ValueError(
                f'Shape of encoder tokens, {tok_e.shape}, does not'
                f' equal {(B, L1)}.')
        if tok_d.shape != (B, L2):
            raise ValueError(
                f'Shape of decoder tokens, {tok_d.shape}, does not'
                f' equal {(B, L2)}.')

        def _UpdateRow(x):
            # (L, H), (L1, H) & (L2, H)
            row_ed, row_e, _ = x
            mask_e = row_e != 0
            len_e = jnp.sum(mask_e, dtype=jnp.int32)
            # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
            # and pick up (L2, H) tensor slice from there.
            zero = jnp.array(0,
                             dtype=len_e.dtype)  # avoid int32/int64 mismatch
            l2_np = jnp.array(L2, dtype=len_e.dtype)
            h_np = jnp.array(H, dtype=len_e.dtype)
            return jax.lax.dynamic_slice(row_ed, (len_e, zero), (l2_np, h_np))

        return jax.lax.map(_UpdateRow, [vec_ed, tok_e, tok_d])

    return tl.Fn('StripFromConcatenateWithPadding', _StripEncToks, n_out=1)
예제 #15
0
def _Upsampler(total_pool_size, separate_cls):
    """Returns an upsampling layer for Funnel Transformer.

  Args:
    total_pool_size: The combined pool size of previously used funnel blocks.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
          embeddings of the first token (`cls` from BERT paper).
  """
    def _Upsample(short, long):
        if separate_cls:
            upsampled_short = jnp.concatenate(
                (short[:, :1, :], short[:, 1:, :].repeat(total_pool_size,
                                                         axis=1)),
                axis=1)
            return index_add(long,
                             (slice(None), slice(
                                 None, upsampled_short.shape[1]), slice(None)),
                             upsampled_short)
        else:
            upsampled_short = short.repeat(total_pool_size, axis=1)
            return long + upsampled_short

    return tl.Fn('Upsampler', _Upsample)
예제 #16
0
def siamese(vocab_size, d_model=128):
    """Returns a Siamese model.

    Args:
        vocab_size (int, optional): Length of the vocabulary. Defaults to
                                    len(vocab).
        d_model (int, optional): Depth of the model. Defaults to 128.

    Returns:
        trax.layers.combinators.Parallel: A Siamese model.
    """
    def normalize(vec):  # normalizes the vectors to have L2 norm 1
        return vec / fastnp.sqrt(fastnp.sum(vec * vec, axis=-1, keepdims=True))

    s_processor = tl.Serial(
        tl.Embedding(vocab_size, d_model),  # Embedding layer
        tl.LSTM(d_model),  # LSTM layer
        tl.Mean(axis=1),  # Mean over columns
        tl.Fn('Normalize', normalize)  # Apply normalize function
    )  # Returns one vector of shape [batch_size, d_model].

    # Run on s1_tensor and s2_tensor in parallel.
    model = tl.Parallel(s_processor, s_processor)
    return model
예제 #17
0
def _StripFromConcatenateWithPadding():
    """Strip out the leading encoder tokens from the concatenated array."""
    def _StripEncToks(vec_ed, tok_e, tok_d):
        # pylint: disable=invalid-name
        B, L, H = vec_ed.shape
        L1 = tok_e.shape[1]
        L2 = tok_d.shape[1]
        # pylint: enable=invalid-name
        assert L == L1 + L2
        assert (B, L1) == tok_e.shape
        assert (B, L2) == tok_d.shape

        def _UpdateRow(x):
            # (L, H), (L1, H) & (L2, H)
            row_ed, row_e, _ = x
            mask_e = row_e != 0
            len_e = jnp.sum(mask_e, dtype=jnp.int32)
            # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
            # and pick up (L2, H) tensor slice from there.
            return jax.lax.dynamic_slice(row_ed, (len_e, 0), (L2, H))

        return jax.lax.map(_UpdateRow, [vec_ed, tok_e, tok_d])

    return tl.Fn('StripFromConcatenateWithPadding', _StripEncToks, n_out=1)
예제 #18
0
def _StripFromConcatenateWithPadding():
  """Strip out the leading encoder tokens from the concatenated array."""

  # Shapes: (L1+L2, H), (L1,) and (L2,)
  def F(vec_ed, tok_e, tok_d):
    mask_e = tok_e != 0
    # Actual length of encoder tokens <= L1
    len_e = jnp.sum(mask_e)
    # Padded length of decoder tokens, this is L2.
    L2 = tok_d.shape[0]  # pylint: disable=invalid-name

    # vec_ed is of type [eeedd00000], we will roll it len_e=3 in reverse.
    # This gives us [dd00000eee] and now we take only the first L2 elements.
    return jnp.roll(vec_ed, -len_e, axis=0)[:L2]

  # TODO(afrozm): Try to do this with sort_key_val instead of roll to get rid of
  # the vmap.
  def _F(vec_ed, tok_e, tok_d):
    return jax.vmap(F)(vec_ed, tok_e, tok_d)

  # We could have written `tl.Fn(..., jax.vmap(F), ...)` here but Trax needs the
  # top-level function (here: jax.vmap) to not have variable or named arguments,
  # so we need a thin wrapper.
  return tl.Fn('StripFromConcatenateWithPadding', _F, n_out=1)
예제 #19
0
def BERT(d_model=768,
         vocab_size=30522,
         max_len=512,
         type_vocab_size=2,
         n_heads=12,
         d_ff=3072,
         n_layers=12,
         head=None,
         init_checkpoint=None,
         mode='eval',
        ):
  """BERT (default hparams are for bert-base-uncased)."""
  layer_norm_eps = 1e-12
  d_head = d_model // n_heads

  word_embeddings = tl.Embedding(d_model, vocab_size)
  type_embeddings = tl.Embedding(d_model, type_vocab_size)
  position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
  embeddings = [
      tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
      tl.Parallel(
          word_embeddings,
          type_embeddings,
          [tl.PaddingMask(),
           tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]
      ),
      tl.Add(),
      position_embeddings,
      tl.LayerNorm(epsilon=layer_norm_eps),
  ]

  encoder = []
  for _ in range(n_layers):
    attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head,
                            bias=True, masked=True, mode=mode)
    feed_forward = [
        tl.Dense(d_ff),
        tl.Gelu(),
        tl.Dense(d_model)
    ]
    encoder += [
        tl.Select([0, 1, 1]),  # Save a copy of the mask
        tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(epsilon=layer_norm_eps),
        tl.Residual(*feed_forward),
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

  encoder += [tl.Select([0], n_in=2)]  # Drop the mask

  pooler = [
      tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
      tl.Dense(d_model),
      tl.Tanh(),
  ]

  init_checkpoint = init_checkpoint if mode == 'train' else None
  bert = PretrainedBERT(
      embeddings + encoder + pooler, init_checkpoint=init_checkpoint)

  if head is not None:
    bert = tl.Serial(bert, head())

  return bert
예제 #20
0
 def LogProb(self):  # pylint: disable=invalid-name
     """Builds a log probability layer for this distribution."""
     return tl.Fn('LogProb',
                  lambda inputs, point: self.log_prob(inputs, point))  # pylint: disable=unnecessary-lambda
예제 #21
0
  def joint_loss(self):
    """Joint policy and value loss."""
    def f(dist_inputs, values, returns, dones, rewards,
          actions, old_log_probs, mask):
      """Definition of the Proximal Policy Optimization loss."""
      del mask  # TODO(lukaszkaiser): make PPO work with Transformer
      # We have dist_inputs of the shape float32[128,1,18]
      assert len(dist_inputs.shape) == 3, (
          f'dist_inputs.shape was {dist_inputs.shape}'
          f'but expected length of the tensor shape is 3')
      # values of the shape float32[128,1,1]
      # returns of the shape float32[128,1,1]
      # dones of the shape int32[128,1,1]
      # rewards of the shape float32[128,1,1]
      # and old_log_probs of the shape float32[128,1]
      assert values.shape == returns.shape, (
          f'values.shape was {values.shape}'
          f'returns.shape was {returns.shape}')
      assert values.shape == dones.shape, (
          f'values.shape was {values.shape}'
          f'returns.shape was {dones.shape}')
      assert rewards.shape == dones.shape, (
          f'values.shape was {values.shape}'
          f'returns.shape was {dones.shape}')
      assert returns.shape[0:2] == old_log_probs.shape, (
          f'returns.shape was {returns.shape}'
          f'old_log_probs.shape was {old_log_probs.shape}')

      # actions is a tensor of the shape int32[128,1] in the case
      # of discrete actions and float32[128,1,6] in the case of
      # half-cheetah and other continuous actions
      # actions agree with returns/values on the first two coordinates
      # meaning batch and time
      assert actions.shape[0:2] == returns.shape[0:2], (
          f'actions.shape was {actions.shape} and '
          f'returns.shape was {returns.shape}')

      ppo_objective = rl_layers.PPOObjective(
          dist_inputs, stop_gradient(values), returns, dones, rewards,
          actions, old_log_probs,
          log_prob_fun=self._policy_dist.log_prob,
          epsilon=self._epsilon,
          normalize_advantages=self._normalize_advantages)

      # we insist that ppo_objective is a vector of shape [128,1]
      assert len(ppo_objective.shape) == 2, (
          f'ppo_objective was {ppo_objective}')
      # which agrees with returns/values/actions on the first two coordinates
      assert ppo_objective.shape[0:2] == values.shape[0:2], (
          f'ppo_objective.shape was {ppo_objective.shape} and '
          f'values.shape was {values.shape}')

      entropy_loss = rl_layers.EntropyLoss(
          dist_inputs,
          distribution=self._policy_dist,
          coeff=self._entropy_coeff,
      )

      assert jnp.ndim(entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

      l2_value_loss = rl_layers.ValueLoss(
          values, returns, value_loss_coeff=self._value_loss_coeff)

      assert jnp.ndim(l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

      return -ppo_objective.mean() + l2_value_loss - entropy_loss

    return tl.Fn('PPOJointLoss', f)
예제 #22
0
 def preferred_move(self):
   """Preferred move - the mean of selected moves."""
   def f(dist_inputs, values):
     del values
     return rl_layers.PreferredMove(dist_inputs, self._policy_dist.sample)
   return tl.Fn('PreferredMove', f)
예제 #23
0
 def log_probs_mean(self):
   """Mean of log_probs aka dist_inputs."""
   def f(dist_inputs, values):
     del values
     return jnp.mean(dist_inputs)
   return tl.Fn('LogProbsMean', f)
예제 #24
0
 def explained_variance(self):
   """Explained variance metric."""
   def f(dist_inputs, values, returns):
     del dist_inputs
     return rl_layers.ExplainedVariance(values, returns)
   return tl.Fn('ExplainedVariance', f)
예제 #25
0
 def value_loss(self):
   """Value loss - so far generic for all A2C."""
   def f(dist_inputs, values, returns):
     del dist_inputs
     return rl_layers.ValueLoss(values, returns, self._value_loss_coeff)
   return tl.Fn('ValueLoss', f)
예제 #26
0
 def advantage_norm(self):
   """Norm of advantages."""
   def f(dist_inputs, values, returns):
     del dist_inputs
     return jnp.linalg.norm(returns - values)
   return tl.Fn('AdvantageNorm', f)
예제 #27
0
 def advantage_mean(self):
   """Mean of advantages."""
   def f(dist_inputs, values, returns):
     del dist_inputs
     return jnp.mean(returns - values)
   return tl.Fn('AdvantageMean', f)
예제 #28
0
 def advantage_std(self):
     return tl.Serial([
         # (dist_inputs, advantages, old_dist_inputs, mask)
         tl.Select([1]),  # Select just the advantages.
         tl.Fn('AdvantageStd', lambda x: jnp.std(x)),  # pylint: disable=unnecessary-lambda
     ])
예제 #29
0
파일: policy_tasks.py 프로젝트: ixxxxu/trax
        def make_metric(aggregate_fn):  # pylint: disable=invalid-name
            def AdvantageMetric(policy_inputs, actions, advantages, mask):
                del policy_inputs, actions, mask
                return aggregate_fn(advantages)

            return tl.Fn('AdvantageMetric', AdvantageMetric)
예제 #30
0
파일: policy_tasks.py 프로젝트: ixxxxu/trax
    def entropy_metric(self):
        def Entropy(policy_inputs, actions, advantages, mask):
            del actions, advantages, mask
            return jnp.mean(self._policy_dist.entropy(policy_inputs))

        return tl.Fn('Entropy', Entropy)