Exemple #1
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """
        in_features = x.shape[-2]

        x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        if self.use_bias:
            bias = self.param(
                "bias", self.bias_init, (self.features,), self.param_dtype
            )
        else:
            bias = None

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_point * self.n_cells),
            self.param_dtype,
        )

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        x, kernel, bias = promote_dtype(x, kernel, bias, dtype=None)
        dtype = x.dtype

        # Convert the convolutional kernel of shape (features, in_features, n_symm)
        # to the expanded kernel of shape (features, in_features, n_point(in),
        # n_point(out), *shape) used in FFT-based group convolutions
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel, s=self.shape).reshape(
            *kernel.shape[:4], self.n_cells
        )

        x = lax.dot_general(
            x, kernel, (((1, 2), (1, 2)), ((3,), (4,))), precision=self.precision
        )
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:2], -1)

        if self.use_bias:
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Exemple #2
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        x = x.reshape(*x.shape[:-1], self.n_cells, self.n_point)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (
                self.out_features,
                self.in_features,
                self.n_point * self.n_cells,
            ),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.mask, (0, 1))

        kernel = self.make_kernel(kernel)

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:4],
                                                    self.n_cells)

        x = lax.dot_general(x,
                            kernel, (((1, 2), (1, 2)), ((3, ), (4, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.out_features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real
Exemple #3
0
    def __call__(self, x: Array) -> Array:
        """Applies the equivariant transform to the inputs along the last two
        dimensions (-2: features, -1: group elements)
        """

        dtype = jnp.promote_types(x.dtype, self.dtype)
        x = jnp.asarray(x, dtype)

        # TODO: Deprecated: Eventually remove and error if less than 3 dimensions
        # infer in_features and ensure input dimensions (batch, in_features,n_sites)
        if x.ndim < 3:
            old_shape = x.shape
            if x.ndim == 1:
                x = jnp.expand_dims(x, (0, 1))
            elif x.ndim == 2:
                x = jnp.expand_dims(x, 1)
            symm_input_warning(old_shape, x.shape, "DenseSymm")

        in_features = x.shape[1]

        x = x.reshape(*x.shape[:-1], self.n_cells, self.sites_per_cell)
        x = x.transpose(0, 1, 3, 2)
        x = x.reshape(*x.shape[:-1], *self.shape)

        kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.features, in_features, self.n_cells * self.sites_per_cell),
            self.dtype,
        )

        kernel = jnp.asarray(kernel, dtype)

        if self.mask is not None:
            kernel = kernel * jnp.expand_dims(self.scaled_mask, (0, 1))

        # Converts the convolutional kernel of shape (features, in_features, n_sites)
        # to the expanded kernel of shape (features, in_features, sites_per_cell,
        # n_point, *shape) used in FFT-based group convolutions.
        kernel = kernel[..., self.mapping]

        x = jnp.fft.fftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)

        kernel = jnp.fft.fftn(kernel,
                              s=self.shape).reshape(*kernel.shape[:4],
                                                    self.n_cells)

        # TODO: the batch ordering should be revised: batch dimensions should
        # be leading
        x = lax.dot_general(x,
                            kernel, (((1, 2), (1, 2)), ((3, ), (4, ))),
                            precision=self.precision)
        x = x.transpose(1, 2, 3, 0)
        x = x.reshape(*x.shape[:3], *self.shape)

        x = jnp.fft.ifftn(x, s=self.shape).reshape(*x.shape[:3], self.n_cells)
        x = x.transpose(0, 1, 3, 2).reshape(*x.shape[:2], -1)

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            x += jnp.expand_dims(bias, (0, 2))

        if jnp.can_cast(x, dtype):
            return x
        else:
            return x.real