def call(self, x: np.ndarray) -> np.ndarray: x = elegy.nn.Flatten()(x) x = elegy.nn.Linear(self.hidden_size)(x) x = jax.nn.relu(x) elegy.add_summary("relu", x) mean = elegy.nn.Linear(self.latent_size, name="linear_mean")(x) log_stddev = elegy.nn.Linear(self.latent_size, name="linear_std")(x) stddev = jnp.exp(log_stddev) elegy.add_loss("kl_divergence", KLDivergence(weight=2e-1)(mean, stddev)) z = mean + stddev * jax.random.normal(elegy.next_rng_key(), mean.shape) return z
def call(self, x): w = elegy.get_parameter("w", [x.shape[-1], self.units], jnp.float32, jnp.ones) b = elegy.get_parameter("b", [self.units], jnp.float32, jnp.ones) n = self.get_state("n", [], np.int32, jnp.zeros) self.set_state("n", n + 1) y = jnp.dot(x, w) + b elegy.add_loss("activation_sum", jnp.sum(y)) elegy.add_metric("activation_mean", jnp.mean(y)) return y
def call(self, x): w = self.add_parameter("w", [x.shape[-1], self.units], initializer=jnp.ones) b = self.add_parameter("b", [self.units], initializer=jnp.ones) n = self.add_parameter("n", [], dtype=jnp.int32, initializer=jnp.zeros, trainable=False) self.update_parameter("n", n + 1) y = jnp.dot(x, w) + b elegy.add_loss("activation_sum", jnp.sum(y)) elegy.add_metric("activation_mean", jnp.mean(y)) return y