Example #1
0
    def call(self, x: np.ndarray) -> np.ndarray:
        x = elegy.nn.Flatten()(x)
        x = elegy.nn.Linear(self.hidden_size)(x)
        x = jax.nn.relu(x)
        elegy.add_summary("relu", x)

        mean = elegy.nn.Linear(self.latent_size, name="linear_mean")(x)
        log_stddev = elegy.nn.Linear(self.latent_size, name="linear_std")(x)
        stddev = jnp.exp(log_stddev)

        elegy.add_loss("kl_divergence", KLDivergence(weight=2e-1)(mean, stddev))

        z = mean + stddev * jax.random.normal(elegy.next_rng_key(), mean.shape)

        return z
Example #2
0
    def call(self, x):
        w = elegy.get_parameter("w", [x.shape[-1], self.units], jnp.float32,
                                jnp.ones)
        b = elegy.get_parameter("b", [self.units], jnp.float32, jnp.ones)

        n = self.get_state("n", [], np.int32, jnp.zeros)

        self.set_state("n", n + 1)

        y = jnp.dot(x, w) + b

        elegy.add_loss("activation_sum", jnp.sum(y))
        elegy.add_metric("activation_mean", jnp.mean(y))

        return y
Example #3
0
        def call(self, x):
            w = self.add_parameter("w", [x.shape[-1], self.units],
                                   initializer=jnp.ones)
            b = self.add_parameter("b", [self.units], initializer=jnp.ones)

            n = self.add_parameter("n", [],
                                   dtype=jnp.int32,
                                   initializer=jnp.zeros,
                                   trainable=False)

            self.update_parameter("n", n + 1)

            y = jnp.dot(x, w) + b

            elegy.add_loss("activation_sum", jnp.sum(y))
            elegy.add_metric("activation_mean", jnp.mean(y))

            return y