Ejemplo n.º 1
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
Ejemplo n.º 2
0
def Softmax5Branches(x_list, n_branches=2, **unused_kwargs):
    """Softmax xs.

  The input xs is a list of embeddings and weights of the form
  w_1 e_1 .... w_n e_n (followed by optional rest that is preserved).

  Args:
    x_list: the input weights and embeddings.
    n_branches: what part of the list to use.

  Returns:
    softmax(w) * e for the joint weights w and embeddings e.
  """
    assert n_branches == 5
    softmax_activations = [x_list[2 * i] for i in range(n_branches)]
    max_sa = softmax_activations[0]
    for x in softmax_activations:
        max_sa = np.maximum(max_sa, x)
    softmax_activations = [x - max_sa for x in softmax_activations]
    softmax_activations = [np.exp(x) for x in softmax_activations]
    sum_sa = sum(softmax_activations)
    softmax_activations = [x / sum_sa for x in softmax_activations]
    res = sum([
        x_list[2 * i + 1] * softmax_activations[i] for i in range(n_branches)
    ])
    return res
 def learning_rate(step):  # pylint: disable=invalid-name
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= np.sqrt(warmup_steps)
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = np.maximum(0.0, (step - warmup_steps) /
                                   float(steps_per_cycle))
             ret *= np.maximum(
                 0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {'learning_rate': ret}
Ejemplo n.º 4
0
 def learning_rate(step):  # pylint: disable=invalid-name
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == "constant":
             ret *= constant
         elif name == "linear_warmup":
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == "rsqrt_decay":
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == "decay_every":
             ret *= (decay_factor**(step // steps_per_decay))
         else:
             raise ValueError("Unknown factor %s." % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {"learning_rate": ret}
Ejemplo n.º 5
0
        def forward_slice(query_slice, q_loop_idx, key, value):  # pylint: disable=invalid-name
            """Forward pass for a subset of the query vectors."""
            if self._share_qk:
                key = self.make_unit_length(key)

            dots = np.matmul(query_slice, np.swapaxes(key, -1,
                                                      -2)) / np.sqrt(depth)

            # Causal masking
            mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx)
            dots = dots - 1e9 * mask

            # Mask out attention to self except when no other targets are available.
            if self._share_qk:
                self_mask = make_self_mask(dots.shape[-2], dots.shape[-1],
                                           q_loop_idx)
                dots = dots - 1e5 * self_mask

            # Softmax.
            dots = np.exp(dots -
                          backend.logsumexp(dots, axis=-1, keepdims=True))

            if self.dropout is not None and self.dropout > 0.0:
                # Dropout is broadcast across the batch+head dimension
                dropout_shape = (1, dots.shape[-2], dots.shape[-1])
                slice_rng = jax.random.fold_in(rng, q_loop_idx)
                keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout)
                keep = backend.random.bernoulli(slice_rng, keep_prob,
                                                dropout_shape)
                multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(
                    keep, keep_prob)
                dots = dots * multiplier

            if self._hard_k > 0:
                top_k = np.sort(dots)[...,
                                      -self._hard_k]  # Get the top-kth weight.
                top_k = jax.lax.stop_gradient(top_k)
                dots -= top_k[...,
                              np.newaxis]  # Subtract (be 0 for lower ones).
                dots = np.maximum(dots, 0)
                dots_sum = np.sum(dots, axis=-1,
                                  keepdims=True)  # Re-normalize.
                dots /= dots_sum  # Re-normalize.

            out_slice = np.matmul(dots, value)
            return out_slice
def Softmax5Branches(x_list, **unused_kwargs):
  """Softmax qs.

  The input xs is a list of weights and embedded queries of the form
  w_1 ... w_n q_1 ... q_n. The q_1 ... q_n will be kept, result appended.

  Args:
    x_list: the input weights and embeddings.

  Returns:
    the weighted average of q_1 ... q_n according to softmax(w).
  """
  n_branches = 5
  softmax_activations = x_list[:n_branches]
  max_sa = softmax_activations[0]
  for x in softmax_activations:
    max_sa = np.maximum(max_sa, x)
  softmax_activations = [x - max_sa for x in softmax_activations]
  softmax_activations = [np.exp(x) for x in softmax_activations]
  sum_sa = sum(softmax_activations)
  softmax_activations = [x / sum_sa for x in softmax_activations]
  res = sum([x_list[i + n_branches] * softmax_activations[i]
             for i in range(n_branches)])
  return res
Ejemplo n.º 7
0
 def forward(self, inputs, weights):
   threshold = weights[0]
   return np.maximum(inputs, threshold)
Ejemplo n.º 8
0
def HardTanh(x, **unused_kwargs):
  """Linear approximation to tanh."""
  return np.maximum(-1, np.minimum(1, x))
Ejemplo n.º 9
0
def HardSigmoid(x, **unused_kwargs):
  """Linear approximation to sigmoid."""
  return np.maximum(0, np.minimum(1, (1 + x)))
Ejemplo n.º 10
0
def ParametricRelu(x, a=1., **unused_kwargs):
  return np.maximum(a * x, np.zeros_like(x))
Ejemplo n.º 11
0
def Relu(x, **unused_kwargs):
  return np.maximum(x, np.zeros_like(x))