def _multiply_operators( hilbert, support_A: Tuple, A: Array, support_B: Tuple, B: Array, *, dtype ) -> Tuple[Tuple, Array]: """ Returns the `Tuple[acting_on, Matrix]` representing the operator obtained by multiplying the two input operators A and B. """ support_A = np.asarray(support_A) support_B = np.asarray(support_B) inters = np.intersect1d(support_A, support_B, return_indices=False) if support_A.size == support_B.size and np.array_equal(support_A, support_B): return tuple(support_A), A @ B elif inters.size == 0: # disjoint supports support = tuple(np.concatenate([support_A, support_B])) operator = np.kron(A, B) operator, support = _reorder_kronecker_product(hilbert, operator, support) return tuple(support), operator else: _support_A = list(support_A) _support_B = list(support_B) _A = A.copy() _B = B.copy() # expand _act to match _act_i supp_B_min = min(support_B) for site in support_A: if site not in support_B: I = np.eye(hilbert.shape[site], dtype=dtype) if site < supp_B_min: _support_B = [site] + _support_B _B = np.kron(I, _B) else: # site > actmax _support_B = _support_B + [site] _B = np.kron(_B, I) supp_A_min = min(support_A) for site in support_B: if site not in support_A: I = np.eye(hilbert.shape[site], dtype=dtype) if site < supp_A_min: _support_A = [site] + _support_A _A = np.kron(I, _A) else: # site > actmax _support_A = _support_A + [site] _A = np.kron(_A, I) # reorder _A, _support_A = _reorder_kronecker_product(hilbert, _A, _support_A) _B, _support_B = _reorder_kronecker_product(hilbert, _B, _support_B) if len(_support_A) == len(_support_B) and np.array_equal( _support_A, _support_B ): # back to the case of non-interesecting with same support return tuple(_support_A), _A @ _B else: raise ValueError("Something failed")
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = (x.reshape(-1, self.n_cells, self.sites_per_cell).transpose( 0, 2, 1).reshape(-1, self.sites_per_cell, *self.shape)) kernel = self.param( "kernel", self.kernel_init, (self.features, self.n_cells * self.sites_per_cell), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, 0) kernel = self.make_kernel(kernel) x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:2], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:3], self.n_cells) x = lax.dot_general(x, kernel, (((1, ), (2, )), ((2, ), (3, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ in_features = x.shape[-2] x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype ) else: bias = None kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_point * self.n_cells), self.param_dtype, ) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) x, kernel, bias = promote_dtype(x, kernel, bias, dtype=None) dtype = x.dtype # Convert the convolutional kernel of shape (features, in_features, n_symm) # to the expanded kernel of shape (features, in_features, n_point(in), # n_point(out), *shape) used in FFT-based group convolutions kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape( *kernel.shape[:4], self.n_cells ) x = lax.dot_general( x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision ) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:2], -1) if self.use_bias: x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = x.reshape(-1, x.shape[1] * x.shape[2]) kernel = self.param( "kernel", self.kernel_init, (self.out_features, self.in_features, self.n_symm), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, (0, 1)) kernel = self.full_kernel(kernel) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.out_features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.out_features, ), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) x += jnp.expand_dims(bias, (0, 2)) return x
def __call__(self, inputs: Array) -> Array: """ Applies a masked linear transformation to the inputs. Args: inputs: input data with dimensions (batch, length, features). Returns: The transformed data. """ if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) else: is_single_input = False batch, size, in_features = inputs.shape inputs = inputs.reshape((batch, size * in_features)) if self.use_bias: bias = self.param( "bias", self.bias_init, (size, self.features), self.param_dtype ) else: bias = None mask = jnp.ones((size, size), dtype=self.param_dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron( mask, jnp.ones((in_features, self.features), dtype=self.param_dtype) ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.param_dtype, ) inputs, mask, kernel, bias = promote_dtype( inputs, mask, kernel, bias, dtype=None ) y = lax.dot(inputs, mask * kernel, precision=self.precision) y = y.reshape((batch, size, self.features)) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: y = y + bias return y
def _to_int_vector(v: Array) -> str: try: v = __to_int_vector(v) return f"[{v[0]},{v[1]},{v[2]}]" except ValueError: # in hexagonal symmetry, you often get a √3 in the x/y coordinate try: w = v.copy() w[1] /= 3**0.5 w = __to_int_vector(w) return f"[{w[0]},{w[1]}√3,{w[2]}]" except ValueError: # just return a normalised v v = v / np.linalg.norm(v) return f"[{v[0]:.3f},{v[1]:.3f},{v[2]:.3f}]"
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ in_features = x.shape[-2] kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_symm), self.param_dtype, ) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype ) else: bias = None if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) kernel, bias, x = promote_dtype(kernel, bias, x, dtype=None) # Converts the convolutional kernel of shape (features, in_features, n_symm) # to a full dense kernel of shape (features, in_features, n_symm, n_symm) # result[out, in, g, h] == kernel[out, in, g^{-1}h] # input dimensions are [in, g], output dimensions are [out, h] kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2) x = lax.dot_general( x, kernel, (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: x += jnp.expand_dims(bias, 1) return x
def __call__(self, x: Array) -> Array: """Applies the symmetrized linear transformation to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) kernel = self.param("kernel", self.kernel_init, (self.features, self.n_sites), self.dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, 0) kernel = self.full_kernel(kernel).reshape(-1, self.features, self.n_symm) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) x += bias return x
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) chunk_size: an int specfying the size of the chunks degradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R jacobian_fun = dense_jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts jacobian_fun = dense_jacobian_cplx elif mode == "holomorphic": split_complex_params = False jacobian_fun = dense_jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'.format( mode ) ) # Stored as contiguous real stacked on top of contiguous imaginary (SOA) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_reim(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn def gradf_fun(params, σ): gradf_dense = jacobian_fun(f, params, σ) return gradf_dense jacobians = nkjax.vmap_chunked(gradf_fun, in_axes=(None, 0), chunk_size=chunk_size)( params, samples ) n_samp = samples.shape[0] * mpi.n_nodes centered_oks = subtract_mean(jacobians, axis=0) / np.sqrt( n_samp, dtype=jacobians.dtype ) centered_oks = centered_oks.reshape(-1, centered_oks.shape[-1]) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
def _reshape_inputs(model: FastARNNConv2D, inputs: Array) -> Array: # noqa: F811 return inputs.reshape((inputs.shape[0], model.L, model.L))
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, pdf=None, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) pdf: |ψ(x)|^2 if exact optimization is being used else None chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts centered_jacobian_fun = compose( stack_jacobian_tuple, partial(centered_jacobian_cplx, _build_fn=lambda *x: x), ) jacobian_fun = jacobian_cplx elif mode == "holomorphic": split_complex_params = False centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}' .format(mode)) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_real(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn if pdf is None: centered_oks = _divide_by_sqrt_n_samp( centered_jacobian_fun( f, params, samples, chunk_size=chunk_size, ), samples, ) else: oks = jacobian_fun(f, params, samples) oks_mean = jax.tree_map(partial(sum, axis=0), _multiply_by_pdf(oks, pdf)) centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean) centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf)) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) # TODO: Deprecated: Eventually remove and error if less than 3 dimensions # infer in_features and ensure input dimensions (batch, in_features,n_sites) if x.ndim < 3: old_shape = x.shape if x.ndim == 1: x = jnp.expand_dims(x, (0, 1)) elif x.ndim == 2: x = jnp.expand_dims(x, 1) symm_input_warning(old_shape, x.shape, "DenseSymm") in_features = x.shape[1] x = x.reshape(*x.shape[:-1], self.n_cells, self.sites_per_cell) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_cells * self.sites_per_cell), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Converts the convolutional kernel of shape (features, in_features, n_sites) # to the expanded kernel of shape (features, in_features, sites_per_cell, # n_point, *shape) used in FFT-based group convolutions. kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:4], self.n_cells) # TODO: the batch ordering should be revised: batch dimensions should # be leading x = lax.dot_general(x, kernel, (((1, 2), (1, 2)), ((3, ), (4, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real