Esempio n. 1
0
 def callback():
     losses = get_current_trainer().valid_loss
     if losses[-1] <= np.min(losses):
         vae.save_weights(overwrite=True)
     # posterior
     vp = VariationalPosterior(model=vae,
                               inputs=x_samples,
                               groundtruth=GroundTruth(y_samples),
                               n_samples=1000)
     px = as_tuple(vp.outputs)
     qz = as_tuple(vp.latents)
     # store the histogram
     mean = tf.reduce_mean(qz[0].mean(), axis=0)
     std = tf.reduce_mean(qz[0].stddev(), axis=0)
     # show traverse image
     images = np.concatenate([
         vp.traverse(i, min_val=-3, max_val=3, num=21,
                     mode='linear').outputs[0].mean().numpy()
         for i in np.argsort(std)[:20]
     ])
     image_traverse = to_image(images,
                               grids=(20, int(images.shape[0] / 20)))
     # show sampled image
     px = as_tuple(vae.decode(z_samples, training=False))
     image_sampled = to_image(px[0].mean().numpy(), grids=(4, 4))
     return dict(mean=mean,
                 std=std,
                 traverse=image_traverse,
                 sampled=image_sampled)
Esempio n. 2
0
 def callback():
     losses = get_current_trainer().valid_loss
     if losses[-1] <= np.min(losses):
         vae.save_weights(overwrite=True)
     # reconstruction
     px, _ = vae(x_samples, training=True)
     image_reconstructed = to_image(as_tuple(px)[0].mean().numpy(),
                                    grids=(4, 4))
     # latent traverse
     vp = VariationalPosterior(model=vae,
                               inputs=x_samples,
                               groundtruth=GroundTruth(y_samples),
                               n_samples=1000)
     # stats
     mean = tf.reduce_mean(vp.latents.mean(), axis=0)
     std = tf.reduce_mean(vp.latents.stddev(), axis=0)
     w_d = tf.reduce_sum(vae.decoder.trainable_variables[0], axis=-1)
     image_latents = plot_latent_units(mean, std, w_d)
     # show traverse image
     images = np.concatenate([
         vp.traverse(i,
                     min_val=-2,
                     max_val=2,
                     num=21,
                     n_samples=1,
                     mode='linear').outputs[0].mean().numpy()
         for i in np.argsort(std)[:20]
     ])
     image_traverse = to_image(images,
                               grids=(20, int(images.shape[0] / 20)))
     # show sampled image
     px = as_tuple(vae.decode(z_samples, training=False))
     image_sampled = to_image(px[0].mean().numpy(), grids=(4, 4))
     # gradients
     all_grads = [(k, v) for k, v in vae.last_metrics.items()
                  if 'grad/' in k]
     encoder_grad = 0
     decoder_grad = 0
     latents_grad = 0
     if len(all_grads) > 0:
         encoder_grad = sum(v for k, v in all_grads if 'Encoder' in k)
         decoder_grad = sum(v for k, v in all_grads if 'Decoder' in k)
         latents_grad = sum(v for k, v in all_grads if 'Latents' in k)
     # return
     return dict(mean=mean,
                 std=std,
                 w_decode=w_d,
                 encoder_grad=encoder_grad,
                 decoder_grad=decoder_grad,
                 latents_grad=latents_grad,
                 noise_units=np.sum(std > 0.9),
                 reconstructed=image_reconstructed,
                 traverse=image_traverse,
                 sampled=image_sampled,
                 latents=image_latents)
Esempio n. 3
0
def callback(
    vae: AmortizedLDA,
    test: tf.data.Dataset,
    vocabulary: Dict[int, str],
):
    print(f"[{type(vae).__name__}]{vae.lda.posterior}-{vae.lda.distribution}")
    # end of training
    if not get_current_trainer().is_training:
        vae.save_weights(overwrite=True)
    return dict(topics=vae.get_topics_string(vocabulary, n_topics=20),
                perplexity=vae.perplexity(test, verbose=False))
Esempio n. 4
0
def main(cfg):
    save_to_yaml(cfg)
    if cfg.ds == 'news5':
        ds = Newsgroup5()
    elif cfg.ds == 'news20':
        ds = Newsgroup20()
    elif cfg.ds == 'news20clean':
        ds = Newsgroup20_clean()
    elif cfg.ds == 'cortex':
        ds = Cortex()
    elif cfg.ds == 'lkm':
        ds = LeukemiaATAC()
    else:
        raise NotImplementedError(f"No support for dataset: {cfg.ds}")
    train = ds.create_dataset(batch_size=batch_size,
                              partition='train',
                              drop_remainder=True)
    valid = ds.create_dataset(batch_size=batch_size, partition='valid')
    test = ds.create_dataset(batch_size=batch_size, partition='test')
    n_words = ds.vocabulary_size
    vocabulary = ds.vocabulary
    ######## prepare the path
    output_dir = get_output_dir()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model_path = os.path.join(output_dir, 'model')
    if cfg.override:
        clean_folder(output_dir, verbose=True)
    ######### preparing all layers
    lda = LatentDirichletDecoder(
        posterior=cfg.posterior,
        distribution=cfg.distribution,
        n_words=n_words,
        n_topics=cfg.n_topics,
        warmup=cfg.warmup,
    )
    fit_kw = dict(train=train,
                  valid=valid,
                  max_iter=cfg.n_iter,
                  optimizer='adam',
                  learning_rate=learning_rate,
                  batch_size=batch_size,
                  valid_freq=valid_freq,
                  compile_graph=True,
                  logdir=output_dir,
                  skip_fitted=True)
    output_dist = RandomVariable(
        n_words,
        cfg.distribution,
        projection=True,
        preactivation='softmax' if cfg.distribution == 'onehot' else 'linear',
        kwargs=dict(probs_input=True) if cfg.distribution == 'onehot' else {},
        name="Words")
    latent_dist = RandomVariable(cfg.n_topics,
                                 'diag',
                                 projection=True,
                                 name="Latents")
    ######## AmortizedLDA
    if cfg.model == 'lda':
        vae = AmortizedLDA(lda=lda,
                           encoder=NetworkConfig([300, 300, 300],
                                                 name='Encoder'),
                           decoder='identity',
                           latents='identity',
                           path=model_path)
        vae.fit(callback=partial(callback,
                                 vae=vae,
                                 test=test,
                                 vocabulary=vocabulary),
                **fit_kw)
    ######## VDA - Variational Dirichlet Autoencoder
    elif cfg.model == 'vda':
        vae = BetaVAE(
            beta=cfg.beta,
            encoder=NetworkConfig([300, 150], name='Encoder'),
            decoder=NetworkConfig([150, 300], name='Decoder'),
            latents=RandomVariable(cfg.n_topics,
                                   'dirichlet',
                                   projection=True,
                                   prior=None,
                                   name="Topics"),
            outputs=output_dist,
            # important, MCMC KL for Dirichlet is very unstable
            analytic=True,
            path=model_path)
        callback1(vae, test, vocabulary)
        kw = dict(fit_kw)
        del kw['optimizer']
        vae.fit(optimizer=tf.optimizers.Adam(learning_rate=1e-4),
                callback=partial(callback1,
                                 vae=vae,
                                 test=test,
                                 vocabulary=vocabulary),
                **kw)
    ######## VAE
    elif cfg.model == 'bvae':
        vae = BetaVAE(beta=cfg.beta,
                      encoder=NetworkConfig([300, 150], name='Encoder'),
                      decoder=NetworkConfig([150, 300], name='Decoder'),
                      latents=latent_dist,
                      outputs=output_dist,
                      path=model_path)
        kw = dict(fit_kw)
        del kw['optimizer']
        vae.fit(optimizer=tf.optimizers.Adam(learning_rate=1e-4), **kw)
    ######## FactorVAE
    elif cfg.model == 'fvae':
        vae = FactorVAE(gamma=6.0,
                        beta=cfg.beta,
                        encoder=NetworkConfig([300, 150], name='Encoder'),
                        decoder=NetworkConfig([150, 300], name='Decoder'),
                        latents=latent_dist,
                        outputs=output_dist,
                        path=model_path)
        kw = dict(fit_kw)
        del kw['optimizer']
        vae.fit(optimizer=[
            tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999),
            tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9)
        ],
                **kw)
    ######## TwoStageLDA
    elif cfg.model == 'lda2':
        vae0_iter = 15000
        vae0 = BetaVAE(
            beta=10.0,
            encoder=NetworkConfig(units=[300], name='Encoder'),
            decoder=NetworkConfig(units=[300, 300], name='Decoder'),
            outputs=DenseDistribution(
                (n_words, ),
                posterior='nb',
                #  posterior_kwargs=dict(probs_input=True),
                # activation='softmax',
                name="Words"),
            latents=RandomVariable(cfg.n_topics,
                                   'diag',
                                   projection=True,
                                   name="Latents"),
            input_shape=(n_words, ),
            path=model_path + '_vae0')
        vae0.fit(callback=lambda: None
                 if get_current_trainer().is_training else vae0.save_weights(),
                 **dict(fit_kw,
                        logdir=output_dir + "_vae0",
                        max_iter=vae0_iter,
                        learning_rate=learning_rate,
                        track_gradients=True))
        vae = TwoStageLDA(lda=lda,
                          encoder=vae0.encoder,
                          decoder=vae0.decoder,
                          latents=vae0.latent_layers,
                          warmup=cfg.warmup - vae0_iter,
                          path=model_path)
        vae.fit(callback=partial(callback,
                                 vae=vae,
                                 test=test,
                                 vocabulary=vocabulary),
                **dict(fit_kw,
                       max_iter=cfg.n_iter - vae0_iter,
                       track_gradients=True))
    ######## EM-LDA
    elif cfg.model == 'em':
        if os.path.exists(model_path):
            with open(model_path, 'rb') as f:
                lda = pickle.load(f)
        else:
            writer = tf.summary.create_file_writer(output_dir)
            lda = LatentDirichletAllocation(n_components=cfg.n_topics,
                                            doc_topic_prior=0.7,
                                            learning_method='online',
                                            verbose=True,
                                            n_jobs=4,
                                            random_state=1)
            with writer.as_default():
                prog = tqdm(train.repeat(-1), desc="Fitting LDA")
                for n_iter, x in enumerate(prog):
                    lda.partial_fit(x)
                    if n_iter % 500 == 0:
                        text = get_topics_text(lda, vocabulary)
                        perp = lda.perplexity(test)
                        tf.summary.text("topics", text, n_iter)
                        tf.summary.scalar("perplexity", perp, n_iter)
                        prog.write(f"[#{n_iter}]Perplexity: {perp:.2f}")
                        prog.write("\n".join(text))
                    if n_iter >= 20000:
                        break
            with open(model_path, 'wb') as f:
                pickle.dump(lda, f)
        # final evaluation
        text = get_topics_text(lda, vocabulary)
        final_score = lda.perplexity(data['test'])
        tf.summary.scalar("perplexity", final_score, step=n_iter + 1)
        print(f"Perplexity:", final_score)
        print("\n".join(text))
Esempio n. 5
0
def callback(vae: vi.VariationalAutoencoder, x: np.ndarray, y: np.ndarray):
    trainer = get_current_trainer()
    px, qz = [], []
    X_i = []
    for x_i in tf.data.Dataset.from_tensor_slices(x).batch(64):
        _ = vae(x_i, training=False)
        px.append(_[0])
        qz.append(_[1])
        X_i.append(x_i)
    # llk
    llk_test = tf.reduce_mean(
        tf.concat([p.log_prob(x_i) for p, x_i in zip(px, X_i)], axis=0))
    # latents
    qz_mean = tf.reduce_mean(tf.concat([q.mean() for q in qz], axis=0), axis=0)
    qz_std = tf.reduce_mean(tf.concat([q.stddev() for q in qz], axis=0),
                            axis=0)
    w = tf.reduce_sum(vae.decoder.trainable_variables[0], axis=1)
    # plot the latents and its weights
    fig = plt.figure(figsize=(6, 4), dpi=200)
    ax = plt.gca()
    l1 = ax.plot(qz_mean,
                 label='mean',
                 linewidth=1.0,
                 linestyle='--',
                 marker='o',
                 markersize=4,
                 color='r',
                 alpha=0.5)
    l2 = ax.plot(qz_std,
                 label='std',
                 linewidth=1.0,
                 linestyle='--',
                 marker='o',
                 markersize=4,
                 color='g',
                 alpha=0.5)
    ax1 = ax.twinx()
    l3 = ax1.plot(w,
                  label='weight',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='b',
                  alpha=0.5)
    lines = l1 + l2 + l3
    labs = [l.get_label() for l in lines]
    ax.grid(True)
    ax.legend(lines, labs)
    img_qz = vs.plot_to_image(fig)
    # reconstruction
    img = px[10].mean().numpy()
    if img.shape[-1] == 1:
        img = np.squeeze(img, axis=-1)
    fig = plt.figure(figsize=(8, 8), dpi=120)
    vs.plot_images(img, grids=(8, 8))
    img_reconstructed = vs.plot_to_image(fig)
    # latents traverse
    # TODO
    return dict(llk_test=llk_test,
                qz_mean=qz_mean,
                qz_std=qz_std,
                w_decoder=w,
                reconstructed=img_reconstructed,
                latents=img_qz)
Esempio n. 6
0
 def callback():
     trainer = get_current_trainer()
     x, y = x_test[:1000], y_test[:1000]
     px, qz = vae(x, training=False)
     # latents
     qz_mean = tf.reduce_mean(qz.mean(), axis=0)
     qz_std = tf.reduce_mean(qz.stddev(), axis=0)
     w = tf.reduce_sum(decoder.trainable_variables[0], axis=(0, 1, 2))
     # plot the latents and its weights
     fig = plt.figure(figsize=(6, 4), dpi=200)
     ax = plt.gca()
     l1 = ax.plot(qz_mean,
                  label='mean',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='r',
                  alpha=0.5)
     l2 = ax.plot(qz_std,
                  label='std',
                  linewidth=1.0,
                  linestyle='--',
                  marker='o',
                  markersize=4,
                  color='g',
                  alpha=0.5)
     ax1 = ax.twinx()
     l3 = ax1.plot(w,
                   label='weight',
                   linewidth=1.0,
                   linestyle='--',
                   marker='o',
                   markersize=4,
                   color='b',
                   alpha=0.5)
     lines = l1 + l2 + l3
     labs = [l.get_label() for l in lines]
     ax.grid(True)
     ax.legend(lines, labs)
     img_qz = vs.plot_to_image(fig)
     # reconstruction
     fig = plt.figure(figsize=(5, 5), dpi=120)
     vs.plot_images(np.squeeze(px.mean().numpy()[:25], axis=-1),
                    grids=(5, 5))
     img_res = vs.plot_to_image(fig)
     # latents
     fig = plt.figure(figsize=(5, 5), dpi=200)
     z = fast_umap(qz.mean().numpy())
     vs.plot_scatter(z, color=y, size=12.0, alpha=0.4)
     img_umap = vs.plot_to_image(fig)
     # gradients
     grads = [(k, v) for k, v in trainer.last_metrics.items()
              if '_grad/' in k]
     encoder_grad = sum(v for k, v in grads if 'Encoder' in k)
     decoder_grad = sum(v for k, v in grads if 'Decoder' in k)
     return dict(reconstruct=img_res,
                 umap=img_umap,
                 latents=img_qz,
                 qz_mean=qz_mean,
                 qz_std=qz_std,
                 w_decoder=w,
                 llk_test=tf.reduce_mean(px.log_prob(x)),
                 encoder_grad=encoder_grad,
                 decoder_grad=decoder_grad)