def _layer_norm_weights(input_signature, **unused_kwargs): """Helper: create layer norm parameters.""" features = input_signature.shape[-1] scale = np.ones(features, dtype=input_signature.dtype) bias = np.zeros(features, dtype=input_signature.dtype) weights = (scale, bias) return weights
def _layer_norm_weights(input_signature): """Helper: create layer norm parameters.""" features = input_signature.shape[-1] scale = np.ones(features) bias = np.zeros(features) weights = (scale, bias) return weights
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): del weights q, k, v = inputs if self._mode in ('train', 'eval'): mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if math.backend_name() == 'jax': mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: assert self._mode == 'predict' state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def forward_with_state(self, inputs, weights, state, rng): del weights q, k, v = inputs if self._mode == 'predict': state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if math.backend_name() == 'jax': mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def new_weights_and_state(self, input_signature): """Helper to initialize batch norm weights.""" axis = self._axis axis = (axis,) if np.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if self._center else () gamma = np.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = np.zeros(stats_shape, dtype=np.float32) running_var = np.ones(stats_shape, dtype=np.float32) n_batches = np.zeros((), dtype=np.int64) weights = (beta, gamma) state = (running_mean, running_var, n_batches) return weights, state
def new_weights(self, input_signature): # Usually (B, W, H, C) shape = input_signature.shape num_channels = shape[-1] gamma = np.ones((num_channels,), dtype=np.float32) beta = np.zeros((num_channels,), dtype=np.float32) epsilon_l = base.EMPTY_WEIGHTS if self._learn_epsilon: epsilon_l = (self._init_learnt_epsilon,) return gamma, beta, epsilon_l
def init_weights_and_state(self, input_signature): """Helper to initialize batch norm weights and state.""" axis = self._axis axis = (axis, ) if jnp.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) # TODO(jonni): Should beta and gamma match the dtype in the input signature? beta = jnp.zeros(shape, dtype='float32') if self._center else () gamma = jnp.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple( get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) running_var = jnp.ones(stats_shape, dtype=jnp.float32) n_batches = jnp.zeros((), dtype=jnp.int64) self.weights = (beta, gamma) self.state = (running_mean, running_var, n_batches)
def new_weights(self, input_signature): features = input_signature.shape[-1] scale = jnp.ones(features, dtype=input_signature.dtype) bias = jnp.zeros(features, dtype=input_signature.dtype) return scale, bias
def init_weights_and_state(self, input_signature): self.weights = jnp.zeros((2, 3)) self.state = jnp.ones(input_signature.shape)
def init_weights_and_state(self, input_signature): features = input_signature.shape[-1] scale = jnp.ones(features, dtype=input_signature.dtype) bias = jnp.zeros(features, dtype=input_signature.dtype) self.weights = scale, bias