def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ in_features = x.shape[-2] x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype ) else: bias = None kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_point * self.n_cells), self.param_dtype, ) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) x, kernel, bias = promote_dtype(x, kernel, bias, dtype=None) dtype = x.dtype # Convert the convolutional kernel of shape (features, in_features, n_symm) # to the expanded kernel of shape (features, in_features, n_point(in), # n_point(out), *shape) used in FFT-based group convolutions kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape( *kernel.shape[:4], self.n_cells ) x = lax.dot_general( x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision ) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:2], -1) if self.use_bias: x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) kernel = self.param( "kernel", self.kernel_init, ( self.out_features, self.in_features, self.n_point * self.n_cells, ), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, (0, 1)) kernel = self.make_kernel(kernel) x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:4], self.n_cells) x = lax.dot_general(x, kernel, (((1, 2), (1, 2)), ((3, ), (4, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.out_features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) x = x.reshape(-1, x.shape[1] * x.shape[2]) kernel = self.param( "kernel", self.kernel_init, (self.out_features, self.in_features, self.n_symm), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, (0, 1)) kernel = self.full_kernel(kernel) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.out_features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.out_features, ), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) x += jnp.expand_dims(bias, (0, 2)) return x
def __call__(self, inputs: Array) -> Array: """ Applies a masked linear transformation to the inputs. Args: inputs: input data with dimensions (batch, length, features). Returns: The transformed data. """ if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) else: is_single_input = False batch, size, in_features = inputs.shape inputs = inputs.reshape((batch, size * in_features)) if self.use_bias: bias = self.param( "bias", self.bias_init, (size, self.features), self.param_dtype ) else: bias = None mask = jnp.ones((size, size), dtype=self.param_dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron( mask, jnp.ones((in_features, self.features), dtype=self.param_dtype) ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.param_dtype, ) inputs, mask, kernel, bias = promote_dtype( inputs, mask, kernel, bias, dtype=None ) y = lax.dot(inputs, mask * kernel, precision=self.precision) y = y.reshape((batch, size, self.features)) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: y = y + bias return y
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ in_features = x.shape[-2] kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_symm), self.param_dtype, ) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype ) else: bias = None if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) kernel, bias, x = promote_dtype(kernel, bias, x, dtype=None) # Converts the convolutional kernel of shape (features, in_features, n_symm) # to a full dense kernel of shape (features, in_features, n_symm, n_symm) # result[out, in, g, h] == kernel[out, in, g^{-1}h] # input dimensions are [in, g], output dimensions are [out, h] kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2) x = lax.dot_general( x, kernel, (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: x += jnp.expand_dims(bias, 1) return x
def __call__(self, x: Array) -> Array: """Applies the symmetrized linear transformation to the inputs along the last dimension. Args: x: The nd-array to be transformed. Returns: The transformed input. """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) kernel = self.param("kernel", self.kernel_init, (self.features, self.n_sites), self.dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.mask, 0) kernel = self.full_kernel(kernel).reshape(-1, self.features, self.n_symm) kernel = jnp.asarray(kernel, dtype) x = lax.dot_general( x, kernel, (((x.ndim - 1, ), (0, )), ((), ())), precision=self.precision, ) x = x.reshape(-1, self.features, self.n_symm) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(self.full_bias(bias), dtype) x += bias return x
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) chunk_size: an int specfying the size of the chunks degradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R jacobian_fun = dense_jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts jacobian_fun = dense_jacobian_cplx elif mode == "holomorphic": split_complex_params = False jacobian_fun = dense_jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'.format( mode ) ) # Stored as contiguous real stacked on top of contiguous imaginary (SOA) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_reim(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn def gradf_fun(params, σ): gradf_dense = jacobian_fun(f, params, σ) return gradf_dense jacobians = nkjax.vmap_chunked(gradf_fun, in_axes=(None, 0), chunk_size=chunk_size)( params, samples ) n_samp = samples.shape[0] * mpi.n_nodes centered_oks = subtract_mean(jacobians, axis=0) / np.sqrt( n_samp, dtype=jacobians.dtype ) centered_oks = centered_oks.reshape(-1, centered_oks.shape[-1]) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
def _reshape_inputs(model: FastARNNConv2D, inputs: Array) -> Array: # noqa: F811 return inputs.reshape((inputs.shape[0], model.L, model.L))
def prepare_centered_oks( apply_fun: Callable, params: PyTree, samples: Array, model_state: Optional[PyTree], mode: str, rescale_shift: bool, pdf=None, chunk_size: int = None, ) -> PyTree: """ compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩ divided by √n In a somewhat intransparent way this also internally splits all parameters to real in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors Args: apply_fun: The forward pass of the Ansatz params : a pytree of parameters p samples : an array of (n in total) batched samples σ model_state: untrained state parameters of the model mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic' rescale_shift: whether scale-invariant regularisation should be used (default: True) pdf: |ψ(x)|^2 if exact optimization is being used else None chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None) Returns: if not rescale_shift: a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n; None else: the same pytree, but the entries for each parameter normalised to unit norm; pytree containing the norms that were divided out (same shape as params) """ # un-batch the samples samples = samples.reshape((-1, samples.shape[-1])) # pre-apply the model state def forward_fn(W, σ): return apply_fun({"params": W, **model_state}, σ) if mode == "real": split_complex_params = True # convert C→R and R&C→R to R→R centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo elif mode == "complex": split_complex_params = True # convert C→C and R&C→C to R→C # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx) # avoid converting to complex and then back # by passing around the oks as a tuple of two pytrees representing the real and imag parts centered_jacobian_fun = compose( stack_jacobian_tuple, partial(centered_jacobian_cplx, _build_fn=lambda *x: x), ) jacobian_fun = jacobian_cplx elif mode == "holomorphic": split_complex_params = False centered_jacobian_fun = centered_jacobian_real_holo jacobian_fun = jacobian_real_holo else: raise NotImplementedError( 'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}' .format(mode)) if split_complex_params: # doesn't do anything if the params are already real params, reassemble = tree_to_real(params) def f(W, σ): return forward_fn(reassemble(W), σ) else: f = forward_fn if pdf is None: centered_oks = _divide_by_sqrt_n_samp( centered_jacobian_fun( f, params, samples, chunk_size=chunk_size, ), samples, ) else: oks = jacobian_fun(f, params, samples) oks_mean = jax.tree_map(partial(sum, axis=0), _multiply_by_pdf(oks, pdf)) centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean) centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf)) if rescale_shift: return _rescale(centered_oks) else: return centered_oks, None
def __call__(self, x: Array) -> Array: """Applies the equivariant transform to the inputs along the last two dimensions (-2: features, -1: group elements) """ dtype = jnp.promote_types(x.dtype, self.dtype) x = jnp.asarray(x, dtype) # TODO: Deprecated: Eventually remove and error if less than 3 dimensions # infer in_features and ensure input dimensions (batch, in_features,n_sites) if x.ndim < 3: old_shape = x.shape if x.ndim == 1: x = jnp.expand_dims(x, (0, 1)) elif x.ndim == 2: x = jnp.expand_dims(x, 1) symm_input_warning(old_shape, x.shape, "DenseSymm") in_features = x.shape[1] x = x.reshape(*x.shape[:-1], self.n_cells, self.sites_per_cell) x = x.transpose(0, 1, 3, 2) x = x.reshape(*x.shape[:-1], *self.shape) kernel = self.param( "kernel", self.kernel_init, (self.features, in_features, self.n_cells * self.sites_per_cell), self.dtype, ) kernel = jnp.asarray(kernel, dtype) if self.mask is not None: kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1)) # Converts the convolutional kernel of shape (features, in_features, n_sites) # to the expanded kernel of shape (features, in_features, sites_per_cell, # n_point, *shape) used in FFT-based group convolutions. kernel = kernel[..., self.mapping] x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(*kernel.shape[:4], self.n_cells) # TODO: the batch ordering should be revised: batch dimensions should # be leading x = lax.dot_general(x, kernel, (((1, 2), (1, 2)), ((3, ), (4, ))), precision=self.precision) x = x.transpose(1, 2, 3, 0) x = x.reshape(*x.shape[:3], *self.shape) x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells) x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features, ), self.dtype) bias = jnp.asarray(bias, dtype) x += jnp.expand_dims(bias, (0, 2)) if jnp.can_cast(x, dtype): return x else: return x.real