Exemplo n.º 1
0
    def update(self, step, grads, weights, slots, opt_params):
        updates = []
        learning_rate = opt_params['learning_rate']
        beta1 = opt_params['beta1']
        decay_rate = opt_params['decay_rate']
        clipping_threshold = opt_params['clipping_threshold']
        weight_decay_rate = opt_params['weight_decay_rate']
        weight_decay_n_steps = opt_params['weight_decay_n_steps']
        weight_decay_rate = jnp.where(
            weight_decay_n_steps <
            1,  # if weight_decay_n_steps == 0, ignore it
            weight_decay_rate,
            (weight_decay_rate *
             jnp.maximum(weight_decay_n_steps - step, 0.0) /
             jnp.maximum(weight_decay_n_steps, 0.0)))
        epsilon1 = opt_params['epsilon1']
        epsilon2 = opt_params['epsilon2']
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= jnp.maximum(jnp.sqrt(jnp.mean(weights * weights)),
                                        epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads
        if self._factored and len(weights.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = (decay_rate * v_row +
                         mixing_rate * jnp.mean(grads_sqr, axis=-1))
            new_v_col = (decay_rate * v_col +
                         mixing_rate * jnp.mean(grads_sqr, axis=-2))
            updates.extend([new_v_row, new_v_col])
            row_mean = jnp.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (row_mean / (new_v_row + epsilon1))**0.5
            col_factor = (new_v_col + epsilon1)**-0.5
            y = (grads * jnp.expand_dims(row_factor, axis=-1) *
                 jnp.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v + epsilon1)**-0.5

        if self._do_clipping:
            clipping_denom = (jnp.maximum(
                1.0,
                jnp.sqrt(jnp.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_weights = (1 - weight_decay_rate) * weights - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_weights.astype(weights.dtype), updates
Exemplo n.º 2
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     # TODO(henrykm): return float(jnp.max(minimum, ret)) would be
     # better but causes TypeError: 'numpy.float64' object cannot
     # be interpreted as an integer
     if ret <= minimum:
         return minimum
     return ret
Exemplo n.º 3
0
  def forward(self, inputs):
    """Executes this layer as part of a forward pass through the model.

    Args:
      inputs: Tensor.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    threshold = self.weights
    return jnp.maximum(inputs, threshold)
Exemplo n.º 4
0
def HardTanh():
  r"""Returns a layer that computes a linear approximation to `Tanh`.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          -1 & \text{if}\ x \leq 0, \\
          x  & \text{if}\ -1 < x < 1, \\
          1  & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('HardTanh', lambda x: jnp.maximum(-1, jnp.minimum(1, x)))
Exemplo n.º 5
0
def HardSigmoid():
  r"""Returns a layer that computes a linear approximation to `Sigmoid`.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0 & \text{if}\ x \leq 0, \\
          x & \text{if}\ 0 < x < 1, \\
          1 & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('HardSigmoid', lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x))))
Exemplo n.º 6
0
def TripletLossFn(v1, v2, margin=0.25):
    """Custom Loss function.

    Args:
        v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1.
        v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2.
        margin (float, optional): Desired margin. Defaults to 0.25.

    Returns:
        jax.interpreters.xla.DeviceArray: Triplet Loss.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###

    # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument)
    scores = fastnp.dot(v1, fastnp.transpose(v2))  # pairwise cosine sim
    # calculate new batch size
    batch_size = len(scores)
    # use fastnp to grab all postive `diagonal` entries in `scores`
    positive = fastnp.diagonal(scores)  # the positive ones (duplicates)
    # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores`
    negative_without_positive = scores - fastnp.eye(batch_size)
    # take the row by row `max` of `negative_without_positive`.
    # Hint: negative_without_positive.max(axis = [?])
    closest_negative = negative_without_positive.max(axis=[1])
    # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores`
    negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores
    # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)`
    mean_negative = fastnp.sum(negative_zero_on_duplicate,
                               axis=1) / (batch_size - 1)
    # compute `fastnp.maximum` among 0.0 and `A`
    # A = subtract `positive` from `margin` and add `closest_negative`
    triplet_loss1 = fastnp.maximum((margin - positive + closest_negative), 0.0)
    # compute `fastnp.maximum` among 0.0 and `B`
    # B = subtract `positive` from `margin` and add `mean_negative`
    triplet_loss2 = fastnp.maximum((margin - positive + mean_negative), 0.0)
    # add the two losses together and take the `fastnp.mean` of it
    triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2)

    ### END CODE HERE ###

    return triplet_loss
Exemplo n.º 7
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     return float(ret)
Exemplo n.º 8
0
def ParametricRelu(a=1.):
  r"""Returns a layer that computes a ReLU function with the given slope.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0  & \text{if}\ x \leq 0, \\
          ax & \text{otherwise}.
      \end{array} \right.

  Args:
    a: Slope of line for positive inputs.
  """
  return Fn('ParametricRelu', lambda x: jnp.maximum(a * x, jnp.zeros_like(x)))
Exemplo n.º 9
0
def _category_cross_entropy(  # pylint: disable=invalid-name
        model_output, targets, label_smoothing, cutoff):
    """Computes category cross entropy with label smoothing."""
    n_categories = model_output.shape[-1]
    target_distributions = core.one_hot(targets, n_categories)
    if label_smoothing:
        if label_smoothing < 0. or label_smoothing > 1.:
            raise ValueError(
                f'Arg label_smoothing ({label_smoothing}) must be between 0 and 1.'
            )
        target_distributions *= (1. - label_smoothing)
        target_distributions += label_smoothing / n_categories
    model_log_distributions = core.log_softmax(model_output)
    cross_ent = -jnp.sum(target_distributions * model_log_distributions,
                         axis=-1)
    if cutoff > 0.0:
        return jnp.maximum(cross_ent, cutoff) - cutoff
    else:
        return cross_ent
Exemplo n.º 10
0
  def test_forward(self):
    layer = tl.Fn(
        'SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)

    x0 = np.array([1, 2, 3, 4, 5])
    x1 = np.array([10, 20, 30, 40, 50])

    y0, y1 = layer((x0, x1))
    self.assertEqual(y0.tolist(), [11, 22, 33, 44, 55])
    self.assertEqual(y1.tolist(), [10, 20, 30, 40, 50])

    y2, y3 = layer.forward((x0, x1))
    self.assertEqual(y2.tolist(), [11, 22, 33, 44, 55])
    self.assertEqual(y3.tolist(), [10, 20, 30, 40, 50])

    (y4, y5), state = layer.pure_fn((x0, x1), tl.EMPTY_WEIGHTS, tl.EMPTY_STATE,
                                    None)
    self.assertEqual(y4.tolist(), [11, 22, 33, 44, 55])
    self.assertEqual(y5.tolist(), [10, 20, 30, 40, 50])
    self.assertEqual(state, tl.EMPTY_STATE)
Exemplo n.º 11
0
  def forward(self, xs):
    self._validate_forward_inputs(xs)
    (step, layers_state) = self.state
    # Get N+1 rngs, N for running layers and one extra.
    rngs = _split_rngs(self.rng, self._n_layers + 1)
    rng0, rngs = rngs[0], rngs[1:]
    if not self.sublayers:  # No-op: leave args unchanged.
      self.state = (step + 1, layers_state)
      return xs

    # Prepare the stack and do some safety checks as in the parent class.
    stack = xs
    new_state = []
    n_layers = self._n_layers
    weights = self.weights
    if n_layers != 1 and len(weights) != n_layers:
      raise ValueError('number of weights ({}) not equal to number of layers '
                       '({})'.format(len(weights), n_layers))
    if n_layers != 1 and len(layers_state) != n_layers:
      raise ValueError('length of state ({}) not equal to number of layers '
                       '({})'.format(len(layers_state), n_layers))

    # TODO(chowdhery): try different strategies, also try running not all
    # layers backwards by using fastmath.stop_gradient where needed.

    # Calculate how many layers to run forward.
    if self._mode == 'train':
      # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
      w_steps = float(self._skipping_warmup_steps)
      f_step = step.astype(jnp.float32)
      warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps)
      # low is the minimum number of layers to *not* skip, from n_layers to 0
      low = warmup * float(n_layers)
      # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
      # because (high - n_layers) / high is the probability we're not skipping
      # (after warmup); so high - n_layers = high - high * skip_fraction
      high = float(n_layers) / self._skip_fraction
      # We want the same rng0 on all cores.
      if fastmath.device_count() > 1:
        rng0 = fastmath.psum(rng0, 'batch')
      n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high)
    else:
      n_forward_layers = float(n_layers)
    # Run layers skipping after a certain number.
    cur_layer_idx = 0.0
    for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs):
      inputs = _inputs_from_stack(layer, stack)
      def CondF(t):
        return layer.pure_fn(t[0], t[1], t[2], t[3])  # pylint: disable=cell-var-from-loop
      def PassF(t):
        return t[0], t[2]
      outputs, s = fastmath.cond(
          fastmath.lt(cur_layer_idx, n_forward_layers),
          CondF,
          PassF,
          (inputs, p, s, rng)
      )
      stack = _outputs_onto_stack(layer, outputs, stack)
      new_state.append(s)
      cur_layer_idx += 1.0
    self.state = (step + 1, new_state)
    return stack