def discriminator_step(self, x_real: jnp.ndarray, states: elegy.States): z = jax.random.normal(states.rng.next(), (len(x_real), 128)) x_fake = self.generator.apply(states.g_params, states.g_states)(z)[0] def d_loss_fn(d_params, states, x_real, x_fake): y_real, d_params, d_states = self.discriminator.apply( d_params, states.d_states)(x_real) y_fake, d_params, d_states = self.discriminator.apply( d_params, d_states)(x_fake) loss = -y_real.mean() + y_fake.mean() gp = gradient_penalty( x_real, x_fake, self.discriminator.apply(d_params, d_states), states.rng.next(), ) loss = loss + gp return loss, (gp, states.update_known(**locals())) (d_loss, (gp, states)), d_grads = jax.value_and_grad( d_loss_fn, has_aux=True)(states.d_params, states, x_real, x_fake) d_grads, d_opt_states = self.d_optimizer.update( d_grads, states.d_opt_states, states.d_params) d_params = optax.apply_updates(states.d_params, d_grads) return d_loss, gp, states.update_known(**locals())
def test_step( self, x, y_true, net_params, states: elegy.States, initializing: bool, rng: elegy.RNGSeq, ): # flatten + scale x = jnp.reshape(x, (x.shape[0], -1)) / 255 # model if initializing: w = jax.random.uniform( rng.next(), shape=[np.prod(x.shape[1:]), 10], minval=-1, maxval=1 ) b = jax.random.uniform(rng.next(), shape=[1], minval=-1, maxval=1) net_params = (w, b) w, b = net_params logits = jnp.dot(x, w) + b # crossentropy loss labels = jax.nn.one_hot(y_true, 10) loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) # metrics logs = dict( acc=jnp.mean(jnp.argmax(logits, axis=-1) == y_true), loss=loss, ) return loss, logs, states.update(net_params=net_params)
def test_step( self, x, states: elegy.States, training, initializing, ): with elegy.hooks.context(losses=True): logits, states = elegy.inject_dependencies(self.pred_step)( x=x, states=states, initializing=initializing, training=training, ) aux_losses = elegy.hooks.get_losses() rng: elegy.RNGSeq = states.rng # crossentropy loss + kl kl_loss = aux_losses["kl_divergence_loss"] bce = elegy.losses.binary_crossentropy(x, logits, from_logits=True).mean() loss = bce + kl_loss # metrics logs = dict( loss=loss, kl_loss=kl_loss, bce_loss=bce, ) if initializing: loss_metrics_fn = self.loss_metrics.init(rng=rng) else: loss_metrics_fn = self.loss_metrics.apply(None, states.metrics_states) logs, _, metrics_states = loss_metrics_fn(logs) states = states.update(metrics_states=metrics_states) return (loss, logs, states)
def test_step( self, x, net_params, states: elegy.States, net_states, rng, metrics_states, training, initializing, mode, ): logits, states, aux_losses, _, _ = self.call_pred_step( x=x, mode=mode, states=states, initializing=initializing, ) # crossentropy loss + kl kl_loss = aux_losses["kl_divergence_loss"] bce = elegy.losses.binary_crossentropy(x, logits, from_logits=True).mean() loss = bce + kl_loss # metrics logs = dict( loss=loss, kl_loss=kl_loss, bce_loss=bce, ) if initializing: loss_metrics_fn = self.loss_metrics.init(rng=rng) else: loss_metrics_fn = self.loss_metrics.apply(metrics_states) logs, metrics_states = loss_metrics_fn(logs) states = states.update(metrics_states=metrics_states) return (loss, logs, states)
def generator_step(self, batch_size: int, states: elegy.States): z = jax.random.normal(states.rng.next(), (batch_size, 128)) def g_loss_fn(g_params, states, z): x_fake, g_params, g_states = self.generator.apply( g_params, states.g_states)(z) y_fake_scores = self.discriminator.apply( states.d_params, states.d_states)(x_fake)[0] y_fake_true = jnp.ones(len(z)) loss = -y_fake_scores.mean() return loss, states.update_known(**locals()) (g_loss, states), g_grads = jax.value_and_grad(g_loss_fn, has_aux=True)(states.g_params, states, z) g_grads, g_opt_states = self.g_optimizer.update( g_grads, states.g_opt_states, states.g_params) g_params = optax.apply_updates(states.g_params, g_grads) return g_loss, states.update_known(**locals())
def init_step(self, x, states: elegy.States): return states.update(a=x.shape)
def init_step(self, x, y_true, states: elegy.States): return states.update(a=x.shape, b=y_true.shape)