Example #1
0
 def testMixedFloatInt(self):
     tree = [
         jnp.array([3], jnp.int32),
         jnp.array([[1., 2.], [3., 4.]], jnp.float32)
     ]
     raveled, unravel = flatten_util.ravel_pytree(tree)
     self.assertEqual(raveled.dtype,
                      jnp.promote_types(jnp.float32, jnp.int32))
     tree_ = unravel(raveled)
     self.assertAllClose(tree, tree_, atol=0., rtol=0.)
Example #2
0
 def testMixedIntBool(self):
     tree = [
         jnp.array([0], jnp.bool_),
         jnp.array([[1, 2], [3, 4]], jnp.int32)
     ]
     raveled, unravel = flatten_util.ravel_pytree(tree)
     self.assertEqual(raveled.dtype,
                      jnp.promote_types(jnp.bool_, jnp.int32))
     tree_ = unravel(raveled)
     self.assertAllClose(tree, tree_, atol=0., rtol=0.)
Example #3
0
 def testMixedFloatComplex(self):
     tree = [
         jnp.array([1.], jnp.float32),
         jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)
     ]
     raveled, unravel = flatten_util.ravel_pytree(tree)
     self.assertEqual(raveled.dtype,
                      jnp.promote_types(jnp.float32, jnp.complex64))
     tree_ = unravel(raveled)
     self.assertAllClose(tree, tree_, atol=0., rtol=0.)
Example #4
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """
        in_features = x.shape[-2]

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_point * self.n_cells),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Convert the convolutional kernel of shape (features, in_features, n_symm)
        # to the expanded kernel of shape (features, in_features, n_point(in),
        # n_point(out), *shape) used in FFT-based group convolutions
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(
            *kernel.shape[:4], self.n_cells
        )

        x = lax.dot_general(
            x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision
        )
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Example #5
0
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        dtype = jnp.promote_types(x_in.dtype, self.dtype)
        x_in = jnp.asarray(x_in, dtype=dtype)

        kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype)
        kernel = kernel + kernel.T
        y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
Example #6
0
def minkowski_distance_p(x, y, p=2):
    """
    Compute the pth power of the L**p distance between two arrays.

    For efficiency, this function computes the L**p distance but does
    not extract the pth root. If `p` is 1 or infinity, this is equal to
    the actual L**p distance.

    Parameters
    ----------
    x : (M, K) array_like
        Input array.
    y : (N, K) array_like
        Input array.
    p : float, 1 <= p <= infinity
        Which Minkowski p-norm to use.

    Examples
    --------
    >>> from scipy.spatial import minkowski_distance_p
    >>> minkowski_distance_p([[0,0],[0,0]], [[1,1],[0,1]])
    array([2, 1])

    """
    x = np.asarray(x)
    y = np.asarray(y)

    # Find smallest common datatype with float64 (return type of this function) - addresses #10262.
    # Don't just cast to float64 for complex input case.
    common_datatype = np.promote_types(np.promote_types(x.dtype, y.dtype), 'float64')

    # Make sure x and y are NumPy arrays of correct datatype.
    x = x.astype(common_datatype)
    y = y.astype(common_datatype)

    if p == np.inf:
        return np.amax(np.abs(y-x), axis=-1)
    elif p == 1:
        return np.sum(np.abs(y-x), axis=-1)
    else:
        return np.sum(np.abs(y-x)**p, axis=-1)
Example #7
0
    def __call__(self, inputs: Array) -> Array:
        """Applies a convolution to the inputs.
        Args:
          inputs: input data with dimensions (batch, spatial_dims..., features).
        Returns:
          The convolved data.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        if isinstance(self.kernel_size, int):
            kernel_size = (self.kernel_size,)
        else:
            kernel_size = self.kernel_size

        is_single_input = False
        if inputs.ndim == len(kernel_size) + 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        strides = self.strides or (1,) * (inputs.ndim - 2)

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = kernel_size + (
            in_features // self.feature_group_count,
            self.features,
        )
        kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(
            inputs,
            kernel,
            strides,
            self.padding,
            lhs_dilation=self.input_dilation,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if is_single_input:
            y = jnp.squeeze(y, axis=0)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Example #8
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked linear transformation to the inputs.

        Args:
          inputs: input data with dimensions (batch, length, features).

        Returns:
          The transformed data.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        is_single_input = False
        if inputs.ndim == 2:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, size, in_features = inputs.shape
        inputs = inputs.reshape((batch, size * in_features))

        mask = jnp.ones((size, size), dtype=self.dtype)
        mask = jnp.triu(mask, self.exclusive)
        mask = jnp.kron(
            mask, jnp.ones((in_features, self.features), dtype=self.dtype))

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, mask),
            (size * in_features, size * self.features),
            self.dtype,
        )
        mask = jnp.asarray(mask, dtype)
        kernel = jnp.asarray(kernel, dtype)

        y = lax.dot(inputs, mask * kernel, precision=self.precision)

        y = y.reshape((batch, size, self.features))

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (size, self.features),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias

        return y
Example #9
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = (x.reshape(-1, self.n_cells, self.sites_per_cell).transpose(
            0, 2, 1).reshape(-1, self.sites_per_cell, *self.shape))

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, self.n_cells * self.sites_per_cell),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, 0)

        kernel = self.make_kernel(kernel)

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:2], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:3],
                                                    self.n_cells)

        x = lax.dot_general(x,
                            kernel, (((1, ), (2, )), ((2, ), (3, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Example #10
0
    def __call__(self, inputs: Array) -> Array:
        """Applies a transposed convolution to the inputs. Behaviour mirrors of
        `jax.lax.conv_transpose`.
        Args:
          inputs: input data with dimensions (batch, spatial_dims..., features).
        Returns:
          The convolved data.
        """
        dtype = jnp.promote_types(self.dtype, inputs.dtype)

        inputs = jnp.asarray(inputs, dtype)

        if isinstance(self.kernel_size, int):
            kernel_size = (self.kernel_size, )
        else:
            kernel_size = self.kernel_size

        is_single_input = False
        if inputs.ndim == len(kernel_size) + 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        strides = self.strides or (1, ) * (inputs.ndim - 2)

        in_features = inputs.shape[-1]
        kernel_shape = kernel_size + (in_features, self.features)
        kernel = self.param("kernel", self.kernel_init, kernel_shape,
                            self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        y = lax.conv_transpose(
            inputs,
            kernel,
            strides,
            self.padding,
            rhs_dilation=self.kernel_dilation,
            precision=self.precision,
        )

        if is_single_input:
            y = jnp.squeeze(y, axis=0)
        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Example #11
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last dimension.
        Args:
          x: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        in_features = x.shape[-2]

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_symm),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (features, in_features, n_symm)
        # to a full dense kernel of shape (features, in_features, n_symm, n_symm)
        # result[out, in, g, h] == kernel[out, in, g^{-1}h]
        # input dimensions are [in, g], output dimensions are [out, h]
        kernel = jnp.take(kernel, jnp.asarray(self.product_table), 2)
        kernel = jnp.asarray(kernel, dtype)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 2, x.ndim - 1), (1, 2)), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.features, self.n_symm)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            x += jnp.expand_dims(bias, 1)

        return x
Example #12
0
    def __init__(
        self,
        hilbert: AbstractHilbert,
        graph: AbstractGraph,
        h: float,
        J: float = 1.0,
        dtype: DType = None,
    ):
        r"""
        Constructs the Ising Operator from an hilbert space and a
        graph specifying the connectivity.

        Args:
            hilbert: Hilbert space the operator acts on.
            h: The strength of the transverse field.
            J: The strength of the coupling. Default is 1.0.
            dtype: The dtype of the matrix elements.

        Examples:
            Constructs an ``Ising`` operator for a 1D system.

            >>> import netket as nk
            >>> g = nk.graph.Hypercube(length=20, n_dim=1, pbc=True)
            >>> hi = nk.hilbert.Spin(s=0.5, N=g.n_nodes)
            >>> op = nk.operator.Ising(h=1.321, hilbert=hi, J=0.5, graph=g)
            >>> print(op)
            Ising(J=0.5, h=1.321; dim=20)
        """
        assert (
            graph.n_nodes == hilbert.size
        ), "The size of the graph must match the hilbert space"

        super().__init__(hilbert)

        if dtype is None:
            dtype = jnp.promote_types(_dtype(h), _dtype(J))
        dtype = np.empty((), dtype=dtype).dtype
        self._dtype = dtype

        self._h = np.array(h, dtype=dtype)
        self._J = np.array(J, dtype=dtype)
        self._edges = np.asarray(
            [[u, v] for u, v in graph.edges()],
            dtype=np.intp,
        )
Example #13
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """
        in_features = x.shape[-2]

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = self.forward_ft(x)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_symm),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        kernel = self.forward_ft(kernel)

        x = tuple(
            lax.dot_general(
                x[i], kernel[i], (((1, 4), (1, 3)), ((2,), (2,)))
            ).transpose(1, 3, 0, 2, 4)
            for i in range(len(x))
        )

        x = self.inverse_ft(x)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)

            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Example #14
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last dimension.
        Args:
          x: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = x.reshape(-1, x.shape[1] * x.shape[2])

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.out_features, self.in_features, self.n_symm),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, (0, 1))

        kernel = self.full_kernel(kernel)
        kernel = jnp.asarray(kernel, dtype)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 1, ), (0, )), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.out_features, self.n_symm)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.out_features, ),
                              self.dtype)
            bias = jnp.asarray(self.full_bias(bias), dtype)
            x += jnp.expand_dims(bias, (0, 2))

        return x
Example #15
0
  def testFftn(self, inverse, real, shape, dtype, axes, s, norm):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: (rng(shape, dtype),)
    jnp_op = _get_fftn_func(jnp.fft, inverse, real)
    np_op = _get_fftn_func(np.fft, inverse, real)
    jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm)
    np_fn = lambda a: np_op(a, axes=axes, norm=norm) if axes is None or axes else a
    # Numpy promotes to complex128 aggressively.
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(jnp_fn, args_maker)
    # Test gradient for differentiable types.
    if (config.x64_enabled and
        dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
      # TODO(skye): can we be more precise?
      tol = 0.15
      jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

    # check dtypes
    dtype = jnp_fn(rng(shape, dtype)).dtype
    expected_dtype = jnp.promote_types(float if inverse and real else complex, dtype)
    self.assertEqual(dtype, expected_dtype)
Example #16
0
    def __call__(self, x: Array) -> Array:
        """Applies the symmetrized linear transformation to the inputs along the last dimension.

        Args:
          x: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        kernel = self.param("kernel", self.kernel_init,
                            (self.features, self.n_sites), self.dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, 0)

        kernel = self.full_kernel(kernel).reshape(-1, self.features,
                                                  self.n_symm)
        kernel = jnp.asarray(kernel, dtype)

        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 1, ), (0, )), ((), ())),
            precision=self.precision,
        )

        x = x.reshape(-1, self.features, self.n_symm)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(self.full_bias(bias), dtype)
            x += bias

        return x
Example #17
0
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.
        Args:
          inputs: The nd-array to be transformed.
        Returns:
          The transformed input.
        """
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)
        batch_dims = _canonicalize_tuple(self.batch_dims)
        if batch_dims:
            max_dim = np.max(batch_dims)
            if set(batch_dims) != set(range(max_dim + 1)):
                raise ValueError(
                    "batch_dims %s must be consecutive leading "
                    "dimensions starting from 0." % str(batch_dims)
                )

        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        ndim = inputs.ndim
        n_batch_dims = len(batch_dims)
        axis = _normalize_axes(axis, ndim)
        batch_dims = _normalize_axes(batch_dims, ndim)
        n_axis, n_features = len(axis), len(features)

        def kernel_init_wrap(rng, shape, dtype=jnp.float64):
            size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
            flat_shape = (
                np.prod(shape[n_batch_dims : n_axis + n_batch_dims]),
                np.prod(shape[-n_features:]),
            )
            kernel = jnp.concatenate(
                [
                    self.kernel_init(rng, flat_shape, dtype)
                    for _ in range(size_batch_dims)
                ],
                axis=0,
            )
            return jnp.reshape(kernel, shape)

        batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
        kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
        kernel = self.param("kernel", kernel_init_wrap, batch_shape + kernel_shape)
        kernel = jnp.asarray(kernel, dtype)

        batch_ind = tuple(range(n_batch_dims))
        contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
        out = lax.dot_general(
            inputs,
            kernel,
            ((axis, contract_ind), (batch_dims, batch_ind)),
            precision=self.precision,
        )

        if self.use_bias:

            def bias_init_wrap(rng, shape, dtype=jnp.float64):
                size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
                flat_shape = (np.prod(shape[-n_features:]),)
                bias = jnp.concatenate(
                    [
                        self.bias_init(rng, flat_shape, dtype)
                        for _ in range(size_batch_dims)
                    ],
                    axis=0,
                )
                return jnp.reshape(bias, shape)

            bias = self.param("bias", bias_init_wrap, batch_shape + features)

            # Reshape bias for broadcast.
            expand_dims = sorted(set(range(inputs.ndim)) - set(axis) - set(batch_dims))
            for ax in expand_dims:
                bias = jnp.expand_dims(bias, ax)
            bias = jnp.asarray(bias, dtype)
            out = out + bias
        return out
Example #18
0
    def __call__(self, x: Array) -> Array:
        """Applies the symmetrized linear transformation to the inputs along the last dimension.

        Args:
          x: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)
        # infer in_features and ensure input dimensions (batch, in_features,n_sites)

        # TODO: Deprecated: Eventually remove and error if less than 3 dimensions
        if x.ndim < 3:
            old_shape = x.shape
            if x.ndim == 1:
                x = jnp.expand_dims(x, (0, 1))
            elif x.ndim == 2:
                x = jnp.expand_dims(x, 1)
            symm_input_warning(old_shape, x.shape, "DenseSymm")

        in_features = x.shape[1]

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_sites),
            self.dtype,
        )

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (self.features, in_features, n_sites)
        # to a full dense kernel of shape (self.features, in_features, n_symm, n_sites).
        # result[out, in, g, r] == kernel[out, in, g^{-1}r]
        kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2)
        kernel = jnp.asarray(kernel, dtype)

        # x is      (batches,       in_featuers,         n_sites)
        # kernel is (self.features, in_features, n_symm, n_sites)
        x = lax.dot_general(
            x,
            kernel,
            (((x.ndim - 2, x.ndim - 1), (1, 3)), ((), ())),
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)

            # Convert symmetry-reduced bias of shape (features,) to the full bias of
            # shape (..., features, 1).
            bias = jnp.expand_dims(bias, 1)
            bias = jnp.asarray(bias, dtype)

            x += bias

        return x
Example #19
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        # TODO: Deprecated: Eventually remove and error if less than 3 dimensions
        # infer in_features and ensure input dimensions (batch, in_features,n_sites)
        if x.ndim < 3:
            old_shape = x.shape
            if x.ndim == 1:
                x = jnp.expand_dims(x, (0, 1))
            elif x.ndim == 2:
                x = jnp.expand_dims(x, 1)
            symm_input_warning(old_shape, x.shape, "DenseSymm")

        in_features = x.shape[1]

        x = x.reshape(*x.shape[:-1], self.n_cells, self.sites_per_cell)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_cells * self.sites_per_cell),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (features, in_features, n_sites)
        # to the expanded kernel of shape (features, in_features, sites_per_cell,
        # n_point, *shape) used in FFT-based group convolutions.
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:4],
                                                    self.n_cells)

        # TODO: the batch ordering should be revised: batch dimensions should
        # be leading
        x = lax.dot_general(x,
                            kernel, (((1, 2), (1, 2)), ((3, ), (4, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Example #20
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked linear transformation to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        size = self.size

        # Number of input sites depended by the output site at the index
        size_i = index + 1

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable("cache", "inputs", zeros, None,
                               (batch, size, in_features), inputs.dtype)

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment
            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: _cache.value.at[:, index - self.exclusive, :].set(
                    inputs),
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        cache_i = cache[:, :size_i, :]
        cache_i = cache_i.reshape((batch, size_i * in_features))

        # The construction of `mask` will be optimized to a constant by JIT
        mask = jnp.ones((size, size), dtype=self.dtype)
        mask = jnp.triu(mask, self.exclusive)
        mask = jnp.kron(
            mask, jnp.ones((in_features, self.features), dtype=self.dtype))

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, mask),
            (size * in_features, size * self.features),
            self.dtype,
        )
        mask = jnp.asarray(mask, dtype)
        kernel = jnp.asarray(kernel, dtype)

        mask_i = mask.reshape((size, in_features, size, self.features))
        mask_i = mask_i[:size_i, :, index, :]
        mask_i = mask_i.reshape((size_i * in_features, self.features))

        kernel_i = kernel.reshape((size, in_features, size, self.features))
        kernel_i = kernel_i[:size_i, :, index, :]
        kernel_i = kernel_i.reshape((size_i * in_features, self.features))

        y_i = lax.dot(cache_i, mask_i * kernel_i, precision=self.precision)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (size, self.features),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)

            bias_i = bias[index, :]

            y_i = y_i + bias_i

        assert y_i.shape[1] == self.features

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
Example #21
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked convolution to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The next output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        L = self.L
        index_w = index % L

        kernel_h, kernel_w = self.kernel_size
        dilation_h, dilation_w = self.kernel_dilation
        ones = (1, 1)

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        assert in_features % self.feature_group_count == 0
        recep_h = (kernel_h - 1) * dilation_h + 1
        recep_w = (kernel_w - 1) * dilation_w + 1

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable(
            "cache",
            "inputs",
            zeros,
            None,
            (batch, recep_h, L, in_features),
            inputs.dtype,
        )

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment

            inputs = jnp.expand_dims(inputs, axis=(1, 2))

            # Index of the input site in the width direction
            index_w_in = (index - self.exclusive) % L

            def _add(cache):
                # return cache.at[:, -1, index_w_in, :].set(inputs)
                return lax.dynamic_update_slice(cache, inputs,
                                                (0, -1, index_w_in, 0))

            def _shift(cache):
                return jnp.concatenate(
                    [
                        cache[:, 1:, :, :],
                        jnp.zeros(
                            (batch, 1, L, in_features), dtype=inputs.dtype),
                    ],
                    axis=1,
                )

            cache_new_row = lax.cond(
                index_w_in == 0,
                lambda _: _add(_shift(_cache.value)),
                lambda _: _shift(_add(_cache.value)),
                None,
            )

            cache_new = lax.cond(
                index_w == 0,
                lambda _: cache_new_row,
                lambda _: _add(_cache.value),
                None,
            )

            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: cache_new,
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count,
            self.features,
        )
        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, self.mask),
            kernel_shape,
            self.dtype,
        )
        kernel = jnp.asarray(kernel, dtype)

        # Zero padding
        cache = jnp.pad(
            cache,
            (
                (0, 0),
                (0, 0),
                (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w),
                (0, 0),
            ),
        )

        # cache = cache[:, :, index_w : index_w + recep_w, :]
        cache = lax.dynamic_slice(cache, (0, 0, index_w, 0),
                                  (batch, recep_h, recep_w, in_features))

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(
            cache.shape)
        y_i = lax.conv_general_dilated(
            cache,
            kernel,
            window_strides=ones,
            padding="VALID",
            lhs_dilation=ones,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            y_i = y_i + bias

        y_i = y_i.squeeze(axis=(1, 2))

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
Example #22
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked convolution to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The next output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        kernel_size = self.kernel_size - self.exclusive
        dilation = self.kernel_dilation

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        assert in_features % self.feature_group_count == 0
        cache_size = kernel_size * dilation - (not self.exclusive) * (
            dilation - 1)

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable(
            "cache",
            "inputs",
            zeros,
            None,
            (batch, cache_size, in_features),
            inputs.dtype,
        )

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment
            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: jnp.concatenate(
                    [_cache.value[:, 1:, :],
                     jnp.expand_dims(inputs, axis=1)],
                    axis=1),
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        kernel_shape = (
            kernel_size,
            in_features // self.feature_group_count,
            self.features,
        )
        kernel = self.param("kernel", self.kernel_init, kernel_shape,
                            self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        if self.exclusive and dilation > 1:
            cache = cache[:, :-(dilation - 1), :]

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(
            cache.shape)
        y_i = lax.conv_general_dilated(
            cache,
            kernel,
            window_strides=(1, ),
            padding="VALID",
            lhs_dilation=(1, ),
            rhs_dilation=(dilation, ),
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            y_i = y_i + bias

        y_i = y_i.squeeze(axis=1)

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
Example #23
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked convolution to the inputs.
        For 1D convolution, there is not really a mask. We only need to apply
        appropriate padding.

        Args:
          inputs: input data with dimensions (batch, length, features).

        Returns:
          The convolved data.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        kernel_size = self.kernel_size - self.exclusive
        dilation = self.kernel_dilation

        is_single_input = False
        if inputs.ndim == 2:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = (
            kernel_size,
            in_features // self.feature_group_count,
            self.features,
        )

        kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype)
        kernel = jnp.asarray(kernel, dtype)

        if self.exclusive:
            inputs = inputs[:, :-dilation, :]

        # Zero padding
        y = jnp.pad(
            inputs,
            (
                (0, 0),
                ((kernel_size - (not self.exclusive)) * dilation, 0),
                (0, 0),
            ),
        )

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(
            y,
            kernel,
            window_strides=(1,),
            padding="VALID",
            lhs_dilation=(1,),
            rhs_dilation=(dilation,),
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias

        return y
Example #24
0
    def __call__(self, inputs: Array) -> Array:
        """
        Applies a masked convolution to the inputs.

        Args:
          inputs: input data with dimensions (batch, width, height, features).

        Returns:
          The convolved data.
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        kernel_h, kernel_w = self.kernel_size
        dilation_h, dilation_w = self.kernel_dilation
        ones = (1, 1)

        is_single_input = False
        if inputs.ndim == 3:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        in_features = inputs.shape[-1]
        assert in_features % self.feature_group_count == 0
        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count,
            self.features,
        )

        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, self.mask),
            kernel_shape,
            self.dtype,
        )
        mask = jnp.asarray(self.mask, dtype)
        kernel = jnp.asarray(kernel, dtype)

        # Zero padding
        y = jnp.pad(
            inputs,
            (
                (0, 0),
                ((kernel_h - 1) * dilation_h, 0),
                (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w),
                (0, 0),
            ),
        )

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape)
        y = lax.conv_general_dilated(
            y,
            mask * kernel,
            window_strides=ones,
            padding="VALID",
            lhs_dilation=ones,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if is_single_input:
            y = y.squeeze(axis=0)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features,), self.dtype)
            bias = jnp.asarray(bias, dtype)
            y = y + bias

        return y