def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, beta_init=zeros, gamma_init=ones): """Layer construction function for a batch normalization layer.""" _beta_init = lambda rng, shape: beta_init(rng, shape) if center else () _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else () axis = (axis, ) if np.isscalar(axis) else axis def init_fun(rng, input_shape): shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta, gamma = _beta_init(rng, shape), _gamma_init(rng, shape) return input_shape, (beta, gamma) def apply_fun(params, x, **kwargs): beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True) z = (x - mean) / np.sqrt(var + epsilon) if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z return init_fun, apply_fun
def _batch_norm_new_params(input_shape, rng, axis=(0, 1, 2), center=True, scale=True, **kwargs): """Helper to initialize batch norm params.""" del rng, kwargs axis = (axis,) if np.isscalar(axis) else axis shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if center else () gamma = np.ones(shape, dtype='float32') if scale else () return (beta, gamma)
def masked_mean(inputs, targets, weights, mask_id=None): """Weighted mean of the inputs, excluding where targets == mask_id.""" inputs = [x.astype(np.float32) for x in inputs] # We assume all elements in the list contribute equally. # TODO(lukaszkaiser): remove this assumption (e.g., when masks differ). length = len(inputs) if mask_id is not None: weights = [w * (1.0 - np.equal(t, mask_id).astype(np.float32)) for t, w in zip(targets, weights)] weight_sums = [np.float32(t.size) if np.isscalar(w) else np.sum(w) for w, t in zip(weights, targets)] return sum([np.sum(x * w) / (length * s) for x, w, s in zip(inputs, weights, weight_sums)])
def new_parameters(self, input_shape, input_dtype, rng): """Helper to initialize batch norm params.""" del input_dtype, rng axis = self._axis axis = (axis, ) if np.isscalar(axis) else axis shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if self._center else () gamma = np.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple( get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = np.zeros(stats_shape, dtype=np.float32) running_var = np.ones(stats_shape, dtype=np.float32) num_batches = np.zeros((), dtype=np.int32) return (beta, gamma), (running_mean, running_var, num_batches)