Пример #1
0
    def test_caching(self):
        n_rows = 10

        features = Features({"foo": Value("string"), "bar": Value("string")})

        with tempfile.TemporaryDirectory() as tmp_dir:
            # Use \n for newline. Windows automatically adds the \r when writing the file
            # see https://docs.python.org/3/library/os.html#os.linesep
            open(os.path.join(tmp_dir, "table.csv"), "w",
                 encoding="utf-8").write("\n".join(",".join(["foo", "bar"])
                                                   for _ in range(n_rows + 1)))
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table.csv"),
                cache_dir=tmp_dir,
                split="train",
                keep_in_memory=False,
            )
            data_file = ds.cache_files[0]["filename"]
            fingerprint = ds._fingerprint
            self.assertEqual(len(ds), n_rows)
            del ds
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table.csv"),
                cache_dir=tmp_dir,
                split="train",
                keep_in_memory=False,
            )
            self.assertEqual(ds.cache_files[0]["filename"], data_file)
            self.assertEqual(ds._fingerprint, fingerprint)
            del ds
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table.csv"),
                cache_dir=tmp_dir,
                split="train",
                features=features,
                keep_in_memory=False,
            )
            self.assertNotEqual(ds.cache_files[0]["filename"], data_file)
            self.assertNotEqual(ds._fingerprint, fingerprint)
            del ds

            open(os.path.join(tmp_dir, "table.csv"), "w",
                 encoding="utf-8").write("\n".join(",".join(["Foo", "Bar"])
                                                   for _ in range(n_rows + 1)))
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table.csv"),
                cache_dir=tmp_dir,
                split="train",
                keep_in_memory=False,
            )
            self.assertNotEqual(ds.cache_files[0]["filename"], data_file)
            self.assertNotEqual(ds._fingerprint, fingerprint)
            self.assertEqual(len(ds), n_rows)
            del ds
Пример #2
0
    def test_features(self):
        n_rows = 10
        n_cols = 3

        def get_features(type):
            return Features({str(i): type for i in range(n_cols)})

        with tempfile.TemporaryDirectory() as tmp_dir:
            open(os.path.join(tmp_dir, "table.csv"), "w",
                 encoding="utf-8").write("\n".join(
                     ",".join([str(i) for i in range(n_cols)])
                     for _ in range(n_rows + 1)))
            for type in [
                    Value("float64"),
                    Value("int8"),
                    ClassLabel(num_classes=n_cols)
            ]:
                features = get_features(type)
                ds = load_dataset(
                    "csv",
                    data_files=os.path.join(tmp_dir, "table.csv"),
                    cache_dir=tmp_dir,
                    split="train",
                    features=features,
                )
                self.assertEqual(len(ds), n_rows)
                self.assertDictEqual(ds.features, features)
                del ds
Пример #3
0
def main(logdir: str = "runs",
         steps_per_epoch: tp.Optional[int] = None,
         epochs: int = 10,
         batch_size: int = 32):

    platform = jax.local_devices()[0].platform
    ndevices = len(jax.devices())
    print('devices ', jax.devices())
    print('platform ', platform)

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = dataset["train"]["image"][..., None]
    y_train = dataset["train"]["label"]
    X_test = dataset["test"]["image"][..., None]
    y_test = dataset["test"]["label"]

    accuracies = {}
    # we run distributed=False twice to remove any initial warmup costs
    for distributed in [False, False, True]:
        print(f'Distributed training = {distributed}')
        start_time = time.time()

        model = eg.Model(module=CNN(),
                         loss=eg.losses.Crossentropy(),
                         metrics=eg.metrics.Accuracy(),
                         optimizer=optax.adam(1e-3),
                         seed=42)

        if distributed:
            model = model.distributed()
            bs = batch_size  #int(batch_size / ndevices)
        else:
            bs = batch_size

        #model.summary(X_train[:64], depth=1)

        history = model.fit(inputs=X_train,
                            labels=y_train,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            batch_size=bs,
                            validation_data=(X_test, y_test),
                            shuffle=True,
                            verbose=3)

        ev = model.evaluate(x=X_test, y=y_test, verbose=1)
        print('eval ', ev)
        accuracies[distributed] = ev['accuracy']

        end_time = time.time()
        print(f'time taken ', {end_time - start_time})

    print(accuracies)
Пример #4
0
    def test_sep(self):
        n_rows = 10
        n_cols = 3

        with tempfile.TemporaryDirectory() as tmp_dir:
            open(os.path.join(tmp_dir, "table_comma.csv"),
                 "w",
                 encoding="utf-8").write("\n".join(
                     ",".join([str(i) for i in range(n_cols)])
                     for _ in range(n_rows + 1)))
            open(os.path.join(tmp_dir, "table_tab.csv"), "w",
                 encoding="utf-8").write("\n".join(
                     "\t".join([str(i) for i in range(n_cols)])
                     for _ in range(n_rows + 1)))
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table_comma.csv"),
                cache_dir=tmp_dir,
                split="train",
                sep=",",
            )
            self.assertEqual(len(ds), n_rows)
            self.assertEqual(len(ds.column_names), n_cols)
            del ds
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table_tab.csv"),
                cache_dir=tmp_dir,
                split="train",
                sep="\t",
            )
            self.assertEqual(len(ds), n_rows)
            self.assertEqual(len(ds.column_names), n_cols)
            del ds
            ds = load_dataset(
                "csv",
                data_files=os.path.join(tmp_dir, "table_comma.csv"),
                cache_dir=tmp_dir,
                split="train",
                sep="\t",
            )
            self.assertEqual(len(ds), n_rows)
            self.assertEqual(len(ds.column_names), 1)
            del ds
Пример #5
0
    def test_caching(self):
        n_samples = 10
        with tempfile.TemporaryDirectory() as tmp_dir:
            # Use \n for newline. Windows automatically adds the \r when writing the file
            # see https://docs.python.org/3/library/os.html#os.linesep
            open(os.path.join(tmp_dir, "text.txt"), "w",
                 encoding="utf-8").write("\n".join("foo"
                                                   for _ in range(n_samples)))
            ds = load_dataset(
                "text",
                data_files=os.path.join(tmp_dir, "text.txt"),
                cache_dir=tmp_dir,
                split="train",
                keep_in_memory=False,
            )
            data_file = ds.cache_files[0]["filename"]
            fingerprint = ds._fingerprint
            self.assertEqual(len(ds), n_samples)
            del ds
            ds = load_dataset(
                "text",
                data_files=os.path.join(tmp_dir, "text.txt"),
                cache_dir=tmp_dir,
                split="train",
                keep_in_memory=False,
            )
            self.assertEqual(ds.cache_files[0]["filename"], data_file)
            self.assertEqual(ds._fingerprint, fingerprint)
            del ds

            open(os.path.join(tmp_dir, "text.txt"), "w",
                 encoding="utf-8").write("\n".join("bar"
                                                   for _ in range(n_samples)))
            ds = load_dataset(
                "text",
                data_files=os.path.join(tmp_dir, "text.txt"),
                cache_dir=tmp_dir,
                split="train",
                keep_in_memory=False,
            )
            self.assertNotEqual(ds.cache_files[0]["filename"], data_file)
            self.assertNotEqual(ds._fingerprint, fingerprint)
            self.assertEqual(len(ds), n_samples)
            del ds
Пример #6
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: int = 200,
    epochs: int = 100,
    batch_size: int = 64,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])
    y_train = dataset["train"]["label"]
    X_test = np.stack(dataset["test"]["image"])
    y_test = dataset["test"]["label"]

    print("X_train:", X_train.shape, X_train.dtype)
    print("y_train:", y_train.shape, y_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)
    print("y_test:", y_test.shape, y_test.dtype)

    model = Model(
        features_out=10,
        optimizer=optax.adam(1e-3),
        eager=eager,
    )

    history = model.fit(
        inputs=X_train,
        labels=y_train,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        validation_data=(X_test, y_test),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir)],
    )

    eg.utils.plot_history(history)
Пример #7
0
    def __init__(self, training: bool = True):

        dataset = load_dataset("mnist")
        dataset.set_format("np")
        X_train = np.stack(dataset["train"]["image"])
        y_train = dataset["train"]["label"]
        X_test = np.stack(dataset["test"]["image"])
        y_test = dataset["test"]["label"]

        if training:
            self.x = X_train
            self.y = y_train
        else:
            self.x = X_test
            self.y = y_test
Пример #8
0
 def test_load_real_dataset(self, dataset_name):
     path = "./datasets/" + dataset_name
     dataset_module = dataset_module_factory(
         path, download_config=DownloadConfig(local_files_only=True))
     builder_cls = import_main_class(dataset_module.module_path)
     name = builder_cls.BUILDER_CONFIGS[
         0].name if builder_cls.BUILDER_CONFIGS else None
     with tempfile.TemporaryDirectory() as temp_cache_dir:
         dataset = load_dataset(path,
                                name=name,
                                cache_dir=temp_cache_dir,
                                download_mode=GenerateMode.FORCE_REDOWNLOAD)
         for split in dataset.keys():
             self.assertTrue(len(dataset[split]) > 0)
         del dataset
Пример #9
0
 def test_load_real_dataset_all_configs(self, dataset_name):
     path = "./datasets/" + dataset_name
     dataset_module = dataset_module_factory(
         path, download_config=DownloadConfig(local_files_only=True))
     builder_cls = import_main_class(dataset_module.module_path)
     config_names = ([
         config.name for config in builder_cls.BUILDER_CONFIGS
     ] if len(builder_cls.BUILDER_CONFIGS) > 0 else [None])
     for name in config_names:
         with tempfile.TemporaryDirectory() as temp_cache_dir:
             dataset = load_dataset(
                 path,
                 name=name,
                 cache_dir=temp_cache_dir,
                 download_mode=DownloadMode.FORCE_REDOWNLOAD)
             for split in dataset.keys():
                 self.assertTrue(len(dataset[split]) > 0)
             del dataset
Пример #10
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: tp.Optional[int] = None,
    epochs: int = 100,
    batch_size: int = 32,
    distributed: bool = False,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])[..., None]
    y_train = dataset["train"]["label"]
    X_test = np.stack(dataset["test"]["image"])[..., None]
    y_test = dataset["test"]["label"]

    print("X_train:", X_train.shape, X_train.dtype)
    print("y_train:", y_train.shape, y_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)
    print("y_test:", y_test.shape, y_test.dtype)

    model = eg.Model(
        module=CNN(),
        loss=eg.losses.Crossentropy(),
        metrics=eg.metrics.Accuracy(),
        optimizer=optax.adam(1e-3),
        eager=eager,
    )

    if distributed:
        model = model.distributed()

    # show model summary
    model.summary(X_train[:64], depth=1)

    history = model.fit(
        inputs=X_train,
        labels=y_train,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        validation_data=(X_test, y_test),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir)],
    )

    eg.utils.plot_history(history)

    print(model.evaluate(x=X_test, y=y_test))

    # get random samples
    idxs = np.random.randint(0, 10000, size=(9, ))
    x_sample = X_test[idxs]

    # get predictions
    model = model.local()
    y_pred = model.predict(x=x_sample)

    # plot results
    figure = plt.figure(figsize=(12, 12))
    for i in range(3):
        for j in range(3):
            k = 3 * i + j
            plt.subplot(3, 3, k + 1)

            plt.title(f"{np.argmax(y_pred[k])}")
            plt.imshow(x_sample[k], cmap="gray")

    plt.show()
Пример #11
0
def main(
    steps_per_epoch: tp.Optional[int] = None,
    batch_size: int = 32,
    epochs: int = 50,
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8)
    X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8)

    # Now binarize data
    X_train = (X_train / 255.0).astype(jnp.float32)
    X_test = (X_test / 255.0).astype(jnp.float32)

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = eg.Model(
        module=VariationalAutoEncoder(latent_size=LATENT_SIZE),
        loss=[BinaryCrossEntropy(from_logits=True, on="logits")],
        optimizer=optax.adam(1e-3),
        eager=eager,
    )
    assert model.module is not None

    model.summary(X_train[:64])

    # Fit with datasets in memory
    history = model.fit(
        inputs=X_train,
        epochs=epochs,
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir)],
    )

    print(
        "\n\n\nMetrics and images can be explored using tensorboard using:",
        f"\n \t\t\t tensorboard --logdir {logdir}",
    )

    eg.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, len(X_test), size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
        figure = plt.figure(figsize=(12, 12))
        for i in range(5):
            plt.subplot(2, 5, i + 1)
            plt.imshow(x_sample[i], cmap="gray")
            plt.subplot(2, 5, 5 + i + 1)
            plt.imshow(y_pred["det_image"][i], cmap="gray")
        # tbwriter.add_figure("VAE Example", figure, epochs)

    # sample
    model_decoder = eg.Model(model.module.decoder)

    z_samples = np.random.normal(size=(12, LATENT_SIZE))
    samples = model_decoder.predict(z_samples)
    samples = jax.nn.sigmoid(samples)

    # plot and save results
    # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
    figure = plt.figure(figsize=(5, 12))
    plt.title("Generative Samples")
    for i in range(5):
        plt.subplot(2, 5, 2 * i + 1)
        plt.imshow(samples[i], cmap="gray")
        plt.subplot(2, 5, 2 * i + 2)
        plt.imshow(samples[i + 1], cmap="gray")
    # # tbwriter.add_figure("VAE Generative Example", figure, epochs)

    plt.show()
Пример #12
0
def distributed_load_dataset(args):
    data_name, tmp_dir, datafiles = args
    dataset = load_dataset(data_name, cache_dir=tmp_dir, data_files=datafiles)
    return dataset
Пример #13
0
def main(
    steps_per_epoch: int = 200,
    batch_size: int = 64,
    epochs: int = 50,
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])
    X_test = np.stack(dataset["test"]["image"])
    # Now binarize data
    X_train = (X_train > 0).astype(jnp.float32)
    X_test = (X_test > 0).astype(jnp.float32)

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    def forward(x: jnp.ndarray):
        return VAE(latent_size=LATENT_SIZE)(x)

    model = eg.Model(
        module=hk.transform_with_state(forward),
        loss=[
            BinaryCrossEntropy(on="logits"),
            KL(weight=0.1),
        ],
        optimizer=optax.adam(1e-3),
        eager=eager,
    )

    model.summary(X_train[:batch_size])

    # Fit with datasets in memory
    history = model.fit(
        inputs=X_train,
        epochs=epochs,
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir)],
    )

    print(
        "\n\n\nMetrics and images can be explored using tensorboard using:",
        f"\n \t\t\t tensorboard --logdir {logdir}",
    )

    eg.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, len(X_test), size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    preds = model.predict(x=x_sample)
    y_pred = jax.nn.sigmoid(preds["logits"])

    # plot and save results
    figure = plt.figure(figsize=(12, 12))
    for i in range(5):
        plt.subplot(2, 5, i + 1)
        plt.imshow(x_sample[i], cmap="gray")
        plt.subplot(2, 5, 5 + i + 1)
        plt.imshow(y_pred[i], cmap="gray")

    plt.show()
Пример #14
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: int = 200,
    epochs: int = 100,
    batch_size: int = 64,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])[..., None]
    y_train = dataset["train"]["label"]
    X_test = np.stack(dataset["test"]["image"])[..., None]
    y_test = dataset["test"]["label"]

    print("X_train:", X_train.shape, X_train.dtype)
    print("y_train:", y_train.shape, y_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)
    print("y_test:", y_test.shape, y_test.dtype)

    model = eg.Model(
        module=CNN(),
        loss=eg.losses.Crossentropy(),
        metrics=eg.metrics.Accuracy(),
        optimizer=optax.adam(1e-3),
        eager=eager,
    )

    # show summary
    model.summary(X_train[:64])

    train_dataset = TensorDataset(torch.from_numpy(X_train),
                                  torch.from_numpy(y_train))
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
    test_dataset = TensorDataset(torch.from_numpy(X_test),
                                 torch.from_numpy(y_test))
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

    history = model.fit(
        train_dataloader,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        validation_data=test_dataloader,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir)],
    )

    eg.utils.plot_history(history)

    model.save("models/conv")

    model = eg.load("models/conv")

    print(model.evaluate(x=X_test, y=y_test))

    # get random samples
    idxs = np.random.randint(0, 10000, size=(9, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
        figure = plt.figure(figsize=(12, 12))
        for i in range(3):
            for j in range(3):
                k = 3 * i + j
                plt.subplot(3, 3, k + 1)

                plt.title(f"{np.argmax(y_pred[k])}")
                plt.imshow(x_sample[k], cmap="gray")
        # tbwriter.add_figure("Conv classifier", figure, 100)

    plt.show()
Пример #15
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: int = 200,
    epochs: int = 100,
    batch_size: int = 64,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])
    X_test = np.stack(dataset["test"]["image"])

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = eg.Model(
        module=MLP(n1=256, n2=64),
        loss=MeanSquaredError(),
        optimizer=optax.rmsprop(0.001),
        eager=eager,
    )

    model.summary(X_train[:64])

    # Notice we are not passing `y`
    history = model.fit(
        inputs=X_train,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir, update_freq=300)],
    )

    eg.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, 10000, size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:

        figure = plt.figure(figsize=(12, 12))
        for i in range(5):
            plt.subplot(2, 5, i + 1)
            plt.imshow(x_sample[i], cmap="gray")
            plt.subplot(2, 5, 5 + i + 1)
            plt.imshow(y_pred[i], cmap="gray")

    plt.show()
Пример #16
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: int = 200,
    batch_size: int = 64,
    epochs: int = 100,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])
    y_train = dataset["train"]["label"]
    X_test = np.stack(dataset["test"]["image"])
    y_test = dataset["test"]["label"]

    print("X_train:", X_train.shape, X_train.dtype)
    print("y_train:", y_train.shape, y_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)
    print("y_test:", y_test.shape, y_test.dtype)

    model = eg.Model(
        module=MLP(n1=300, n2=100),
        loss=[
            eg.losses.Crossentropy(),
            eg.regularizers.L2(l=1e-4),
        ],
        metrics=eg.metrics.Accuracy(),
        optimizer=optax.adamw(1e-3),
        eager=eager,
    )

    model.summary(X_train[:64])

    history = model.fit(
        inputs=X_train,
        labels=y_train,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        validation_data=(X_test, y_test),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir)],
    )

    eg.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, 10000, size=(9,))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
        figure = plt.figure(figsize=(12, 12))
        for i in range(3):
            for j in range(3):
                k = 3 * i + j
                plt.subplot(3, 3, k + 1)
                plt.title(f"{np.argmax(y_pred[k])}")
                plt.imshow(x_sample[k], cmap="gray")
        # tbwriter.add_figure("Predictions", figure, 100)

    plt.show()

    print(
        "\n\n\nMetrics and images can be explored using tensorboard using:",
        f"\n \t\t\t tensorboard --logdir {logdir}",
    )