def DotProductAttention(query, key, value, mask, dropout, mode, rng): """Core dot product self-attention. Args: query: array of representations key: array of representations value: array of representations mask: attention-mask, gates attention dropout: float: dropout rate mode: 'eval' or 'train': whether to use dropout rng: JAX PRNGKey: subkey for disposable use Returns: Self attention for q, k, v arrays. """ depth = np.shape(query)[-1] dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if math.backend_name() == 'jax': mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = np.where(mask, dots, np.full_like(dots, -1e9)) # Softmax. dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) out = np.matmul(dots, value) return out
def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS, state=base.EMPTY_STATE, rng=None): if self._mode in ('train', 'eval'): x = inputs symbol_size = jnp.shape(x)[1] px = weights[:, :symbol_size, :] if self._dropout == 0: return (x + px, state) else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return (x + px * multiplier, state) else: assert self._mode == 'predict' assert self._dropout == 0 # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. if inputs.shape[1] == 1: return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1) else: emb = [] for i in range(inputs.shape[0]): emb.append(jax.lax.dynamic_slice_in_dim( weights[0], state[i], inputs.shape[1], axis=0)) return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
def forward_with_state(self, x, weights, state, rng): """Pure transformer-style multi-headed attention. Args: x: inputs (q, k, v, mask) weights: parameters (none) state: parameters (none) rng: Single-use random number generator (JAX PRNG key). Returns: Pure Multi-headed attention result, and the mask. """ del weights n_heads, dropout, mode = self._n_heads, self._dropout, self._mode q, k, v, mask = x d_feature = q.shape[-1] assert d_feature % n_heads == 0 d_head = d_feature // n_heads nbatch = jnp.shape(q)[0] # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head def SplitHeads(x): return jnp.transpose( jnp.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3)) # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature def JoinHeads(x): # pylint: disable=invalid-name return jnp.reshape( jnp.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head)) # Split heads, dot-product attention, rejoin heads. res = JoinHeads( DotProductAttention( SplitHeads(q), SplitHeads(k), SplitHeads(v), mask, dropout=dropout, mode=mode, rng=rng)) return (res, mask), state # Keep the mask.
def forward(self, inputs, weights): gamma, beta, epsilon_l = weights epsilon = self._init_epsilon if epsilon_l is not base.EMPTY_WEIGHTS: epsilon += np.abs(epsilon_l[0]) # Omit B and C axis = tuple(range(1, len(np.shape(inputs)) - 1)) # (B, 1, 1, C) nu2 = np.mean(inputs**2, axis=axis, keepdims=True) # (B, W, H, C) xhat = inputs / np.sqrt(nu2 + epsilon) return gamma * xhat + beta
def forward(self, inputs): if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if math.backend_name() == 'jax': keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = math.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def forward_with_state(self, x, weights=layer_base.EMPTY_WEIGHTS, state=layer_base.EMPTY_STATE, rng=None, **kwargs): length = np.shape(x)[1] max_pos = self._base**self._n_digits assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = np.arange(0, length) if self._mode == 'train': positions += jax.random.randint(rng, (), 0, max_pos - length) pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = np.mod(cur_positions, self._base) cur_positions //= self._base pos_embeddings.append(np.take(weights[i], cur_indices, axis=0)) embeddings = np.concatenate(pos_embeddings, axis=-1) return (x + embeddings[None, :, :], state)
def NewPositionalEncoding(x, positions=None): """Implements new positional encoding.""" x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. return pos