def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, **unused_kwargs): """Layer construction function for a batch normalization layer.""" mean = np.mean(x, axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, axis, keepdims=True) var = m1 - mean**2 z = (x - mean) / np.sqrt(var + epsilon) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, **unused_kwargs): """Layer construction function for a batch normalization layer.""" mean = np.mean(x, axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, axis, keepdims=True) var = m1 - mean**2 # x mustn't be onp.ndarray here; otherwise `x-mean` will call mean.__rsub__ # with each element of x, resulting in an onp.ndarray with dtype `object`. z = (x - mean) / np.sqrt(var + epsilon).astype(x.dtype) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if center and scale: ret = gamma * z + beta elif center: ret = z + beta elif scale: ret = gamma * z else: ret = z assert ret.dtype == x.dtype, ('The dtype of the output (%s) of batch norm is ' 'not the same as the input (%s). Batch norm ' 'should not change the dtype' % (ret.dtype, x.dtype)) return ret
def call(self, x, params, state, **unused_kwargs): """Layer construction function for a batch normalization layer.""" running_mean, running_var, num_batches = state if self._mode == 'train': mean = np.mean(x, self._axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, self._axis, keepdims=True) var = m1 - mean**2 num_batches = num_batches + 1 if self._momentum is None: # A simple average over all batches seen so far exponential_average_factor = 1.0 / num_batches else: exponential_average_factor = self._momentum def average(factor, new, old): return (factor * new + (1 - factor) * old).astype(old.dtype) running_mean = average(exponential_average_factor, mean, running_mean) running_var = average(exponential_average_factor, var, running_var) state = (running_mean, running_var, num_batches) else: mean = running_mean var = running_var z = (x - mean.astype(x.dtype)) / np.sqrt(var + self._epsilon).astype( x.dtype) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in self._axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if self._center and self._scale: output = gamma * z + beta elif self._center: output = z + beta elif self._scale: output = gamma * z else: output = z assert output.dtype == x.dtype, ( 'The dtype of the output (%s) of batch ' 'norm is not the same as the input (%s). ' 'Batch norm should not change the dtype' % (output.dtype, x.dtype)) return output, state
def apply_fun(params, x, **kwargs): beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True) z = (x - mean) / np.sqrt(var + epsilon) if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z