def test_basic_predict(self): # FIXME: test succeeds if run alone or if run on the cpu-only version of jax # test fails with "DNN library is not found" if run on gpu with all other tests together model = elegy.Model(elegy.nets.resnet.ResNet18(), eager=True) assert isinstance(model.module, elegy.Module) x = np.random.random((2, 224, 224, 3)).astype(np.float32) model.init(x) y = model.predict(x) # update_modules results in a call to `set_default_parameters` for elegy Modules # it might be better to have the user call this explicitly to avoid potential OOM model.update_modules() assert jnp.all(y.shape == (2, 1000)) # test loading weights from file with tempfile.TemporaryDirectory() as tempdir: pklpath = os.path.join(tempdir, "delete_me.pkl") open(pklpath, "wb").write( pickle.dumps(model.module.get_default_parameters())) new_r18 = elegy.nets.resnet.ResNet18(weights=pklpath) y2 = elegy.Model(new_r18, eager=True).predict(x, initialize=True) assert np.allclose(y, y2, rtol=0.001)
def test_example(self): class MLP(elegy.Module): def __apply__(self, input): mlp = hk.Sequential([ hk.Linear(10), ]) return mlp(input) callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3) # This callback will stop the training when there is no improvement in # the for three consecutive epochs. model = elegy.Model( module=MLP.defer(), loss=elegy.losses.MeanSquaredError(), optimizer=optix.rmsprop(0.01), ) history = model.fit( np.arange(100).reshape(5, 20).astype(np.float32), np.zeros(5), epochs=10, batch_size=1, callbacks=[callback], verbose=0, ) assert len(history.history["loss"]) == 7 # Only 7 epochs are run.
def main(batch_size: int = 64, k: int = 5, debug: bool = False): noise = np.float32(np.random.normal(size=(3000, 1))) # random noise y_train = np.float32(np.random.uniform(-10.5, 10.5, (1, 3000))).T X_train = np.float32( np.sin(0.75 * y_train) * 7.0 + y_train * 0.5 + noise * 1.0) X_train = X_train / np.abs(X_train.max()) y_train = y_train / np.abs(y_train.max()) visualize_data(X_train, y_train) model = elegy.Model(module=MixtureModel(k=k), loss=MixtureNLL(), optimizer=optax.adam(3e-4)) model.summary(X_train[:batch_size], depth=1) model.fit( x=X_train, y=y_train, epochs=500, batch_size=batch_size, shuffle=True, ) visualize_model(X_train, y_train, model, k)
def test_evaluate(self): class mse(eg.Loss): def call(self, target, preds): return jnp.mean((target - preds)**2) class mae(eg.Metric): value: eg.MetricState = eg.MetricState.node( default=jnp.array(0.0, jnp.float32)) def update(self, target, preds): return jnp.mean(jnp.abs(target - preds)) def compute(self) -> tp.Any: return self.value model = eg.Model( module=eg.Linear(1), loss=dict(a=mse()), metrics=dict(b=mae()), optimizer=optax.adamw(1e-3), eager=True, ) X = np.random.uniform(size=(5, 2)) y = np.random.uniform(size=(5, 1)) logs = model.evaluate(x=X, y=y) assert "a/mse_loss" in logs assert "b/mae" in logs assert "loss" in logs
def test_cloudpickle(self): model = elegy.Model( module=MLP(n1=3, n2=1), loss=[ elegy.losses.SparseCategoricalCrossentropy(from_logits=True), elegy.regularizers.GlobalL2(l=1e-4), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.adamw(1e-3), run_eagerly=True, ) X = np.random.uniform(size=(5, 7, 7)) y = np.random.randint(10, size=(5, )) model.init(X, y) y0 = model.predict(X) model_pkl = cloudpickle.dumps(model) newmodel = cloudpickle.loads(model_pkl) newmodel.states = model.states newmodel.initial_states = model.initial_states y1 = newmodel.predict(X) assert np.all(y0 == y1)
def test_evaluate(self): model = elegy.Model( module=MLP(n1=3, n2=1), loss=[ elegy.losses.SparseCategoricalCrossentropy(from_logits=True), elegy.regularizers.GlobalL2(l=1e-4), ], metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=optax.adamw(1e-3), run_eagerly=True, ) X = np.random.uniform(size=(5, 7, 7)) y = np.random.randint(10, size=(5, )) history = model.fit( x=X, y=y, epochs=1, steps_per_epoch=1, batch_size=5, validation_data=(X, y), shuffle=True, verbose=1, ) logs = model.evaluate(X, y) eval_acc = logs["sparse_categorical_accuracy"] predict_acc = (model.predict(X).argmax(-1) == y).mean() assert eval_acc == predict_acc
def test_lr_logging(self): model = elegy.Model( module=MLP(n1=3, n2=1), loss=elegy.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=elegy.metrics.SparseCategoricalAccuracy(), optimizer=elegy.Optimizer( optax.adamw(1.0, b1=0.95), lr_schedule=lambda step, epoch: jnp.array(1e-3), ), run_eagerly=True, ) X = np.random.uniform(size=(5, 7, 7)) y = np.random.randint(10, size=(5, )) history = model.fit( x=X, y=y, epochs=1, steps_per_epoch=1, batch_size=5, validation_data=(X, y), shuffle=True, verbose=0, ) assert "lr" in history.history assert np.allclose(history.history["lr"], 1e-3)
def test_evaluate(self): model = eg.Model( module=MLP(dmid=3, dout=4), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(l=1e-4), ], metrics=eg.metrics.Accuracy(), optimizer=optax.adamw(1e-3), eager=True, ) X = np.random.uniform(size=(5, 2)) y = np.random.randint(4, size=(5, )) history = model.fit( inputs=X, labels=y, epochs=1, steps_per_epoch=1, batch_size=5, validation_data=(X, y), shuffle=True, verbose=1, ) logs = model.evaluate(X, y) eval_acc = logs["accuracy"] predict_acc = (model.predict(X).argmax(-1) == y).mean() assert eval_acc == predict_acc
def test_example_restore(self): class MLP(eg.Module): @eg.compact def __call__(self, x): x = eg.Linear(10)(x) x = jax.lax.stop_gradient(x) return x # This callback will stop the training when there is no improvement in # the for three consecutive epochs. model = eg.Model( module=MLP(), loss=eg.losses.MeanSquaredError(), optimizer=optax.rmsprop(0.01), ) history = model.fit( inputs=np.ones((5, 20)), labels=np.zeros((5, 10)), epochs=10, batch_size=1, callbacks=[ eg.callbacks.EarlyStopping(monitor="loss", patience=3, restore_best_weights=True) ], verbose=0, ) assert len(history.history["loss"]) == 4 # Only 4 epochs are run.
def test_on_model(self): model = elegy.Model(module=elegy.nn.Linear(2)) x = np.ones([3, 5]) y_pred = model.predict(x, initialize=True) logs = model.evaluate(x)
def main(logdir: str = "runs", steps_per_epoch: tp.Optional[int] = None, epochs: int = 10, batch_size: int = 32): platform = jax.local_devices()[0].platform ndevices = len(jax.devices()) print('devices ', jax.devices()) print('platform ', platform) current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = dataset["train"]["image"][..., None] y_train = dataset["train"]["label"] X_test = dataset["test"]["image"][..., None] y_test = dataset["test"]["label"] accuracies = {} # we run distributed=False twice to remove any initial warmup costs for distributed in [False, False, True]: print(f'Distributed training = {distributed}') start_time = time.time() model = eg.Model(module=CNN(), loss=eg.losses.Crossentropy(), metrics=eg.metrics.Accuracy(), optimizer=optax.adam(1e-3), seed=42) if distributed: model = model.distributed() bs = batch_size #int(batch_size / ndevices) else: bs = batch_size #model.summary(X_train[:64], depth=1) history = model.fit(inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=bs, validation_data=(X_test, y_test), shuffle=True, verbose=3) ev = model.evaluate(x=X_test, y=y_test, verbose=1) print('eval ', ev) accuracies[distributed] = ev['accuracy'] end_time = time.time() print(f'time taken ', {end_time - start_time}) print(accuracies)
def test_predict(self): model = eg.Model(module=eg.Linear(1)) X = np.random.uniform(size=(5, 2)) y = np.random.randint(10, size=(5, 1)) y_pred = model.predict(X) assert y_pred.shape == (5, 1)
def test_autodownload_pretrained_r50(self): fname, _ = urllib.request.urlretrieve( "https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg" ) im = np.array(PIL.Image.open(fname).resize([224, 224 ])) / np.float32(255) r50 = elegy.nets.resnet.ResNet50(weights="imagenet") with jax.disable_jit(): assert elegy.Model(r50).predict(im[np.newaxis]).argmax() == 245
def test_predict(self): model = elegy.Model(module=elegy.nn.Linear(1)) X = np.random.uniform(size=(5, 10)) y = np.random.randint(10, size=(5, 1)) model.init(x=X, y=y) y_pred = model.predict(x=X) assert y_pred.shape == (5, 1)
def test_on_predict(self): class TestModule(elegy.Module): def call(self, x, training): return elegy.nn.BatchNormalization()(x, training) model = elegy.Model(module=TestModule()) x = jnp.ones([3, 5]) y_pred = model.predict(x) logs = model.evaluate(x)
def main(debug: bool = False, eager: bool = False): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = elegy.Model( module=VariationalAutoEncoder.defer(), loss=[KLDivergence(), BinaryCrossEntropy(on="logits")], optimizer=optix.adam(1e-3), run_eagerly=eager, ) epochs = 10 # Fit with datasets in memory history = model.fit( x=X_train, epochs=epochs, batch_size=64, steps_per_epoch=100, validation_data=(X_test, ), shuffle=True, ) plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot results plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred["image"][i], cmap="gray") plt.show()
def test_on_predict(self): class TestModule(elegy.Module): def call(self, x, training): return elegy.nn.Dropout(0.5)(x, training) model = elegy.Model(TestModule()) x = jnp.ones([3, 5]) y_pred = model.predict(x) logs = model.evaluate(x) assert jnp.all(y_pred == x)
def test_on_predict(self): model = elegy.Model( elegy.nn.Sequential(lambda: [ elegy.nn.Flatten(), elegy.nn.Linear(5), jax.nn.relu, elegy.nn.Linear(2), ])) x = jnp.ones([3, 5]) y_pred = model.predict(x) logs = model.evaluate(x)
def test_model_fit(self): ds = DS0() loader_train = elegy.data.DataLoader(ds, batch_size=4, n_workers=4) loader_valid = elegy.data.DataLoader(ds, batch_size=4, n_workers=4) class Module(elegy.Module): def call(self, x): x = jnp.mean(x, axis=[1, 2]) x = elegy.nn.Linear(20)(x) return x model = elegy.Model( Module(), loss=elegy.losses.SparseCategoricalCrossentropy(), optimizer=optax.sgd(0.1), ) model.fit(loader_train, validation_data=loader_valid, epochs=3)
def test_distributed_init(self): n_devices = jax.device_count() batch_size = 5 * n_devices x = np.random.uniform(size=(batch_size, 1)) y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2)) model = eg.Model( eg.Linear(2), loss=[eg.losses.MeanSquaredError()], ) model = model.distributed() model.init_on_batch(x) assert model.module.kernel.shape == (n_devices, 1, 2) assert model.module.bias.shape == (n_devices, 2)
def test_saved_model(self): with TemporaryDirectory() as model_dir: model = eg.Model(module=eg.Linear(4)) x = np.random.uniform(size=(5, 6)) model.merge model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8]) output = str(sh.ls(model_dir)) assert "saved_model.pb" in output assert "variables" in output saved_model = tf.saved_model.load(model_dir) saved_model
def test_saved_model(self): with TemporaryDirectory() as model_dir: model = elegy.Model(module=elegy.nn.Linear(4)) x = np.random.uniform(size=(5, 6)) with pytest.raises(elegy.types.ModelNotInitialized): model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8]) model.init(x) model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8]) output = str(sh.ls(model_dir)) assert "saved_model.pb" in output assert "variables" in output saved_model = tf.saved_model.load(model_dir) saved_model
def test_saved_model_poly(self): with TemporaryDirectory() as model_dir: model = eg.Model(module=eg.Linear(4)) x = np.random.uniform(size=(5, 6)).astype(np.float32) model.saved_model(x, model_dir, batch_size=None) output = str(sh.ls(model_dir)) assert "saved_model.pb" in output assert "variables" in output saved_model = tf.saved_model.load(model_dir) # change batch x = np.random.uniform(size=(3, 6)).astype(np.float32) y = saved_model(x) assert y.shape == (3, 4)
def test_evaluate(self): def mse(y_true, y_pred): return jnp.mean((y_true - y_pred)**2) def mae(y_true, y_pred): return jnp.mean(jnp.abs(y_true - y_pred)) model = elegy.Model( module=elegy.nn.Linear(1), loss=dict(a=mse), metrics=dict(b=mae), optimizer=optax.adamw(1e-3), run_eagerly=True, ) X = np.random.uniform(size=(5, 10)) y = np.random.uniform(size=(5, 1)) logs = model.evaluate(x=X, y=y) assert "a/mse_loss" in logs assert "b/mae" in logs assert "loss" in logs
def test_cloudpickle(self): model = eg.Model( module=eg.Linear(10), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(1e-4), ], metrics=eg.metrics.Accuracy(), optimizer=optax.adamw(1e-3), eager=True, ) X = np.random.uniform(size=(5, 2)) y = np.random.randint(10, size=(5, )) y0 = model.predict(X) with TemporaryDirectory() as model_dir: model.save(model_dir) newmodel = eg.load(model_dir) y1 = newmodel.predict(X) assert np.all(y0 == y1)
def test_example(self): class MLP(elegy.Module): def call(self, x): x = elegy.nn.Linear(10)(x) x = jax.lax.stop_gradient(x) return x callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3) # This callback will stop the training when there is no improvement in # the for three consecutive epochs. model = elegy.Model( module=MLP(), loss=elegy.losses.MeanSquaredError(), optimizer=optax.rmsprop(0.01), ) history = model.fit( x=np.ones((5, 20)), y=np.zeros((5, 10)), epochs=10, batch_size=1, callbacks=[callback], verbose=0, ) assert len(history.history["loss"]) == 4 # Only 4 epochs are run.
def test_summaries(self): class ModuleC(linen.Module): @linen.compact @elegy.flax_summarize def __call__(self, x): c1 = self.param("c1", lambda _: jnp.ones([5])) c2 = self.variable("states", "c2", lambda: jnp.ones([6])) x = jax.nn.relu(x) elegy.flax_summary(self, "relu", jax.nn.relu, x) return x class ModuleB(linen.Module): @linen.compact @elegy.flax_summarize def __call__(self, x): b1 = self.param("b1", lambda _: jnp.ones([3])) b2 = self.variable("states", "b2", lambda: jnp.ones([4])) x = ModuleC()(x) x = jax.nn.relu(x) elegy.flax_summary(self, "relu", jax.nn.relu, x) return x class ModuleA(linen.Module): @linen.compact @elegy.flax_summarize def __call__(self, x): a1 = self.param("a1", lambda _: jnp.ones([1])) a2 = self.variable("states", "a2", lambda: jnp.ones([2])) x = ModuleB()(x) x = jax.nn.relu(x) elegy.flax_summary(self, "relu", jax.nn.relu, x) return x model = elegy.Model(ModuleA()) model.init(x=jnp.ones([10, 2])) summary_text = model.summary(x=jnp.ones([10, 2]), depth=1, return_repr=True) assert summary_text is not None lines = summary_text.split("\n") assert "ModuleB_0" in lines[7] assert "ModuleB" in lines[7] assert "(10, 2)" in lines[7] assert "8" in lines[7] assert "32 B" in lines[7] assert "10" in lines[7] assert "40 B" in lines[7] assert "relu" in lines[9] assert "(10, 2)" in lines[9] assert "*" in lines[11] assert "ModuleA" in lines[11] assert "(10, 2)" in lines[11] assert "1" in lines[11] assert "4 B" in lines[11] assert "2" in lines[11] assert "8 B" in lines[11] assert "9" in lines[13] assert "36 B" in lines[13] assert "12" in lines[13] assert "48 B" in lines[13] assert "21" in lines[16] assert "84 B" in lines[16]
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: tp.Optional[int] = None, epochs: int = 100, batch_size: int = 32, distributed: bool = False, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"])[..., None] y_train = dataset["train"]["label"] X_test = np.stack(dataset["test"]["image"])[..., None] y_test = dataset["test"]["label"] print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) model = eg.Model( module=CNN(), loss=eg.losses.Crossentropy(), metrics=eg.metrics.Accuracy(), optimizer=optax.adam(1e-3), eager=eager, ) if distributed: model = model.distributed() # show model summary model.summary(X_train[:64], depth=1) history = model.fit( inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, y_test), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], ) eg.utils.plot_history(history) print(model.evaluate(x=X_test, y=y_test)) # get random samples idxs = np.random.randint(0, 10000, size=(9, )) x_sample = X_test[idxs] # get predictions model = model.local() y_pred = model.predict(x=x_sample) # plot results figure = plt.figure(figsize=(12, 12)) for i in range(3): for j in range(3): k = 3 * i + j plt.subplot(3, 3, k + 1) plt.title(f"{np.argmax(y_pred[k])}") plt.imshow(x_sample[k], cmap="gray") plt.show()
def test_summaries(self): class ModuleC(elegy.Module): def call(self, x): c1 = self.add_parameter("c1", lambda: jnp.ones([5])) c2 = self.add_parameter("c2", lambda: jnp.ones([6]), trainable=False) x = jax.nn.relu(x) self.add_summary("relu", jax.nn.relu, x) return x class ModuleB(elegy.Module): def call(self, x): b1 = self.add_parameter("b1", lambda: jnp.ones([3])) b2 = self.add_parameter("b2", lambda: jnp.ones([4]), trainable=False) x = ModuleC()(x) x = jax.nn.relu(x) self.add_summary("relu", jax.nn.relu, x) return x class ModuleA(elegy.Module): def call(self, x): a1 = self.add_parameter("a1", lambda: jnp.ones([1])) a2 = self.add_parameter("a2", lambda: jnp.ones([2]), trainable=False) x = ModuleB()(x) x = jax.nn.relu(x) self.add_summary("relu", jax.nn.relu, x) return x model = elegy.Model(ModuleA()) summary_text = model.summary(x=jnp.ones([10, 2]), depth=1, return_repr=True) assert summary_text is not None lines = summary_text.split("\n") assert "module_b" in lines[7] assert "ModuleB" in lines[7] assert "(10, 2)" in lines[7] assert "8" in lines[7] assert "32 B" in lines[7] assert "10" in lines[7] assert "40 B" in lines[7] assert "relu" in lines[9] assert "(10, 2)" in lines[9] assert "*" in lines[11] assert "ModuleA" in lines[11] assert "(10, 2)" in lines[11] assert "1" in lines[11] assert "4 B" in lines[11] assert "2" in lines[11] assert "8 B" in lines[11] assert "21" in lines[13] assert "84 B" in lines[13] assert "9" in lines[14] assert "36 B" in lines[14] assert "12" in lines[15] assert "48 B" in lines[15]
def main( steps_per_epoch: tp.Optional[int] = None, batch_size: int = 32, epochs: int = 50, debug: bool = False, eager: bool = False, logdir: str = "runs", ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8) X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8) # Now binarize data X_train = (X_train / 255.0).astype(jnp.float32) X_test = (X_test / 255.0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = eg.Model( module=VariationalAutoEncoder(latent_size=LATENT_SIZE), loss=[BinaryCrossEntropy(from_logits=True, on="logits")], optimizer=optax.adam(1e-3), eager=eager, ) assert model.module is not None model.summary(X_train[:64]) # Fit with datasets in memory history = model.fit( inputs=X_train, epochs=epochs, batch_size=batch_size, steps_per_epoch=steps_per_epoch, validation_data=(X_test, ), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir)], ) print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", ) eg.utils.plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred["det_image"][i], cmap="gray") # tbwriter.add_figure("VAE Example", figure, epochs) # sample model_decoder = eg.Model(model.module.decoder) z_samples = np.random.normal(size=(12, LATENT_SIZE)) samples = model_decoder.predict(z_samples) samples = jax.nn.sigmoid(samples) # plot and save results # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(5, 12)) plt.title("Generative Samples") for i in range(5): plt.subplot(2, 5, 2 * i + 1) plt.imshow(samples[i], cmap="gray") plt.subplot(2, 5, 2 * i + 2) plt.imshow(samples[i + 1], cmap="gray") # # tbwriter.add_figure("VAE Generative Example", figure, epochs) plt.show()