def _ibp_dot_general(lhs: PrimitiveInput, rhs: PrimitiveInput, **kwargs) -> PrimitiveInput: """Propagation of IBP bounds through a general dot product. We don't know if the bound is on the left or right hand side, but we expect that one hand is a bound and the other is a constant/parameter. Args: lhs: First input to the dot primitive. rhs: Second input to the dot primitive. **kwargs: Dict with the parameters of the general dot product. Returns: out_bounds: IntervalBound on the output of the dot product. """ if (isinstance(lhs, bound_propagation.Bound) != isinstance( rhs, bound_propagation.Bound)): lhses = _decompose_affine_argument(lhs) rhses = _decompose_affine_argument(rhs) forward_mean = lax.dot_general(lhses[1], rhses[1], **kwargs) forward_range = lax.dot_general(lhses[0], rhses[0], **kwargs) out_lb = forward_mean - forward_range out_ub = forward_mean + forward_range return IntervalBound(out_lb, out_ub) elif ((not isinstance(lhs, bound_propagation.Bound)) and (not isinstance(rhs, bound_propagation.Bound))): # Both are arrays, so can simply go through return lax.dot_general(lhs, rhs, **kwargs) else: raise ValueError('BoundPropagation through general dot product ' 'is not supported when both inputs are bounds.')
def testDotGeneral(self): R = onp.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(2, 1))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) expected = onp.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(3, None))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) expected = onp.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(None, 2))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True)
def __call__(self, x: Array) -> Array: """Applies the symmetrized linear transformation to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) # infer in_features and ensure input dimensions (batch, in_features,n_sites) # TODO: Deprecated: Eventually remove and error if less than 3 dimensions if x.ndim < 3: old_shape = x.shape if x.ndim == 1: x = jnp.expand_dims(x, (0, 1)) elif x.ndim == 2: x = jnp.expand_dims(x, 1) symm_input_warning(old_shape, x.shape, "DenseSymm") in_features = x.shape[1] kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_sites), self.dtype, ) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Converts the convolutional kernel of shape (self.features, in_features, n_sites) # to a full dense kernel of shape (self.features, in_features, n_symm, n_sites). # result[out, in, g, r] == kernel[out, in, g^{-1}r] kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2) kernel = jnp.asarray(kernel, dtype) # x is (batches, in_featuers, n_sites) # kernel is (self.features, in_features, n_symm, n_sites) x = lax.dot_general( x, kernel, (((x.ndim - 2, x.ndim - 1), (1, 3)), ((), ())), precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) # Convert symmetry-reduced bias of shape (features,) to the full bias of # shape (..., features, 1). bias = jnp.expand_dims(bias, 1) bias = jnp.asarray(bias, dtype) x += bias return x
def apply(self, inputs, features, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros): """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. features: the number of output features. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. Returns: The transformed input. """ inputs = jnp.asarray(inputs, dtype) kernel = self.param('kernel', (inputs.shape[-1], features), kernel_init) kernel = jnp.asarray(kernel, dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1, ), (0, )), ((), ())), precision=precision) if bias: bias = self.param('bias', (features, ), bias_init) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs, kernel): inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ()))) bias = jnp.asarray(self.bias, self.dtype) y = y + bias return y
def test_vmap_after(self): batch = 4 qy_size = 128 db_size = 1024 feature_dim = 32 k = 10 rng = jtu.rand_default(self.rng()) qy = rng([qy_size, feature_dim, batch], np.float32) db = rng([db_size, feature_dim, batch], np.float32) recall = 0.95 # Create ground truth gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2]))) _, gt_args = lax.top_k(gt_scores, k) gt_args = lax.transpose(gt_args, [2, 0, 1]) gt_args = lax.reshape(gt_args, [qy_size * batch, k]) # test target def approx_max_k(qy, db): scores = qy @ db.transpose() return lax.approx_max_k(scores, k) _, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db) ann_args = lax.transpose(ann_args, [2, 0, 1]) ann_args = lax.reshape(ann_args, [qy_size * batch, k]) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall)
def dot_general_dependency_rule(outstart, outcount, lhs, rhs, dimension_numbers, precision): if not is_ones(outcount): raise NotImplementedError outshape = outcount.shape outslices = list(zip(outstart, outshape)) (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_other_out_dims = list( range(len(lhs_batch), len(lhs.shape) - len(lhs_contracting))) rhs_other_out_dims = list( range(len(rhs_batch) + len(lhs_other_out_dims), len(outshape))) lhs_outstart, lhs_outshape = unzip2( [outslices[d] for d in list(lhs_batch) + lhs_other_out_dims]) (lhs_box, ), (lhs_count, ), _ = reduce_dependency_rule(None)( lhs_outstart, Ones(lhs_outshape), lhs, axes=lhs_contracting) rhs_outstart, rhs_outshape = unzip2( [outslices[d] for d in list(rhs_batch) + rhs_other_out_dims]) (rhs_box, ), (rhs_count, ), _ = reduce_dependency_rule(None)( rhs_outstart, Ones(rhs_outshape), rhs, axes=rhs_contracting) incounts = [ materialize(lhs_count) * prod(np.take(outshape, rhs_other_out_dims)) if isinstance(lhs, LazyArray) else None, materialize(rhs_count) * prod(np.take(outshape, lhs_other_out_dims)) if isinstance(rhs, LazyArray) else None ] return ([lhs_box, rhs_box], incounts, lambda *inslices: lax.dot_general( *inslices, dimension_numbers, precision))
def generalized_kernel_feature_creator(data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data): """Constructs kernel features for fast generalized attention. Args: data: input for which features are computes projection_matrix: matrix used to compute features batch_dims_t: tuple of batch dimensions precision: precision parameter kernel_fn: kernel function used kernel_epsilon: additive positive term added to every feature for numerical stability normalize_data: predicate indicating whether data should be normalized. Returns: Random features for fast generalized attention. """ if normalize_data: data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) else: data_normalizer = 1.0 if projection_matrix is None: return kernel_fn(data_normalizer * data) + kernel_epsilon else: data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix data_dash = lax.dot_general( data_normalizer * data, data_thick_random_matrix, (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), precision=precision) data_prime = kernel_fn(data_dash) + kernel_epsilon return data_prime
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources, mesh_data): rng = jtu.rand_default(self.rng()) lhs = rng(lhs_shape, np.float32) rhs = rng(rhs_shape, np.float32) expected_out, ref_vjp = jax.vjp( lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums), lhs, rhs) out_bar = rng(expected_out.shape, np.float32) expected_lhs, expected_rhs = ref_vjp(out_bar) def pdot_fun(x, y, out_bar): pdot = partial(jax.lax.pdot, axis_name=pdot_spec.contract_names, pos_batch=pdot_spec.pos_batch_after_mapping, pos_contract=pdot_spec.pos_contract_after_mapping) _, pdot_vjp = jax.vjp(pdot, x, y) return pdot_vjp(out_bar) fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes, [*pdot_spec.batch_names, ...]], out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes), axis_resources=axis_resources) with with_mesh(mesh_data): lhs_bar, rhs_bar = fun(lhs, rhs, out_bar) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False, atol=tol, rtol=tol) self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False, atol=tol, rtol=tol)
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources, mesh_data): rng = jtu.rand_default(self.rng()) lhs = rng(lhs_shape, np.float32) rhs = rng(rhs_shape, np.float32) def pdot_fun(x, y): # print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n' # f' axis_name={contract_names},\n' # f' pos_contract={spec.pos_contract_after_mapping}\n' # f' pos_batch={spec.pos_batch_after_mapping})') return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names, pos_batch=pdot_spec.pos_batch_after_mapping, pos_contract=pdot_spec.pos_contract_after_mapping) fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes], out_axes=[*pdot_spec.batch_names, ...], axis_resources=axis_resources) with with_mesh(mesh_data): result = fun(lhs, rhs) expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
def nonnegative_softmax_kernel_feature_creator(data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001): """Constructs nonnegative kernel features for fast softmax attention. Args: data: input for which features are computes projection_matrix: random matrix used to compute features attention_dims_t: tuple of attention dimensions batch_dims_t: tuple of batch dimensions precision: precision parameter is_query: predicate indicating whether input data corresponds to queries or keys normalize_data: predicate indicating whether data should be normalized, eps: numerical stabilizer. Returns: Random features for fast softmax attention. """ if normalize_data: # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where # w_norm = w * data_normalizer for w in {q,k}. data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) else: data_normalizer = 1.0 ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix data_dash = lax.dot_general( data_normalizer * data, data_thick_random_matrix, (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), precision=precision) diag_data = jnp.square(data) diag_data = jnp.sum(diag_data, axis=data.ndim - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) last_dims_t = (len(data_dash.shape) - 1,) if is_query: data_dash = ratio * ( jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps) else: data_dash = ratio * ( jnp.exp(data_dash - diag_data - jnp.max( data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) + eps) return data_dash
def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel = self.param( "kernel", self.kernel_init, (inputs.shape[-1], self.features), self.dtype ) kernel = jnp.asarray(kernel, dtype) y = lax.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, query, key, value, bias=None, dtype=jnp.float32): assert key.ndim == query.ndim assert key.ndim == value.ndim n = query.ndim attn_weights = lax.dot_general(query, key, (((n - 1, ), (n - 1, )), ((), ()))) if bias is not None: attn_weights += bias attn_weights = self.attn_module()(attn_weights) attn_weights = attn_weights.astype(dtype) contract_dims = (tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1))) y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) return y
def __call__(self, x, kernel): y = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) return y + self.bias
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ in_features = x.shape[-2] x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype ) else: bias = None kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_point * self.n_cells), self.param_dtype, ) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) x, kernel, bias = promote_dtype(x, kernel, bias, dtype=None) dtype = x.dtype # Convert the convolutional kernel of shape (features, in_features, n_symm) # to the expanded kernel of shape (features, in_features, n_point(in), # n_point(out), *shape) used in FFT-based group convolutions kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape( *kernel.shape[:4], self.n_cells ) x = lax.dot_general( x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision ) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:2], -1) if self.use_bias: x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, inputs): kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())),) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y
def _cov_full_batch_diag_spatial(x1: np.ndarray, x2: np.ndarray, batch_axis: int, channel_axis: int) -> np.ndarray: diag_axes = tuple(i for i in range(x1.ndim) if i != batch_axis and i != channel_axis) ret = lax.dot_general(x1, x2, (((channel_axis, ), (channel_axis, )), (diag_axes, diag_axes))) ret = np.moveaxis(ret, (-2, -1), (0, 1)) return ret
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) kernel = self.param( "kernel", self.kernel_init, ( self.out_features, self.in_features, self.n_point * self.n_cells, ), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, (0, 1)) kernel = self.make_kernel(kernel) x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:4], self.n_cells) x = lax.dot_general(x, kernel, (((1, 2), (1, 2)), ((3, ), (4, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.out_features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def parallel_topk(qy, db, db_offset): scores = lax.dot_general(qy, db, (([1], [1]), ([], []))) ann_vals, ann_args = lax.approx_min_k( scores, k=k, reduction_dimension=1, recall_target=recall, reduction_input_size_override=db_size, aggregate_to_topk=False) return (ann_vals, ann_args + db_offset)
def __call__(self, inputs): inputs = jnp.asarray(inputs, self.dtype) kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1])) kernel = jnp.asarray(kernel.transpose(), self.dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) if self.use_bias: bias = self.param("bias", jax.nn.initializers.zeros, (self.features,)) bias = jnp.asarray(bias, self.dtype) y = y + bias return y
def __call__(self, inputs): kernel = self.param( "kernel", self.kernel_init, (self.n_tasks, inputs.shape[-1], self.features) ) y = lax.dot_general( inputs, kernel, dimension_numbers=(((2,), (1,)), ((0,), (0,))) ) bias = self.param("bias", self.bias_init, (self.n_tasks, 1, self.features)) y = y + bias return y
def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ inputs = jnp.asarray(inputs, self.dtype) ndim = inputs.ndim n_batch_dims = len(self.batch_dims) axis = _normalize_axes(self.axis, ndim) batch_dims = _normalize_axes(self.batch_dims, ndim) n_axis, n_features = len(axis), len(self.features) def kernel_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), np.prod(shape[-n_features:]),) kernel = jnp.concatenate([self.kernel_init(rng, flat_shape, dtype) for _ in range(size_batch_dims)], axis=0) return jnp.reshape(kernel, shape) batch_shape = tuple([inputs.shape[ax] for ax in batch_dims]) kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + self.features kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape) kernel = jnp.asarray(kernel, self.dtype) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) out = lax.dot_general(inputs, kernel, ((axis, contract_ind), (batch_dims, batch_ind)), precision=self.precision) if self.use_bias: def bias_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) bias = jnp.concatenate([self.bias_init(rng, flat_shape, dtype) for _ in range(size_batch_dims)], axis=0) return jnp.reshape(bias, shape) bias = self.param('bias', bias_init_wrap, batch_shape + self.features) # Reshape bias for broadcast. expand_dims = sorted( set(range(inputs.ndim)) - set(axis) - set(batch_dims)) for ax in expand_dims: bias = jnp.expand_dims(bias, ax) bias = jnp.asarray(bias, self.dtype) out = out + bias return out
def dot_general_int(ops): lhs_, rhs_ = ops input_dtype = lhs_.dtype lhs_int = lhs_.astype(jnp.int8) rhs_int = rhs_.astype(jnp.int8) return lax.dot_general( lhs_int, rhs_int, dimension_numbers=dimension_numbers, precision=dot_precision, preferred_element_type=jnp.int32).astype(input_dtype)
def __call__(self, x, kernel): x = jnp.asarray(x, self.dtype) kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) bias = jnp.asarray(self.bias, self.dtype) return y + bias
def low_rank_projection(inputs, kernel, precision): """low rank projection.""" input_dim = inputs.shape[1] # this kernel/parameter relies on sequence length kernel = kernel[:input_dim, :] inputs = inputs.transpose((0, 3, 2, 1)) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1, ), (0, )), ((), ())), precision=precision) y = y.transpose((0, 3, 2, 1)) return y
def sincos_softmax_kernel_feature_creator(data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True): """ Constructs kernel sin-cos features for fast softmax attention Args: data: input for which features are computes projection_matrix: random matrix used to compute features attention_dims_t: tuple of attention dimensions batch_dims_t: tuple of batch dimensions precision: precision parameter normalize_data: predicate indicating whether data should be normalized Returns: Random features for fast softmax attention. """ if normalize_data: # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) * # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}. data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) else: data_normalizer = 1.0 ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix data_dash = lax.dot_general( data_normalizer * data, data_thick_random_matrix, (((data.ndim - 1, ), (data_thick_random_matrix.ndim - 1, )), (batch_dims_t, batch_dims_t)), precision=precision, ) data_dash_cos = ratio * jnp.cos(data_dash) data_dash_sin = ratio * jnp.sin(data_dash) data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1) # Constructing D_data and data^{'} diag_data = jnp.square(data) diag_data = jnp.sum(diag_data, axis=data.ndim - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) # Additional renormalization for numerical stability data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True) diag_data -= data_renormalizer diag_data = jnp.exp(diag_data) data_prime = data_dash * diag_data return data_prime
def __call__(self, inputs): inputs = jnp.asarray(inputs, self.dtype) kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1, ), (0, )), ((), ())), precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features, )) bias = jnp.asarray(bias, self.dtype) y = y + bias return y
def _dot_product_attention(scope: Scope, query: Array, key: Array, value: Array, bias: Optional[Array] = None, attn_fn: Callable = softmax_attn, dtype=jnp.float32): assert key.ndim == query.ndim assert key.ndim == value.ndim n = query.ndim attn_weights = lax.dot_general(query, key, (((n - 1, ), (n - 1, )), ((), ()))) if bias is not None: attn_weights += bias attn_weights = attn_fn(scope, attn_weights) attn_weights = attn_weights.astype(dtype) contract_dims = (tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1))) y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) return y
def dot_general(lhs: np.ndarray, rhs: np.ndarray, contracting_dims: Axes, batch_dims: Axes, precision=None) -> np.ndarray: """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims. Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where `dimension_numbers == ((contracting_dims, contracting_dims), (batch_dims, batch_dims))`, but allows arbitrary dimension order and preserves it in the output. See XLA's `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`. Args: lhs: np.ndarray. rhs: np.ndarray. contracting_dims: contracting dimensions. batch_dims: batch dimensions. precision: Optional. Either `None`, which means the default precision for the backend, or a `Precision` enum value. Returns: Dot product result with preserved dimension order. """ contracting_dims = canonicalize_axis(contracting_dims, lhs) batch_dims = canonicalize_axis(batch_dims, lhs) n_batch_dims = len(batch_dims) leading_batch_dims = range(n_batch_dims) lhs = np.moveaxis(lhs, batch_dims, leading_batch_dims) if rhs is None: rhs = lhs else: rhs = np.moveaxis(rhs, batch_dims, leading_batch_dims) shifted_contracting_dims = [ i + sum(1 if i < b else 0 for b in batch_dims) for i in contracting_dims ] dimension_numbers = ((shifted_contracting_dims, shifted_contracting_dims), (leading_batch_dims, leading_batch_dims)) prod = lax.dot_general(lhs, rhs, dimension_numbers, precision) prod = zip_axes(prod, n_batch_dims) res_batch_dims = get_res_batch_dims(contracting_dims, batch_dims) prod = np.moveaxis(prod, leading_batch_dims, res_batch_dims) return prod
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ in_features = x.shape[-2] kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_symm), self.param_dtype, ) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype ) else: bias = None if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) kernel, bias, x = promote_dtype(kernel, bias, x, dtype=None) # Converts the convolutional kernel of shape (features, in_features, n_symm) # to a full dense kernel of shape (features, in_features, n_symm, n_symm) # result[out, in, g, h] == kernel[out, in, g^{-1}h] # input dimensions are [in, g], output dimensions are [out, h] kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2) x = lax.dot_general( x, kernel, (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: x += jnp.expand_dims(bias, 1) return x