Esempio n. 1
0
    def call(self, x: np.ndarray) -> dict:
        x = x.astype(jnp.float32)

        z = Encoder(self.hidden_size, self.latent_size)(x)
        logits = Decoder(self.hidden_size, self.output_shape)(z)

        p = jax.nn.sigmoid(logits)
        image = jax.random.bernoulli(elegy.next_rng_key(), p)

        return dict(image=image, logits=logits, det_image=p)
Esempio n. 2
0
        def call(self, x):
            w = self.add_parameter("w", [x.shape[-1], self.units],
                                   initializer=jnp.ones)
            b = self.add_parameter("b", [self.units], initializer=jnp.ones)

            elegy.next_rng_key()

            n = self.add_parameter("n", [],
                                   dtype=jnp.int32,
                                   initializer=jnp.zeros,
                                   trainable=False)

            self.update_parameter("n", n + 1)

            y = jnp.dot(x, w) + b

            elegy.add_loss("activation_sum", jnp.sum(y))
            elegy.add_metric("activation_mean", jnp.mean(y))

            return y
Esempio n. 3
0
    def call(self, x: np.ndarray) -> np.ndarray:
        x = elegy.nn.Flatten()(x)
        x = elegy.nn.Linear(self.hidden_size)(x)
        x = jax.nn.relu(x)
        elegy.add_summary("relu", x)

        mean = elegy.nn.Linear(self.latent_size, name="linear_mean")(x)
        log_stddev = elegy.nn.Linear(self.latent_size, name="linear_std")(x)
        stddev = jnp.exp(log_stddev)

        elegy.add_loss("kl_divergence", KLDivergence(weight=2e-1)(mean, stddev))

        z = mean + stddev * jax.random.normal(elegy.next_rng_key(), mean.shape)

        return z