def test_logsumexp_cplx(a, b): a = jnp.asarray(a) if b is not None: b = jnp.asarray(b) expected = jnp.log( complex(jnp.exp(a[0]) * b[0] + jnp.exp(a[1]) * b[1])) else: expected = jnp.log(complex(jnp.exp(a[0]) + jnp.exp(a[1]))) c = logsumexp_cplx(a, b=b) assert jnp.iscomplexobj(c) assert_allclose(c, expected, atol=1e-8)
def __call__(self, x): if x.ndim < 3: x = jnp.expand_dims(x, -2) # add a feature dimension x_flip = self.dense_symm(-1 * x) x = self.dense_symm(x) for layer in range(self.layers - 1): x = self.activation(x) x_flip = self.activation(x_flip) x_new = ( self.equivariant_layers[layer](x) + self.equivariant_layers_flip[layer](x_flip) ) / np.sqrt(2) x_flip = ( self.equivariant_layers[layer](x_flip) + self.equivariant_layers_flip[layer](x) ) / np.sqrt(2) x = jnp.array(x_new, copy=True) x = jnp.concatenate((x, x_flip), -1) x = self.output_activation(x) if self.parity == 1: par_chars = jnp.expand_dims( jnp.concatenate( (jnp.array(self.characters), jnp.array(self.characters)), 0 ), (0, 1), ) else: par_chars = jnp.expand_dims( jnp.concatenate( (jnp.array(self.characters), -1 * jnp.array(self.characters)), 0 ), (0, 1), ) if self.complex_output: x = logsumexp_cplx(x, axis=(-2, -1), b=par_chars) else: x = logsumexp(x, axis=(-2, -1), b=par_chars) if self.equal_amplitudes: return 1j * jnp.imag(x) else: return x
def __call__(self, x): if x.ndim < 3: x = jnp.expand_dims(x, -2) # add a feature dimension x = self.dense_symm(x) for layer in range(self.layers - 1): x = self.activation(x) x = self.equivariant_layers[layer](x) x = self.output_activation(x) if self.complex_output: x = logsumexp_cplx(x, axis=(-2, -1), b=jnp.asarray(self.characters)) else: x = logsumexp(x, axis=(-2, -1), b=jnp.asarray(self.characters)) if self.equal_amplitudes: return 1j * jnp.imag(x) else: return x