def new_weights_and_state(self, input_signature): """Helper to initialize batch norm weights.""" axis = self._axis axis = (axis,) if np.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if self._center else () gamma = np.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = np.zeros(stats_shape, dtype=np.float32) running_var = np.ones(stats_shape, dtype=np.float32) n_batches = np.zeros((), dtype=np.int64) weights = (beta, gamma) state = (running_mean, running_var, n_batches) return weights, state
def init_weights_and_state(self, input_signature): """Helper to initialize batch norm weights and state.""" axis = self._axis axis = (axis, ) if jnp.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) # TODO(jonni): Should beta and gamma match the dtype in the input signature? beta = jnp.zeros(shape, dtype='float32') if self._center else () gamma = jnp.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple( get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) running_var = jnp.ones(stats_shape, dtype=jnp.float32) n_batches = jnp.zeros((), dtype=jnp.int64) self.weights = (beta, gamma) self.state = (running_mean, running_var, n_batches)