def main(cfg: dict): assert cfg.vae is not None, 'model is not provided in the configuration' outdir = get_output_dir() # load dataset ds = MNIST() shape = ds.shape train = ds.create_dataset(partition='train', batch_size=batch_size, shuffle=2000, drop_remainder=True) valid = ds.create_dataset(partition='valid', batch_size=batch_size) x_test, y_test = ds.numpy(partition='test', shuffle=None, label_percent=True) y_test = ds.labels[np.argmax(y_test, axis=-1)] # create the model vae_class = vi.get_vae(cfg.vae) for i, (posterior, prior, activation) in enumerate(product(*posteriors_info)): name = f"{posterior}_{prior.name}_{activation}" path = os.path.join(outdir, name) if not os.path.exists(path): os.makedirs(path) model_path = os.path.join(path, 'model') vae = vae_class(encoder=encoder.create_network(), decoder=decoder.create_network(), observation=vi.RVconf(shape, 'bernoulli', projection=True, name='Image'), latents=vi.RVconf( encoded_size, posterior, projection=True, prior=prior, kwargs=dict(scale_activation=activation), name='Latents'), analytic=False, path=model_path, name=name) vae.build((None, ) + shape) vae.load_weights() vae.fit(train=train, valid=valid, max_iter=max_iter, valid_freq=1000, compile_graph=True, skip_fitted=True, on_valid_end=partial(callback, vae=vae, x=x_test, y=y_test), logdir=path, track_gradients=True).save_weights()
def main(cfg: dict): assert cfg.vae is not None and cfg.beta is not None, \ f'Invalid arguments: {cfg}' if cfg.ds == 'mnist': ds = MNIST() elif cfg.ds == 'fmnist': ds = FashionMNIST() else: raise NotImplementedError(f'No support for dataset with name={cfg.ds}') input_shape = ds.shape train = ds.create_dataset(partition='train', batch_size=batch_size) valid = ds.create_dataset(partition='valid', batch_size=batch_size) x_test, y_test = ds.numpy(partition='test', batch_size=batch_size, shuffle=1000, inc_labels=1.0) y_test = ds.labels[np.argmax(y_test, axis=-1)] ## create the prior and the network pz = tfd.Sample(tfd.Normal(loc=0, scale=1), sample_shape=encoded_size) z_samples = pz.sample(16) encoder = create_encoder(input_shape) decoder = create_decoder() ## create the model # tfp model API if cfg.vae == 'tfp': encoder.append(tfpl.MultivariateNormalTriL(encoded_size)) encoder = keras.Sequential(encoder, name='Encoder') decoder.append(tfpl.IndependentBernoulli(input_shape)) decoder = keras.Sequential(decoder, name="Decoder") vae = keras.Model( inputs=encoder.inputs, outputs=[decoder(encoder.outputs[0]), encoder.outputs[0]], name='VAE') # odin model API else: encoder = keras.Sequential(encoder, name='Encoder') decoder = keras.Sequential(decoder, name="Decoder") vae = get_vae(cfg.vae)( encoder=encoder, decoder=decoder, # latents=tfpl.MultivariateNormalTriL(encoded_size), latents=RVmeta(event_shape=(encoded_size, ), posterior='mvntril', projection=False, name="Latent"), observation=RVmeta(event_shape=input_shape, posterior="bernoulli", projection=False, name="Image"), ) ### training the model vae.build(input_shape=(None, ) + input_shape) params = vae.trainable_variables opt = tf.optimizers.Adam(learning_rate=1e-3) def optimize(x, training=None): with tf.GradientTape(watch_accessed_variables=False) as tape: if training: tape.watch(params) px, qz = vae(x, training=training) z = qz._value() kl = tf.reduce_mean(qz.log_prob(z) - pz.log_prob(z), axis=-1) nll = -tf.reduce_mean(px.log_prob(x), axis=-1) loss = nll + cfg.beta * kl if training: grads = tape.gradient(loss, params) grads_params = [(g, p) for g, p in zip(grads, params) if g is not None] opt.apply_gradients(grads_params) grads = { f'_grad/{p.name}': tf.linalg.norm(g) for p, g in grads_params } else: grads = dict() return loss, dict(nll=nll, kl=kl, **grads) def callback(): trainer = get_current_trainer() x, y = x_test[:1000], y_test[:1000] px, qz = vae(x, training=False) # latents qz_mean = tf.reduce_mean(qz.mean(), axis=0) qz_std = tf.reduce_mean(qz.stddev(), axis=0) w = tf.reduce_sum(decoder.trainable_variables[0], axis=(0, 1, 2)) # plot the latents and its weights fig = plt.figure(figsize=(6, 4), dpi=200) ax = plt.gca() l1 = ax.plot(qz_mean, label='mean', linewidth=1.0, linestyle='--', marker='o', markersize=4, color='r', alpha=0.5) l2 = ax.plot(qz_std, label='std', linewidth=1.0, linestyle='--', marker='o', markersize=4, color='g', alpha=0.5) ax1 = ax.twinx() l3 = ax1.plot(w, label='weight', linewidth=1.0, linestyle='--', marker='o', markersize=4, color='b', alpha=0.5) lines = l1 + l2 + l3 labs = [l.get_label() for l in lines] ax.grid(True) ax.legend(lines, labs) img_qz = vs.plot_to_image(fig) # reconstruction fig = plt.figure(figsize=(5, 5), dpi=120) vs.plot_images(np.squeeze(px.mean().numpy()[:25], axis=-1), grids=(5, 5)) img_res = vs.plot_to_image(fig) # latents fig = plt.figure(figsize=(5, 5), dpi=200) z = fast_umap(qz.mean().numpy()) vs.plot_scatter(z, color=y, size=12.0, alpha=0.4) img_umap = vs.plot_to_image(fig) # gradients grads = [(k, v) for k, v in trainer.last_metrics.items() if '_grad/' in k] encoder_grad = sum(v for k, v in grads if 'Encoder' in k) decoder_grad = sum(v for k, v in grads if 'Decoder' in k) return dict(reconstruct=img_res, umap=img_umap, latents=img_qz, qz_mean=qz_mean, qz_std=qz_std, w_decoder=w, llk_test=tf.reduce_mean(px.log_prob(x)), encoder_grad=encoder_grad, decoder_grad=decoder_grad) trainer = Trainer(logdir=get_output_dir()) trainer.fit(train_ds=train.repeat(-1), optimize=optimize, valid_ds=valid, max_iter=max_iter, compile_graph=True, log_tag=f'{cfg.vae}_{cfg.beta}', callback=callback, valid_freq=1000)