def train_model(create_model,
                args,
                classification=False,
                image=False,
                trainer_fun=None):
    assert args.save_path.endswith(".pickle") == False

    init_key = random.PRNGKey(args.init_key_seed)
    train_key = random.PRNGKey(args.train_key_seed)
    eval_key = random.PRNGKey(args.eval_key_seed)

    train_ds, get_test_ds = get_dataset(
        args.dataset,
        args.batch_size,
        args.n_batches,
        args.test_batch_size,
        args.test_n_batches,
        quantize_bits=args.quantize_bits,
        classification=classification,
        label_keep_percent=args.label_keep_percent,
        random_label_percent=args.random_label_percent)

    doubly_batched_inputs = next(train_ds)

    inputs = jax.tree_map(lambda x: x[0], doubly_batched_inputs)
    flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, ))

    print("n_params", flow.n_params)

    # Make sure that the save_path folder exists
    pathlib.Path(args.save_path).mkdir(parents=True, exist_ok=True)

    trainer = initialize_trainer(flow,
                                 clip=args.clip,
                                 lr=args.lr,
                                 warmup=args.warmup,
                                 cosine_decay_steps=args.cosine_decay_steps,
                                 save_path=args.save_path,
                                 retrain=args.retrain,
                                 classification=classification,
                                 image=image,
                                 trainer_fun=trainer_fun)

    return train(train_key,
                 eval_key,
                 trainer,
                 train_ds,
                 get_test_ds,
                 max_iters=args.max_iters,
                 save_path=args.save_path,
                 eval_interval=args.eval_interval,
                 image=image,
                 classification=classification)
Beispiel #2
0
def evaluate_image_model(create_model, args, classification=False):
    assert args.save_path.endswith(".pickle") == False

    init_key = random.PRNGKey(args.init_key_seed)
    train_key = random.PRNGKey(args.train_key_seed)
    eval_key = random.PRNGKey(args.eval_key_seed)

    train_ds, get_test_ds = get_dataset(args.dataset,
                                        args.batch_size,
                                        args.n_batches,
                                        args.test_batch_size,
                                        args.test_n_batches,
                                        quantize_bits=args.quantize_bits,
                                        classification=classification,
                                        label_keep_percent=1.0,
                                        random_label_percent=0.0)

    doubly_batched_inputs = next(train_ds)
    inputs = {"x": doubly_batched_inputs["x"][0]}

    if "y" in doubly_batched_inputs:
        inputs["y"] = doubly_batched_inputs["y"][0]

    flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, ))
    print("n_params", flow.n_params)

    # Evaluate the test set
    trainer = initialize_trainer(flow,
                                 clip=args.clip,
                                 lr=args.lr,
                                 warmup=args.warmup,
                                 cosine_decay_steps=args.cosine_decay_steps,
                                 save_path=args.save_path,
                                 retrain=args.retrain,
                                 train_args=args.train_args,
                                 classification=classification)

    # Generate reconstructions
    outputs = flow.apply(init_key, inputs, is_training=False)
    outputs["x"] += random.normal(init_key, outputs["x"].shape)
    reconstr = flow.reconstruct(init_key, outputs, generate_image=True)

    # Plot the reconstructions
    fig, axes = plt.subplots(4, 12)
    axes = axes.ravel()
    for i, ax in enumerate(axes[:8]):
        ax.imshow(reconstr["image"][i].squeeze())

    # Generate some interpolations
    interp = jax.vmap(partial(util.spherical_interpolation,
                              N=4))(outputs["x"][:4], outputs["x"][4:8])
    interp = interp.reshape((-1, ) + flow.latent_shape)
    interpolations = flow.reconstruct(init_key, {"x": interp},
                                      generate_image=True)
    for i, ax in enumerate(axes[8:16]):
        ax.imshow(interpolations["image"][i].squeeze())

    # Generate samples
    samples = flow.sample(eval_key,
                          n_samples=axes.size - 16,
                          generate_image=True)
    for i, ax in enumerate(axes[16:]):
        ax.imshow(samples["image"][i].squeeze())

    plt.show()

    import pdb
    pdb.set_trace()

    test_losses = sorted(trainer.test_losses.items(), key=lambda x: x[0])
    test_losses = jnp.array(test_losses)

    test_ds = get_test_ds()
    res = trainer.evaluate_test(eval_key, test_ds, bits_per_dim=True)
    print("test", trainer.summarize_losses_and_aux(res))

    import pdb
    pdb.set_trace()
Beispiel #3
0
        return nux.sequential(Dense(), ShiftScale())

    def create_fun(should_repeat=True, n_repeats=2):
        if should_repeat:
            repeated = repeat(block, n_repeats=n_repeats)
        else:
            repeated = nux.sequential(*[block() for _ in range(n_repeats)])
        return repeated
        # return sequential(ShiftScale(),
        #                   sequential(repeated),
        #                              Scale(0.2))

    rng = random.PRNGKey(1)
    x = random.normal(rng, (10, 7, 3))
    inputs = {"x": x[0]}
    flow = nux.Flow(create_fun, rng, inputs, batch_axes=(0, ))

    outputs1 = flow.apply(rng, inputs)
    outputs2 = flow.apply(rng, inputs, no_scan=True)

    doubly_batched_inputs = {"x": x}
    trainer = nux.MaximumLikelihoodTrainer(flow)

    trainer.grad_step(rng, inputs)
    trainer.grad_step_for_loop(rng, doubly_batched_inputs)
    trainer.grad_step_scan_loop(rng, doubly_batched_inputs)

    rng = random.PRNGKey(1)
    x = random.normal(rng, (10, 3))
    inputs = {"x": x}
    flow = nux.Flow(partial(create_fun, should_repeat=False),
Beispiel #4
0
def evaluate_2d_model(create_model, args, classification=False):
    assert args.save_path.endswith(".pickle") == False

    init_key = random.PRNGKey(args.init_key_seed)
    train_key = random.PRNGKey(args.train_key_seed)
    eval_key = random.PRNGKey(args.eval_key_seed)

    train_ds, get_test_ds = get_dataset(args.dataset,
                                        args.batch_size,
                                        args.n_batches,
                                        args.test_batch_size,
                                        args.test_n_batches,
                                        quantize_bits=args.quantize_bits,
                                        classification=classification,
                                        label_keep_percent=1.0,
                                        random_label_percent=0.0)

    doubly_batched_inputs = next(train_ds)
    inputs = {"x": doubly_batched_inputs["x"][0]}

    if "y" in doubly_batched_inputs:
        inputs["y"] = doubly_batched_inputs["y"][0]

    flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, ))

    outputs = flow.apply(init_key, inputs)

    print("n_params", flow.n_params)

    trainer = initialize_trainer(flow,
                                 clip=args.clip,
                                 lr=args.lr,
                                 warmup=args.warmup,
                                 cosine_decay_steps=args.cosine_decay_steps,
                                 save_path=args.save_path,
                                 retrain=args.retrain,
                                 train_args=args.train_args,
                                 classification=classification)

    test_losses = sorted(trainer.test_losses.items(), key=lambda x: x[0])
    test_losses = jnp.array(test_losses)

    test_ds = get_test_ds()
    res = trainer.evaluate_test(eval_key, test_ds)
    print("test", trainer.summarize_losses_and_aux(res))

    # Plot samples
    samples = flow.sample(eval_key, n_samples=5000, manifold_sample=True)

    # Find the spread of the data
    data = doubly_batched_inputs["x"].reshape((-1, 2))
    (xmin, ymin), (xmax, ymax) = data.min(axis=0), data.max(axis=0)
    xspread, yspread = xmax - xmin, ymax - ymin
    xmin -= 0.25 * xspread
    xmax += 0.25 * xspread
    ymin -= 0.25 * yspread
    ymax += 0.25 * yspread

    # Plot the samples against the true samples and also a dentisy plot
    if "prediction" in samples:
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(28, 7))
        ax1.scatter(*data.T)
        ax1.set_title("True Samples")
        ax2.scatter(*samples["x"].T, alpha=0.2, s=3, c=samples["prediction"])
        ax2.set_title("Learned Samples")
        ax1.set_xlim(xmin, xmax)
        ax1.set_ylim(ymin, ymax)
        ax2.set_xlim(xmin, xmax)
        ax2.set_ylim(ymin, ymax)

        n_importance_samples = 100
        x_range, y_range = jnp.linspace(xmin, xmax,
                                        100), jnp.linspace(ymin, ymax, 100)
        X, Y = jnp.meshgrid(x_range, y_range)
        XY = jnp.dstack([X, Y]).reshape((-1, 2))
        XY = jnp.broadcast_to(XY[None, ...],
                              (n_importance_samples, ) + XY.shape)
        outputs = flow.scan_apply(eval_key, {"x": XY})
        outputs["log_px"] = jax.scipy.special.logsumexp(
            outputs["log_px"], axis=0) - jnp.log(n_importance_samples)
        outputs["prediction"] = jnp.mean(outputs["prediction"], axis=0)

        Z = jnp.exp(outputs["log_px"])
        ax3.contourf(X, Y, Z.reshape(X.shape))
        ax4.contourf(X, Y, outputs["prediction"].reshape(X.shape))
        plt.show()
    else:
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(21, 7))
        ax1.scatter(*data.T)
        ax1.set_title("True Samples")
        ax2.scatter(*samples["x"].T, alpha=0.2, s=3)
        ax2.set_title("Learned Samples")
        ax1.set_xlim(xmin, xmax)
        ax1.set_ylim(ymin, ymax)
        ax2.set_xlim(xmin, xmax)
        ax2.set_ylim(ymin, ymax)

        n_importance_samples = 100
        x_range, y_range = jnp.linspace(xmin, xmax,
                                        100), jnp.linspace(ymin, ymax, 100)
        X, Y = jnp.meshgrid(x_range, y_range)
        XY = jnp.dstack([X, Y]).reshape((-1, 2))
        XY = jnp.broadcast_to(XY[None, ...],
                              (n_importance_samples, ) + XY.shape)
        outputs = flow.scan_apply(eval_key, {"x": XY})
        outputs["log_px"] = jax.scipy.special.logsumexp(
            outputs["log_px"], axis=0) - jnp.log(n_importance_samples)

        Z = jnp.exp(outputs["log_px"])
        ax3.contourf(X, Y, Z.reshape(X.shape))
        plt.show()

    assert 0
        #                  squeeze_excite=False,
        #                  zero_init=False)

        flat_flow = nux.sequential(
            PaddingMultiscaleAndChannel(n_squeeze=2,
                                        output_channel=1,
                                        create_network=create_network),
            nux.UnitGaussianPrior())
        return flat_flow

    rng = random.PRNGKey(1)
    # x = random.normal(rng, (10, 8))
    x = random.normal(rng, (10, 8, 8, 3))

    inputs = {"x": x}
    flow = nux.Flow(create_fun, rng, inputs, batch_axes=(0, ))
    print(f"flow.n_params: {flow.n_params}")

    def loss(params, state, key, inputs):
        outputs, _ = flow._apply_fun(params, state, key, inputs)
        log_px = outputs.get("log_pz", 0.0) + outputs.get("log_det", 0.0)
        return -log_px.mean()

    outputs = flow.scan_apply(rng, inputs)
    samples = flow.sample(rng, n_samples=4)
    trainer = nux.MaximumLikelihoodTrainer(flow)

    trainer.grad_step(rng, inputs)
    trainer.grad_step_for_loop(rng, inputs)
    trainer.grad_step_scan_loop(rng, inputs)