Exemplo n.º 1
0
def _multiply_operators(
    hilbert, support_A: Tuple, A: Array, support_B: Tuple, B: Array, *, dtype
) -> Tuple[Tuple, Array]:
    """
    Returns the `Tuple[acting_on, Matrix]` representing the operator obtained by
    multiplying the two input operators A and B.
    """
    support_A = np.asarray(support_A)
    support_B = np.asarray(support_B)

    inters = np.intersect1d(support_A, support_B, return_indices=False)

    if support_A.size == support_B.size and np.array_equal(support_A, support_B):
        return tuple(support_A), A @ B
    elif inters.size == 0:
        # disjoint supports
        support = tuple(np.concatenate([support_A, support_B]))
        operator = np.kron(A, B)
        operator, support = _reorder_kronecker_product(hilbert, operator, support)
        return tuple(support), operator
    else:
        _support_A = list(support_A)
        _support_B = list(support_B)
        _A = A.copy()
        _B = B.copy()

        # expand _act to match _act_i
        supp_B_min = min(support_B)
        for site in support_A:
            if site not in support_B:
                I = np.eye(hilbert.shape[site], dtype=dtype)
                if site < supp_B_min:
                    _support_B = [site] + _support_B
                    _B = np.kron(I, _B)
                else:  # site > actmax
                    _support_B = _support_B + [site]
                    _B = np.kron(_B, I)

        supp_A_min = min(support_A)
        for site in support_B:
            if site not in support_A:
                I = np.eye(hilbert.shape[site], dtype=dtype)
                if site < supp_A_min:
                    _support_A = [site] + _support_A
                    _A = np.kron(I, _A)
                else:  #  site > actmax
                    _support_A = _support_A + [site]
                    _A = np.kron(_A, I)

        # reorder
        _A, _support_A = _reorder_kronecker_product(hilbert, _A, _support_A)
        _B, _support_B = _reorder_kronecker_product(hilbert, _B, _support_B)

        if len(_support_A) == len(_support_B) and np.array_equal(
            _support_A, _support_B
        ):
            # back to the case of non-interesecting with same support
            return tuple(_support_A), _A @ _B
        else:
            raise ValueError("Something failed")
Exemplo n.º 2
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = (x.reshape(-1, self.n_cells, self.sites_per_cell).transpose(
            0, 2, 1).reshape(-1, self.sites_per_cell, *self.shape))

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, self.n_cells * self.sites_per_cell),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, 0)

        kernel = self.make_kernel(kernel)

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:2], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:3],
                                                    self.n_cells)

        x = lax.dot_general(x,
                            kernel, (((1, ), (2, )), ((2, ), (3, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Exemplo n.º 3
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """
        in_features = x.shape[-2]

        x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (self.features,), self.param_dtype
            )
        else:
            bias = None

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_point * self.n_cells),
            self.param_dtype,
        )

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        x, kernel, bias = promote_dtype(x, kernel, bias, dtype=None)
        dtype = x.dtype

        # Convert the convolutional kernel of shape (features, in_features, n_symm)
        # to the expanded kernel of shape (features, in_features, n_point(in),
        # n_point(out), *shape) used in FFT-based group convolutions
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(
            *kernel.shape[:4], self.n_cells
        )

        x = lax.dot_general(
            x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision
        )
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:2], -1)

        if self.use_bias:
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Exemplo n.º 4
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last dimension.
        Args:
          x: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = x.reshape(-1, x.shape[1] * x.shape[2])

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.out_features, self.in_features, self.n_symm),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, (0, 1))

        kernel = self.full_kernel(kernel)
        kernel = jnp.asarray(kernel, dtype)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 1, ), (0, )), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.out_features, self.n_symm)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.out_features, ),
                              self.dtype)
            bias = jnp.asarray(self.full_bias(bias), dtype)
            x += jnp.expand_dims(bias, (0, 2))

        return x
Exemplo n.º 5
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked linear transformation to the inputs.

        Args:
          inputs: input data with dimensions (batch, length, features).

        Returns:
          The transformed data.
        """
        if inputs.ndim == 2:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)
        else:
            is_single_input = False

        batch, size, in_features = inputs.shape
        inputs = inputs.reshape((batch, size * in_features))

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (size, self.features), self.param_dtype
            )
        else:
            bias = None

        mask = jnp.ones((size, size), dtype=self.param_dtype)
        mask = jnp.triu(mask, self.exclusive)
        mask = jnp.kron(
            mask, jnp.ones((in_features, self.features), dtype=self.param_dtype)
        )

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, mask),
            (size * in_features, size * self.features),
            self.param_dtype,
        )

        inputs, mask, kernel, bias = promote_dtype(
            inputs, mask, kernel, bias, dtype=None
        )

        y = lax.dot(inputs, mask * kernel, precision=self.precision)

        y = y.reshape((batch, size, self.features))

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            y = y + bias

        return y
Exemplo n.º 6
0
def _to_int_vector(v: Array) -> str:
    try:
        v = __to_int_vector(v)
        return f"[{v[0]},{v[1]},{v[2]}]"
    except ValueError:
        # in hexagonal symmetry, you often get a √3 in the x/y coordinate
        try:
            w = v.copy()
            w[1] /= 3**0.5
            w = __to_int_vector(w)
            return f"[{w[0]},{w[1]}√3,{w[2]}]"
        except ValueError:
            # just return a normalised v
            v = v / np.linalg.norm(v)
            return f"[{v[0]:.3f},{v[1]:.3f},{v[2]:.3f}]"
Exemplo n.º 7
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last dimension.
        Args:
          x: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        in_features = x.shape[-2]

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_symm),
            self.param_dtype,
        )

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (self.features,), self.param_dtype
            )
        else:
            bias = None

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        kernel, bias, x = promote_dtype(kernel, bias, x, dtype=None)

        # Converts the convolutional kernel of shape (features, in_features, n_symm)
        # to a full dense kernel of shape (features, in_features, n_symm, n_symm)
        # result[out, in, g, h] == kernel[out, in, g^{-1}h]
        # input dimensions are [in, g], output dimensions are [out, h]
        kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.features, self.n_symm)

        if self.use_bias:
            x += jnp.expand_dims(bias, 1)

        return x
Exemplo n.º 8
0
    def __call__(self, x: Array) -> Array:
        """Applies the symmetrized linear transformation to the inputs along the last dimension.

        Args:
          x: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        kernel = self.param("kernel", self.kernel_init,
                            (self.features, self.n_sites), self.dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, 0)

        kernel = self.full_kernel(kernel).reshape(-1, self.features,
                                                  self.n_symm)
        kernel = jnp.asarray(kernel, dtype)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 1, ), (0, )), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.features, self.n_symm)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(self.full_bias(bias), dtype)
            x += bias

        return x
Exemplo n.º 9
0
def prepare_centered_oks(
    apply_fun: Callable,
    params: PyTree,
    samples: Array,
    model_state: Optional[PyTree],
    mode: str,
    rescale_shift: bool,
    chunk_size: int = None,
) -> PyTree:
    """
    compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩
    divided by √n

    In a somewhat intransparent way this also internally splits all parameters to real
    in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ
    which is only compatible with split-to-real pytree vectors

    Args:
        apply_fun: The forward pass of the Ansatz
        params : a pytree of parameters p
        samples : an array of (n in total) batched samples σ
        model_state: untrained state parameters of the model
        mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic'
        rescale_shift: whether scale-invariant regularisation should be used (default: True)
        chunk_size: an int specfying the size of the chunks degradient should be computed in (default: None)

    Returns:
        if not rescale_shift:
            a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n;
            None
        else:
            the same pytree, but the entries for each parameter normalised to unit norm;
            pytree containing the norms that were divided out (same shape as params)

    """
    # un-batch the samples
    samples = samples.reshape((-1, samples.shape[-1]))

    # pre-apply the model state
    def forward_fn(W, σ):
        return apply_fun({"params": W, **model_state}, σ)

    if mode == "real":
        split_complex_params = True  # convert C→R and R&C→R to R→R
        jacobian_fun = dense_jacobian_real_holo
    elif mode == "complex":
        split_complex_params = True  # convert C→C and R&C→C to R→C
        # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx)

        # avoid converting to complex and then back
        # by passing around the oks as a tuple of two pytrees representing the real and imag parts
        jacobian_fun = dense_jacobian_cplx
    elif mode == "holomorphic":
        split_complex_params = False
        jacobian_fun = dense_jacobian_real_holo
    else:
        raise NotImplementedError(
            'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'.format(
                mode
            )
        )

    # Stored as contiguous real stacked on top of contiguous imaginary (SOA)
    if split_complex_params:
        # doesn't do anything if the params are already real
        params, reassemble = tree_to_reim(params)

        def f(W, σ):
            return forward_fn(reassemble(W), σ)

    else:
        f = forward_fn

    def gradf_fun(params, σ):
        gradf_dense = jacobian_fun(f, params, σ)
        return gradf_dense

    jacobians = nkjax.vmap_chunked(gradf_fun, in_axes=(None, 0), chunk_size=chunk_size)(
        params, samples
    )

    n_samp = samples.shape[0] * mpi.n_nodes
    centered_oks = subtract_mean(jacobians, axis=0) / np.sqrt(
        n_samp, dtype=jacobians.dtype
    )

    centered_oks = centered_oks.reshape(-1, centered_oks.shape[-1])

    if rescale_shift:
        return _rescale(centered_oks)
    else:
        return centered_oks, None
Exemplo n.º 10
0
def _reshape_inputs(model: FastARNNConv2D,
                    inputs: Array) -> Array:  # noqa: F811
    return inputs.reshape((inputs.shape[0], model.L, model.L))
Exemplo n.º 11
0
def prepare_centered_oks(
    apply_fun: Callable,
    params: PyTree,
    samples: Array,
    model_state: Optional[PyTree],
    mode: str,
    rescale_shift: bool,
    pdf=None,
    chunk_size: int = None,
) -> PyTree:
    """
    compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩
    divided by √n

    In a somewhat intransparent way this also internally splits all parameters to real
    in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ
    which is only compatible with split-to-real pytree vectors

    Args:
        apply_fun: The forward pass of the Ansatz
        params : a pytree of parameters p
        samples : an array of (n in total) batched samples σ
        model_state: untrained state parameters of the model
        mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic'
        rescale_shift: whether scale-invariant regularisation should be used (default: True)
        pdf: |ψ(x)|^2 if exact optimization is being used else None
        chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None)

    Returns:
        if not rescale_shift:
            a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n;
            None
        else:
            the same pytree, but the entries for each parameter normalised to unit norm;
            pytree containing the norms that were divided out (same shape as params)

    """
    # un-batch the samples
    samples = samples.reshape((-1, samples.shape[-1]))

    # pre-apply the model state
    def forward_fn(W, σ):
        return apply_fun({"params": W, **model_state}, σ)

    if mode == "real":
        split_complex_params = True  # convert C→R and R&C→R to R→R
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    elif mode == "complex":
        split_complex_params = True  # convert C→C and R&C→C to R→C
        # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx)

        # avoid converting to complex and then back
        # by passing around the oks as a tuple of two pytrees representing the real and imag parts
        centered_jacobian_fun = compose(
            stack_jacobian_tuple,
            partial(centered_jacobian_cplx, _build_fn=lambda *x: x),
        )
        jacobian_fun = jacobian_cplx
    elif mode == "holomorphic":
        split_complex_params = False
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    else:
        raise NotImplementedError(
            'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'
            .format(mode))

    if split_complex_params:
        # doesn't do anything if the params are already real
        params, reassemble = tree_to_real(params)

        def f(W, σ):
            return forward_fn(reassemble(W), σ)

    else:
        f = forward_fn

    if pdf is None:
        centered_oks = _divide_by_sqrt_n_samp(
            centered_jacobian_fun(
                f,
                params,
                samples,
                chunk_size=chunk_size,
            ),
            samples,
        )
    else:
        oks = jacobian_fun(f, params, samples)
        oks_mean = jax.tree_map(partial(sum, axis=0),
                                _multiply_by_pdf(oks, pdf))
        centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean)

        centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf))
    if rescale_shift:
        return _rescale(centered_oks)
    else:
        return centered_oks, None
Exemplo n.º 12
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        # TODO: Deprecated: Eventually remove and error if less than 3 dimensions
        # infer in_features and ensure input dimensions (batch, in_features,n_sites)
        if x.ndim < 3:
            old_shape = x.shape
            if x.ndim == 1:
                x = jnp.expand_dims(x, (0, 1))
            elif x.ndim == 2:
                x = jnp.expand_dims(x, 1)
            symm_input_warning(old_shape, x.shape, "DenseSymm")

        in_features = x.shape[1]

        x = x.reshape(*x.shape[:-1], self.n_cells, self.sites_per_cell)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_cells * self.sites_per_cell),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (features, in_features, n_sites)
        # to the expanded kernel of shape (features, in_features, sites_per_cell,
        # n_point, *shape) used in FFT-based group convolutions.
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:4],
                                                    self.n_cells)

        # TODO: the batch ordering should be revised: batch dimensions should
        # be leading
        x = lax.dot_general(x,
                            kernel, (((1, 2), (1, 2)), ((3, ), (4, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real