Exemplo n.º 1
0
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5,
              center=True, scale=True, **unused_kwargs):
  """Layer construction function for a batch normalization layer."""
  mean = np.mean(x, axis, keepdims=True)
  # Fast but less numerically-stable variance calculation than np.var.
  m1 = np.mean(x**2, axis, keepdims=True)
  var = m1 - mean**2
  # x mustn't be onp.ndarray here; otherwise `x-mean` will call mean.__rsub__
  # with each element of x, resulting in an onp.ndarray with dtype `object`.
  z = (x - mean) / np.sqrt(var + epsilon).astype(x.dtype)

  # Expand the parameters to have the right axes.
  beta, gamma = params
  # TODO(phawkins): np.expand_dims should accept an axis tuple.
  # (https://github.com/numpy/numpy/issues/12290)
  ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
  beta = beta[ed]
  gamma = gamma[ed]

  # Return the z rescaled by the parameters if requested.
  if center and scale:
    ret = gamma * z + beta
  elif center:
    ret = z + beta
  elif scale:
    ret = gamma * z
  else:
    ret = z
  assert ret.dtype == x.dtype, ('The dtype of the output (%s) of batch norm is '
                                'not the same as the input (%s). Batch norm '
                                'should not change the dtype' %
                                (ret.dtype, x.dtype))
  return ret
Exemplo n.º 2
0
def BatchNorm(x,
              params,
              axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              **unused_kwargs):
    """Layer construction function for a batch normalization layer."""
    mean = np.mean(x, axis, keepdims=True)
    # Fast but less numerically-stable variance calculation than np.var.
    m1 = np.mean(x**2, axis, keepdims=True)
    var = m1 - mean**2
    z = (x - mean) / np.sqrt(var + epsilon)

    # Expand the parameters to have the right axes.
    beta, gamma = params
    # TODO(phawkins): np.expand_dims should accept an axis tuple.
    # (https://github.com/numpy/numpy/issues/12290)
    ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
    beta = beta[ed]
    gamma = gamma[ed]

    # Return the z rescaled by the parameters if requested.
    if center and scale:
        return gamma * z + beta
    if center:
        return z + beta
    if scale:
        return gamma * z
    return z
Exemplo n.º 3
0
 def apply_fun(params, inputs, **kwargs):
     del kwargs
     (scale, bias) = params
     mean = np.mean(inputs, axis=-1, keepdims=True)
     variance = np.mean((inputs - mean)**2, axis=-1, keepdims=True)
     norm_inputs = (inputs - mean) / np.sqrt(variance + epsilon)
     return norm_inputs * scale + bias
Exemplo n.º 4
0
    def call(self, x, params, state, **unused_kwargs):
        """Layer construction function for a batch normalization layer."""

        running_mean, running_var, num_batches = state

        if self._mode == 'train':
            mean = np.mean(x, self._axis, keepdims=True)
            # Fast but less numerically-stable variance calculation than np.var.
            m1 = np.mean(x**2, self._axis, keepdims=True)
            var = m1 - mean**2
            num_batches = num_batches + 1
            if self._momentum is None:
                # A simple average over all batches seen so far
                exponential_average_factor = 1.0 / num_batches
            else:
                exponential_average_factor = self._momentum

            def average(factor, new, old):
                return (factor * new + (1 - factor) * old).astype(old.dtype)

            running_mean = average(exponential_average_factor, mean,
                                   running_mean)
            running_var = average(exponential_average_factor, var, running_var)
            state = (running_mean, running_var, num_batches)
        else:
            mean = running_mean
            var = running_var

        z = (x - mean.astype(x.dtype)) / np.sqrt(var + self._epsilon).astype(
            x.dtype)

        # Expand the parameters to have the right axes.
        beta, gamma = params
        # TODO(phawkins): np.expand_dims should accept an axis tuple.
        # (https://github.com/numpy/numpy/issues/12290)
        ed = tuple(None if i in self._axis else slice(None)
                   for i in range(np.ndim(x)))
        beta = beta[ed]
        gamma = gamma[ed]

        # Return the z rescaled by the parameters if requested.
        if self._center and self._scale:
            output = gamma * z + beta
        elif self._center:
            output = z + beta
        elif self._scale:
            output = gamma * z
        else:
            output = z
        assert output.dtype == x.dtype, (
            'The dtype of the output (%s) of batch '
            'norm is not the same as the input (%s). '
            'Batch norm should not change the dtype' % (output.dtype, x.dtype))
        return output, state
Exemplo n.º 5
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        learning_rate = opt_params["learning_rate"]
        beta1 = opt_params["beta1"]
        decay_rate = opt_params["decay_rate"]
        clipping_threshold = opt_params["clipping_threshold"]
        weight_decay_rate = opt_params["weight_decay_rate"]
        epsilon1 = opt_params["epsilon1"]
        epsilon2 = opt_params["epsilon2"]
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        # TODO(lukaszkaiser): why is the astype needed here? Check and correct.
        return new_params.astype(params.dtype), updates
Exemplo n.º 6
0
def masked_mean(inputs, targets, mask_id=None):
  """Mean of the inputs but counting only those where targets != mask_id."""
  x = inputs.astype(np.float32)
  if mask_id is None:
    return np.mean(x)
  unmask = 1.0 - np.equal(targets, mask_id).astype(np.float32)
  return np.sum(x * unmask) / np.sum(unmask)
Exemplo n.º 7
0
 def combine(x):
     if len(x.shape) > 1:
         batch_size = x.shape[0] * x.shape[1]
         return np.reshape(x, [batch_size] + list(x.shape[2:]))
     # TODO(lukaszkaiser): is returning averages for scalars the right choice?
     # If it is only scalar, return the average.
     return np.mean(x, axis=0)
Exemplo n.º 8
0
    def update(self, step, grads, params, slots, opt_params):
        updates = []
        (learning_rate, beta1, decay_rate, clipping_threshold,
         weight_decay_rate, epsilon1, epsilon2) = opt_params
        decay_rate = self._decay_rate_pow(step, exponent=decay_rate)
        update_scale = learning_rate
        if self._multiply_by_parameter_scale:
            update_scale *= np.maximum(np.sqrt(np.mean(params * params)),
                                       epsilon2)
        mixing_rate = 1.0 - decay_rate

        grads_sqr = grads * grads + epsilon1
        if self._factored and len(params.shape) >= 2:
            v_row = slots.pop(0)
            v_col = slots.pop(0)
            new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-1)
            new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr,
                                                                   axis=-2)
            updates.extend([new_v_row, new_v_col])
            row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
            row_factor = (new_v_row / row_col_mean)**-0.5
            col_factor = (new_v_col)**-0.5
            y = (grads * np.expand_dims(row_factor, axis=-1) *
                 np.expand_dims(col_factor, axis=-2))
        else:
            v = slots.pop(0)
            new_v = decay_rate * v + mixing_rate * grads_sqr
            updates.append(new_v)
            y = grads * (new_v)**-0.5

        if self._do_clipping:
            clipping_denom = (np.maximum(
                1.0,
                np.sqrt(np.mean(y * y)) / clipping_threshold))
            y /= clipping_denom

        subtrahend = update_scale * y
        if self._do_momentum:
            m = slots.pop(0)
            new_m = beta1 * m + (1.0 - beta1) * subtrahend
            subtrahend = new_m
            updates.append(new_m)

        new_params = (1 - weight_decay_rate) * params - subtrahend
        return new_params, updates
Exemplo n.º 9
0
  def update(self, i, g, x, state):
    updates = []
    decay_rate = self._decay_rate(i)
    update_scale = self._step_size(i)
    if self._multiply_by_parameter_scale:
      update_scale *= np.maximum(np.sqrt(np.mean(x * x)), self._epsilon2)
    mixing_rate = 1.0 - decay_rate

    g_sqr = g * g + self._epsilon1
    if self._factored and len(x.shape) >= 2:
      v_row = state.pop(0)
      v_col = state.pop(0)
      new_v_row = decay_rate * v_row + mixing_rate * np.mean(g_sqr, axis=-1)
      new_v_col = decay_rate * v_col + mixing_rate * np.mean(g_sqr, axis=-2)
      updates.extend([new_v_row, new_v_col])
      row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True)
      row_factor = (new_v_row / row_col_mean)**-0.5
      col_factor = (new_v_col)**-0.5
      y = (
          g * np.expand_dims(row_factor, axis=-1) *
          np.expand_dims(col_factor, axis=-2))
    else:
      v = state.pop(0)
      new_v = decay_rate * v + mixing_rate * g_sqr
      updates.append(new_v)
      y = g * (new_v)**-0.5

    if self._clipping_threshold is not None:
      clipping_denom = (
          np.maximum(1.0,
                     np.sqrt(np.mean(y * y)) / self._clipping_threshold))
      y /= clipping_denom

    subtrahend = update_scale * y
    if self._beta1:
      m = state.pop(0)
      new_m = self._beta1 * m + (1.0 - self._beta1) * subtrahend
      subtrahend = new_m
      updates.append(new_m)

    new_x = x - subtrahend
    return new_x, updates
Exemplo n.º 10
0
def masked_mean(inputs, targets, mask_id=None):
  """Mean of the inputs but counting only those where targets != mask_id."""
  inputs = [x.astype(np.float32) for x in inputs]
  # We assume all elements in the list contribute equally.
  # TODO(lukaszkaiser): remove this assumption (e.g., when masks differ).
  length = len(inputs)
  if mask_id is None:
    # TODO(lukaszkaiser): can we just divide the sum by length? XLA optimizes?
    return sum([np.mean(x) / length for x in inputs])
  unmask = [1.0 - np.equal(t, mask_id).astype(np.float32) for t in targets]
  return sum([np.sum(x * m) / (length * np.sum(m))
              for x, m in zip(inputs, unmask)])
Exemplo n.º 11
0
 def apply_fun(params, x, **kwargs):
   beta, gamma = params
   # TODO(phawkins): np.expand_dims should accept an axis tuple.
   # (https://github.com/numpy/numpy/issues/12290)
   ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
   beta = beta[ed]
   gamma = gamma[ed]
   mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
   z = (x - mean) / np.sqrt(var + epsilon)
   if center and scale: return gamma * z + beta
   if center: return z + beta
   if scale: return gamma * z
   return z
Exemplo n.º 12
0
def Mean(x, params, axis=-1, keepdims=False, **kwargs):
    del params, kwargs
    return np.mean(x, axis=axis, keepdims=keepdims)
Exemplo n.º 13
0
def fastvar(x, axis, keepdims):
    """A fast but less numerically-stable variance calculation than np.var."""
    m1 = np.mean(x**2, axis, keepdims=keepdims)
    m2 = np.mean(x, axis, keepdims=keepdims)**2
    return m1 - m2
Exemplo n.º 14
0
def crossentropy_loss(logpred, target):
    """Calculate crossentropy loss."""
    return -np.mean(
        np.sum(logpred * slax.one_hot(target, logpred.shape[-1]), axis=-1))
Exemplo n.º 15
0
def LayerNorm(x, params, epsilon=1e-6, **unused_kwargs):
    (scale, bias) = params
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
    norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
    return norm_inputs * scale + bias
Exemplo n.º 16
0
 def make_unit_length(self, x, epsilon=1e-6):
     variance = np.mean(x**2, axis=-1, keepdims=True)
     norm_inputs = x / np.sqrt(variance + epsilon)
     return norm_inputs
Exemplo n.º 17
0
def Mean(x, axis=-1, keepdims=False, **unused_kwargs):
    return np.mean(x, axis=axis, keepdims=keepdims)