Beispiel #1
0
            def pred_step(self, x, net_states, initializing):
                if initializing:
                    states = elegy.States(net_states=0)
                else:
                    states = elegy.States(net_states=net_states + 1)

                return elegy.PredStep.simple(x + 1.0, states)
Beispiel #2
0
 def test_step(self, metrics_states, initializing):
     return elegy.TestStep(
         loss=0.1,
         logs=dict(loss=1.0),
         states=elegy.States(metrics_states=0) if initializing else
         elegy.States(metrics_states=metrics_states + 1),
     )
Beispiel #3
0
 def train_step(self, states, initializing):
     return elegy.TrainStep(
         logs=dict(loss=2.0),
         states=elegy.States(
             optimizer_states=0) if initializing else elegy.States(
                 optimizer_states=states.optimizer_states + 1),
     )
Beispiel #4
0
            def pred_step(self, x, initializing, states):
                if initializing:
                    states = elegy.States(net_states=0)
                else:
                    states = elegy.States(net_states=states.net_states + 1)

                return elegy.PredStep(x + 1.0, states)
Beispiel #5
0
            def pred_step(self, x, states, initializing):
                if initializing:
                    states = elegy.States(net_states=0)
                else:
                    states = elegy.States(net_states=states.net_states + 1)

                return elegy.PredStep(
                    y_pred=1,
                    states=states,
                )
Beispiel #6
0
            def train_step(self, x, optimizer_states, initializing):
                if initializing:
                    states = elegy.States(optimizer_states=0)
                else:
                    states = elegy.States(optimizer_states=optimizer_states +
                                          1)

                return elegy.TrainStep(
                    logs=dict(loss=jnp.sum(x)),
                    states=states,
                )
Beispiel #7
0
            def test_step(self, x, metrics_states, initializing):
                if initializing:
                    states = elegy.States(metrics_states=0)
                else:
                    states = elegy.States(metrics_states=metrics_states + 1)

                return elegy.TestStep(
                    loss=0.1,
                    logs=dict(loss=jnp.sum(x)),
                    states=states,
                )
Beispiel #8
0
 def states_step(self):
     return elegy.States(
         net_params=None,
         net_states=None,
         metrics_states=None,
         optimizer_states=None,
     )
Beispiel #9
0
            def pred_step(self):
                nonlocal N
                N = N + 1

                return elegy.PredStep.simple(
                    y_pred=None,
                    states=elegy.States(net_params=1, net_states=2),
                )
Beispiel #10
0
    def train_step(self, x, y_true, net_params, optimizer_states, initializing,
                   rng):
        def loss_fn(net_params, x, y_true):
            # flatten + scale
            x = jnp.reshape(x, (x.shape[0], -1)) / 255

            # model
            if initializing:
                w = jax.random.uniform(rng.next(),
                                       shape=[np.prod(x.shape[1:]), 10],
                                       minval=-1,
                                       maxval=1)
                b = jax.random.uniform(rng.next(),
                                       shape=[1],
                                       minval=-1,
                                       maxval=1)
                net_params = (w, b)

            w, b = net_params
            logits = jnp.dot(x, w) + b

            # crossentropy loss
            labels = jax.nn.one_hot(y_true, 10)
            sample_loss = -jnp.sum(labels * jax.nn.log_softmax(logits),
                                   axis=-1)
            loss = jnp.mean(sample_loss)

            # metrics
            logs = dict(
                accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true),
                loss=loss,
            )

            return loss, (logs, net_params)

        # train
        (loss,
         (logs,
          net_params)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
              net_params,  # gradients target
              x,
              y_true,
          )

        if initializing:
            optimizer_states = self.optim.init(net_params)
        else:
            grads, optimizer_states = self.optim.update(
                grads, optimizer_states, net_params)
            net_params = optax.apply_updates(net_params, grads)

        return logs, elegy.States(
            net_params=net_params,
            optimizer_states=optimizer_states,
        )
Beispiel #11
0
    def init_step(self, x):
        rng = elegy.KeySeq(0)
        gx, g_params, g_states = self.generator.init(rng=rng)(x)
        dx, d_params, d_states = self.discriminator.init(rng=rng)(gx)

        g_optimizer_states = self.g_optimizer.init(g_params)
        d_optimizer_states = self.d_optimizer.init(d_params)

        return elegy.States(
            g_states=g_states,
            d_states=d_states,
            g_params=g_params,
            d_params=d_params,
            g_opt_states=g_optimizer_states,
            d_opt_states=d_optimizer_states,
            rng=rng,
            step=0,
        )
Beispiel #12
0
def main(
    steps_per_epoch: int = 200,
    batch_size: int = 64,
    epochs: int = 50,
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get()
    # Now binarize data
    X_train = (X_train > 0).astype(jnp.float32)
    X_test = (X_test > 0).astype(jnp.float32)

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = VariationalAutoEncoder(
        latent_size=LATENT_SIZE,
        optimizer=optax.adam(1e-3),
        run_eagerly=eager,
    )

    # Fit with datasets in memory
    history = model.fit(
        x=X_train,
        epochs=epochs,
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[TensorBoard(logdir)],
    )

    print(
        "\n\n\nMetrics and images can be explored using tensorboard using:",
        f"\n \t\t\t tensorboard --logdir {logdir}",
    )

    elegy.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, len(X_test), size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)
    y_pred = jax.nn.sigmoid(y_pred)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
        figure = plt.figure(figsize=(12, 12))
        for i in range(5):
            plt.subplot(2, 5, i + 1)
            plt.imshow(x_sample[i], cmap="gray")
            plt.subplot(2, 5, 5 + i + 1)
            plt.imshow(y_pred[i], cmap="gray")
        # # tbwriter.add_figure("VAE Example", figure, epochs)

    plt.show()

    # sample
    model_decoder = elegy.Model(
        model.decoder,
        states=elegy.States(
            net_params=model.states.net_params[1],
            net_states=model.states.net_states[1],
        ),
        initialized=True,
    )

    z_samples = np.random.normal(size=(12, LATENT_SIZE))
    logits = model_decoder.predict(z_samples)
    samples = jax.nn.sigmoid(logits)

    # plot and save results
    # with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:
    figure = plt.figure(figsize=(5, 12))
    plt.title("Generative Samples")
    for i in range(5):
        plt.subplot(2, 5, 2 * i + 1)
        plt.imshow(samples[i], cmap="gray")
        plt.subplot(2, 5, 2 * i + 2)
        plt.imshow(samples[i + 1], cmap="gray")
    # # tbwriter.add_figure("VAE Generative Example", figure, epochs)

    plt.show()