예제 #1
0
def callback(
    vae: AmortizedLDA,
    test: tf.data.Dataset,
    vocabulary: Dict[int, str],
):
    print(f"[{type(vae).__name__}]{vae.lda.posterior}-{vae.lda.distribution}")
    # end of training
    if not get_current_trainer().is_training:
        vae.save_weights(overwrite=True)
    return dict(topics=vae.get_topics_string(vocabulary, n_topics=20),
                perplexity=vae.perplexity(test, verbose=False))
예제 #2
0
def main(cfg):
    save_to_yaml(cfg)
    if cfg.ds == 'news5':
        ds = Newsgroup5()
    elif cfg.ds == 'news20':
        ds = Newsgroup20()
    elif cfg.ds == 'news20clean':
        ds = Newsgroup20_clean()
    elif cfg.ds == 'cortex':
        ds = Cortex()
    elif cfg.ds == 'lkm':
        ds = LeukemiaATAC()
    else:
        raise NotImplementedError(f"No support for dataset: {cfg.ds}")
    train = ds.create_dataset(batch_size=batch_size,
                              partition='train',
                              drop_remainder=True)
    valid = ds.create_dataset(batch_size=batch_size, partition='valid')
    test = ds.create_dataset(batch_size=batch_size, partition='test')
    n_words = ds.vocabulary_size
    vocabulary = ds.vocabulary
    ######## prepare the path
    output_dir = get_output_dir()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model_path = os.path.join(output_dir, 'model')
    if cfg.override:
        clean_folder(output_dir, verbose=True)
    ######### preparing all layers
    lda = LatentDirichletDecoder(
        posterior=cfg.posterior,
        distribution=cfg.distribution,
        n_words=n_words,
        n_topics=cfg.n_topics,
        warmup=cfg.warmup,
    )
    fit_kw = dict(train=train,
                  valid=valid,
                  max_iter=cfg.n_iter,
                  optimizer='adam',
                  learning_rate=learning_rate,
                  batch_size=batch_size,
                  valid_freq=valid_freq,
                  compile_graph=True,
                  logdir=output_dir,
                  skip_fitted=True)
    output_dist = RVconf(
        n_words,
        cfg.distribution,
        projection=True,
        preactivation='softmax' if cfg.distribution == 'onehot' else 'linear',
        kwargs=dict(probs_input=True) if cfg.distribution == 'onehot' else {},
        name="Words")
    latent_dist = RVconf(cfg.n_topics,
                         'mvndiag',
                         projection=True,
                         name="Latents")
    ######## AmortizedLDA
    if cfg.model == 'lda':
        vae = AmortizedLDA(lda=lda,
                           encoder=NetConf([300, 300, 300], name='Encoder'),
                           decoder='identity',
                           latents='identity',
                           path=model_path)
        vae.fit(on_valid_end=partial(callback,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **fit_kw)
    ######## VDA - Variational Dirichlet Autoencoder
    elif cfg.model == 'vda':
        vae = BetaVAE(
            beta=cfg.beta,
            encoder=NetConf([300, 150], name='Encoder'),
            decoder=NetConf([150, 300], name='Decoder'),
            latents=RVconf(cfg.n_topics,
                           'dirichlet',
                           projection=True,
                           prior=None,
                           name="Topics"),
            outputs=output_dist,
            # important, MCMC KL for Dirichlet is very unstable
            analytic=True,
            path=model_path,
            name="VDA")
        vae.fit(on_valid_end=partial(callback1,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       valid_freq=1000,
                       optimizer=tf.optimizers.Adam(learning_rate=1e-4)))
    ######## VAE
    elif cfg.model == 'model':
        vae = BetaVAE(beta=cfg.beta,
                      encoder=NetConf([300, 300], name='Encoder'),
                      decoder=NetConf([300], name='Decoder'),
                      latents=latent_dist,
                      outputs=output_dist,
                      path=model_path,
                      name="VAE")
        callback1(vae, test, vocabulary)
        vae.fit(on_valid_end=partial(callback1,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       valid_freq=1000,
                       optimizer=tf.optimizers.Adam(learning_rate=1e-4)))
    ######## factorVAE
    elif cfg.model == 'fvae':
        vae = FactorVAE(gamma=6.0,
                        beta=cfg.beta,
                        encoder=NetConf([300, 150], name='Encoder'),
                        decoder=NetConf([150, 300], name='Decoder'),
                        latents=latent_dist,
                        outputs=output_dist,
                        path=model_path)
        vae.fit(on_valid_end=partial(callback1,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       valid_freq=1000,
                       optimizer=[
                           tf.optimizers.Adam(learning_rate=1e-4,
                                              beta_1=0.9,
                                              beta_2=0.999),
                           tf.optimizers.Adam(learning_rate=1e-4,
                                              beta_1=0.5,
                                              beta_2=0.9)
                       ]))
    ######## TwoStageLDA
    elif cfg.model == 'lda2':
        vae0_iter = 10000
        vae0 = BetaVAE(beta=1.0,
                       encoder=NetConf(units=[300], name='Encoder'),
                       decoder=NetConf(units=[300, 300], name='Decoder'),
                       outputs=DistributionDense(
                           (n_words, ),
                           posterior='onehot',
                           posterior_kwargs=dict(probs_input=True),
                           activation='softmax',
                           name="Words"),
                       latents=RVconf(cfg.n_topics,
                                      'mvndiag',
                                      projection=True,
                                      name="Latents"),
                       input_shape=(n_words, ),
                       path=model_path + '_vae0')
        vae0.fit(on_valid_end=lambda: None
                 if get_current_trainer().is_training else vae0.save_weights(),
                 **dict(fit_kw,
                        logdir=output_dir + "_vae0",
                        max_iter=vae0_iter,
                        learning_rate=learning_rate,
                        track_gradients=False))
        vae = TwoStageLDA(lda=lda,
                          encoder=vae0.encoder,
                          decoder=vae0.decoder,
                          latents=vae0.latent_layers,
                          warmup=cfg.warmup - vae0_iter,
                          path=model_path)
        vae.fit(on_valid_end=partial(callback,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       max_iter=cfg.n_iter - vae0_iter,
                       track_gradients=False))
    ######## EM-LDA
    elif cfg.model == 'em':
        if os.path.exists(model_path):
            with open(model_path, 'rb') as f:
                lda = pickle.load(f)
        else:
            writer = tf.summary.create_file_writer(output_dir)
            lda = LatentDirichletAllocation(n_components=cfg.n_topics,
                                            doc_topic_prior=0.7,
                                            learning_method='online',
                                            verbose=True,
                                            n_jobs=4,
                                            random_state=1)
            with writer.as_default():
                prog = tqdm(train.repeat(-1), desc="Fitting LDA")
                for n_iter, x in enumerate(prog):
                    lda.partial_fit(x)
                    if n_iter % 500 == 0:
                        text = get_topics_text(lda.components_, vocabulary)
                        perp = lda.perplexity(test)
                        tf.summary.text("topics", text, n_iter)
                        tf.summary.scalar("perplexity", perp, n_iter)
                        prog.write(f"[#{n_iter}]Perplexity: {perp:.2f}")
                        prog.write("\n".join(text))
                    if n_iter >= 20000:
                        break
            with open(model_path, 'wb') as f:
                pickle.dump(lda, f)
        # final evaluation
        text = get_topics_text(lda, vocabulary)
        final_score = lda.perplexity(data['test'])
        tf.summary.scalar("perplexity", final_score, step=n_iter + 1)
        print(f"Perplexity:", final_score)
        print("\n".join(text))