def testMixedFloatInt(self): tree = [ jnp.array([3], jnp.int32), jnp.array([[1., 2.], [3., 4.]], jnp.float32) ] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedIntBool(self): tree = [ jnp.array([0], jnp.bool_), jnp.array([[1, 2], [3, 4]], jnp.int32) ] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testMixedFloatComplex(self): tree = [ jnp.array([1.], jnp.float32), jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64) ] raveled, unravel = flatten_util.ravel_pytree(tree) self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64)) tree_ = unravel(raveled) self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ in_features = x.shape[-2] dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_point * self.n_cells), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Convert the convolutional kernel of shape (features, in_features, n_symm) # to the expanded kernel of shape (features, in_features, n_point(in), # n_point(out), *shape) used in FFT-based group convolutions kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape( *kernel.shape[:4], self.n_cells ) x = lax.dot_general( x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision ) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, x_in: Array): nv = x_in.shape[-1] dtype = jnp.promote_types(x_in.dtype, self.dtype) x_in = jnp.asarray(x_in, dtype=dtype) kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype) kernel = kernel + kernel.T y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
def minkowski_distance_p(x, y, p=2): """ Compute the pth power of the L**p distance between two arrays. For efficiency, this function computes the L**p distance but does not extract the pth root. If `p` is 1 or infinity, this is equal to the actual L**p distance. Parameters ---------- x : (M, K) array_like Input array. y : (N, K) array_like Input array. p : float, 1 <= p <= infinity Which Minkowski p-norm to use. Examples -------- >>> from scipy.spatial import minkowski_distance_p >>> minkowski_distance_p([[0,0],[0,0]], [[1,1],[0,1]]) array([2, 1]) """ x = np.asarray(x) y = np.asarray(y) # Find smallest common datatype with float64 (return type of this function) - addresses #10262. # Don't just cast to float64 for complex input case. common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype), 'float64') # Make sure x and y are NumPy arrays of correct datatype. x = x.astype(common_datatype) y = y.astype(common_datatype) if p == np.inf: return np.amax(np.abs(y-x), axis=-1) elif p == 1: return np.sum(np.abs(y-x), axis=-1) else: return np.sum(np.abs(y-x)**p, axis=-1)
def __call__(self, inputs: Array) -> Array: """Applies a convolution to the inputs. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size,) else: kernel_size = self.kernel_size is_single_input = False if inputs.ndim == len(kernel_size) + 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) strides = self.strides or (1,) * (inputs.ndim - 2) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = kernel_size + ( in_features // self.feature_group_count, self.features, ) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( inputs, kernel, strides, self.padding, lhs_dilation=self.input_dilation, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs: Array) -> Array: """ Applies a masked linear transformation to the inputs. Args: inputs: input data with dimensions (batch, length, features). Returns: The transformed data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) is_single_input = False if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, size, in_features = inputs.shape inputs = inputs.reshape((batch, size * in_features)) mask = jnp.ones((size, size), dtype=self.dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron( mask, jnp.ones((in_features, self.features), dtype=self.dtype)) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.dtype, ) mask = jnp.asarray(mask, dtype) kernel = jnp.asarray(kernel, dtype) y = lax.dot(inputs, mask * kernel, precision=self.precision) y = y.reshape((batch, size, self.features)) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (size, self.features), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = (x.reshape(-1, self.n_cells, self.sites_per_cell).transpose( 0, 2, 1).reshape(-1, self.sites_per_cell, *self.shape)) kernel = self.param( "kernel", self.kernel_init, (self.features, self.n_cells * self.sites_per_cell), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, 0) kernel = self.make_kernel(kernel) x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:2], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:3], self.n_cells) x = lax.dot_general(x, kernel, (((1, ), (2, )), ((2, ), (3, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. Behaviour mirrors of `jax.lax.conv_transpose`. Args: inputs: input data with dimensions (batch, spatial_dims..., features). Returns: The convolved data. """ dtype = jnp.promote_types(self.dtype, inputs.dtype) inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size, ) else: kernel_size = self.kernel_size is_single_input = False if inputs.ndim == len(kernel_size) + 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) strides = self.strides or (1, ) * (inputs.ndim - 2) in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, self.features) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) y = lax.conv_transpose( inputs, kernel, strides, self.padding, rhs_dilation=self.kernel_dilation, precision=self.precision, ) if is_single_input: y = jnp.squeeze(y, axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ in_features = x.shape[-2] dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_symm), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Converts the convolutional kernel of shape (features, in_features, n_symm) # to a full dense kernel of shape (features, in_features, n_symm, n_symm) # result[out, in, g, h] == kernel[out, in, g^{-1}h] # input dimensions are [in, g], output dimensions are [out, h] kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) x += jnp.expand_dims(bias, 1) return x
def __init__( self, hilbert: AbstractHilbert, graph: AbstractGraph, h: float, J: float = 1.0, dtype: DType = None, ): r""" Constructs the Ising Operator from an hilbert space and a graph specifying the connectivity. Args: hilbert: Hilbert space the operator acts on. h: The strength of the transverse field. J: The strength of the coupling. Default is 1.0. dtype: The dtype of the matrix elements. Examples: Constructs an ``Ising`` operator for a 1D system. >>> import netket as nk >>> g = nk.graph.Hypercube(length=20, n_dim=1, pbc=True) >>> hi = nk.hilbert.Spin(s=0.5, N=g.n_nodes) >>> op = nk.operator.Ising(h=1.321, hilbert=hi, J=0.5, graph=g) >>> print(op) Ising(J=0.5, h=1.321; dim=20) """ assert ( graph.n_nodes == hilbert.size ), "The size of the graph must match the hilbert space" super().__init__(hilbert) if dtype is None: dtype = jnp.promote_types(_dtype(h), _dtype(J)) dtype = np.empty((), dtype=dtype).dtype self._dtype = dtype self._h = np.array(h, dtype=dtype) self._J = np.array(J, dtype=dtype) self._edges = np.asarray( [[u, v] for u, v in graph.edges()], dtype=np.intp, )
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ in_features = x.shape[-2] dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = self.forward_ft(x) kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_symm), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) kernel = self.forward_ft(kernel) x = tuple( lax.dot_general( x[i], kernel[i], (((1, 4), (1, 3)), ((2,), (2,))) ).transpose(1, 3, 0, 2, 4) for i in range(len(x)) ) x = self.inverse_ft(x) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = x.reshape(-1, x.shape[1] * x.shape[2]) kernel = self.param( "kernel", self.kernel_init, (self.out_features, self.in_features, self.n_symm), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, (0, 1)) kernel = self.full_kernel(kernel) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.out_features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.out_features, ), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) x += jnp.expand_dims(bias, (0, 2)) return x
def testFftn(self, inverse, real, shape, dtype, axes, s, norm): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_op = _get_fftn_func(jnp.fft, inverse, real) np_op = _get_fftn_func(np.fft, inverse, real) jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm) np_fn = lambda a: np_op(a, axes=axes, norm=norm) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if (config.x64_enabled and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) # check dtypes dtype = jnp_fn(rng(shape, dtype)).dtype expected_dtype = jnp.promote_types(float if inverse and real else complex, dtype) self.assertEqual(dtype, expected_dtype)
def __call__(self, x: Array) -> Array: """Applies the symmetrized linear transformation to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) kernel = self.param("kernel", self.kernel_init, (self.features, self.n_sites), self.dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, 0) kernel = self.full_kernel(kernel).reshape(-1, self.features, self.n_symm) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) x += bias return x
def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) batch_dims = _canonicalize_tuple(self.batch_dims) if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( "batch_dims %s must be consecutive leading " "dimensions starting from 0." % str(batch_dims) ) dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) def kernel_init_wrap(rng, shape, dtype=jnp.float64): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = ( np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), np.prod(shape[-n_features:]), ) kernel = jnp.concatenate( [ self.kernel_init(rng, flat_shape, dtype) for _ in range(size_batch_dims) ], axis=0, ) return jnp.reshape(kernel, shape) batch_shape = tuple([inputs.shape[ax] for ax in batch_dims]) kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features kernel = self.param("kernel", kernel_init_wrap, batch_shape + kernel_shape) kernel = jnp.asarray(kernel, dtype) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) out = lax.dot_general( inputs, kernel, ((axis, contract_ind), (batch_dims, batch_ind)), precision=self.precision, ) if self.use_bias: def bias_init_wrap(rng, shape, dtype=jnp.float64): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) bias = jnp.concatenate( [ self.bias_init(rng, flat_shape, dtype) for _ in range(size_batch_dims) ], axis=0, ) return jnp.reshape(bias, shape) bias = self.param("bias", bias_init_wrap, batch_shape + features) # Reshape bias for broadcast. expand_dims = sorted(set(range(inputs.ndim)) - set(axis) - set(batch_dims)) for ax in expand_dims: bias = jnp.expand_dims(bias, ax) bias = jnp.asarray(bias, dtype) out = out + bias return out
def __call__(self, x: Array) -> Array: """Applies the symmetrized linear transformation to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) # infer in_features and ensure input dimensions (batch, in_features,n_sites) # TODO: Deprecated: Eventually remove and error if less than 3 dimensions if x.ndim < 3: old_shape = x.shape if x.ndim == 1: x = jnp.expand_dims(x, (0, 1)) elif x.ndim == 2: x = jnp.expand_dims(x, 1) symm_input_warning(old_shape, x.shape, "DenseSymm") in_features = x.shape[1] kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_sites), self.dtype, ) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Converts the convolutional kernel of shape (self.features, in_features, n_sites) # to a full dense kernel of shape (self.features, in_features, n_symm, n_sites). # result[out, in, g, r] == kernel[out, in, g^{-1}r] kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2) kernel = jnp.asarray(kernel, dtype) # x is (batches, in_featuers, n_sites) # kernel is (self.features, in_features, n_symm, n_sites) x = lax.dot_general( x, kernel, (((x.ndim - 2, x.ndim - 1), (1, 3)), ((), ())), precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) # Convert symmetry-reduced bias of shape (features,) to the full bias of # shape (..., features, 1). bias = jnp.expand_dims(bias, 1) bias = jnp.asarray(bias, dtype) x += bias return x
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) # TODO: Deprecated: Eventually remove and error if less than 3 dimensions # infer in_features and ensure input dimensions (batch, in_features,n_sites) if x.ndim < 3: old_shape = x.shape if x.ndim == 1: x = jnp.expand_dims(x, (0, 1)) elif x.ndim == 2: x = jnp.expand_dims(x, 1) symm_input_warning(old_shape, x.shape, "DenseSymm") in_features = x.shape[1] x = x.reshape(*x.shape[:-1], self.n_cells, self.sites_per_cell) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_cells * self.sites_per_cell), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Converts the convolutional kernel of shape (features, in_features, n_sites) # to the expanded kernel of shape (features, in_features, sites_per_cell, # n_point, *shape) used in FFT-based group convolutions. kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:4], self.n_cells) # TODO: the batch ordering should be revised: batch dimensions should # be leading x = lax.dot_general(x, kernel, (((1, 2), (1, 2)), ((3, ), (4, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def update_site(self, inputs: Array, index: int) -> Array: """ Adds an input site into the cache, and applies the masked linear transformation to the cache. Args: inputs: an input site to be added into the cache with dimensions (batch, features). index: the index of the output site. The index of the input site should be `index - self.exclusive`. Returns: The output site with dimensions (batch, features). """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) is_single_input = False if inputs.ndim == 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, in_features = inputs.shape size = self.size # Number of input sites depended by the output site at the index size_i = index + 1 # Initialize the cache with zeros, and the RNG key is None # `cache.dtype` must be the same as `inputs.dtype` (no promotion) _cache = self.variable("cache", "inputs", zeros, None, (batch, size, in_features), inputs.dtype) initializing = self.is_mutable_collection("params") if not initializing: # Add the input site into the cache # To write the cache, use `_cache.value` as the left value of the assignment _cache.value = lax.cond( index - self.exclusive >= 0, lambda _: _cache.value.at[:, index - self.exclusive, :].set( inputs), lambda _: _cache.value, None, ) cache = _cache.value cache = jnp.asarray(cache, dtype) cache_i = cache[:, :size_i, :] cache_i = cache_i.reshape((batch, size_i * in_features)) # The construction of `mask` will be optimized to a constant by JIT mask = jnp.ones((size, size), dtype=self.dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron( mask, jnp.ones((in_features, self.features), dtype=self.dtype)) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.dtype, ) mask = jnp.asarray(mask, dtype) kernel = jnp.asarray(kernel, dtype) mask_i = mask.reshape((size, in_features, size, self.features)) mask_i = mask_i[:size_i, :, index, :] mask_i = mask_i.reshape((size_i * in_features, self.features)) kernel_i = kernel.reshape((size, in_features, size, self.features)) kernel_i = kernel_i[:size_i, :, index, :] kernel_i = kernel_i.reshape((size_i * in_features, self.features)) y_i = lax.dot(cache_i, mask_i * kernel_i, precision=self.precision) if self.use_bias: bias = self.param("bias", self.bias_init, (size, self.features), self.dtype) bias = jnp.asarray(bias, dtype) bias_i = bias[index, :] y_i = y_i + bias_i assert y_i.shape[1] == self.features if is_single_input: y_i = y_i.squeeze(axis=0) return y_i
def update_site(self, inputs: Array, index: int) -> Array: """ Adds an input site into the cache, and applies the masked convolution to the cache. Args: inputs: an input site to be added into the cache with dimensions (batch, features). index: the index of the output site. The index of the input site should be `index - self.exclusive`. Returns: The next output site with dimensions (batch, features). """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) L = self.L index_w = index % L kernel_h, kernel_w = self.kernel_size dilation_h, dilation_w = self.kernel_dilation ones = (1, 1) is_single_input = False if inputs.ndim == 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, in_features = inputs.shape assert in_features % self.feature_group_count == 0 recep_h = (kernel_h - 1) * dilation_h + 1 recep_w = (kernel_w - 1) * dilation_w + 1 # Initialize the cache with zeros, and the RNG key is None # `cache.dtype` must be the same as `inputs.dtype` (no promotion) _cache = self.variable( "cache", "inputs", zeros, None, (batch, recep_h, L, in_features), inputs.dtype, ) initializing = self.is_mutable_collection("params") if not initializing: # Add the input site into the cache # To write the cache, use `_cache.value` as the left value of the assignment inputs = jnp.expand_dims(inputs, axis=(1, 2)) # Index of the input site in the width direction index_w_in = (index - self.exclusive) % L def _add(cache): # return cache.at[:, -1, index_w_in, :].set(inputs) return lax.dynamic_update_slice(cache, inputs, (0, -1, index_w_in, 0)) def _shift(cache): return jnp.concatenate( [ cache[:, 1:, :, :], jnp.zeros( (batch, 1, L, in_features), dtype=inputs.dtype), ], axis=1, ) cache_new_row = lax.cond( index_w_in == 0, lambda _: _add(_shift(_cache.value)), lambda _: _shift(_add(_cache.value)), None, ) cache_new = lax.cond( index_w == 0, lambda _: cache_new_row, lambda _: _add(_cache.value), None, ) _cache.value = lax.cond( index - self.exclusive >= 0, lambda _: cache_new, lambda _: _cache.value, None, ) cache = _cache.value cache = jnp.asarray(cache, dtype) kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features, ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, self.mask), kernel_shape, self.dtype, ) kernel = jnp.asarray(kernel, dtype) # Zero padding cache = jnp.pad( cache, ( (0, 0), (0, 0), (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w), (0, 0), ), ) # cache = cache[:, :, index_w : index_w + recep_w, :] cache = lax.dynamic_slice(cache, (0, 0, index_w, 0), (batch, recep_h, recep_w, in_features)) dimension_numbers = flax.linen.linear._conv_dimension_numbers( cache.shape) y_i = lax.conv_general_dilated( cache, kernel, window_strides=ones, padding="VALID", lhs_dilation=ones, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) y_i = y_i + bias y_i = y_i.squeeze(axis=(1, 2)) if is_single_input: y_i = y_i.squeeze(axis=0) return y_i
def update_site(self, inputs: Array, index: int) -> Array: """ Adds an input site into the cache, and applies the masked convolution to the cache. Args: inputs: an input site to be added into the cache with dimensions (batch, features). index: the index of the output site. The index of the input site should be `index - self.exclusive`. Returns: The next output site with dimensions (batch, features). """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_size = self.kernel_size - self.exclusive dilation = self.kernel_dilation is_single_input = False if inputs.ndim == 1: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, in_features = inputs.shape assert in_features % self.feature_group_count == 0 cache_size = kernel_size * dilation - (not self.exclusive) * ( dilation - 1) # Initialize the cache with zeros, and the RNG key is None # `cache.dtype` must be the same as `inputs.dtype` (no promotion) _cache = self.variable( "cache", "inputs", zeros, None, (batch, cache_size, in_features), inputs.dtype, ) initializing = self.is_mutable_collection("params") if not initializing: # Add the input site into the cache # To write the cache, use `_cache.value` as the left value of the assignment _cache.value = lax.cond( index - self.exclusive >= 0, lambda _: jnp.concatenate( [_cache.value[:, 1:, :], jnp.expand_dims(inputs, axis=1)], axis=1), lambda _: _cache.value, None, ) cache = _cache.value cache = jnp.asarray(cache, dtype) kernel_shape = ( kernel_size, in_features // self.feature_group_count, self.features, ) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) if self.exclusive and dilation > 1: cache = cache[:, :-(dilation - 1), :] dimension_numbers = flax.linen.linear._conv_dimension_numbers( cache.shape) y_i = lax.conv_general_dilated( cache, kernel, window_strides=(1, ), padding="VALID", lhs_dilation=(1, ), rhs_dilation=(dilation, ), dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) y_i = y_i + bias y_i = y_i.squeeze(axis=1) if is_single_input: y_i = y_i.squeeze(axis=0) return y_i
def __call__(self, inputs: Array) -> Array: """ Applies a masked convolution to the inputs. For 1D convolution, there is not really a mask. We only need to apply appropriate padding. Args: inputs: input data with dimensions (batch, length, features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_size = self.kernel_size - self.exclusive dilation = self.kernel_dilation is_single_input = False if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = ( kernel_size, in_features // self.feature_group_count, self.features, ) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) if self.exclusive: inputs = inputs[:, :-dilation, :] # Zero padding y = jnp.pad( inputs, ( (0, 0), ((kernel_size - (not self.exclusive)) * dilation, 0), (0, 0), ), ) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( y, kernel, window_strides=(1,), padding="VALID", lhs_dilation=(1,), rhs_dilation=(dilation,), dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
def __call__(self, inputs: Array) -> Array: """ Applies a masked convolution to the inputs. Args: inputs: input data with dimensions (batch, width, height, features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_h, kernel_w = self.kernel_size dilation_h, dilation_w = self.kernel_dilation ones = (1, 1) is_single_input = False if inputs.ndim == 3: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features, ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, self.mask), kernel_shape, self.dtype, ) mask = jnp.asarray(self.mask, dtype) kernel = jnp.asarray(kernel, dtype) # Zero padding y = jnp.pad( inputs, ( (0, 0), ((kernel_h - 1) * dilation_h, 0), (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w), (0, 0), ), ) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( y, mask * kernel, window_strides=ones, padding="VALID", lhs_dilation=ones, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y