def _layer_norm_weights(input_signature): """Helper: create layer norm parameters.""" features = input_signature.shape[-1] scale = np.ones(features) bias = np.zeros(features) weights = (scale, bias) return weights
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): del weights q, k, v = inputs if self._mode in ('train', 'eval'): mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if backend.get_name() == 'jax': mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: assert self._mode == 'predict' state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def _layer_norm_params_and_state(input_shape, input_dtype, rng): """Helper: create layer norm parameters.""" del input_dtype, rng features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) params = (scale, bias) return params, ()
def new_weights(self, input_signature): # Usually (B, W, H, C) shape = input_signature.shape num_channels = shape[-1] gamma = np.ones((num_channels, ), dtype=np.float32) beta = np.zeros((num_channels, ), dtype=np.float32) epsilon_l = base.EMPTY_WEIGHTS if self._learn_epsilon: epsilon_l = (self._init_learnt_epsilon, ) return gamma, beta, epsilon_l
def new_weights_and_state(self, input_signature): """Helper to initialize batch norm weights.""" axis = self._axis axis = (axis, ) if np.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if self._center else () gamma = np.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple( get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = np.zeros(stats_shape, dtype=np.float32) running_var = np.ones(stats_shape, dtype=np.float32) n_batches = np.zeros((), dtype=np.int64) weights = (beta, gamma) state = (running_mean, running_var, n_batches) return weights, state