Example #1
0
def BatchNorm(axis=(0, 1, 2),
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_init=zeros,
              gamma_init=ones):
    """Layer construction function for a batch normalization layer."""
    _beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
    _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
    axis = (axis, ) if np.isscalar(axis) else axis

    def init_fun(rng, input_shape):
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape)
        return input_shape, (beta, gamma)

    def apply_fun(params, x, **kwargs):
        beta, gamma = params
        # TODO(phawkins): np.expand_dims should accept an axis tuple.
        # (https://github.com/numpy/numpy/issues/12290)
        ed = tuple(None if i in axis else slice(None)
                   for i in range(np.ndim(x)))
        beta = beta[ed]
        gamma = gamma[ed]
        mean, var = np.mean(x, axis, keepdims=True), fastvar(x,
                                                             axis,
                                                             keepdims=True)
        z = (x - mean) / np.sqrt(var + epsilon)
        if center and scale: return gamma * z + beta
        if center: return z + beta
        if scale: return gamma * z
        return z

    return init_fun, apply_fun
Example #2
0
def _batch_norm_new_params(input_shape, rng, axis=(0, 1, 2),
                           center=True, scale=True, **kwargs):
  """Helper to initialize batch norm params."""
  del rng, kwargs
  axis = (axis,) if np.isscalar(axis) else axis
  shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
  beta = np.zeros(shape, dtype='float32') if center else ()
  gamma = np.ones(shape, dtype='float32') if scale else ()
  return (beta, gamma)
Example #3
0
def masked_mean(inputs, targets, weights, mask_id=None):
  """Weighted mean of the inputs, excluding where targets == mask_id."""
  inputs = [x.astype(np.float32) for x in inputs]
  # We assume all elements in the list contribute equally.
  # TODO(lukaszkaiser): remove this assumption (e.g., when masks differ).
  length = len(inputs)
  if mask_id is not None:
    weights = [w * (1.0 - np.equal(t, mask_id).astype(np.float32))
               for t, w in zip(targets, weights)]
  weight_sums = [np.float32(t.size) if np.isscalar(w) else np.sum(w)
                 for w, t in zip(weights, targets)]
  return sum([np.sum(x * w) / (length * s)
              for x, w, s in zip(inputs, weights, weight_sums)])
    def new_parameters(self, input_shape, input_dtype, rng):
        """Helper to initialize batch norm params."""
        del input_dtype, rng
        axis = self._axis
        axis = (axis, ) if np.isscalar(axis) else axis
        shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
        beta = np.zeros(shape, dtype='float32') if self._center else ()
        gamma = np.ones(shape, dtype='float32') if self._scale else ()

        def get_stats_axis(i, d):
            if i in axis:
                return 1
            else:
                return d

        stats_shape = tuple(
            get_stats_axis(i, d) for i, d in enumerate(input_shape))
        running_mean = np.zeros(stats_shape, dtype=np.float32)
        running_var = np.ones(stats_shape, dtype=np.float32)
        num_batches = np.zeros((), dtype=np.int32)
        return (beta, gamma), (running_mean, running_var, num_batches)