def evaluate(args: Arguments): np.random.seed(1) tf.random.set_seed(1) path, model_path = get_path(args) if not os.path.exists(model_path + '.index'): return None ds = MNIST() model = BetaGammaVAE(**get_dense_networks(args), gamma=float(args.gamma), beta=float(args.beta), name=f'Z{args.zdim}B{args.beta}G{args.gamma}'.replace( '.', '')) model.build(ds.full_shape) model.load_weights(model_path, raise_notfound=True, verbose=True) # test = ds.create_dataset('test', batch_size=32) for x in test.take(1): px, qz = model(x, training=False) x = prepare_images(px.mean().numpy(), True)[0] llk = tf.reduce_mean( tf.concat( [model(x, training=False)[0].log_prob(x) for x in test.take(200)], 0)).numpy() return dict(beta=args.beta, gamma=args.gamma, zdim=args.zdim, finetune=args.finetune, step=model.step.numpy(), llk=llk, image=x)
def training(job: Job): np.random.seed(1) tf.random.set_seed(1) path = get_path(job) exist_files = glob.glob(f'{path}*') if OVERWRITE: for f in exist_files: if os.path.isdir(f): shutil.rmtree(f) else: os.remove(f) print('Remove:', f) os.makedirs(path) elif len(exist_files) > 1: print('Skip training:', job) return ds = MNIST() train = ds.create_dataset('train', batch_size=32) vae = BetaGammaVAE(beta=job.beta, gamma=job.gamma, **networks(job.zdim)) vae.build(ds.full_shape) vae.fit(train, learning_rate=1e-3, max_iter=80000, logdir=path, skip_fitted=True) vae.save_weights(path, overwrite=True)
def main(cfg: dict): assert cfg.vae is not None, 'model is not provided in the configuration' outdir = get_output_dir() # load dataset ds = MNIST() shape = ds.shape train = ds.create_dataset(partition='train', batch_size=batch_size, shuffle=2000, drop_remainder=True) valid = ds.create_dataset(partition='valid', batch_size=batch_size) x_test, y_test = ds.numpy(partition='test', shuffle=None, label_percent=True) y_test = ds.labels[np.argmax(y_test, axis=-1)] # create the model vae_class = vi.get_vae(cfg.vae) for i, (posterior, prior, activation) in enumerate(product(*posteriors_info)): name = f"{posterior}_{prior.name}_{activation}" path = os.path.join(outdir, name) if not os.path.exists(path): os.makedirs(path) model_path = os.path.join(path, 'model') vae = vae_class(encoder=encoder.create_network(), decoder=decoder.create_network(), observation=vi.RVconf(shape, 'bernoulli', projection=True, name='Image'), latents=vi.RVconf( encoded_size, posterior, projection=True, prior=prior, kwargs=dict(scale_activation=activation), name='Latents'), analytic=False, path=model_path, name=name) vae.build((None, ) + shape) vae.load_weights() vae.fit(train=train, valid=valid, max_iter=max_iter, valid_freq=1000, compile_graph=True, skip_fitted=True, on_valid_end=partial(callback, vae=vae, x=x_test, y=y_test), logdir=path, track_gradients=True).save_weights()
def test_vae_y(args): ds = MNIST() train_y = ds.create_dataset('train', label_percent=1.0).map(lambda x, y: y) valid_y = ds.create_dataset('valid', label_percent=1.0).map(lambda x, y: (y, y)) gamma, beta = args basedir = os.path.join(root_path, 'vaey') save_path = os.path.join(basedir, f'{gamma}_{beta}') logdir = os.path.join(basedir, f'{gamma}_{beta}_log') vae_y = BetaGammaVAE( encoder=Sequential( [Dense(256, 'relu'), Dense(256, 'relu')], name='Encoder'), decoder=Sequential( [Dense(256, 'relu'), Dense(256, 'relu')], name='Decoder'), latents=MVNDiagLatents(10), observation=DistributionDense([10], posterior='onehot', projection=True, name='Digits'), gamma=gamma, beta=beta) vae_y.build((None, 10)) vae_y.load_weights(save_path) vae_y.fit(train_y, max_iter=20000, logdir=logdir, skip_fitted=True) vae_y.save_weights(save_path) gym = DisentanglementGym(model=vae_y, valid=valid_y) with gym.run_model(partition='valid'): y_true = np.argmax(gym.y_true, -1) y_pred = np.argmax(gym.px_z[0].mode(), -1) acc = accuracy_score(y_true, y_pred) results = dict(acc=acc, llk=gym.log_likelihood()[0], kl=gym.kl_divergence()[0], au=gym.active_units()[0], gamma=gamma, beta=beta) gym.plot_correlation() gym.plot_latents_stats() gym.plot_latents_tsne() gym.save_figures(save_path + '.pdf', verbose=True) return results
def load_vae_eval(job: Job): np.random.seed(1) tf.random.set_seed(1) path = get_path(job) ds = MNIST() vae = BetaGammaVAE(beta=job.beta, gamma=job.gamma, **networks(job.zdim)) vae.build(ds.full_shape) vae.trainable = False try: vae.load_weights(path, verbose=False, raise_notfound=True) except FileNotFoundError: return None, None return ds, vae
def train(args: Arguments): np.random.seed(1) tf.random.set_seed(1) path, model_path = get_path(args) ds = MNIST() model = BetaGammaVAE(**get_dense_networks(args), gamma=float(args.gamma), beta=float(args.beta), name=f'Z{args.zdim}B{args.beta}G{args.gamma}'.replace( '.', '')) model.build(ds.full_shape) print(model) optim1 = tf.optimizers.Adam(learning_rate=5e-4) optim2 = tf.optimizers.Adam(learning_rate=1e-4) # === 0. helper best_llk = [-np.inf, 0] valid = ds.create_dataset('valid') def callback(): llk = tf.reduce_mean( tf.concat([model(x)[0].log_prob(x) for x in valid.take(100)], 0)).numpy() if llk > best_llk[0]: best_llk[0] = llk best_llk[1] = model.step.numpy() model.trainer.print('*Save weights at:', model_path) model.save_weights(model_path, overwrite=True) model.trainer.print( f'Current:{llk:.2f} Best:{best_llk[0]:.2f} Step:{int(best_llk[1])}' ) for k, v in model.last_train_metrics.items(): if '_' == k[0]: print(k, v.shape) # === 1. training train_kw = dict(on_valid_end=callback, valid_interval=30, track_gradients=False) def train_ds(): return ds.create_dataset('train', batch_size=BS) ## two-stage training if args.finetune: initial_weights = [ model.decoder.get_weights(), model.observation.get_weights() ] model.fit(train_ds(), max_iter=MAX_ITER // 2, optimizer=optim1, **train_kw) model.decoder.set_weights(initial_weights[0]) model.observation.set_weights(initial_weights[1]) model.encoder.trainable = False model.latents.trainable = False print('Fine-tuning .....') model.fit(train_ds(), max_iter=MAX_ITER // 2 + MAX_ITER // 4, optimizer=optim2, **train_kw) ## full training else: model.fit(train_ds(), max_iter=MAX_ITER, optimizer=optim1, **train_kw)
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from odin.bay import TwoStageVAE, plot_latent_stats from odin.fuel import MNIST from odin.networks import get_networks from odin import visual as vs from tqdm import tqdm from odin.ml import fast_tsne ds = MNIST() train = ds.create_dataset('train', batch_size=32) valid = ds.create_dataset('valid', batch_size=36, label_percent=1.0, drop_remainder=True) vae = TwoStageVAE(**get_networks(ds.name)) vae.build(ds.full_shape) if True: vae.load_weights('/tmp/twostagevae', verbose=True, raise_notfound=True) else: vae.fit(train, learning_rate=1e-3, max_iter=300000) vae.save_weights('/tmp/twostagevae') exit() Z = [] U = [] Z_hat = [] Y = []
def main(cfg: dict): assert cfg.vae is not None and cfg.beta is not None, \ f'Invalid arguments: {cfg}' if cfg.ds == 'bmnist': ds = BinarizedMNIST() elif cfg.ds == 'mnist': ds = MNIST() elif cfg.ds == 'fmnist': ds = FashionMNIST() else: raise NotImplementedError(f'No support for dataset with name={cfg.ds}') input_shape = ds.shape train = ds.create_dataset(partition='train', batch_size=batch_size) valid = ds.create_dataset(partition='valid', batch_size=batch_size) x_test, y_test = ds.numpy(partition='test', batch_size=batch_size, shuffle=1000, label_percent=1.0) y_test = ds.labels[np.argmax(y_test, axis=-1)] ## create the prior and the network pz = tfd.Sample(tfd.Normal(loc=0, scale=1), sample_shape=encoded_size) z_samples = pz.sample(16) encoder = create_encoder(input_shape) decoder = create_decoder() ## create the model # tfp model API if cfg.vae == 'tfp': encoder.append(tfpl.MultivariateNormalTriL(encoded_size)) encoder = keras.Sequential(encoder, name='encoder') decoder.append(tfpl.IndependentBernoulli(input_shape)) decoder = keras.Sequential(decoder, name="decoder") vae = keras.Model( inputs=encoder.inputs, outputs=[decoder(encoder.outputs[0]), encoder.outputs[0]], name='tfp_vae') # odin model API else: encoder = keras.Sequential(encoder, name='encoder') decoder = keras.Sequential(decoder, name="decoder") vae = get_vae(cfg.vae)( encoder=encoder, decoder=decoder, # latents=tfpl.MultivariateNormalTriL(encoded_size), latents=RVconf(event_shape=(encoded_size, ), posterior='mvntril', projection=False, name="latents"), observation=RVconf(event_shape=input_shape, posterior="bernoulli", projection=False, name="image"), name=f'odin_{cfg.vae}') ### training the model vae.build(input_shape=(None, ) + input_shape) params = vae.trainable_variables opt = tf.optimizers.Adam(learning_rate=learning_rate) def optimize(x, training=None): with tf.GradientTape(watch_accessed_variables=False) as tape: if training: tape.watch(params) px, qz = vae(x, training=training) z = qz._value() kl = tf.reduce_mean(qz.log_prob(z) - pz.log_prob(z), axis=-1) nll = -tf.reduce_mean(px.log_prob(x), axis=-1) loss = nll + cfg.beta * kl if training: grads = tape.gradient(loss, params) grads_params = [(g, p) for g, p in zip(grads, params) if g is not None] opt.apply_gradients(grads_params) grads = { f'_grad/{p.name}': tf.linalg.norm(g) for p, g in grads_params } else: grads = dict() return loss, dict(nll=nll, kl=kl, **grads) def callback(): trainer = get_current_trainer() x, y = x_test[:1000], y_test[:1000] px, qz = vae(x, training=False) # latents qz_mean = tf.reduce_mean(qz.mean(), axis=0) qz_std = tf.reduce_mean(qz.stddev(), axis=0) w = tf.reduce_sum(decoder.trainable_variables[0], axis=(0, 1, 2)) # plot the latents and its weights fig = plt.figure(figsize=(6, 4), dpi=200) ax = plt.gca() l1 = ax.plot(qz_mean, label='mean', linewidth=1.0, linestyle='--', marker='o', markersize=4, color='r', alpha=0.5) l2 = ax.plot(qz_std, label='std', linewidth=1.0, linestyle='--', marker='o', markersize=4, color='g', alpha=0.5) ax1 = ax.twinx() l3 = ax1.plot(w, label='weight', linewidth=1.0, linestyle='--', marker='o', markersize=4, color='b', alpha=0.5) lines = l1 + l2 + l3 labs = [l.get_label() for l in lines] ax.grid(True) ax.legend(lines, labs) img_qz = vs.plot_to_image(fig) # reconstruction fig = plt.figure(figsize=(5, 5), dpi=120) vs.plot_images(np.squeeze(px.mean().numpy()[:25], axis=-1), grids=(5, 5)) img_res = vs.plot_to_image(fig) # latents fig = plt.figure(figsize=(5, 5), dpi=200) z = fast_umap(qz.mean().numpy()) vs.plot_scatter(z, color=y, size=12.0, alpha=0.4) img_umap = vs.plot_to_image(fig) # gradients grads = [(k, v) for k, v in trainer.last_train_metrics.items() if '_grad/' in k] encoder_grad = sum(v for k, v in grads if 'Encoder' in k) decoder_grad = sum(v for k, v in grads if 'Decoder' in k) return dict(reconstruct=img_res, umap=img_umap, latents=img_qz, qz_mean=qz_mean, qz_std=qz_std, w_decoder=w, llk_test=tf.reduce_mean(px.log_prob(x)), encoder_grad=encoder_grad, decoder_grad=decoder_grad) ### Create trainer and fit trainer = Trainer(logdir=get_output_dir()) trainer.fit(train_ds=train.repeat(-1), optimize=optimize, valid_ds=valid, max_iter=max_iter, compile_graph=True, log_tag=f'{cfg.vae}_{cfg.beta}', on_valid_end=callback, valid_freq=1000)
from odin import backend as bk from odin.bay.vi.autoencoder import (ConditionalM2VAE, FactorDiscriminator, ImageNet, create_image_autoencoder) from odin.bay.vi.utils import marginalize_categorical_labels from odin.fuel import MNIST, STL10, CelebA, LegoFaces, Shapes3D, dSprites from odin.networks import (ConditionalEmbedding, ConditionalProjection, RepeaterEmbedding, SkipConnection, get_conditional_embedding, skip_connect) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' tf.random.set_seed(1) np.random.seed(1) ds = MNIST() train = ds.create_dataset(partition='train', inc_labels=0.5) test = ds.create_dataset(partition='test', inc_labels=True) encoder, decoder = create_image_autoencoder(image_shape=(28, 28, 1), input_shape=(28, 28, 2), center0=True, latent_shape=20) vae = ConditionalM2VAE(encoder=encoder, decoder=decoder, conditional_embedding='embed', alpha=0.1 * 10) vae.fit(train, compile_graph=True, epochs=-1, max_iter=8000, sample_shape=()) x = vae.sample_data(16) y = vae.sample_labels(16) m = tf.cast(
# =========================================================================== if args.ds == "cortex": sc = Cortex() elif args.ds == "embryo": sc = HumanEmbryos() elif args.ds == "pbmc5k": sc = PBMC('5k') elif args.ds == "pbmc10k": sc = PBMC('10k') elif args.ds == "news20": sc = Newsgroup20() elif args.ds == "news5": sc = Newsgroup5() elif args.ds == "mnist": raise NotImplementedError sc = MNIST() else: raise NotImplementedError(args.ds) shape = sc.shape train = sc.create_dataset(batch_size=batch_size, partition='train') valid = sc.create_dataset(batch_size=batch_size, partition='valid') test = sc.create_dataset(batch_size=batch_size, partition='test') # concat both valid and test to final evaluation set X_test = [] y_test = [] for x, y in sc.create_dataset(batch_size=batch_size, partition='valid', inc_labels=True).concatenate( sc.create_dataset(batch_size=batch_size, partition='test',