def _do_custom_gradients(self, x, weights, state, rng): """Calls this layer for a forward pass, but with custom gradients.""" assert math.backend_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms @jax.custom_transforms def _do_forward(y, weights): old_weights, old_state, old_rng = self._weights, self._state, self._rng res = self.forward(y, weights) s = self._state self._weights, self._state, self._rng = old_weights, old_state, old_rng return res, s # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" old_weights, old_state, old_rng = self._weights, self._state, self._rng output = self.forward(y, weights) new_state = self._state self._weights, self._state, self._rng = old_weights, old_state, old_rng def vjpfun(grad): grad = grad[0] # Ignore dummy gradient wrt state. res = self.backward(y, output, grad, weights, state, new_state, rng) return res return (output, new_state), vjpfun jax.defvjp_all(_do_forward, do_forward_vjp) output, state = _do_forward(x, weights) state = jax.lax.stop_gradient(state) return output, state
def make_ito_integrate(flat_f, flat_g, ts, dt, bm, method='milstein'): # `flat_f` and `flat_g` each takes in *flat* states and parameters and returns *flat* gradients. # Make fast jitted helper functions for Milstein correction. @jax.jit def flat_g_prod(flat_y, t, args, noise): return flat_g(flat_y, t, args) * noise flat_gdg = jax.jit(make_gdg_prod(flat_g_prod)) @jax.custom_transforms def ito_integrate_(flat_y0, flat_args): return ito_integrate(flat_f, flat_g, flat_y0, ts, bm, dt, args=flat_args, g_prod=flat_g_prod, gdg=flat_gdg, method=method) # (T, batch_size * D) def vjp_all(flat_y0, flat_args): ans = ito_integrate(flat_f, flat_g, flat_y0, ts, bm, dt, args=flat_args, g_prod=flat_g_prod, gdg=flat_gdg, method=method) # (T, batch_size * D). def actual_vjp_all(cotan): T, _ = cotan.shape v_flat_y, v_flat_args = cotan[-1, :], np.zeros_like(flat_args) for i in range(T - 1, 0, -1): ts_local = np.array([ts[i - 1], ts[i]]) _, (v_flat_y, v_flat_args) = vjp_ito_integrate(v_yt=v_flat_y, v_argst=v_flat_args, yt=ans[i, :], f=flat_f, g=flat_g, ts=ts_local, bm=bm, dt=dt, args=flat_args, method=method) v_flat_y = v_flat_y + cotan[i - 1, :] return v_flat_y, v_flat_args return ans, actual_vjp_all jax.defvjp_all(ito_integrate_, vjp_all) return ito_integrate_
def test_custom_transforms_vjp_nones(self): # issue rasied by jsnoek@ and jumper@ @jax.custom_transforms def solve(a, b): return np.dot(np.linalg.inv(a), b) # print(solve(a, b)) def solve_vjp(a, b): x = solve(a, b) def vjp(x_tangent): dx = np.dot(solve(a, x_tangent), x.T) out = (dx, b * 0.) return out return x, vjp jax.defvjp_all(solve, solve_vjp) gf = grad(lambda a, b: np.sum(solve(a, b))) n = 3 a_in = np.linspace(0, 1, n)[:, None] a = np.dot(a_in, a_in.T) + np.eye(n) * 0.1 real_x = onp.random.RandomState(0).randn(n) b = np.dot(a + np.eye(a.shape[0]), real_x) print(gf(a, b)) # doesn't crash
def _custom_gradient(f): """Jax implementation of tf.custom_gradient.""" if not JAX_MODE: # Numpy backend ignores custom gradients, so we do too. return lambda *args, **kwargs: f(*args, **kwargs)[0] import jax # pylint: disable=g-import-not-at-top def f_(*args, **kwargs): value, vjp = f(*args, **kwargs) def vjp_(cts_out): cts_in = vjp(cts_out) if not isinstance(cts_in, tuple): cts_in = (cts_in, ) return cts_in return value, vjp_ @jax.custom_transforms def wrapped(*args, **kwargs): value, _ = f(*args, **kwargs) return value jax.defvjp_all(wrapped, f_) return wrapped
def _custom_gradient(f): """JAX implementation of tf.custom_gradient.""" if not JAX_MODE: # Numpy backend ignores custom gradients, so we do too. return lambda *args, **kwargs: f(*args, **kwargs)[0] def f_(*args, **kwargs): value, vjp = f(*args, **kwargs) def vjp_(cts_out): cts_in = vjp(cts_out) if isinstance(cts_in, list): cts_in = tuple(cts_in) elif not isinstance(cts_in, tuple): cts_in = (cts_in, ) return cts_in return value, vjp_ @jax.custom_transforms def wrapped(*args, **kwargs): value, _ = f(*args, **kwargs) return value jax.defvjp_all(wrapped, f_) return wrapped
def __call__(self, x, params=(), **kwargs): assert backend.get_name() == 'jax', ( 'Reversible layers are only supported in JAX') if params is () and self._params: # pylint: disable=literal-comparison # TODO(kitaev): Figure out why parameter sharing doesn't work (if this # explicit error isn't thrown, a jax tracer error occurs instead) raise NotImplementedError( 'Parameter sharing between reversible layers is not implemented.' ) @jax.custom_transforms def do_call(x, params, kwargs): return super(ReversibleLayerMixin, self).__call__(x, params, **kwargs) def do_call_vjp(x, params, kwargs): output = super(ReversibleLayerMixin, self).__call__(x, params, **kwargs) def vjpfun(ct): _, input_ct = self.inverse_and_vjp(output, ct, params, **kwargs) return input_ct return output, vjpfun jax.defvjp_all(do_call, do_call_vjp) return do_call(x, params, kwargs)
def build_odeint(ofunc, rtol=1.4e-8, atol=1.4e-8): """Return `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)`. Given the function ofunc(y, t, *args), return the jitted function `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)` with the VJP of `f` defined using `vjp_odeint`, where: `y0` is the initial condition of the ODE integration, `t` is the time course of the integration, and `*args` are all other arguments to `ofunc`. Args: ofunc: The function to be wrapped into an ODE integration. rtol: relative local error tolerance for solver. atol: absolute local error tolerance for solver. Returns: `f(y0, t, args) = odeint(ofunc(y, t, *args), y0, t, args)` """ ct_odeint = jax.custom_transforms( lambda y0, t, *args: odeint(ofunc, y0, t, *args, rtol=rtol, atol=atol)) v = lambda y0, t, *args: vjp_odeint( ofunc, y0, t, *args, rtol=rtol, atol=atol) jax.defvjp_all(ct_odeint, v) return jax.jit(ct_odeint)
def permute_via_sort(val, keys, inverse_keys, axis=0): """Permutation helper for LSH attention.""" # It is *not* safe to use jax.custom_vjp here (see permute_via_gather). keys = jax.lax.stop_gradient(keys) inverse_keys = jax.lax.stop_gradient(inverse_keys) def permute_impl(val): # On TPU, sorting scalars by key is faster than a gather. _, permuted = jax.lax.sort_key_val(keys, val, dimension=axis) return permuted def permute_vjp(val): permuted = permute_impl(jax.lax.stop_gradient(val)) def vjpfun(permuted_grad): _, val_grad = jax.lax.sort_key_val(inverse_keys, permuted_grad, dimension=axis) return (val_grad, ) return permuted, vjpfun permute = jax.custom_transforms(permute_impl) jax.defvjp_all(permute, permute_vjp) return permute(val)
def permute_via_gather(val, permutation, inverse_permutation, axis=0): """Permutation helper for LSH attention.""" # It is *not* safe to use jax.custom_vjp here. The most likely cause is that # it can't close over values: https://github.com/google/jax/issues/2676 # The error only occurs in some configurations (e.g. use_python_loop = True, # num_parallel_heads = 1) but not others. permutation = jax.lax.stop_gradient(permutation) inverse_permutation = jax.lax.stop_gradient(inverse_permutation) def permute_impl(val): return jnp.take(val, permutation, axis=axis) def permute_vjp(val): permuted = permute_impl(jax.lax.stop_gradient(val)) def vjpfun(permuted_grad): # JAX autodiff would synthesize a scatter operation because it doesn't # know that the indices are a permutatation. However on TPU, gathers are # faster than scatters (at least in the regime the LSH attention uses). return (jnp.take(permuted_grad, inverse_permutation, axis=axis), ) return permuted, vjpfun permute = jax.custom_transforms(permute_impl) jax.defvjp_all(permute, permute_vjp) return permute(val)
def _do_custom_gradients(self, x, weights, state, **kwargs): """Calls this layer for a forward pass, but with custom gradients.""" assert backend.get_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # TODO(wangpeng): JAX doesn't support custom grads for functions with # auxiliary output yet (https://github.com/google/jax/issues/844). Will # remove the constraints on state below when this feature is added to # JAX. assert not jax.tree_util.tree_leaves(state), ( 'Custom gradients require trivial start state. Got %s' % str(state)) def check_end_state(output_state): output, state = output_state assert not jax.tree_util.tree_leaves(state), ( 'Custom gradients require trivial end state. Got %s' % str(state)) return output # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms # Note that we capture the kwargs and don't calculate gradients wrt. them. @jax.custom_transforms def _do_forward(y, weights): return check_end_state( self.forward_with_state(y, weights=weights, state=state, **kwargs)) # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" stash = None if Layer._STASH_IN is None: Layer._STASH_IN = stash = {} output = check_end_state( self.forward_with_state(y, weights=weights, state=state, **kwargs)) if stash is not None: Layer._STASH_IN = None def vjpfun(grad): assert Layer._STASH_OUT is None Layer._STASH_OUT = stash res = self.backward(y, output, grad, weights, state, **kwargs) Layer._STASH_OUT = None return res return output, vjpfun jax.defvjp_all(_do_forward, do_forward_vjp) return _do_forward(x, weights), state
def __call__(self, x, params=(), state=(), **kwargs): try: # If params are nothing, we may be reusing this layer. # Use the cached parameters to calculate the value. # Note: to make sure jit tracers can decide this branch in python we # use "params is ()" instead of, e.g., "not params" or "params == ()". if params is (): # pylint: disable=literal-comparison params = self._params else: # In this case, we're called for the first time: cache parameters. self._params = params if not self.has_custom_grad: return self.call(x, params=params, state=state, **kwargs) # TODO(wangpeng): JAX doesn't support custom grads for functions with # auxiliary output yet (https://github.com/google/jax/issues/844). Will # remove the constraints on state below when this feature is added to # JAX. assert state is (), ( # pylint: disable=literal-comparison 'Custom gradients require trivial start state. Got %s' % str(state)) def check_end_state(output_state): output, state = output_state assert state is (), ( # pylint: disable=literal-comparison 'Custom gradients require trivial end state. Got %s' % str(state)) return output # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms # Note that we capture the kwargs and don't calculate gradients wrt. them. @jax.custom_transforms def do_call(y, params): return check_end_state( self.call(y, params=params, state=(), **kwargs)) # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_call_vjp(y, params): output = check_end_state( self.call(y, params=params, state=(), **kwargs)) def vjpfun(grad): return self.custom_grad(y, output, grad, params, **kwargs) return output, vjpfun jax.defvjp_all(do_call, do_call_vjp) return do_call(x, params), () except Exception: name, trace = self.__class__.__name__, _short_traceback() raise LayerError(name, 'call', self._caller, shapes(x), trace)
def permute_via_gather(val, permutation, inverse_permutation, axis=0): """Permutation helper for LSH attention.""" def permute_impl(val): return np.take(val, permutation, axis=axis) def permute_vjp(val): permuted = permute_impl(jax.lax.stop_gradient(val)) def vjpfun(permuted_grad): # JAX autodiff would synthesize a scatter operation because it doesn't # know that the indices are a permutatation. However on TPU, gathers are # faster than scatters (at least in the regime the LSH attention uses). return (np.take(permuted_grad, inverse_permutation, axis=axis),) return permuted, vjpfun permute = jax.custom_transforms(permute_impl) jax.defvjp_all(permute, permute_vjp) return permute(val)
def __call__(self, x, params=(), **kwargs): try: # If params are nothing, we may be reusing this layer. # Use the cached parameters to calculate the value. # Note: to make sure jit tracers can decide this branch in python we # use "params is ()" instead of, e.g., "not params" or "params == ()". if params is (): # pylint: disable=literal-comparison params = self._params # In this case, we're called for the first time: cache parameters. self._params = params if not self.has_custom_grad: return self.call(x, params=params, **kwargs) # Custom gradients part. assert backend.get_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms @jax.custom_transforms def do_call(y, params, kwargs): return self.call(y, params=params, **kwargs) # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all # Note that we make arguments positional to allow gradients wrt. them. def do_call_vjp(y, params, kwargs): output = self.call(y, params=params, **kwargs) def vjpfun(grad): return self.custom_grad(y, output, grad, params, **kwargs) return output, vjpfun jax.defvjp_all(do_call, do_call_vjp) return do_call(x, params, kwargs) except Exception: name, trace = self.__class__.__name__, _short_traceback() raise LayerError(name, 'call', self._caller, shapes(x), trace)
def _do_custom_gradients(self, x, weights, state, **kwargs): """Calls this layer for a forward pass, but with custom gradients.""" assert backend.get_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms # Note that we capture the kwargs and don't calculate gradients wrt. them. @jax.custom_transforms def _do_forward(y, weights): res = self.forward_with_state(y, weights=weights, state=state, **kwargs) return res # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" output, new_state = self.forward_with_state(y, weights=weights, state=state, **kwargs) def vjpfun(grad): grad = grad[0] # Ignore dummy gradient wrt state. res = self.backward(y, output, grad, weights, state, new_state, **kwargs) return res return (output, state), vjpfun jax.defvjp_all(_do_forward, do_forward_vjp) output, state = _do_forward(x, weights) state = jax.lax.stop_gradient(state) return output, state
x[(x < up_limit) & (x > low_limit)] = np.log( softplus(x[(x < up_limit) & (x > low_limit)])) #x[x<low_limit] = x[x<low_limit] return x def safe_logsoftplus_vjp(ans, x, low_limit=-30): x_shape = x.shape operator = np.ones(x.shape) operator[x > low_limit] = 1 / ( (1 + np.exp(-x[x > low_limit])) * softplus(x[x > low_limit])) return lambda g: np.full(x_shape, g) * operator #return lambda g: np.full(x_shape, g) * 1/ ((1+np.exp(-x))*safe_softplus(x)) jax.defvjp_all(safe_logsoftplus, safe_logsoftplus_vjp) def make_cov(N, rh, len_sc): M1 = np.array([np.arange(N)]) - np.transpose(np.array([np.arange(N)])) K = rh * np.exp(-(np.square(M1) / (2 * np.square(len_sc)))) return K def bbvi(logprob, N, num_samples): """Implements http://arxiv.org/abs/1401.0118, and uses the local reparameterization trick from http://arxiv.org/abs/1506.02557 Structure of function taken from: https://github.com/HIPS/autograd/blob/master/examples/black_box_svi.py inputs:
def forward_unbatched(self, x, *, weights, state, update_state): w_q, w_v, w_o = weights q = np.matmul(x, w_q) v = np.matmul(x, w_v) if update_state: _, old_rng = state rng = jax.random.fold_in(old_rng, 0) hash_rng = jax.random.fold_in(rng, 1) buckets = self.hash_vectors(q, hash_rng) state = (buckets, rng) else: buckets, rng = state rng = jax.random.fold_in(rng, 2) seqlen = x.shape[0] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sq = np.take(q, st, axis=0) sv = np.take(v, st, axis=0) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=True) q_info = st so, slogits = attend( sq, k=None, v=sv, q_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, dropout=self.attention_dropout, rng=rng, ) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes > 1: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = np.sum(o * probs, axis=0) assert o.shape == (seqlen, w_v.shape[-1]) out = np.matmul(o, w_o) return out, state
def single_call(self, qk, v, buckets, rng=None): # We use the same vector as both a query and a key. seqlen = qk.shape[-2] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = np.take(qk, st, axis=0) sv = np.take(v, st, axis=0) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = np.reshape( sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.convert_element_type( jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = jax.lax.convert_element_type( jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), np.float32) dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. # TODO(kitaev): is one strategy faster or more numerically stable? if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 locs = np.moveaxis( np.concatenate([ np.reshape(locs1, (self.n_hashes, seqlen)), np.reshape(locs2, (self.n_hashes, seqlen)), ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) slocs = np.take(locs, st, axis=0) b_locs = np.reshape( slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) # Queries always use the primary location (based on locs1). b_locs1 = b_locs[:, :, None, :self.n_hashes] if self._hard_k > 0: range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use nouse_locs = np.reshape( np.broadcast_to( nouse_locs[:, None, :], (self.n_hashes, self.n_bins, self.n_hashes)), (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) b_locs1 = b_locs1 * nouse_locs bq_locs = np.broadcast_to(b_locs1, b_locs.shape[:2] + (2, self.n_hashes)) bq_locs = np.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = np.sum(jax.lax.convert_element_type( jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), np.float32), axis=-1) assert dup_counts.shape == dots.shape if self._hard_k > 0: dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) else: dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) # Each query only attends to the top k most relevant keys. if self._hard_k > 0: b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. b_top_dots = jax.lax.stop_gradient(b_top_dots) s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) top_dots = np.take(s_top_dots, undo_sort, axis=0) merged_top_dots = np.moveaxis( np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] # It's possible to compute the partition function at this point, but right # now this codepath isn't set up for backprop, and there might also be # issues computing it this way if two dot-products are exactly equal. sdots_thresh = dots_thresh[st] bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) bdots_thresh = jax.lax.stop_gradient(bdots_thresh) top_k_mask = jax.lax.convert_element_type( dots < bdots_thresh[..., None], np.float32) dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._dropout > 0.0: # Dropout is broadcast across the bin dimension dropout_shape = (1, dots.shape[-2], dots.shape[-1]) keep_prob = jax.lax.tie_in(dots, 1.0 - self._dropout) keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in( keep, keep_prob) dots = dots * multiplier bo = np.matmul(dots, bv) so = np.reshape(bo, (-1, bo.shape[-1])) slogits = np.reshape(dots_logsumexp, (-1, )) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes == 1: out = o else: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == v.shape return out
def _custom_grad(f_vjp, f_original): f_ = jax.custom_transforms(f_original) jax.defvjp_all(f_, f_vjp) return f_
def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params, kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, _, v = inputs seqlen = qk.shape[-2] # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops? qk = np.tile(qk, (self.n_hashes, 1, 1)) v = np.tile(v, (self.n_hashes, 1, 1)) # bins are n_hashes*n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_hashes*n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int( (self.n_buckets_per_bin * self.n_bins + 1) * seqlen ) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) # TODO(kitaev): why does jax flag integer indices as differentiable? # If we don't call stop_gradient here, custom gradients below won't work # because the primitive functions close over "differentiable" variables. sjoint_t = jax.lax.stop_gradient(sjoint_t) undo_sort = jax.lax.stop_gradient(undo_sort) # The backward pass of gather is in general a scatter operation, but we know # we're dealing with permutations so we use gather for the backward pass # too. This custom gradient should be about 2x faster than having jax infer # one that uses scatter ops instead. def permute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2) def unpermute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2) @jax.custom_transforms def permute(vecs): return permute_impl(vecs) def permute_vjp(vecs): out_vecs = permute_impl(vecs) def vjpfun(grad): return (unpermute_impl(grad), ) return out_vecs, vjpfun @jax.custom_transforms def unpermute(vecs): return unpermute_impl(vecs) def unpermute_vjp(vecs): out_vecs = unpermute_impl(vecs) def vjpfun(grad): return (permute_impl(grad), ) return out_vecs, vjpfun jax.defvjp_all(permute, permute_vjp) jax.defvjp_all(unpermute, unpermute_vjp) sqk = permute(qk) sv = permute(v) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Sum to re-normalize. dots_logsumexp += np.log(dots_sum) # Add it to the weight. dots /= dots_sum # Re-normalize. bo = np.matmul(dots, bv) so = unchunk_vectors(bo) slogits = unchunk_vectors(dots_logsumexp) o = unpermute(so) logits = unpermute(slogits) o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == inputs[2].shape return out, state
@jax.custom_transforms def _minimum_(x, y, name=None): # pylint: disable=unused-argument return np.minimum(x, y) # TF and Jax have differing behavior # when the inputs to maximum/minimum are equal. # This custom transforms rule # modifies Jax to match TF's behavior. def _maximum_vjp(x, y): out_primals = _maximum_(x, y) def vjp(g): gx = g * np.where(x >= y, np.ones_like(x), np.zeros_like(x)) return (gx.astype(x.dtype), (g - gx).astype(y.dtype)) return out_primals, vjp jax.defvjp_all(_maximum_, _maximum_vjp) def _minimum_vjp(x, y): out_primals = _minimum_(x, y) def vjp(g): gx = g * np.where(x <= y, np.ones_like(x), np.zeros_like(x)) return (gx.astype(x.dtype), (g - gx).astype(y.dtype)) return out_primals, vjp jax.defvjp_all(_minimum_, _minimum_vjp) # Need to wrap in a function because custom_transforms # returns an object, not a function # which breaks docstring wrapping def _promote_dtypes(x, y): # Need to explicitly promote types