def call(self, x: np.ndarray) -> dict: x = x.astype(jnp.float32) z = Encoder(self.hidden_size, self.latent_size)(x) logits = Decoder(self.hidden_size, self.output_shape)(z) p = jax.nn.sigmoid(logits) image = jax.random.bernoulli(elegy.next_rng_key(), p) return dict(image=image, logits=logits, det_image=p)
def call(self, x): w = self.add_parameter("w", [x.shape[-1], self.units], initializer=jnp.ones) b = self.add_parameter("b", [self.units], initializer=jnp.ones) elegy.next_rng_key() n = self.add_parameter("n", [], dtype=jnp.int32, initializer=jnp.zeros, trainable=False) self.update_parameter("n", n + 1) y = jnp.dot(x, w) + b elegy.add_loss("activation_sum", jnp.sum(y)) elegy.add_metric("activation_mean", jnp.mean(y)) return y
def call(self, x: np.ndarray) -> np.ndarray: x = elegy.nn.Flatten()(x) x = elegy.nn.Linear(self.hidden_size)(x) x = jax.nn.relu(x) elegy.add_summary("relu", x) mean = elegy.nn.Linear(self.latent_size, name="linear_mean")(x) log_stddev = elegy.nn.Linear(self.latent_size, name="linear_std")(x) stddev = jnp.exp(log_stddev) elegy.add_loss("kl_divergence", KLDivergence(weight=2e-1)(mean, stddev)) z = mean + stddev * jax.random.normal(elegy.next_rng_key(), mean.shape) return z