def init(self, x): shape = x.shape state = [] if self._factored and len(shape) >= 2: v_row = np.zeros(shape[:-1], dtype=np.float32) v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32) state.extend([v_row, v_col]) else: v = np.zeros_like(x) state.append(v) if self._beta1: m = np.zeros_like(x) state.append(m) return state
def init(self, params): shape = params.shape slots = [] if self._factored and len(shape) >= 2: v_row = np.zeros(shape[:-1], dtype=np.float32) v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32) slots.extend([v_row, v_col]) else: v = np.zeros_like(params) slots.append(v) if self._do_momentum: m = np.zeros_like(params) slots.append(m) return slots
def LayerNormParams(input_shape, input_dtype, rng, epsilon=1e-6): """Helper: create layer norm parameters.""" del input_dtype, rng, epsilon features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return (scale, bias)
def EncoderDecoderMask(x, **unused_kwargs): """Makes encoder-decoder mask from decoder input and a padding mask.""" decoder_input, padding_mask = x padding_mask = np.reshape( padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1])) # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len]. return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
def test_computes_mean_with_weights(self, backend_name): with backend.use_backend(backend_name): inputs = [np.array([1, 2, 3])] targets = [np.zeros(3)] weights = [np.array([3, 1, 0])] mean = trax.masked_mean(inputs, targets, weights) onp.testing.assert_allclose(mean, 1.25)
def _layer_norm_params(input_shape, input_dtype, rng): """Helper: create layer norm parameters.""" del input_dtype, rng features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return (scale, bias)
def _layer_norm_new_params(input_shape, rng, epsilon=1e-6): # pylint: disable=invalid-name """Helper: create layer norm parameters.""" del rng, epsilon features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return (scale, bias)
def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. res = np.concatenate([x, pos], axis=2) return res
def _batch_norm_new_params(input_shape, rng, axis=(0, 1, 2), center=True, scale=True, **kwargs): """Helper to initialize batch norm params.""" del rng, kwargs axis = (axis,) if np.isscalar(axis) else axis shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if center else () gamma = np.ones(shape, dtype='float32') if scale else () return (beta, gamma)
def new_parameters(self, input_shape, input_dtype, rng): """Helper to initialize batch norm params.""" del input_dtype, rng axis = self._axis axis = (axis, ) if np.isscalar(axis) else axis shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) beta = np.zeros(shape, dtype='float32') if self._center else () gamma = np.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple( get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = np.zeros(stats_shape, dtype=np.float32) running_var = np.ones(stats_shape, dtype=np.float32) num_batches = np.zeros((), dtype=np.int32) return (beta, gamma), (running_mean, running_var, num_batches)
def _fast_inference_init_state(input_shapes, input_dtypes, buffer_length): """Initializes state of a causal attention layer for fast inference.""" ((batch_size, _, _), _, _) = input_shapes def init_buffer(shape, dtype): (_, _, depth) = shape return np.zeros((batch_size, buffer_length, depth), dtype=dtype) (_, k, v) = tuple( init_buffer(shape, dtype) for (shape, dtype) in zip(input_shapes, input_dtypes) ) mask = np.zeros((batch_size, 1, buffer_length)) index = 0 state = (k, v, mask, index) return state
def init(self, x): vs = [np.zeros(sz, dtype=x.dtype) for sz in x.shape] return (np.zeros_like(x), vs)
def forward_and_vjp(self, inputs, ct, params=(), **kwargs): # This is the core of the memory-efficient attention implementation, where # we use the jax.lax.while_loop primitive to compute attention for a small # set of query positions at a time. Note how in the backwards pass, we # compute both the forward direction (to recover the previous layer's # activations) and the backward direction simultaneously. This allows us to # only use a single loop, where the inner portion of the loop does a slice # of the forward+backward joint computation. Unfortunately we have had to # introduce a large number of wrapper classes (including # ReversibleAttentionHalfResidual and ApplyAttentionWrapper) for the sole # purpose of connecting this implementation of forward_and_vjp with the core # backprop implementation. query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None def make_mask(N, M, k): x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): output_slice, vjpfun = jax.vjp(forward_slice, query_slice, q_loop_idx, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[2] q_loop_stride = self._loop_stride assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=2) # ignore partial_ct[1], which is wrt the loop idx key_ct_accum = key_ct_accum + partial_ct[2] value_ct_accum = value_ct_accum + partial_ct[3] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def init(self, params): vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape] return (np.zeros_like(params), vs)
def init_buffer(shape, dtype): (_, _, depth) = shape return np.zeros((batch_size, buffer_length, depth), dtype=dtype)
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None # jax uses the term cotangent (ct) to refer to gradient signals, and # vector-Jacobian product (vjp) for back-propagation through a layer. def make_mask(N, M, k): # pylint: disable=invalid-name """Constructs a slice of the causal attention mask. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name # Capture q_loop_idx to avoid calculated gradients wrt. it. def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name return forward_slice(query_slice, q_loop_idx, key, value) output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx, query_slice, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[-2] q_loop_stride = self._loop_stride assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): # pylint: disable=invalid-name q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): # pylint: disable=invalid-name """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=-2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=-2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=-2) key_ct_accum = key_ct_accum + partial_ct[1] value_ct_accum = value_ct_accum + partial_ct[2] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=-2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]
def init_fun(_, input_shape): features = input_shape[-1] scale = np.ones(features) bias = np.zeros(features) return input_shape, (scale, bias)
def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True, rng=None): assert return_output or ct is not None, 'No work to perform!' # pylint: disable=protected-access stash_buckets = (return_output and ct is None and base.Layer._STASH_IN is not None) if return_output and ct is not None and base.Layer._STASH_OUT is not None: buckets = base.Layer._STASH_OUT.pop(self) else: buckets = None # pylint: enable=protected-access # The approach here is to perform attention for one batch element and head # at a time. Note that there is absolutely no interaction across examples or # heads: this layer has no parameters, and hashing patterns are also # different across examples/heads. As a result, batching doesn't give any # performance gains except in the case of accelerator under-utilization. We # assume that hash-based attention will be applied primarily to long # sequences, where unbatched attention for a single head has sufficient # computation to fill up the accelerator. batch_loop_idx = np.zeros((), dtype=np.int32) batch_loop_max = qk.shape[0] init_vals = (batch_loop_idx,) if return_output: out_accum = np.zeros_like(qk) init_vals = init_vals + (out_accum,) if stash_buckets: buckets_accum = np.zeros( [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32) init_vals = init_vals + (buckets_accum,) if ct is not None: qk_ct_accum = np.zeros_like(qk) v_ct_accum = np.zeros_like(v) init_vals = init_vals + (qk_ct_accum, v_ct_accum) def cond_fun(vals): batch_loop_idx = vals[0] return jax.lax.lt(batch_loop_idx, batch_loop_max) def body_fun(vals): """Performs attention for a single batch element and head.""" batch_loop_idx = vals[0] if self._prng is None: hash_rng = jax.random.fold_in(rng, batch_loop_idx) else: # TODO(kitaev): Maybe use the same RNG across examples (but not heads)? hash_rng = jax.random.fold_in(self._prng, batch_loop_idx) qk_slice = jax.lax.dynamic_index_in_dim( qk, batch_loop_idx, axis=0, keepdims=False) v_slice = jax.lax.dynamic_index_in_dim( v, batch_loop_idx, axis=0, keepdims=False) if buckets is None: buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng) else: buckets_slice = jax.lax.dynamic_index_in_dim( buckets, batch_loop_idx, axis=0, keepdims=False) if ct is None: out_slice = self.single_call( qk_slice, v_slice, buckets_slice, hash_rng=hash_rng) else: def _do_single_call(qk_slice, v_slice): return self.single_call( qk_slice, v_slice, buckets_slice, hash_rng=hash_rng) ct_slice = jax.lax.dynamic_index_in_dim( ct, batch_loop_idx, axis=0, keepdims=False) out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice) qk_ct_slice, v_ct_slice = vjpfun(ct_slice) new_vals = (batch_loop_idx + 1,) if return_output: out_accum = vals[1] out_accum = jax.lax.dynamic_update_index_in_dim( out_accum, out_slice, batch_loop_idx, axis=0) new_vals = new_vals + (out_accum,) if stash_buckets: buckets_accum = vals[2] buckets_accum = jax.lax.dynamic_update_index_in_dim( buckets_accum, buckets_slice, batch_loop_idx, axis=0) new_vals = new_vals + (buckets_accum,) if ct is not None: qk_ct_accum, v_ct_accum = vals[-2:] qk_ct_accum = jax.lax.dynamic_update_index_in_dim( qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0) v_ct_accum = jax.lax.dynamic_update_index_in_dim( v_ct_accum, v_ct_slice, batch_loop_idx, axis=0) new_vals = new_vals + (qk_ct_accum, v_ct_accum) return new_vals final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if return_output: out = final_vals[1] else: out = None if stash_buckets: base.Layer._STASH_IN[self] = final_vals[2] # pylint: disable=protected-access if ct is not None: input_ct = final_vals[-2:] else: input_ct = None return out, input_ct
return init def glorot(out_dim=0, in_dim=1, scale=onp.sqrt(2)): """An initializer function for random Glorot-scaled coefficients.""" def init(rng, shape): fan_in, fan_out = shape[in_dim], shape[out_dim] size = onp.prod(onp.delete(shape, [in_dim, out_dim])) std = scale / np.sqrt((fan_in + fan_out) / 2. * size) return (std * backend.random.normal(rng, shape)).astype('float32') return init zeros = lambda rng, shape: np.zeros(shape, dtype='float32') ones = lambda rng, shape: np.ones(shape, dtype='float32') # Layers # Each layer constructor function returns an (init_fun, apply_fun) pair, where # init_fun: takes an input shape and returns an (output_shape, params) pair, # apply_fun: takes params, inputs, and an rng key and applies the layer. def Dense(out_dim, W_init=glorot(), b_init=randn()): """Layer constructor function for a dense (fully-connected) layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) w, b = W_init(rng, (input_shape[-1], out_dim)), b_init(rng, (out_dim, ))
def dummy_loss_fn(params): inputs = (np.zeros(input_shape[0], dtype=np.int32), ) * 2 output = model(inputs, params=params, state=state, rng=rng) dummy_loss = backend.numpy.sum(output[0]) return dummy_loss