def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. return pos
def Initializer(shape, rng): del rng logging.info('Loading pretrained embeddings from %s', path) with tf.io.gfile.GFile(path, 'rb') as f: parameters = np.load(f) assert np.shape(parameters) == shape, ('Expected shape %s, got %s' % (shape, np.shape(parameters))) return parameters
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None, **kwargs): if self._mode in ('train', 'eval'): x = inputs symbol_size = np.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if backend.get_name() == 'jax': keep_prob = jax.lax.tie_in( x, np.full((), keep_prob, dtype=x.dtype)) keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. return (inputs + np.expand_dims(weights[:, state, :], 1), state + 1)
def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if backend.get_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def forward(self, inputs, weights): gamma, beta, epsilon_l = weights epsilon = self._init_epsilon if epsilon_l is not base.EMPTY_WEIGHTS: epsilon += np.abs(epsilon_l[0]) # Omit B and C axis = tuple(range(1, len(np.shape(inputs)) - 1)) # (B, 1, 1, C) nu2 = np.mean(inputs**2, axis=axis, keepdims=True) # (B, W, H, C) xhat = inputs / np.sqrt(nu2 + epsilon) return gamma * xhat + beta
def forward_with_state(self, x, weights, state, rng): """Pure transformer-style multi-headed attention. Args: x: inputs (q, k, v, mask) weights: parameters (none) state: parameters (none) rng: random number generator Returns: Pure Multi-headed attention result, and the mask. """ del weights n_heads, dropout, mode = self._n_heads, self._dropout, self._mode q, k, v, mask = x d_feature = q.shape[-1] assert d_feature % n_heads == 0 d_head = d_feature // n_heads nbatch = np.shape(q)[0] # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head def SplitHeads(x): return np.transpose(np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3)) # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature def JoinHeads(x): # pylint: disable=invalid-name return np.reshape(np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head)) # Split heads, dot-product attention, rejoin heads. res = JoinHeads( DotProductAttention(SplitHeads(q), SplitHeads(k), SplitHeads(v), mask, dropout=dropout, mode=mode, rng=rng)) return (res, mask), state # Keep the mask.
def forward_and_backward(self, inputs, ct, state=base.EMPTY_STATE, new_state=base.EMPTY_STATE, rng=None, **kwargs): del state, new_state, kwargs query, key, value = inputs depth = np.shape(query)[-1] do_backprop = ct is not None # jax uses the term cotangent (ct) to refer to gradient signals, and # vector-Jacobian product (vjp) for back-propagation through a layer. def make_mask(N, M, k): # pylint: disable=invalid-name """Constructs a slice of the causal attention mask. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def make_self_mask(N, M, k): # pylint: disable=invalid-name """Masks out elements attending to self. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) mask = jax.lax.eq((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name """Forward pass for a subset of the query vectors.""" if self._share_qk: key = self.make_unit_length(key) dots = np.matmul(query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) # Causal masking mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. if self._share_qk: self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) dots = dots - 1e5 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) if self.dropout is not None and self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) slice_rng = jax.random.fold_in(rng, q_loop_idx) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize. dots /= dots_sum # Re-normalize. out_slice = np.matmul(dots, value) return out_slice def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name # Capture q_loop_idx to avoid calculated gradients wrt. it. def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name return forward_slice(query_slice, q_loop_idx, key, value) output_slice, vjpfun = jax.vjp(forward_slice_with_q_loop_idx, query_slice, key, value) return output_slice, vjpfun(ct_slice) q_loop_idx = np.zeros((), dtype=np.int32) q_loop_max = query.shape[-2] q_loop_stride = self._loop_stride if q_loop_max == 1: # For abstract runs with unknown shapes. q_loop_stride = 1 assert q_loop_max % q_loop_stride == 0, ( 'Stride must evenly divide the number of query elements.') out_accum = np.zeros_like(query) if do_backprop: query_ct_accum = np.zeros_like(query) key_ct_accum = np.zeros_like(key) value_ct_accum = np.zeros_like(value) init_vals = (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: init_vals = (q_loop_idx, out_accum) def cond_fun(vals): # pylint: disable=invalid-name q_loop_idx = vals[0] return jax.lax.lt(q_loop_idx, q_loop_max) def body_fun(vals): # pylint: disable=invalid-name """Compute a slice of the attention mechanism.""" if do_backprop: (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) = vals else: q_loop_idx, out_accum = vals query_slice = jax.lax.dynamic_slice_in_dim(query, q_loop_idx, q_loop_stride, axis=-2) if do_backprop: ct_slice = jax.lax.dynamic_slice_in_dim(ct, q_loop_idx, q_loop_stride, axis=-2) out_slice, partial_ct = forward_and_vjp_slice( query_slice, q_loop_idx, key, value, ct_slice) query_ct_accum = jax.lax.dynamic_update_slice_in_dim( query_ct_accum, partial_ct[0], q_loop_idx, axis=-2) key_ct_accum = key_ct_accum + partial_ct[1] value_ct_accum = value_ct_accum + partial_ct[2] else: out_slice = forward_slice(query_slice, q_loop_idx, key, value) out_accum = jax.lax.dynamic_update_slice_in_dim(out_accum, out_slice, q_loop_idx, axis=-2) q_loop_idx = q_loop_idx + q_loop_stride if do_backprop: return (q_loop_idx, out_accum, query_ct_accum, key_ct_accum, value_ct_accum) else: return (q_loop_idx, out_accum) final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) if not do_backprop: return final_vals[1], None else: return final_vals[1], final_vals[2:]