def train_model(create_model, args, classification=False, image=False, trainer_fun=None): assert args.save_path.endswith(".pickle") == False init_key = random.PRNGKey(args.init_key_seed) train_key = random.PRNGKey(args.train_key_seed) eval_key = random.PRNGKey(args.eval_key_seed) train_ds, get_test_ds = get_dataset( args.dataset, args.batch_size, args.n_batches, args.test_batch_size, args.test_n_batches, quantize_bits=args.quantize_bits, classification=classification, label_keep_percent=args.label_keep_percent, random_label_percent=args.random_label_percent) doubly_batched_inputs = next(train_ds) inputs = jax.tree_map(lambda x: x[0], doubly_batched_inputs) flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, )) print("n_params", flow.n_params) # Make sure that the save_path folder exists pathlib.Path(args.save_path).mkdir(parents=True, exist_ok=True) trainer = initialize_trainer(flow, clip=args.clip, lr=args.lr, warmup=args.warmup, cosine_decay_steps=args.cosine_decay_steps, save_path=args.save_path, retrain=args.retrain, classification=classification, image=image, trainer_fun=trainer_fun) return train(train_key, eval_key, trainer, train_ds, get_test_ds, max_iters=args.max_iters, save_path=args.save_path, eval_interval=args.eval_interval, image=image, classification=classification)
def evaluate_image_model(create_model, args, classification=False): assert args.save_path.endswith(".pickle") == False init_key = random.PRNGKey(args.init_key_seed) train_key = random.PRNGKey(args.train_key_seed) eval_key = random.PRNGKey(args.eval_key_seed) train_ds, get_test_ds = get_dataset(args.dataset, args.batch_size, args.n_batches, args.test_batch_size, args.test_n_batches, quantize_bits=args.quantize_bits, classification=classification, label_keep_percent=1.0, random_label_percent=0.0) doubly_batched_inputs = next(train_ds) inputs = {"x": doubly_batched_inputs["x"][0]} if "y" in doubly_batched_inputs: inputs["y"] = doubly_batched_inputs["y"][0] flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, )) print("n_params", flow.n_params) # Evaluate the test set trainer = initialize_trainer(flow, clip=args.clip, lr=args.lr, warmup=args.warmup, cosine_decay_steps=args.cosine_decay_steps, save_path=args.save_path, retrain=args.retrain, train_args=args.train_args, classification=classification) # Generate reconstructions outputs = flow.apply(init_key, inputs, is_training=False) outputs["x"] += random.normal(init_key, outputs["x"].shape) reconstr = flow.reconstruct(init_key, outputs, generate_image=True) # Plot the reconstructions fig, axes = plt.subplots(4, 12) axes = axes.ravel() for i, ax in enumerate(axes[:8]): ax.imshow(reconstr["image"][i].squeeze()) # Generate some interpolations interp = jax.vmap(partial(util.spherical_interpolation, N=4))(outputs["x"][:4], outputs["x"][4:8]) interp = interp.reshape((-1, ) + flow.latent_shape) interpolations = flow.reconstruct(init_key, {"x": interp}, generate_image=True) for i, ax in enumerate(axes[8:16]): ax.imshow(interpolations["image"][i].squeeze()) # Generate samples samples = flow.sample(eval_key, n_samples=axes.size - 16, generate_image=True) for i, ax in enumerate(axes[16:]): ax.imshow(samples["image"][i].squeeze()) plt.show() import pdb pdb.set_trace() test_losses = sorted(trainer.test_losses.items(), key=lambda x: x[0]) test_losses = jnp.array(test_losses) test_ds = get_test_ds() res = trainer.evaluate_test(eval_key, test_ds, bits_per_dim=True) print("test", trainer.summarize_losses_and_aux(res)) import pdb pdb.set_trace()
return nux.sequential(Dense(), ShiftScale()) def create_fun(should_repeat=True, n_repeats=2): if should_repeat: repeated = repeat(block, n_repeats=n_repeats) else: repeated = nux.sequential(*[block() for _ in range(n_repeats)]) return repeated # return sequential(ShiftScale(), # sequential(repeated), # Scale(0.2)) rng = random.PRNGKey(1) x = random.normal(rng, (10, 7, 3)) inputs = {"x": x[0]} flow = nux.Flow(create_fun, rng, inputs, batch_axes=(0, )) outputs1 = flow.apply(rng, inputs) outputs2 = flow.apply(rng, inputs, no_scan=True) doubly_batched_inputs = {"x": x} trainer = nux.MaximumLikelihoodTrainer(flow) trainer.grad_step(rng, inputs) trainer.grad_step_for_loop(rng, doubly_batched_inputs) trainer.grad_step_scan_loop(rng, doubly_batched_inputs) rng = random.PRNGKey(1) x = random.normal(rng, (10, 3)) inputs = {"x": x} flow = nux.Flow(partial(create_fun, should_repeat=False),
def evaluate_2d_model(create_model, args, classification=False): assert args.save_path.endswith(".pickle") == False init_key = random.PRNGKey(args.init_key_seed) train_key = random.PRNGKey(args.train_key_seed) eval_key = random.PRNGKey(args.eval_key_seed) train_ds, get_test_ds = get_dataset(args.dataset, args.batch_size, args.n_batches, args.test_batch_size, args.test_n_batches, quantize_bits=args.quantize_bits, classification=classification, label_keep_percent=1.0, random_label_percent=0.0) doubly_batched_inputs = next(train_ds) inputs = {"x": doubly_batched_inputs["x"][0]} if "y" in doubly_batched_inputs: inputs["y"] = doubly_batched_inputs["y"][0] flow = nux.Flow(create_model, init_key, inputs, batch_axes=(0, )) outputs = flow.apply(init_key, inputs) print("n_params", flow.n_params) trainer = initialize_trainer(flow, clip=args.clip, lr=args.lr, warmup=args.warmup, cosine_decay_steps=args.cosine_decay_steps, save_path=args.save_path, retrain=args.retrain, train_args=args.train_args, classification=classification) test_losses = sorted(trainer.test_losses.items(), key=lambda x: x[0]) test_losses = jnp.array(test_losses) test_ds = get_test_ds() res = trainer.evaluate_test(eval_key, test_ds) print("test", trainer.summarize_losses_and_aux(res)) # Plot samples samples = flow.sample(eval_key, n_samples=5000, manifold_sample=True) # Find the spread of the data data = doubly_batched_inputs["x"].reshape((-1, 2)) (xmin, ymin), (xmax, ymax) = data.min(axis=0), data.max(axis=0) xspread, yspread = xmax - xmin, ymax - ymin xmin -= 0.25 * xspread xmax += 0.25 * xspread ymin -= 0.25 * yspread ymax += 0.25 * yspread # Plot the samples against the true samples and also a dentisy plot if "prediction" in samples: fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(28, 7)) ax1.scatter(*data.T) ax1.set_title("True Samples") ax2.scatter(*samples["x"].T, alpha=0.2, s=3, c=samples["prediction"]) ax2.set_title("Learned Samples") ax1.set_xlim(xmin, xmax) ax1.set_ylim(ymin, ymax) ax2.set_xlim(xmin, xmax) ax2.set_ylim(ymin, ymax) n_importance_samples = 100 x_range, y_range = jnp.linspace(xmin, xmax, 100), jnp.linspace(ymin, ymax, 100) X, Y = jnp.meshgrid(x_range, y_range) XY = jnp.dstack([X, Y]).reshape((-1, 2)) XY = jnp.broadcast_to(XY[None, ...], (n_importance_samples, ) + XY.shape) outputs = flow.scan_apply(eval_key, {"x": XY}) outputs["log_px"] = jax.scipy.special.logsumexp( outputs["log_px"], axis=0) - jnp.log(n_importance_samples) outputs["prediction"] = jnp.mean(outputs["prediction"], axis=0) Z = jnp.exp(outputs["log_px"]) ax3.contourf(X, Y, Z.reshape(X.shape)) ax4.contourf(X, Y, outputs["prediction"].reshape(X.shape)) plt.show() else: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(21, 7)) ax1.scatter(*data.T) ax1.set_title("True Samples") ax2.scatter(*samples["x"].T, alpha=0.2, s=3) ax2.set_title("Learned Samples") ax1.set_xlim(xmin, xmax) ax1.set_ylim(ymin, ymax) ax2.set_xlim(xmin, xmax) ax2.set_ylim(ymin, ymax) n_importance_samples = 100 x_range, y_range = jnp.linspace(xmin, xmax, 100), jnp.linspace(ymin, ymax, 100) X, Y = jnp.meshgrid(x_range, y_range) XY = jnp.dstack([X, Y]).reshape((-1, 2)) XY = jnp.broadcast_to(XY[None, ...], (n_importance_samples, ) + XY.shape) outputs = flow.scan_apply(eval_key, {"x": XY}) outputs["log_px"] = jax.scipy.special.logsumexp( outputs["log_px"], axis=0) - jnp.log(n_importance_samples) Z = jnp.exp(outputs["log_px"]) ax3.contourf(X, Y, Z.reshape(X.shape)) plt.show() assert 0
# squeeze_excite=False, # zero_init=False) flat_flow = nux.sequential( PaddingMultiscaleAndChannel(n_squeeze=2, output_channel=1, create_network=create_network), nux.UnitGaussianPrior()) return flat_flow rng = random.PRNGKey(1) # x = random.normal(rng, (10, 8)) x = random.normal(rng, (10, 8, 8, 3)) inputs = {"x": x} flow = nux.Flow(create_fun, rng, inputs, batch_axes=(0, )) print(f"flow.n_params: {flow.n_params}") def loss(params, state, key, inputs): outputs, _ = flow._apply_fun(params, state, key, inputs) log_px = outputs.get("log_pz", 0.0) + outputs.get("log_det", 0.0) return -log_px.mean() outputs = flow.scan_apply(rng, inputs) samples = flow.sample(rng, n_samples=4) trainer = nux.MaximumLikelihoodTrainer(flow) trainer.grad_step(rng, inputs) trainer.grad_step_for_loop(rng, inputs) trainer.grad_step_scan_loop(rng, inputs)