def test_caching(self): n_rows = 10 features = Features({"foo": Value("string"), "bar": Value("string")}) with tempfile.TemporaryDirectory() as tmp_dir: # Use \n for newline. Windows automatically adds the \r when writing the file # see https://docs.python.org/3/library/os.html#os.linesep open(os.path.join(tmp_dir, "table.csv"), "w", encoding="utf-8").write("\n".join(",".join(["foo", "bar"]) for _ in range(n_rows + 1))) ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train", keep_in_memory=False, ) data_file = ds.cache_files[0]["filename"] fingerprint = ds._fingerprint self.assertEqual(len(ds), n_rows) del ds ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train", keep_in_memory=False, ) self.assertEqual(ds.cache_files[0]["filename"], data_file) self.assertEqual(ds._fingerprint, fingerprint) del ds ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train", features=features, keep_in_memory=False, ) self.assertNotEqual(ds.cache_files[0]["filename"], data_file) self.assertNotEqual(ds._fingerprint, fingerprint) del ds open(os.path.join(tmp_dir, "table.csv"), "w", encoding="utf-8").write("\n".join(",".join(["Foo", "Bar"]) for _ in range(n_rows + 1))) ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train", keep_in_memory=False, ) self.assertNotEqual(ds.cache_files[0]["filename"], data_file) self.assertNotEqual(ds._fingerprint, fingerprint) self.assertEqual(len(ds), n_rows) del ds
def test_features(self): n_rows = 10 n_cols = 3 def get_features(type): return Features({str(i): type for i in range(n_cols)}) with tempfile.TemporaryDirectory() as tmp_dir: open(os.path.join(tmp_dir, "table.csv"), "w", encoding="utf-8").write("\n".join( ",".join([str(i) for i in range(n_cols)]) for _ in range(n_rows + 1))) for type in [ Value("float64"), Value("int8"), ClassLabel(num_classes=n_cols) ]: features = get_features(type) ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train", features=features, ) self.assertEqual(len(ds), n_rows) self.assertDictEqual(ds.features, features) del ds
def main(logdir: str = "runs", steps_per_epoch: tp.Optional[int] = None, epochs: int = 10, batch_size: int = 32): platform = jax.local_devices()[0].platform ndevices = len(jax.devices()) print('devices ', jax.devices()) print('platform ', platform) current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = dataset["train"]["image"][..., None] y_train = dataset["train"]["label"] X_test = dataset["test"]["image"][..., None] y_test = dataset["test"]["label"] accuracies = {} # we run distributed=False twice to remove any initial warmup costs for distributed in [False, False, True]: print(f'Distributed training = {distributed}') start_time = time.time() model = eg.Model(module=CNN(), loss=eg.losses.Crossentropy(), metrics=eg.metrics.Accuracy(), optimizer=optax.adam(1e-3), seed=42) if distributed: model = model.distributed() bs = batch_size #int(batch_size / ndevices) else: bs = batch_size #model.summary(X_train[:64], depth=1) history = model.fit(inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=bs, validation_data=(X_test, y_test), shuffle=True, verbose=3) ev = model.evaluate(x=X_test, y=y_test, verbose=1) print('eval ', ev) accuracies[distributed] = ev['accuracy'] end_time = time.time() print(f'time taken ', {end_time - start_time}) print(accuracies)
def test_sep(self): n_rows = 10 n_cols = 3 with tempfile.TemporaryDirectory() as tmp_dir: open(os.path.join(tmp_dir, "table_comma.csv"), "w", encoding="utf-8").write("\n".join( ",".join([str(i) for i in range(n_cols)]) for _ in range(n_rows + 1))) open(os.path.join(tmp_dir, "table_tab.csv"), "w", encoding="utf-8").write("\n".join( "\t".join([str(i) for i in range(n_cols)]) for _ in range(n_rows + 1))) ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table_comma.csv"), cache_dir=tmp_dir, split="train", sep=",", ) self.assertEqual(len(ds), n_rows) self.assertEqual(len(ds.column_names), n_cols) del ds ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table_tab.csv"), cache_dir=tmp_dir, split="train", sep="\t", ) self.assertEqual(len(ds), n_rows) self.assertEqual(len(ds.column_names), n_cols) del ds ds = load_dataset( "csv", data_files=os.path.join(tmp_dir, "table_comma.csv"), cache_dir=tmp_dir, split="train", sep="\t", ) self.assertEqual(len(ds), n_rows) self.assertEqual(len(ds.column_names), 1) del ds
def test_caching(self): n_samples = 10 with tempfile.TemporaryDirectory() as tmp_dir: # Use \n for newline. Windows automatically adds the \r when writing the file # see https://docs.python.org/3/library/os.html#os.linesep open(os.path.join(tmp_dir, "text.txt"), "w", encoding="utf-8").write("\n".join("foo" for _ in range(n_samples))) ds = load_dataset( "text", data_files=os.path.join(tmp_dir, "text.txt"), cache_dir=tmp_dir, split="train", keep_in_memory=False, ) data_file = ds.cache_files[0]["filename"] fingerprint = ds._fingerprint self.assertEqual(len(ds), n_samples) del ds ds = load_dataset( "text", data_files=os.path.join(tmp_dir, "text.txt"), cache_dir=tmp_dir, split="train", keep_in_memory=False, ) self.assertEqual(ds.cache_files[0]["filename"], data_file) self.assertEqual(ds._fingerprint, fingerprint) del ds open(os.path.join(tmp_dir, "text.txt"), "w", encoding="utf-8").write("\n".join("bar" for _ in range(n_samples))) ds = load_dataset( "text", data_files=os.path.join(tmp_dir, "text.txt"), cache_dir=tmp_dir, split="train", keep_in_memory=False, ) self.assertNotEqual(ds.cache_files[0]["filename"], data_file) self.assertNotEqual(ds._fingerprint, fingerprint) self.assertEqual(len(ds), n_samples) del ds
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: int = 200, epochs: int = 100, batch_size: int = 64, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"]) y_train = dataset["train"]["label"] X_test = np.stack(dataset["test"]["image"]) y_test = dataset["test"]["label"] print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) model = Model( features_out=10, optimizer=optax.adam(1e-3), eager=eager, ) history = model.fit( inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, y_test), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], ) eg.utils.plot_history(history)
def __init__(self, training: bool = True): dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"]) y_train = dataset["train"]["label"] X_test = np.stack(dataset["test"]["image"]) y_test = dataset["test"]["label"] if training: self.x = X_train self.y = y_train else: self.x = X_test self.y = y_test
def test_load_real_dataset(self, dataset_name): path = "./datasets/" + dataset_name dataset_module = dataset_module_factory( path, download_config=DownloadConfig(local_files_only=True)) builder_cls = import_main_class(dataset_module.module_path) name = builder_cls.BUILDER_CONFIGS[ 0].name if builder_cls.BUILDER_CONFIGS else None with tempfile.TemporaryDirectory() as temp_cache_dir: dataset = load_dataset(path, name=name, cache_dir=temp_cache_dir, download_mode=GenerateMode.FORCE_REDOWNLOAD) for split in dataset.keys(): self.assertTrue(len(dataset[split]) > 0) del dataset
def test_load_real_dataset_all_configs(self, dataset_name): path = "./datasets/" + dataset_name dataset_module = dataset_module_factory( path, download_config=DownloadConfig(local_files_only=True)) builder_cls = import_main_class(dataset_module.module_path) config_names = ([ config.name for config in builder_cls.BUILDER_CONFIGS ] if len(builder_cls.BUILDER_CONFIGS) > 0 else [None]) for name in config_names: with tempfile.TemporaryDirectory() as temp_cache_dir: dataset = load_dataset( path, name=name, cache_dir=temp_cache_dir, download_mode=DownloadMode.FORCE_REDOWNLOAD) for split in dataset.keys(): self.assertTrue(len(dataset[split]) > 0) del dataset
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: tp.Optional[int] = None, epochs: int = 100, batch_size: int = 32, distributed: bool = False, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"])[..., None] y_train = dataset["train"]["label"] X_test = np.stack(dataset["test"]["image"])[..., None] y_test = dataset["test"]["label"] print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) model = eg.Model( module=CNN(), loss=eg.losses.Crossentropy(), metrics=eg.metrics.Accuracy(), optimizer=optax.adam(1e-3), eager=eager, ) if distributed: model = model.distributed() # show model summary model.summary(X_train[:64], depth=1) history = model.fit( inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, y_test), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], ) eg.utils.plot_history(history) print(model.evaluate(x=X_test, y=y_test)) # get random samples idxs = np.random.randint(0, 10000, size=(9, )) x_sample = X_test[idxs] # get predictions model = model.local() y_pred = model.predict(x=x_sample) # plot results figure = plt.figure(figsize=(12, 12)) for i in range(3): for j in range(3): k = 3 * i + j plt.subplot(3, 3, k + 1) plt.title(f"{np.argmax(y_pred[k])}") plt.imshow(x_sample[k], cmap="gray") plt.show()
def main( steps_per_epoch: tp.Optional[int] = None, batch_size: int = 32, epochs: int = 50, debug: bool = False, eager: bool = False, logdir: str = "runs", ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8) X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8) # Now binarize data X_train = (X_train / 255.0).astype(jnp.float32) X_test = (X_test / 255.0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = eg.Model( module=VariationalAutoEncoder(latent_size=LATENT_SIZE), loss=[BinaryCrossEntropy(from_logits=True, on="logits")], optimizer=optax.adam(1e-3), eager=eager, ) assert model.module is not None model.summary(X_train[:64]) # Fit with datasets in memory history = model.fit( inputs=X_train, epochs=epochs, batch_size=batch_size, steps_per_epoch=steps_per_epoch, validation_data=(X_test, ), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir)], ) print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", ) eg.utils.plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred["det_image"][i], cmap="gray") # tbwriter.add_figure("VAE Example", figure, epochs) # sample model_decoder = eg.Model(model.module.decoder) z_samples = np.random.normal(size=(12, LATENT_SIZE)) samples = model_decoder.predict(z_samples) samples = jax.nn.sigmoid(samples) # plot and save results # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(5, 12)) plt.title("Generative Samples") for i in range(5): plt.subplot(2, 5, 2 * i + 1) plt.imshow(samples[i], cmap="gray") plt.subplot(2, 5, 2 * i + 2) plt.imshow(samples[i + 1], cmap="gray") # # tbwriter.add_figure("VAE Generative Example", figure, epochs) plt.show()
def distributed_load_dataset(args): data_name, tmp_dir, datafiles = args dataset = load_dataset(data_name, cache_dir=tmp_dir, data_files=datafiles) return dataset
def main( steps_per_epoch: int = 200, batch_size: int = 64, epochs: int = 50, debug: bool = False, eager: bool = False, logdir: str = "runs", ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"]) X_test = np.stack(dataset["test"]["image"]) # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) def forward(x: jnp.ndarray): return VAE(latent_size=LATENT_SIZE)(x) model = eg.Model( module=hk.transform_with_state(forward), loss=[ BinaryCrossEntropy(on="logits"), KL(weight=0.1), ], optimizer=optax.adam(1e-3), eager=eager, ) model.summary(X_train[:batch_size]) # Fit with datasets in memory history = model.fit( inputs=X_train, epochs=epochs, batch_size=batch_size, steps_per_epoch=steps_per_epoch, validation_data=(X_test, ), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir)], ) print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", ) eg.utils.plot_history(history) # get random samples idxs = np.random.randint(0, len(X_test), size=(5, )) x_sample = X_test[idxs] # get predictions preds = model.predict(x=x_sample) y_pred = jax.nn.sigmoid(preds["logits"]) # plot and save results figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") plt.show()
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: int = 200, epochs: int = 100, batch_size: int = 64, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"])[..., None] y_train = dataset["train"]["label"] X_test = np.stack(dataset["test"]["image"])[..., None] y_test = dataset["test"]["label"] print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) model = eg.Model( module=CNN(), loss=eg.losses.Crossentropy(), metrics=eg.metrics.Accuracy(), optimizer=optax.adam(1e-3), eager=eager, ) # show summary model.summary(X_train[:64]) train_dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataset = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)) test_dataloader = DataLoader(test_dataset, batch_size=batch_size) history = model.fit( train_dataloader, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=test_dataloader, callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], ) eg.utils.plot_history(history) model.save("models/conv") model = eg.load("models/conv") print(model.evaluate(x=X_test, y=y_test)) # get random samples idxs = np.random.randint(0, 10000, size=(9, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(3): for j in range(3): k = 3 * i + j plt.subplot(3, 3, k + 1) plt.title(f"{np.argmax(y_pred[k])}") plt.imshow(x_sample[k], cmap="gray") # tbwriter.add_figure("Conv classifier", figure, 100) plt.show()
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: int = 200, epochs: int = 100, batch_size: int = 64, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"]) X_test = np.stack(dataset["test"]["image"]) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = eg.Model( module=MLP(n1=256, n2=64), loss=MeanSquaredError(), optimizer=optax.rmsprop(0.001), eager=eager, ) model.summary(X_train[:64]) # Notice we are not passing `y` history = model.fit( inputs=X_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, ), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir=logdir, update_freq=300)], ) eg.utils.plot_history(history) # get random samples idxs = np.random.randint(0, 10000, size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") plt.show()
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: int = 200, batch_size: int = 64, epochs: int = 100, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"]) y_train = dataset["train"]["label"] X_test = np.stack(dataset["test"]["image"]) y_test = dataset["test"]["label"] print("X_train:", X_train.shape, X_train.dtype) print("y_train:", y_train.shape, y_train.dtype) print("X_test:", X_test.shape, X_test.dtype) print("y_test:", y_test.shape, y_test.dtype) model = eg.Model( module=MLP(n1=300, n2=100), loss=[ eg.losses.Crossentropy(), eg.regularizers.L2(l=1e-4), ], metrics=eg.metrics.Accuracy(), optimizer=optax.adamw(1e-3), eager=eager, ) model.summary(X_train[:64]) history = model.fit( inputs=X_train, labels=y_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, y_test), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir=logdir)], ) eg.utils.plot_history(history) # get random samples idxs = np.random.randint(0, 10000, size=(9,)) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(3): for j in range(3): k = 3 * i + j plt.subplot(3, 3, k + 1) plt.title(f"{np.argmax(y_pred[k])}") plt.imshow(x_sample[k], cmap="gray") # tbwriter.add_figure("Predictions", figure, 100) plt.show() print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", )