Example #1
0
    def test_basic_predict(self):
        # FIXME: test succeeds if run alone or if run on the cpu-only version of jax
        # test fails with "DNN library is not found" if run on gpu with all other tests together

        model = elegy.Model(elegy.nets.resnet.ResNet18(), eager=True)
        assert isinstance(model.module, elegy.Module)

        x = np.random.random((2, 224, 224, 3)).astype(np.float32)

        model.init(x)
        y = model.predict(x)

        # update_modules results in a call to `set_default_parameters` for elegy Modules
        # it might be better to have the user call this explicitly to avoid potential OOM
        model.update_modules()

        assert jnp.all(y.shape == (2, 1000))

        # test loading weights from file
        with tempfile.TemporaryDirectory() as tempdir:
            pklpath = os.path.join(tempdir, "delete_me.pkl")
            open(pklpath, "wb").write(
                pickle.dumps(model.module.get_default_parameters()))

            new_r18 = elegy.nets.resnet.ResNet18(weights=pklpath)
            y2 = elegy.Model(new_r18, eager=True).predict(x, initialize=True)

        assert np.allclose(y, y2, rtol=0.001)
Example #2
0
    def test_example(self):
        class MLP(elegy.Module):
            def __apply__(self, input):
                mlp = hk.Sequential([
                    hk.Linear(10),
                ])
                return mlp(input)

        callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3)
        # This callback will stop the training when there is no improvement in
        # the for three consecutive epochs.
        model = elegy.Model(
            module=MLP.defer(),
            loss=elegy.losses.MeanSquaredError(),
            optimizer=optix.rmsprop(0.01),
        )
        history = model.fit(
            np.arange(100).reshape(5, 20).astype(np.float32),
            np.zeros(5),
            epochs=10,
            batch_size=1,
            callbacks=[callback],
            verbose=0,
        )
        assert len(history.history["loss"]) == 7  # Only 7 epochs are run.
Example #3
0
def main(batch_size: int = 64, k: int = 5, debug: bool = False):

    noise = np.float32(np.random.normal(size=(3000, 1)))  # random noise
    y_train = np.float32(np.random.uniform(-10.5, 10.5, (1, 3000))).T
    X_train = np.float32(
        np.sin(0.75 * y_train) * 7.0 + y_train * 0.5 + noise * 1.0)

    X_train = X_train / np.abs(X_train.max())
    y_train = y_train / np.abs(y_train.max())

    visualize_data(X_train, y_train)

    model = elegy.Model(module=MixtureModel(k=k),
                        loss=MixtureNLL(),
                        optimizer=optax.adam(3e-4))

    model.summary(X_train[:batch_size], depth=1)

    model.fit(
        x=X_train,
        y=y_train,
        epochs=500,
        batch_size=batch_size,
        shuffle=True,
    )

    visualize_model(X_train, y_train, model, k)
Example #4
0
    def test_evaluate(self):
        class mse(eg.Loss):
            def call(self, target, preds):
                return jnp.mean((target - preds)**2)

        class mae(eg.Metric):
            value: eg.MetricState = eg.MetricState.node(
                default=jnp.array(0.0, jnp.float32))

            def update(self, target, preds):
                return jnp.mean(jnp.abs(target - preds))

            def compute(self) -> tp.Any:
                return self.value

        model = eg.Model(
            module=eg.Linear(1),
            loss=dict(a=mse()),
            metrics=dict(b=mae()),
            optimizer=optax.adamw(1e-3),
            eager=True,
        )

        X = np.random.uniform(size=(5, 2))
        y = np.random.uniform(size=(5, 1))

        logs = model.evaluate(x=X, y=y)

        assert "a/mse_loss" in logs
        assert "b/mae" in logs
        assert "loss" in logs
Example #5
0
    def test_cloudpickle(self):
        model = elegy.Model(
            module=MLP(n1=3, n2=1),
            loss=[
                elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
                elegy.regularizers.GlobalL2(l=1e-4),
            ],
            metrics=elegy.metrics.SparseCategoricalAccuracy(),
            optimizer=optax.adamw(1e-3),
            run_eagerly=True,
        )

        X = np.random.uniform(size=(5, 7, 7))
        y = np.random.randint(10, size=(5, ))

        model.init(X, y)
        y0 = model.predict(X)

        model_pkl = cloudpickle.dumps(model)
        newmodel = cloudpickle.loads(model_pkl)

        newmodel.states = model.states
        newmodel.initial_states = model.initial_states

        y1 = newmodel.predict(X)
        assert np.all(y0 == y1)
Example #6
0
    def test_evaluate(self):

        model = elegy.Model(
            module=MLP(n1=3, n2=1),
            loss=[
                elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
                elegy.regularizers.GlobalL2(l=1e-4),
            ],
            metrics=elegy.metrics.SparseCategoricalAccuracy(),
            optimizer=optax.adamw(1e-3),
            run_eagerly=True,
        )

        X = np.random.uniform(size=(5, 7, 7))
        y = np.random.randint(10, size=(5, ))

        history = model.fit(
            x=X,
            y=y,
            epochs=1,
            steps_per_epoch=1,
            batch_size=5,
            validation_data=(X, y),
            shuffle=True,
            verbose=1,
        )

        logs = model.evaluate(X, y)

        eval_acc = logs["sparse_categorical_accuracy"]
        predict_acc = (model.predict(X).argmax(-1) == y).mean()

        assert eval_acc == predict_acc
Example #7
0
    def test_lr_logging(self):
        model = elegy.Model(
            module=MLP(n1=3, n2=1),
            loss=elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=elegy.metrics.SparseCategoricalAccuracy(),
            optimizer=elegy.Optimizer(
                optax.adamw(1.0, b1=0.95),
                lr_schedule=lambda step, epoch: jnp.array(1e-3),
            ),
            run_eagerly=True,
        )

        X = np.random.uniform(size=(5, 7, 7))
        y = np.random.randint(10, size=(5, ))

        history = model.fit(
            x=X,
            y=y,
            epochs=1,
            steps_per_epoch=1,
            batch_size=5,
            validation_data=(X, y),
            shuffle=True,
            verbose=0,
        )

        assert "lr" in history.history
        assert np.allclose(history.history["lr"], 1e-3)
Example #8
0
    def test_evaluate(self):

        model = eg.Model(
            module=MLP(dmid=3, dout=4),
            loss=[
                eg.losses.Crossentropy(),
                eg.regularizers.L2(l=1e-4),
            ],
            metrics=eg.metrics.Accuracy(),
            optimizer=optax.adamw(1e-3),
            eager=True,
        )

        X = np.random.uniform(size=(5, 2))
        y = np.random.randint(4, size=(5, ))

        history = model.fit(
            inputs=X,
            labels=y,
            epochs=1,
            steps_per_epoch=1,
            batch_size=5,
            validation_data=(X, y),
            shuffle=True,
            verbose=1,
        )

        logs = model.evaluate(X, y)

        eval_acc = logs["accuracy"]
        predict_acc = (model.predict(X).argmax(-1) == y).mean()

        assert eval_acc == predict_acc
Example #9
0
    def test_example_restore(self):
        class MLP(eg.Module):
            @eg.compact
            def __call__(self, x):
                x = eg.Linear(10)(x)
                x = jax.lax.stop_gradient(x)
                return x

        # This callback will stop the training when there is no improvement in
        # the for three consecutive epochs.
        model = eg.Model(
            module=MLP(),
            loss=eg.losses.MeanSquaredError(),
            optimizer=optax.rmsprop(0.01),
        )
        history = model.fit(
            inputs=np.ones((5, 20)),
            labels=np.zeros((5, 10)),
            epochs=10,
            batch_size=1,
            callbacks=[
                eg.callbacks.EarlyStopping(monitor="loss",
                                           patience=3,
                                           restore_best_weights=True)
            ],
            verbose=0,
        )
        assert len(history.history["loss"]) == 4  # Only 4 epochs are run.
Example #10
0
    def test_on_model(self):

        model = elegy.Model(module=elegy.nn.Linear(2))

        x = np.ones([3, 5])

        y_pred = model.predict(x, initialize=True)
        logs = model.evaluate(x)
Example #11
0
def main(logdir: str = "runs",
         steps_per_epoch: tp.Optional[int] = None,
         epochs: int = 10,
         batch_size: int = 32):

    platform = jax.local_devices()[0].platform
    ndevices = len(jax.devices())
    print('devices ', jax.devices())
    print('platform ', platform)

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = dataset["train"]["image"][..., None]
    y_train = dataset["train"]["label"]
    X_test = dataset["test"]["image"][..., None]
    y_test = dataset["test"]["label"]

    accuracies = {}
    # we run distributed=False twice to remove any initial warmup costs
    for distributed in [False, False, True]:
        print(f'Distributed training = {distributed}')
        start_time = time.time()

        model = eg.Model(module=CNN(),
                         loss=eg.losses.Crossentropy(),
                         metrics=eg.metrics.Accuracy(),
                         optimizer=optax.adam(1e-3),
                         seed=42)

        if distributed:
            model = model.distributed()
            bs = batch_size  #int(batch_size / ndevices)
        else:
            bs = batch_size

        #model.summary(X_train[:64], depth=1)

        history = model.fit(inputs=X_train,
                            labels=y_train,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            batch_size=bs,
                            validation_data=(X_test, y_test),
                            shuffle=True,
                            verbose=3)

        ev = model.evaluate(x=X_test, y=y_test, verbose=1)
        print('eval ', ev)
        accuracies[distributed] = ev['accuracy']

        end_time = time.time()
        print(f'time taken ', {end_time - start_time})

    print(accuracies)
Example #12
0
    def test_predict(self):

        model = eg.Model(module=eg.Linear(1))

        X = np.random.uniform(size=(5, 2))
        y = np.random.randint(10, size=(5, 1))

        y_pred = model.predict(X)

        assert y_pred.shape == (5, 1)
Example #13
0
    def test_autodownload_pretrained_r50(self):
        fname, _ = urllib.request.urlretrieve(
            "https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg"
        )
        im = np.array(PIL.Image.open(fname).resize([224, 224
                                                    ])) / np.float32(255)

        r50 = elegy.nets.resnet.ResNet50(weights="imagenet")
        with jax.disable_jit():
            assert elegy.Model(r50).predict(im[np.newaxis]).argmax() == 245
Example #14
0
    def test_predict(self):

        model = elegy.Model(module=elegy.nn.Linear(1))

        X = np.random.uniform(size=(5, 10))
        y = np.random.randint(10, size=(5, 1))

        model.init(x=X, y=y)
        y_pred = model.predict(x=X)

        assert y_pred.shape == (5, 1)
Example #15
0
    def test_on_predict(self):
        class TestModule(elegy.Module):
            def call(self, x, training):
                return elegy.nn.BatchNormalization()(x, training)

        model = elegy.Model(module=TestModule())

        x = jnp.ones([3, 5])

        y_pred = model.predict(x)
        logs = model.evaluate(x)
Example #16
0
def main(debug: bool = False, eager: bool = False):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get()
    # Now binarize data
    X_train = (X_train > 0).astype(jnp.float32)
    X_test = (X_test > 0).astype(jnp.float32)

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = elegy.Model(
        module=VariationalAutoEncoder.defer(),
        loss=[KLDivergence(), BinaryCrossEntropy(on="logits")],
        optimizer=optix.adam(1e-3),
        run_eagerly=eager,
    )

    epochs = 10

    # Fit with datasets in memory
    history = model.fit(
        x=X_train,
        epochs=epochs,
        batch_size=64,
        steps_per_epoch=100,
        validation_data=(X_test, ),
        shuffle=True,
    )
    plot_history(history)

    # get random samples
    idxs = np.random.randint(0, len(X_test), size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot results
    plt.figure(figsize=(12, 12))
    for i in range(5):
        plt.subplot(2, 5, i + 1)
        plt.imshow(x_sample[i], cmap="gray")
        plt.subplot(2, 5, 5 + i + 1)
        plt.imshow(y_pred["image"][i], cmap="gray")

    plt.show()
Example #17
0
    def test_on_predict(self):
        class TestModule(elegy.Module):
            def call(self, x, training):
                return elegy.nn.Dropout(0.5)(x, training)

        model = elegy.Model(TestModule())

        x = jnp.ones([3, 5])

        y_pred = model.predict(x)
        logs = model.evaluate(x)

        assert jnp.all(y_pred == x)
Example #18
0
    def test_on_predict(self):

        model = elegy.Model(
            elegy.nn.Sequential(lambda: [
                elegy.nn.Flatten(),
                elegy.nn.Linear(5),
                jax.nn.relu,
                elegy.nn.Linear(2),
            ]))

        x = jnp.ones([3, 5])

        y_pred = model.predict(x)
        logs = model.evaluate(x)
Example #19
0
    def test_model_fit(self):
        ds = DS0()
        loader_train = elegy.data.DataLoader(ds, batch_size=4, n_workers=4)
        loader_valid = elegy.data.DataLoader(ds, batch_size=4, n_workers=4)

        class Module(elegy.Module):
            def call(self, x):
                x = jnp.mean(x, axis=[1, 2])
                x = elegy.nn.Linear(20)(x)
                return x

        model = elegy.Model(
            Module(),
            loss=elegy.losses.SparseCategoricalCrossentropy(),
            optimizer=optax.sgd(0.1),
        )
        model.fit(loader_train, validation_data=loader_valid, epochs=3)
Example #20
0
    def test_distributed_init(self):
        n_devices = jax.device_count()
        batch_size = 5 * n_devices

        x = np.random.uniform(size=(batch_size, 1))
        y = 1.4 * x + 0.1 * np.random.uniform(size=(batch_size, 2))

        model = eg.Model(
            eg.Linear(2),
            loss=[eg.losses.MeanSquaredError()],
        )

        model = model.distributed()

        model.init_on_batch(x)

        assert model.module.kernel.shape == (n_devices, 1, 2)
        assert model.module.bias.shape == (n_devices, 2)
Example #21
0
    def test_saved_model(self):

        with TemporaryDirectory() as model_dir:

            model = eg.Model(module=eg.Linear(4))

            x = np.random.uniform(size=(5, 6))

            model.merge

            model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8])

            output = str(sh.ls(model_dir))

            assert "saved_model.pb" in output
            assert "variables" in output

            saved_model = tf.saved_model.load(model_dir)

            saved_model
Example #22
0
    def test_saved_model(self):

        with TemporaryDirectory() as model_dir:

            model = elegy.Model(module=elegy.nn.Linear(4))

            x = np.random.uniform(size=(5, 6))

            with pytest.raises(elegy.types.ModelNotInitialized):
                model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8])

            model.init(x)
            model.saved_model(x, model_dir, batch_size=[1, 2, 4, 8])

            output = str(sh.ls(model_dir))

            assert "saved_model.pb" in output
            assert "variables" in output

            saved_model = tf.saved_model.load(model_dir)

            saved_model
Example #23
0
    def test_saved_model_poly(self):

        with TemporaryDirectory() as model_dir:

            model = eg.Model(module=eg.Linear(4))

            x = np.random.uniform(size=(5, 6)).astype(np.float32)

            model.saved_model(x, model_dir, batch_size=None)

            output = str(sh.ls(model_dir))

            assert "saved_model.pb" in output
            assert "variables" in output

            saved_model = tf.saved_model.load(model_dir)

            # change batch
            x = np.random.uniform(size=(3, 6)).astype(np.float32)
            y = saved_model(x)

            assert y.shape == (3, 4)
Example #24
0
    def test_evaluate(self):
        def mse(y_true, y_pred):
            return jnp.mean((y_true - y_pred)**2)

        def mae(y_true, y_pred):
            return jnp.mean(jnp.abs(y_true - y_pred))

        model = elegy.Model(
            module=elegy.nn.Linear(1),
            loss=dict(a=mse),
            metrics=dict(b=mae),
            optimizer=optax.adamw(1e-3),
            run_eagerly=True,
        )

        X = np.random.uniform(size=(5, 10))
        y = np.random.uniform(size=(5, 1))

        logs = model.evaluate(x=X, y=y)

        assert "a/mse_loss" in logs
        assert "b/mae" in logs
        assert "loss" in logs
Example #25
0
    def test_cloudpickle(self):
        model = eg.Model(
            module=eg.Linear(10),
            loss=[
                eg.losses.Crossentropy(),
                eg.regularizers.L2(1e-4),
            ],
            metrics=eg.metrics.Accuracy(),
            optimizer=optax.adamw(1e-3),
            eager=True,
        )

        X = np.random.uniform(size=(5, 2))
        y = np.random.randint(10, size=(5, ))

        y0 = model.predict(X)

        with TemporaryDirectory() as model_dir:
            model.save(model_dir)
            newmodel = eg.load(model_dir)

        y1 = newmodel.predict(X)
        assert np.all(y0 == y1)
Example #26
0
    def test_example(self):
        class MLP(elegy.Module):
            def call(self, x):
                x = elegy.nn.Linear(10)(x)
                x = jax.lax.stop_gradient(x)
                return x

        callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3)
        # This callback will stop the training when there is no improvement in
        # the for three consecutive epochs.
        model = elegy.Model(
            module=MLP(),
            loss=elegy.losses.MeanSquaredError(),
            optimizer=optax.rmsprop(0.01),
        )
        history = model.fit(
            x=np.ones((5, 20)),
            y=np.zeros((5, 10)),
            epochs=10,
            batch_size=1,
            callbacks=[callback],
            verbose=0,
        )
        assert len(history.history["loss"]) == 4  # Only 4 epochs are run.
Example #27
0
    def test_summaries(self):
        class ModuleC(linen.Module):
            @linen.compact
            @elegy.flax_summarize
            def __call__(self, x):
                c1 = self.param("c1", lambda _: jnp.ones([5]))
                c2 = self.variable("states", "c2", lambda: jnp.ones([6]))

                x = jax.nn.relu(x)
                elegy.flax_summary(self, "relu", jax.nn.relu, x)

                return x

        class ModuleB(linen.Module):
            @linen.compact
            @elegy.flax_summarize
            def __call__(self, x):
                b1 = self.param("b1", lambda _: jnp.ones([3]))
                b2 = self.variable("states", "b2", lambda: jnp.ones([4]))

                x = ModuleC()(x)

                x = jax.nn.relu(x)
                elegy.flax_summary(self, "relu", jax.nn.relu, x)

                return x

        class ModuleA(linen.Module):
            @linen.compact
            @elegy.flax_summarize
            def __call__(self, x):
                a1 = self.param("a1", lambda _: jnp.ones([1]))
                a2 = self.variable("states", "a2", lambda: jnp.ones([2]))

                x = ModuleB()(x)

                x = jax.nn.relu(x)
                elegy.flax_summary(self, "relu", jax.nn.relu, x)

                return x

        model = elegy.Model(ModuleA())
        model.init(x=jnp.ones([10, 2]))

        summary_text = model.summary(x=jnp.ones([10, 2]),
                                     depth=1,
                                     return_repr=True)
        assert summary_text is not None

        lines = summary_text.split("\n")

        assert "ModuleB_0" in lines[7]
        assert "ModuleB" in lines[7]
        assert "(10, 2)" in lines[7]
        assert "8" in lines[7]
        assert "32 B" in lines[7]
        assert "10" in lines[7]
        assert "40 B" in lines[7]

        assert "relu" in lines[9]
        assert "(10, 2)" in lines[9]

        assert "*" in lines[11]
        assert "ModuleA" in lines[11]
        assert "(10, 2)" in lines[11]
        assert "1" in lines[11]
        assert "4 B" in lines[11]
        assert "2" in lines[11]
        assert "8 B" in lines[11]

        assert "9" in lines[13]
        assert "36 B" in lines[13]

        assert "12" in lines[13]
        assert "48 B" in lines[13]

        assert "21" in lines[16]
        assert "84 B" in lines[16]
Example #28
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: tp.Optional[int] = None,
    epochs: int = 100,
    batch_size: int = 32,
    distributed: bool = False,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])[..., None]
    y_train = dataset["train"]["label"]
    X_test = np.stack(dataset["test"]["image"])[..., None]
    y_test = dataset["test"]["label"]

    print("X_train:", X_train.shape, X_train.dtype)
    print("y_train:", y_train.shape, y_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)
    print("y_test:", y_test.shape, y_test.dtype)

    model = eg.Model(
        module=CNN(),
        loss=eg.losses.Crossentropy(),
        metrics=eg.metrics.Accuracy(),
        optimizer=optax.adam(1e-3),
        eager=eager,
    )

    if distributed:
        model = model.distributed()

    # show model summary
    model.summary(X_train[:64], depth=1)

    history = model.fit(
        inputs=X_train,
        labels=y_train,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        validation_data=(X_test, y_test),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir)],
    )

    eg.utils.plot_history(history)

    print(model.evaluate(x=X_test, y=y_test))

    # get random samples
    idxs = np.random.randint(0, 10000, size=(9, ))
    x_sample = X_test[idxs]

    # get predictions
    model = model.local()
    y_pred = model.predict(x=x_sample)

    # plot results
    figure = plt.figure(figsize=(12, 12))
    for i in range(3):
        for j in range(3):
            k = 3 * i + j
            plt.subplot(3, 3, k + 1)

            plt.title(f"{np.argmax(y_pred[k])}")
            plt.imshow(x_sample[k], cmap="gray")

    plt.show()
Example #29
0
    def test_summaries(self):
        class ModuleC(elegy.Module):
            def call(self, x):
                c1 = self.add_parameter("c1", lambda: jnp.ones([5]))
                c2 = self.add_parameter("c2", lambda: jnp.ones([6]), trainable=False)

                x = jax.nn.relu(x)
                self.add_summary("relu", jax.nn.relu, x)

                return x

        class ModuleB(elegy.Module):
            def call(self, x):
                b1 = self.add_parameter("b1", lambda: jnp.ones([3]))
                b2 = self.add_parameter("b2", lambda: jnp.ones([4]), trainable=False)

                x = ModuleC()(x)

                x = jax.nn.relu(x)
                self.add_summary("relu", jax.nn.relu, x)

                return x

        class ModuleA(elegy.Module):
            def call(self, x):
                a1 = self.add_parameter("a1", lambda: jnp.ones([1]))
                a2 = self.add_parameter("a2", lambda: jnp.ones([2]), trainable=False)

                x = ModuleB()(x)

                x = jax.nn.relu(x)
                self.add_summary("relu", jax.nn.relu, x)

                return x

        model = elegy.Model(ModuleA())

        summary_text = model.summary(x=jnp.ones([10, 2]), depth=1, return_repr=True)
        assert summary_text is not None

        lines = summary_text.split("\n")

        assert "module_b" in lines[7]
        assert "ModuleB" in lines[7]
        assert "(10, 2)" in lines[7]
        assert "8" in lines[7]
        assert "32 B" in lines[7]
        assert "10" in lines[7]
        assert "40 B" in lines[7]

        assert "relu" in lines[9]
        assert "(10, 2)" in lines[9]

        assert "*" in lines[11]
        assert "ModuleA" in lines[11]
        assert "(10, 2)" in lines[11]
        assert "1" in lines[11]
        assert "4 B" in lines[11]
        assert "2" in lines[11]
        assert "8 B" in lines[11]

        assert "21" in lines[13]
        assert "84 B" in lines[13]

        assert "9" in lines[14]
        assert "36 B" in lines[14]

        assert "12" in lines[15]
        assert "48 B" in lines[15]
Example #30
0
def main(
    steps_per_epoch: tp.Optional[int] = None,
    batch_size: int = 32,
    epochs: int = 50,
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8)
    X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8)

    # Now binarize data
    X_train = (X_train / 255.0).astype(jnp.float32)
    X_test = (X_test / 255.0).astype(jnp.float32)

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = eg.Model(
        module=VariationalAutoEncoder(latent_size=LATENT_SIZE),
        loss=[BinaryCrossEntropy(from_logits=True, on="logits")],
        optimizer=optax.adam(1e-3),
        eager=eager,
    )
    assert model.module is not None

    model.summary(X_train[:64])

    # Fit with datasets in memory
    history = model.fit(
        inputs=X_train,
        epochs=epochs,
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir)],
    )

    print(
        "\n\n\nMetrics and images can be explored using tensorboard using:",
        f"\n \t\t\t tensorboard --logdir {logdir}",
    )

    eg.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, len(X_test), size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
        figure = plt.figure(figsize=(12, 12))
        for i in range(5):
            plt.subplot(2, 5, i + 1)
            plt.imshow(x_sample[i], cmap="gray")
            plt.subplot(2, 5, 5 + i + 1)
            plt.imshow(y_pred["det_image"][i], cmap="gray")
        # tbwriter.add_figure("VAE Example", figure, epochs)

    # sample
    model_decoder = eg.Model(model.module.decoder)

    z_samples = np.random.normal(size=(12, LATENT_SIZE))
    samples = model_decoder.predict(z_samples)
    samples = jax.nn.sigmoid(samples)

    # plot and save results
    # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
    figure = plt.figure(figsize=(5, 12))
    plt.title("Generative Samples")
    for i in range(5):
        plt.subplot(2, 5, 2 * i + 1)
        plt.imshow(samples[i], cmap="gray")
        plt.subplot(2, 5, 2 * i + 2)
        plt.imshow(samples[i + 1], cmap="gray")
    # # tbwriter.add_figure("VAE Generative Example", figure, epochs)

    plt.show()