示例#1
0
def test_logsumexp_cplx(a, b):
    a = jnp.asarray(a)
    if b is not None:
        b = jnp.asarray(b)
        expected = jnp.log(
            complex(jnp.exp(a[0]) * b[0] + jnp.exp(a[1]) * b[1]))
    else:
        expected = jnp.log(complex(jnp.exp(a[0]) + jnp.exp(a[1])))
    c = logsumexp_cplx(a, b=b)

    assert jnp.iscomplexobj(c)
    assert_allclose(c, expected, atol=1e-8)
示例#2
0
    def __call__(self, x):
        if x.ndim < 3:
            x = jnp.expand_dims(x, -2)  # add a feature dimension

        x_flip = self.dense_symm(-1 * x)
        x = self.dense_symm(x)

        for layer in range(self.layers - 1):
            x = self.activation(x)
            x_flip = self.activation(x_flip)

            x_new = (
                self.equivariant_layers[layer](x)
                + self.equivariant_layers_flip[layer](x_flip)
            ) / np.sqrt(2)
            x_flip = (
                self.equivariant_layers[layer](x_flip)
                + self.equivariant_layers_flip[layer](x)
            ) / np.sqrt(2)
            x = jnp.array(x_new, copy=True)

        x = jnp.concatenate((x, x_flip), -1)

        x = self.output_activation(x)

        if self.parity == 1:
            par_chars = jnp.expand_dims(
                jnp.concatenate(
                    (jnp.array(self.characters), jnp.array(self.characters)), 0
                ),
                (0, 1),
            )
        else:
            par_chars = jnp.expand_dims(
                jnp.concatenate(
                    (jnp.array(self.characters), -1 * jnp.array(self.characters)), 0
                ),
                (0, 1),
            )

        if self.complex_output:
            x = logsumexp_cplx(x, axis=(-2, -1), b=par_chars)
        else:
            x = logsumexp(x, axis=(-2, -1), b=par_chars)

        if self.equal_amplitudes:
            return 1j * jnp.imag(x)
        else:
            return x
示例#3
0
    def __call__(self, x):
        if x.ndim < 3:
            x = jnp.expand_dims(x, -2)  # add a feature dimension
        x = self.dense_symm(x)

        for layer in range(self.layers - 1):
            x = self.activation(x)
            x = self.equivariant_layers[layer](x)

        x = self.output_activation(x)

        if self.complex_output:
            x = logsumexp_cplx(x, axis=(-2, -1), b=jnp.asarray(self.characters))
        else:
            x = logsumexp(x, axis=(-2, -1), b=jnp.asarray(self.characters))

        if self.equal_amplitudes:
            return 1j * jnp.imag(x)
        else:
            return x