def _layer_norm_params(input_shape, input_dtype, rng): """Helper: create layer norm parameters.""" del input_dtype, rng features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return (scale, bias)
def LayerNormParams(input_shape, input_dtype, rng, epsilon=1e-6): """Helper: create layer norm parameters.""" del input_dtype, rng, epsilon features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return (scale, bias)
def _layer_norm_new_params(input_shape, rng, epsilon=1e-6): # pylint: disable=invalid-name """Helper: create layer norm parameters.""" del rng, epsilon features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return (scale, bias)
def EncoderDecoderMask(x, **unused_kwargs): """Make encoder-decoder mask from a padding mask and decoder input.""" (padding_mask, decoder_input) = x padding_mask = np.reshape( padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1])) # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len]. return padding_mask + np.ones((1, 1, decoder_input.shape[1], 1))
def call(self, inputs, params=(), rng=None, **kwargs): del params q, k, v = inputs mask_size = q.shape[-2] mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) res = tl.DotProductAttention( q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res
def _batch_norm_new_params(input_shape, rng, axis=(0, 1, 2), center=True, scale=True, **kwargs): """Helper to initialize batch norm params.""" del rng, kwargs axis = (axis,) if np.isscalar(axis) else axis shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if center else () gamma = np.ones(shape, dtype='float32') if scale else () return (beta, gamma)
def new_parameters(self, input_shape, input_dtype, rng): """Helper to initialize batch norm params.""" del input_dtype, rng axis = self._axis axis = (axis, ) if np.isscalar(axis) else axis shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if self._center else () gamma = np.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple( get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = np.zeros(stats_shape, dtype=np.float32) running_var = np.ones(stats_shape, dtype=np.float32) num_batches = np.zeros((), dtype=np.int32) return (beta, gamma), (running_mean, running_var, num_batches)
def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params q, k, v = inputs mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient in # that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if backend.get_name() == 'jax': mask = np.tril(np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril(onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) res = DotProductAttention( q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def ChunkedAttentionSelector(x, params, selector=None, **kwargs): """Select which chunks to attend to in chunked attention. Args: x: inputs, a list of elements of the form (q, k, v), mask for each chunk. params: parameters (unused). selector: a function from chunk_number -> list of chunk numbers that says which other chunks should be appended to the given one (previous if None). **kwargs: unused other arguments. Returns: a list of elements of the form (q, k', v', mask') where k', v' and mask' are concatenations of k, v and identity-extended masks from selected chunks. """ del params, kwargs selector = selector or (lambda x: [] if x < 1 else [x - 1]) triples, masks = zip(*x) (queries, keys, values) = zip(*triples) result = [] for i in range(len(x)): selected = selector(i) # Since keys and values are [batch, length, depth] we concatenate on axis=1. # We also always include the current key or value at the end. new_key_list = [keys[j] for j in selected] new_key = np.concatenate(new_key_list + [keys[i]], axis=1) new_value = np.concatenate([values[j] for j in selected] + [values[i]], axis=1) # Masks are (1, query-len, key-len) so we concatenate on axis=2. new_mask_shapes = [(1, queries[i].shape[1], key.shape[1]) for key in new_key_list] cur_mask = masks[i] # Masks are all-1 for the added chunks (no masking). new_mask_list = [ np.ones(s, dtype=cur_mask.dtype) for s in new_mask_shapes ] # We still use the current (often causal) mask for the final chunk. new_mask = np.concatenate(new_mask_list + [cur_mask], axis=2) result.append((queries[i], new_key, new_value, new_mask)) return tuple(result)
def forward(self, inputs, params=(), state=(), rng=None, **kwargs): del params q, k, v = inputs if self._mode in ('train', 'eval'): mask_size = q.shape[-2] # Not all backends define np.tril. However, using onp.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if backend.get_name() == 'jax': mask = np.tril( np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: mask = onp.tril( onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) else: assert self._mode == 'predict' state = _fast_inference_update_state(inputs, state) (k, v, mask, _) = state res = DotProductAttention( q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) return res, state
def rescale(outputs, inputs): one = np.ones(inputs.shape[1:-1], dtype=inputs.dtype) window_sizes = lax.reduce_window(one, 0., lax.add, dims, spatial_strides, padding) return outputs / window_sizes[..., np.newaxis]
return init def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)): """An initializer function for random Glorot-scaled coefficients.""" def init(rng, shape): fan_in, fan_out = shape[in_dim], shape[out_dim] size = onp.prod(onp.delete(shape, [in_dim, out_dim])) std = scale / np.sqrt((fan_in + fan_out) / 2. * size) return (std * backend.random.normal(rng, shape)).astype('float32') return init zeros = lambda rng, shape: np.zeros(shape, dtype='float32') ones = lambda rng, shape: np.ones(shape, dtype='float32') # Layers # Each layer constructor function returns an (init_fun, apply_fun) pair, where # init_fun: takes an input shape and returns an (output_shape, params) pair, # apply_fun: takes params, inputs, and an rng key and applies the layer. def Dense(out_dim, W_init=glorot(), b_init=randn()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) w, b = W_init(rng, (input_shape[-1], out_dim)), b_init(rng, (out_dim, )) return output_shape, (w, b)
def init_fun(_, input_shape): features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return input_shape, (scale, bias)