Пример #1
0
    def discriminator_step(self, x_real: jnp.ndarray, states: elegy.States):
        z = jax.random.normal(states.rng.next(), (len(x_real), 128))
        x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0]

        def d_loss_fn(d_params, states, x_real, x_fake):
            y_real, d_params, d_states = self.discriminator.apply(
                d_params, states.d_states)(x_real)
            y_fake, d_params, d_states = self.discriminator.apply(
                d_params, d_states)(x_fake)
            loss = -y_real.mean() + y_fake.mean()
            gp = gradient_penalty(
                x_real,
                x_fake,
                self.discriminator.apply(d_params, d_states),
                states.rng.next(),
            )
            loss = loss + gp
            return loss, (gp, states.update_known(**locals()))

        (d_loss, (gp, states)), d_grads = jax.value_and_grad(
            d_loss_fn, has_aux=True)(states.d_params, states, x_real, x_fake)
        d_grads, d_opt_states = self.d_optimizer.update(
            d_grads, states.d_opt_states, states.d_params)
        d_params = optax.apply_updates(states.d_params, d_grads)

        return d_loss, gp, states.update_known(**locals())
    def test_step(
        self,
        x,
        y_true,
        net_params,
        states: elegy.States,
        initializing: bool,
        rng: elegy.RNGSeq,
    ):
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255

        # model
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10], minval=-1, maxval=1
            )
            b = jax.random.uniform(rng.next(), shape=[1], minval=-1, maxval=1)
            net_params = (w, b)

        w, b = net_params
        logits = jnp.dot(x, w) + b

        # crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))

        # metrics
        logs = dict(
            acc=jnp.mean(jnp.argmax(logits, axis=-1) == y_true),
            loss=loss,
        )

        return loss, logs, states.update(net_params=net_params)
Пример #3
0
    def test_step(
        self,
        x,
        states: elegy.States,
        training,
        initializing,
    ):
        with elegy.hooks.context(losses=True):
            logits, states = elegy.inject_dependencies(self.pred_step)(
                x=x,
                states=states,
                initializing=initializing,
                training=training,
            )
            aux_losses = elegy.hooks.get_losses()

        rng: elegy.RNGSeq = states.rng

        # crossentropy loss + kl
        kl_loss = aux_losses["kl_divergence_loss"]
        bce = elegy.losses.binary_crossentropy(x, logits,
                                               from_logits=True).mean()
        loss = bce + kl_loss

        # metrics
        logs = dict(
            loss=loss,
            kl_loss=kl_loss,
            bce_loss=bce,
        )

        if initializing:
            loss_metrics_fn = self.loss_metrics.init(rng=rng)
        else:
            loss_metrics_fn = self.loss_metrics.apply(None,
                                                      states.metrics_states)

        logs, _, metrics_states = loss_metrics_fn(logs)

        states = states.update(metrics_states=metrics_states)

        return (loss, logs, states)
Пример #4
0
    def test_step(
        self,
        x,
        net_params,
        states: elegy.States,
        net_states,
        rng,
        metrics_states,
        training,
        initializing,
        mode,
    ):
        logits, states, aux_losses, _, _ = self.call_pred_step(
            x=x,
            mode=mode,
            states=states,
            initializing=initializing,
        )

        # crossentropy loss + kl
        kl_loss = aux_losses["kl_divergence_loss"]
        bce = elegy.losses.binary_crossentropy(x, logits,
                                               from_logits=True).mean()
        loss = bce + kl_loss

        # metrics
        logs = dict(
            loss=loss,
            kl_loss=kl_loss,
            bce_loss=bce,
        )

        if initializing:
            loss_metrics_fn = self.loss_metrics.init(rng=rng)
        else:
            loss_metrics_fn = self.loss_metrics.apply(metrics_states)

        logs, metrics_states = loss_metrics_fn(logs)

        states = states.update(metrics_states=metrics_states)

        return (loss, logs, states)
Пример #5
0
    def generator_step(self, batch_size: int, states: elegy.States):
        z = jax.random.normal(states.rng.next(), (batch_size, 128))

        def g_loss_fn(g_params, states, z):
            x_fake, g_params, g_states = self.generator.apply(
                g_params, states.g_states)(z)
            y_fake_scores = self.discriminator.apply(
                states.d_params, states.d_states)(x_fake)[0]
            y_fake_true = jnp.ones(len(z))
            loss = -y_fake_scores.mean()
            return loss, states.update_known(**locals())

        (g_loss,
         states), g_grads = jax.value_and_grad(g_loss_fn,
                                               has_aux=True)(states.g_params,
                                                             states, z)
        g_grads, g_opt_states = self.g_optimizer.update(
            g_grads, states.g_opt_states, states.g_params)
        g_params = optax.apply_updates(states.g_params, g_grads)

        return g_loss, states.update_known(**locals())
Пример #6
0
 def init_step(self, x, states: elegy.States):
     return states.update(a=x.shape)
Пример #7
0
 def init_step(self, x, y_true, states: elegy.States):
     return states.update(a=x.shape, b=y_true.shape)