def pred_step(self, x, net_states, initializing): if initializing: states = elegy.States(net_states=0) else: states = elegy.States(net_states=net_states + 1) return elegy.PredStep.simple(x + 1.0, states)
def test_step(self, metrics_states, initializing): return elegy.TestStep( loss=0.1, logs=dict(loss=1.0), states=elegy.States(metrics_states=0) if initializing else elegy.States(metrics_states=metrics_states + 1), )
def train_step(self, states, initializing): return elegy.TrainStep( logs=dict(loss=2.0), states=elegy.States( optimizer_states=0) if initializing else elegy.States( optimizer_states=states.optimizer_states + 1), )
def pred_step(self, x, initializing, states): if initializing: states = elegy.States(net_states=0) else: states = elegy.States(net_states=states.net_states + 1) return elegy.PredStep(x + 1.0, states)
def pred_step(self, x, states, initializing): if initializing: states = elegy.States(net_states=0) else: states = elegy.States(net_states=states.net_states + 1) return elegy.PredStep( y_pred=1, states=states, )
def train_step(self, x, optimizer_states, initializing): if initializing: states = elegy.States(optimizer_states=0) else: states = elegy.States(optimizer_states=optimizer_states + 1) return elegy.TrainStep( logs=dict(loss=jnp.sum(x)), states=states, )
def test_step(self, x, metrics_states, initializing): if initializing: states = elegy.States(metrics_states=0) else: states = elegy.States(metrics_states=metrics_states + 1) return elegy.TestStep( loss=0.1, logs=dict(loss=jnp.sum(x)), states=states, )
def states_step(self): return elegy.States( net_params=None, net_states=None, metrics_states=None, optimizer_states=None, )
def pred_step(self): nonlocal N N = N + 1 return elegy.PredStep.simple( y_pred=None, states=elegy.States(net_params=1, net_states=2), )
def train_step(self, x, y_true, net_params, optimizer_states, initializing, rng): def loss_fn(net_params, x, y_true): # flatten + scale x = jnp.reshape(x, (x.shape[0], -1)) / 255 # model if initializing: w = jax.random.uniform(rng.next(), shape=[np.prod(x.shape[1:]), 10], minval=-1, maxval=1) b = jax.random.uniform(rng.next(), shape=[1], minval=-1, maxval=1) net_params = (w, b) w, b = net_params logits = jnp.dot(x, w) + b # crossentropy loss labels = jax.nn.one_hot(y_true, 10) sample_loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) loss = jnp.mean(sample_loss) # metrics logs = dict( accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true), loss=loss, ) return loss, (logs, net_params) # train (loss, (logs, net_params)), grads = jax.value_and_grad(loss_fn, has_aux=True)( net_params, # gradients target x, y_true, ) if initializing: optimizer_states = self.optim.init(net_params) else: grads, optimizer_states = self.optim.update( grads, optimizer_states, net_params) net_params = optax.apply_updates(net_params, grads) return logs, elegy.States( net_params=net_params, optimizer_states=optimizer_states, )
def init_step(self, x): rng = elegy.KeySeq(0) gx, g_params, g_states = self.generator.init(rng=rng)(x) dx, d_params, d_states = self.discriminator.init(rng=rng)(gx) g_optimizer_states = self.g_optimizer.init(g_params) d_optimizer_states = self.d_optimizer.init(d_params) return elegy.States( g_states=g_states, d_states=d_states, g_params=g_params, d_params=d_params, g_opt_states=g_optimizer_states, d_opt_states=d_optimizer_states, rng=rng, step=0, )
def main( steps_per_epoch: int = 200, batch_size: int = 64, epochs: int = 50, debug: bool = False, eager: bool = False, logdir: str = "runs", ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = VariationalAutoEncoder( latent_size=LATENT_SIZE, optimizer=optax.adam(1e-3), run_eagerly=eager, ) # Fit with datasets in memory history = model.fit( x=X_train, epochs=epochs, batch_size=batch_size, steps_per_epoch=steps_per_epoch, validation_data=(X_test, ), shuffle=True, callbacks=[TensorBoard(logdir)], ) print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", ) elegy.utils.plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) y_pred = jax.nn.sigmoid(y_pred) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") # # tbwriter.add_figure("VAE Example", figure, epochs) plt.show() # sample model_decoder = elegy.Model( model.decoder, states=elegy.States( net_params=model.states.net_params[1], net_states=model.states.net_states[1], ), initialized=True, ) z_samples = np.random.normal(size=(12, LATENT_SIZE)) logits = model_decoder.predict(z_samples) samples = jax.nn.sigmoid(logits) # plot and save results # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(5, 12)) plt.title("Generative Samples") for i in range(5): plt.subplot(2, 5, 2 * i + 1) plt.imshow(samples[i], cmap="gray") plt.subplot(2, 5, 2 * i + 2) plt.imshow(samples[i + 1], cmap="gray") # # tbwriter.add_figure("VAE Generative Example", figure, epochs) plt.show()