Пример #1
0
def BatchNorm(x,
              params,
              axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              **unused_kwargs):
    """Layer construction function for a batch normalization layer."""
    mean = np.mean(x, axis, keepdims=True)
    # Fast but less numerically-stable variance calculation than np.var.
    m1 = np.mean(x**2, axis, keepdims=True)
    var = m1 - mean**2
    z = (x - mean) / np.sqrt(var + epsilon)

    # Expand the parameters to have the right axes.
    beta, gamma = params
    # TODO(phawkins): np.expand_dims should accept an axis tuple.
    # (https://github.com/numpy/numpy/issues/12290)
    ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
    beta = beta[ed]
    gamma = gamma[ed]

    # Return the z rescaled by the parameters if requested.
    if center and scale:
        return gamma * z + beta
    if center:
        return z + beta
    if scale:
        return gamma * z
    return z
Пример #2
0
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5,
              center=True, scale=True, **unused_kwargs):
  """Layer construction function for a batch normalization layer."""
  mean = np.mean(x, axis, keepdims=True)
  # Fast but less numerically-stable variance calculation than np.var.
  m1 = np.mean(x**2, axis, keepdims=True)
  var = m1 - mean**2
  # x mustn't be onp.ndarray here; otherwise `x-mean` will call mean.__rsub__
  # with each element of x, resulting in an onp.ndarray with dtype `object`.
  z = (x - mean) / np.sqrt(var + epsilon).astype(x.dtype)

  # Expand the parameters to have the right axes.
  beta, gamma = params
  # TODO(phawkins): np.expand_dims should accept an axis tuple.
  # (https://github.com/numpy/numpy/issues/12290)
  ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
  beta = beta[ed]
  gamma = gamma[ed]

  # Return the z rescaled by the parameters if requested.
  if center and scale:
    ret = gamma * z + beta
  elif center:
    ret = z + beta
  elif scale:
    ret = gamma * z
  else:
    ret = z
  assert ret.dtype == x.dtype, ('The dtype of the output (%s) of batch norm is '
                                'not the same as the input (%s). Batch norm '
                                'should not change the dtype' %
                                (ret.dtype, x.dtype))
  return ret
Пример #3
0
    def call(self, x, params, state, **unused_kwargs):
        """Layer construction function for a batch normalization layer."""

        running_mean, running_var, num_batches = state

        if self._mode == 'train':
            mean = np.mean(x, self._axis, keepdims=True)
            # Fast but less numerically-stable variance calculation than np.var.
            m1 = np.mean(x**2, self._axis, keepdims=True)
            var = m1 - mean**2
            num_batches = num_batches + 1
            if self._momentum is None:
                # A simple average over all batches seen so far
                exponential_average_factor = 1.0 / num_batches
            else:
                exponential_average_factor = self._momentum

            def average(factor, new, old):
                return (factor * new + (1 - factor) * old).astype(old.dtype)

            running_mean = average(exponential_average_factor, mean,
                                   running_mean)
            running_var = average(exponential_average_factor, var, running_var)
            state = (running_mean, running_var, num_batches)
        else:
            mean = running_mean
            var = running_var

        z = (x - mean.astype(x.dtype)) / np.sqrt(var + self._epsilon).astype(
            x.dtype)

        # Expand the parameters to have the right axes.
        beta, gamma = params
        # TODO(phawkins): np.expand_dims should accept an axis tuple.
        # (https://github.com/numpy/numpy/issues/12290)
        ed = tuple(None if i in self._axis else slice(None)
                   for i in range(np.ndim(x)))
        beta = beta[ed]
        gamma = gamma[ed]

        # Return the z rescaled by the parameters if requested.
        if self._center and self._scale:
            output = gamma * z + beta
        elif self._center:
            output = z + beta
        elif self._scale:
            output = gamma * z
        else:
            output = z
        assert output.dtype == x.dtype, (
            'The dtype of the output (%s) of batch '
            'norm is not the same as the input (%s). '
            'Batch norm should not change the dtype' % (output.dtype, x.dtype))
        return output, state
Пример #4
0
 def apply_fun(params, x, **kwargs):
   beta, gamma = params
   # TODO(phawkins): np.expand_dims should accept an axis tuple.
   # (https://github.com/numpy/numpy/issues/12290)
   ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
   beta = beta[ed]
   gamma = gamma[ed]
   mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
   z = (x - mean) / np.sqrt(var + epsilon)
   if center and scale: return gamma * z + beta
   if center: return z + beta
   if scale: return gamma * z
   return z