Exemple #1
0
    def update(self, step, grads, weights, slots, opt_params):
        updates = []
        learning_rate = opt_params['learning_rate']
        beta1 = opt_params['beta1']
        decay_rate = opt_params['decay_rate']
        clipping_threshold = opt_params['clipping_threshold']
        weight_decay_rate = opt_params['weight_decay_rate']
        weight_decay_n_steps = opt_params['weight_decay_n_steps']
        weight_decay_rate = jnp.where(
            weight_decay_n_steps <
            1,  # if weight_decay_n_steps == 0, ignore it
            weight_decay_rate,
            (weight_decay_rate *
             jnp.maximum(weight_decay_n_steps - step, 0.0) /
             jnp.maximum(weight_decay_n_steps, 0.0)))
        epsilon1 = opt_params['epsilon1']
        epsilon2 = opt_params['epsilon2']
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)),
                                        epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads
        if self._factored and len(weights.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = (decay_rate * v_row +
                         mixing_rate * jnp.mean(grads_sqr, axis=-1))
            new_v_col = (decay_rate * v_col +
                         mixing_rate * jnp.mean(grads_sqr, axis=-2))
            updates.extend([new_v_row, new_v_col])
            row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (row_mean / (new_v_row + epsilon1))**0.5
            col_factor = (new_v_col + epsilon1)**-0.5
            y = (grads * jnp.expand_dims(row_factor, axis=-1) *
                 jnp.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v + epsilon1)**-0.5

        if self._do_clipping:
            clipping_denom = (jnp.maximum(
                1.0,
                jnp.sqrt(jnp.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_weights = (1 - weight_decay_rate) * weights - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_weights.astype(weights.dtype), updates
Exemple #2
0
def mean_or_pmean(n_devices, x, axis=None):
  """jnp.mean or pmean.

  `x` is a distributed value. Directly calling jnp.mean on `x` means stacking
  x's components together to form a large array and then doing jnp.mean on
  it. In TF, stacking `x` will introduce D2H copy, so we use a collective
  (pmean) here instead of directly calling jnp.mean for TF.

  Args:
    n_devices: number of devices.
    x: a distributed array.
    axis: the axis to reduce. Can only be 0 or None.

  Returns:
    A local array.
  """
  if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1:
    if axis not in (None, 0):
      raise ValueError('axis can only be None or 0')
    x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices
    if axis is None:
      x = jnp.mean(x)
    return x
  else:
    return jnp.mean(x, axis=axis)
Exemple #3
0
 def forward(self, x):
     scale, bias = self.weights
     mean = jnp.mean(x, axis=-1, keepdims=True)
     centered = x - mean
     variance = jnp.mean(centered * centered, axis=-1, keepdims=True)
     norm_inputs = centered / jnp.sqrt(variance + self._epsilon)
     return norm_inputs * scale + bias
Exemple #4
0
 def forward(self, x):
     scale, bias = self.weights
     mean = jnp.mean(x, axis=-1, keepdims=True)
     sub = x - mean
     variance = jnp.mean(sub * sub, axis=-1, keepdims=True)
     norm_inputs = sub / jnp.sqrt(variance + self._epsilon)
     return norm_inputs * scale + bias
Exemple #5
0
 def f(preds, values, returns, actions, mask):
   advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
Exemple #6
0
    def train_step(self, batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current optimizer parameters.
        opt_param_updates = self._for_n_devices(
            {'learning_rate': np.array(self.learning_rate)})
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        weights, slots, opt_params = opt_state
        (weights,
         slots), stat, self._model_state, self._rngs = self._jit_update_fn(
             (weights, slots), self._step, opt_params, batch,
             self._model_state, self._rngs)
        self._opt_state = opt_state._replace(weights=weights, slots=slots)
        if self._should_log_now():
            for name, value in stat.items():
                # TODO(afrozm): value is a scalar, but sometimes JAX is crashing here
                # with a device put array error complaining that it should be an array.
                # On multiple devices, take the mean.
                scalar_value = np.mean(np.array(value))
                self._train_sw.scalar('training/' + name,
                                      scalar_value,
                                      step=self._step)
        self._step += 1
Exemple #7
0
    def train_step(self, batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current optimizer parameters.
        # TODO(pkozakowski): Optimizer parameters get polluted with model state,
        # which doesn't break anything but is weird. Filter it out.
        opt_param_updates = self._for_n_devices(
            fastmath.nested_map(np.array, self.nontrainable_params))
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        weights, slots, opt_params = opt_state
        (weights,
         slots), stat, self._model_state, self._rngs = self._jit_update_fn(
             (weights, slots), self._step, opt_params, batch,
             self._model_state, self._rngs)
        self._opt_state = opt_state._replace(weights=weights, slots=slots)
        if self._should_log_now():
            for name, value in stat.items():
                scalar_value = np.mean(
                    value)  # On  multiple devices, take the mean.
                self._train_sw.scalar('training/' + name,
                                      scalar_value,
                                      step=self._step)
        self._step += 1
Exemple #8
0
    def test_custom_zero_grad(self, backend):
        class IdWithZeroGrad(tl.Layer):
            def forward(self, x):
                return x

            @property
            def has_backward(self):
                return True

            def backward(self, inputs, output, grad, weights, state, new_state,
                         rng):
                return (jnp.zeros_like(grad), ())

        with fastmath.use_backend(backend):
            layer = IdWithZeroGrad()
            rng = fastmath.random.get_prng(0)
            input_signature = shapes.ShapeDtype((9, 17))
            random_input = fastmath.random.uniform(rng,
                                                   input_signature.shape,
                                                   minval=-1.0,
                                                   maxval=1.0)
            layer.init(input_signature)
            f = lambda x: jnp.mean(layer(x))
            grad = fastmath.grad(f)(random_input)
            self.assertEqual(grad.shape, (9, 17))  # Gradient for each input.
            self.assertEqual(sum(sum(grad * grad)), 0.0)  # Each one is 0.
Exemple #9
0
 def _preprocess_advantages(self, advantages):
   if self._advantage_normalization:
     advantages = (
         (advantages - jnp.mean(advantages)) /
         (jnp.std(advantages) + self._advantage_normalization_epsilon)
     )
   return advantages
Exemple #10
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
Exemple #11
0
 def ClipFraction(dist_inputs, actions, old_log_probs):
     """Probability Ratio Mean from the PPO algorithm."""
     probs_ratio = rl_layers.ProbsRatio(
         dist_inputs,
         actions,
         old_log_probs,
         log_prob_fun=self._policy_dist.log_prob)
     return jnp.mean(jnp.abs(probs_ratio - 1) > self._epsilon)
Exemple #12
0
 def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   ppo_objective = rl_layers.PPOObjective(
       dist_inputs, values, returns, dones, rewards, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob,
       epsilon=self._epsilon,
       normalize_advantages=self._normalize_advantages)
   return jnp.mean(ppo_objective)
Exemple #13
0
 def f(model_output, targets):  # pylint: disable=invalid-name
   beta2 = beta ** 2
   predictions = jnp.argmax(model_output, axis=-1)
   n_categories = model_output.shape[-1]
   f_scores = jnp.empty(0)
   for k in range(initial_category_index, n_categories):
     _, _, _, precision, recall = _precision_recall(predictions, targets, k)
     f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2))
   return jnp.mean(f_scores)
Exemple #14
0
def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun):
    """Probability Ratio from the PPO algorithm."""
    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == old_log_probs.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')
    approximate_kl_divergence = 0.5 * \
        jnp.mean(new_log_probs - old_log_probs) ** 2
    return approximate_kl_divergence
Exemple #15
0
 def _aggregate_values(self, values, aggregate_max, act_log_probs):
     if self._q_value:
         if aggregate_max:
             values = jnp.max(values, axis=1)
         elif self._sample_all_discrete_actions:
             values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
         else:
             values = jnp.mean(values, axis=1)
     return np.array(values)  # Move the values to CPU.
Exemple #16
0
 def calculate_weights(self, advantages):
     """Calculates advantage-based weights for log loss in policy training."""
     if self._advantage_normalization:
         # Normalize advantages.
         advantages -= jnp.mean(advantages)
         advantage_std = jnp.std(advantages)
         advantages /= advantage_std + self._advantage_normalization_epsilon
     weights = self._weight_fn(advantages)
     assert weights.shape == advantages.shape
     return weights
Exemple #17
0
 def predict(x, weights, state, rng):
     """Predict function JIT-compiled and parallelized as requested."""
     res, state = _combine_devices(
         model_predict(reshape_by_device(x, n_devices), weights, state,
                       jnp.stack(fastmath.random.split(rng, n_devices))))
     if do_mean:
         return fastmath.nested_map(lambda y: jnp.mean(y, axis=0),
                                    res), state
     else:
         return res, state
Exemple #18
0
 def f(dist_inputs, values, returns, dones, rewards,
       actions, old_log_probs, mask):
   """A2C objective mean."""
   # TODO(henrykm): include dones, rewards
   del old_log_probs
   a2c_objective = rl_layers.A2CObjective(
       dist_inputs, values, returns, dones, rewards, actions, mask,
       log_prob_fun=self._policy_dist.log_prob,
       normalize_advantages=self._normalize_advantages)
   return jnp.mean(a2c_objective)
Exemple #19
0
 def f(values, weights):  # pylint: disable=invalid-name
     # This function assumes weights are 0 or 1.
     # Then compute 1: not-correct, 0: correct or masked
     not_correct = (1.0 - values) * weights
     axis_to_sum = list(range(1, len(not_correct.shape)))
     # Summing not-correct on all axes but batch. We're summing 0s and 1s,
     # so the sum is 0 if it's all 0 and >=1 in all other cases.
     not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum)
     # Sequence is correct if not_correct_seq is 0, reverting here.
     correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq)
     return jnp.mean(correct_seq)  # Mean over batch.
Exemple #20
0
def mean_or_pmean(n_devices, x, axis=None):
  """Computes the mean of a distributed value ``x``.

  Args:
    n_devices: Number of devices.
    x: Distributed array.
    axis: Axis along which to compute means; can only be ``0`` or ``None``.

  Returns:
    A local array.
  """
  if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1:
    if axis not in (None, 0):
      raise ValueError('axis can only be None or 0')
    x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices
    if axis is None:
      x = jnp.mean(x)
    return x
  else:
    return jnp.mean(x, axis=axis)
Exemple #21
0
 def fn(dist_inputs, actions, q_values, act_log_probs, mask):
     del dist_inputs, actions, mask
     q_values = jnp.swapaxes(q_values, 0, 1)
     act_log_probs = jnp.swapaxes(act_log_probs, 0, 1)
     if self._sample_all_discrete_actions:
         values = jnp.sum(q_values * jnp.exp(act_log_probs), axis=0)
     else:
         values = jnp.mean(q_values, axis=0)
     advantages = q_values - values  # Broadcasting values over n_samples
     if preprocess:
         advantages = self._preprocess_advantages(advantages)
     return advantages
Exemple #22
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
Exemple #23
0
 def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   del dones, rewards
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   # advantages are of the shape [128,1,1]
   # and probs_ratio are of the shape [128,1]
   advantages = advantages.squeeze(axis=2)
   clipped_objective = rl_layers.ClippedObjective(
       probs_ratio, advantages, epsilon=self._epsilon)
   return jnp.mean(clipped_objective)
Exemple #24
0
    def _aggregate_values(self, values, aggregate, act_log_probs):
        # Normalize the Q-values before aggragetion, so it can adapt to the scale
        # of the returns. This does not affect mean and max aggregation.
        scale = 1
        epsilon = 1e-5
        if self._q_value_normalization == 'std':
            scale = jnp.std(values) + epsilon
        elif self._q_value_normalization == 'abs':
            scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon
        values /= scale

        temp = self._q_value_temperature
        if self._q_value:
            assert values.shape[:2] == (self._value_batch_size,
                                        self._q_value_n_samples)
            if aggregate == 'max':
                # max_a Q(s, a)
                values = jnp.max(values, axis=1)
            elif aggregate == 'softmax':
                # sum_a (Q(s, a) * w(s, a))
                # where w(s, .) = softmax (Q(s, .) / T)
                weights = tl.Softmax(axis=1)(values / temp)
                values = jnp.sum(values * weights, axis=1)
            elif aggregate == 'logsumexp':
                # log(mean_a exp(Q(s, a) / T)) * T
                n = values.shape[1]
                values = (fastmath.logsumexp(values / temp, axis=1) -
                          jnp.log(n)) * temp
            else:
                assert aggregate == 'mean'
                # mean_a Q(s, a)
                if self._sample_all_discrete_actions:
                    values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
                else:
                    values = jnp.mean(values, axis=1)

        # Re-scale the Q-values after aggregation.
        values *= scale
        return np.array(values)  # Move the values to CPU.
Exemple #25
0
    def f(new_log_probs, advantages, old_log_probs, mask):
      # new_log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # old_log_probs of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if new_log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           advantages.shape))
      if new_log_probs.shape != old_log_probs.shape:
        raise ValueError('New log-probs and old log-probs shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           old_log_probs.shape))
      if new_log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (new_log_probs.shape, mask.shape))

      # The ratio between new_probs and old_probs expressed
      # using log_probs and exponentaion
      probs_ratio = jnp.exp(new_log_probs - old_log_probs)
      if advantages.shape != probs_ratio.shape:
        raise ValueError('New log-probs and old log probs shapes '
                         'should be the same, %s != %s' % (advantages.shape,
                                                           probs_ratio.shape))
      unclipped_objective = probs_ratio * advantages
      clipped_objective = jnp.clip(probs_ratio,
                                   1 - self._epsilon,
                                   1 + self._epsilon) * advantages

      if unclipped_objective.shape != probs_ratio.shape:
        raise ValueError('unclipped_objective and clipped_objective shapes '
                         'should be the same, %s != %s' % (
                             unclipped_objective.shape,
                             clipped_objective.shape))

      ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

      if ppo_objective.shape != mask.shape:
        raise ValueError('ppo_objective and mask shapes '
                         'should be the same, %s != %s' % (
                             ppo_objective.shape,
                             mask.shape))

      ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask)
      entropy_vec = self._policy_dist.entropy(
          new_log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)
      combined_loss = ppo_loss - entropy_loss

      return combined_loss
Exemple #26
0
    def forward(self, inputs):
        gamma, beta, epsilon_l = self.weights

        epsilon = self._init_epsilon
        if epsilon_l is not base.EMPTY_WEIGHTS:
            epsilon += jnp.abs(epsilon_l[0])

        # Omit B and C
        axis = tuple(range(1, len(jnp.shape(inputs)) - 1))
        # (B, 1, 1, C)
        nu2 = jnp.mean(inputs**2, axis=axis, keepdims=True)
        # (B, W, H, C)
        xhat = inputs / jnp.sqrt(nu2 + epsilon)

        return gamma * xhat + beta
Exemple #27
0
 def policy_metrics(self):
   metrics = {
       'policy_loss': self.policy_loss,
       'advantage_mean': tl.Serial(
           self._policy_inputs_to_advantages(False),
           tl.Fn('Mean', lambda x: jnp.mean(x))  # pylint: disable=unnecessary-lambda
       ),
       'advantage_std': tl.Serial(
           self._policy_inputs_to_advantages(False),
           tl.Fn('Std', lambda x: jnp.std(x))  # pylint: disable=unnecessary-lambda
       )
   }
   metrics.update(awr_metrics(
       self._beta, preprocess_layer=self._policy_inputs_to_advantages(True)))
   return metrics
Exemple #28
0
def Mean(axis=-1, keepdims=False):
    """Returns a layer that computes mean values using one tensor axis.

  `Mean` uses one tensor axis to form groups of values and replaces each group
  with the mean value of that group. The resulting values can either remain
  in their own size 1 axis (`keepdims=True`), or that axis can be removed from
  the overall tensor (default `keepdims=False`), lowering the rank of the
  tensor by one.

  Args:
    axis: Axis along which values are grouped for computing a mean.
    keepdims: If `True`, keep the resulting size 1 axis as a separate tensor
        axis; else, remove that axis.
  """
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
Exemple #29
0
    def f(model_output, targets):  # pylint: disable=invalid-name
        def non_nan(x):  # pylint: disable=invalid-name
            return jnp.where(jnp.isnan(x), 0., x)

        beta2 = beta**2
        predictions = jnp.argmax(model_output, axis=-1)
        n_categories = model_output.shape[-1]
        f_scores = jnp.empty(0)
        for k in range(initial_category_index, n_categories):
            n_correct = sum((predictions == k) & (targets == k))
            precision = non_nan(n_correct / sum(predictions == k))
            recall = non_nan(n_correct / sum(targets == k))
            f_score = non_nan((beta2 + 1) * (precision * recall) /
                              ((beta2 * precision) + recall))
            f_scores = jnp.append(f_scores, f_score)
        return jnp.mean(f_scores)
Exemple #30
0
  def train_step(self, batch):
    """Run one training step and update self._opt_state."""
    # Calculate the current optimizer parameters.
    opt_param_updates = self._for_n_devices(
        {'learning_rate': np.array(self.learning_rate)})
    opt_state = self._opt_state
    opt_state.opt_params.update(opt_param_updates)

    # Run the update.
    weights, slots, opt_params = opt_state
    (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn(
        (weights, slots), self._step, opt_params, batch,
        self._model_state, self._rngs)
    self._opt_state = opt_state._replace(weights=weights, slots=slots)
    if self._should_log_now():
      for name, value in stat.items():
        scalar_value = np.mean(value)  # On  multiple devices, take the mean.
        self._train_sw.scalar('training/' + name, scalar_value, step=self._step)
    self._step += 1