def main( steps_per_epoch: int = 200, batch_size: int = 64, epochs: int = 50, debug: bool = False, eager: bool = False, logdir: str = "runs", ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) vae = VariationalAutoEncoder(latent_size=LATENT_SIZE) model = elegy.Model( module=vae, loss=[BinaryCrossEntropy(from_logits=True, on="logits")], optimizer=optax.adam(1e-3), run_eagerly=eager, ) model.summary(X_train[:64]) # Fit with datasets in memory history = model.fit( x=X_train, epochs=epochs, batch_size=batch_size, steps_per_epoch=steps_per_epoch, validation_data=(X_test, ), shuffle=True, callbacks=[TensorBoard(logdir)], ) print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", ) elegy.utils.plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred["det_image"][i], cmap="gray") # tbwriter.add_figure("VAE Example", figure, epochs) plt.show() # call update_modules to enable parameter transfer # for now only Elegy Modules support this model.update_modules() # sample model_decoder = elegy.Model(vae.decoder) z_samples = np.random.normal(size=(12, LATENT_SIZE)) samples = model_decoder.predict(z_samples, initialize=True) samples = jax.nn.sigmoid(samples) # plot and save results # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(5, 12)) plt.title("Generative Samples") for i in range(5): plt.subplot(2, 5, 2 * i + 1) plt.imshow(samples[i], cmap="gray") plt.subplot(2, 5, 2 * i + 2) plt.imshow(samples[i + 1], cmap="gray") # # tbwriter.add_figure("VAE Generative Example", figure, epochs) plt.show()
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: int = 200, epochs: int = 100, batch_size: int = 64, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) X_train, y_train, X_test, y_test = dataget.image.mnist( global_cache=True).get() X_train = X_train[..., None] X_test = X_test[..., None] print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) class CNN(elegy.Module): def call(self, image: jnp.ndarray, training: bool): @elegy.to_module def ConvBlock(x, units, kernel, stride=1): x = elegy.nn.Conv2D(units, kernel, stride=stride, padding="same")(x) x = elegy.nn.BatchNormalization()(x, training) x = elegy.nn.Dropout(0.2)(x, training) return jax.nn.relu(x) x: np.ndarray = image.astype(jnp.float32) / 255.0 # base x = ConvBlock()(x, 32, [3, 3]) x = ConvBlock()(x, 64, [3, 3], stride=2) x = ConvBlock()(x, 64, [3, 3], stride=2) x = ConvBlock()(x, 128, [3, 3], stride=2) # GlobalAveragePooling2D x = jnp.mean(x, axis=[1, 2]) # 1x1 Conv x = elegy.nn.Linear(10)(x) return x model = elegy.Model( module=CNN(), loss=elegy.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.adam(1e-3), run_eagerly=eager, ) # show model summary model.summary(X_train[:64], depth=1) history = model.fit( x=X_train, y=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, y_test), shuffle=True, callbacks=[TensorBoard(logdir=logdir)], ) elegy.utils.plot_history(history) model.save("models/conv") model = elegy.load("models/conv") print(model.evaluate(x=X_test, y=y_test)) # get random samples idxs = np.random.randint(0, 10000, size=(9, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(3): for j in range(3): k = 3 * i + j plt.subplot(3, 3, k + 1) plt.title(f"{np.argmax(y_pred[k])}") plt.imshow(x_sample[k], cmap="gray") # tbwriter.add_figure("Conv classifier", figure, 100) plt.show()
def main( steps_per_epoch: int = 200, epochs: int = 50, debug: bool = False, eager: bool = False, logdir: str = "runs", ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) vae = VAE(latent_size=LATENT_SIZE) # model = VariationalAutoEncoder(latent_size=LATENT_SIZE, optimizer=optax.adam(1e-3)) def loss(x, y_pred): logits, mean, stddev = y_pred ce_loss = elegy.losses.binary_crossentropy(x, logits, from_logits=True).mean() kl_loss = 2e-1 * kl_divergence(mean, stddev) return ce_loss + kl_loss model = elegy.Model( module=vae, loss=loss, optimizer=optax.adam(1e-3), run_eagerly=eager, ) # Fit with datasets in memory history = model.fit( x=X_train, epochs=epochs, batch_size=64, steps_per_epoch=steps_per_epoch, validation_data=(X_test, ), shuffle=True, callbacks=[TensorBoard(logdir)], ) print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", ) elegy.utils.plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions logits, mean, stddev = model.predict(x=x_sample) y_pred = jax.nn.sigmoid(logits) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") # # tbwriter.add_figure("VAE Example", figure, epochs) plt.show()