def callback( vae: AmortizedLDA, test: tf.data.Dataset, vocabulary: Dict[int, str], ): print(f"[{type(vae).__name__}]{vae.lda.posterior}-{vae.lda.distribution}") # end of training if not get_current_trainer().is_training: vae.save_weights(overwrite=True) return dict(topics=vae.get_topics_string(vocabulary, n_topics=20), perplexity=vae.perplexity(test, verbose=False))
def main(cfg): save_to_yaml(cfg) if cfg.ds == 'news5': ds = Newsgroup5() elif cfg.ds == 'news20': ds = Newsgroup20() elif cfg.ds == 'news20clean': ds = Newsgroup20_clean() elif cfg.ds == 'cortex': ds = Cortex() elif cfg.ds == 'lkm': ds = LeukemiaATAC() else: raise NotImplementedError(f"No support for dataset: {cfg.ds}") train = ds.create_dataset(batch_size=batch_size, partition='train', drop_remainder=True) valid = ds.create_dataset(batch_size=batch_size, partition='valid') test = ds.create_dataset(batch_size=batch_size, partition='test') n_words = ds.vocabulary_size vocabulary = ds.vocabulary ######## prepare the path output_dir = get_output_dir() if not os.path.exists(output_dir): os.makedirs(output_dir) model_path = os.path.join(output_dir, 'model') if cfg.override: clean_folder(output_dir, verbose=True) ######### preparing all layers lda = LatentDirichletDecoder( posterior=cfg.posterior, distribution=cfg.distribution, n_words=n_words, n_topics=cfg.n_topics, warmup=cfg.warmup, ) fit_kw = dict(train=train, valid=valid, max_iter=cfg.n_iter, optimizer='adam', learning_rate=learning_rate, batch_size=batch_size, valid_freq=valid_freq, compile_graph=True, logdir=output_dir, skip_fitted=True) output_dist = RVconf( n_words, cfg.distribution, projection=True, preactivation='softmax' if cfg.distribution == 'onehot' else 'linear', kwargs=dict(probs_input=True) if cfg.distribution == 'onehot' else {}, name="Words") latent_dist = RVconf(cfg.n_topics, 'mvndiag', projection=True, name="Latents") ######## AmortizedLDA if cfg.model == 'lda': vae = AmortizedLDA(lda=lda, encoder=NetConf([300, 300, 300], name='Encoder'), decoder='identity', latents='identity', path=model_path) vae.fit(on_valid_end=partial(callback, vae=vae, test=test, vocabulary=vocabulary), **fit_kw) ######## VDA - Variational Dirichlet Autoencoder elif cfg.model == 'vda': vae = BetaVAE( beta=cfg.beta, encoder=NetConf([300, 150], name='Encoder'), decoder=NetConf([150, 300], name='Decoder'), latents=RVconf(cfg.n_topics, 'dirichlet', projection=True, prior=None, name="Topics"), outputs=output_dist, # important, MCMC KL for Dirichlet is very unstable analytic=True, path=model_path, name="VDA") vae.fit(on_valid_end=partial(callback1, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, valid_freq=1000, optimizer=tf.optimizers.Adam(learning_rate=1e-4))) ######## VAE elif cfg.model == 'model': vae = BetaVAE(beta=cfg.beta, encoder=NetConf([300, 300], name='Encoder'), decoder=NetConf([300], name='Decoder'), latents=latent_dist, outputs=output_dist, path=model_path, name="VAE") callback1(vae, test, vocabulary) vae.fit(on_valid_end=partial(callback1, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, valid_freq=1000, optimizer=tf.optimizers.Adam(learning_rate=1e-4))) ######## factorVAE elif cfg.model == 'fvae': vae = FactorVAE(gamma=6.0, beta=cfg.beta, encoder=NetConf([300, 150], name='Encoder'), decoder=NetConf([150, 300], name='Decoder'), latents=latent_dist, outputs=output_dist, path=model_path) vae.fit(on_valid_end=partial(callback1, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, valid_freq=1000, optimizer=[ tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999), tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9) ])) ######## TwoStageLDA elif cfg.model == 'lda2': vae0_iter = 10000 vae0 = BetaVAE(beta=1.0, encoder=NetConf(units=[300], name='Encoder'), decoder=NetConf(units=[300, 300], name='Decoder'), outputs=DistributionDense( (n_words, ), posterior='onehot', posterior_kwargs=dict(probs_input=True), activation='softmax', name="Words"), latents=RVconf(cfg.n_topics, 'mvndiag', projection=True, name="Latents"), input_shape=(n_words, ), path=model_path + '_vae0') vae0.fit(on_valid_end=lambda: None if get_current_trainer().is_training else vae0.save_weights(), **dict(fit_kw, logdir=output_dir + "_vae0", max_iter=vae0_iter, learning_rate=learning_rate, track_gradients=False)) vae = TwoStageLDA(lda=lda, encoder=vae0.encoder, decoder=vae0.decoder, latents=vae0.latent_layers, warmup=cfg.warmup - vae0_iter, path=model_path) vae.fit(on_valid_end=partial(callback, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, max_iter=cfg.n_iter - vae0_iter, track_gradients=False)) ######## EM-LDA elif cfg.model == 'em': if os.path.exists(model_path): with open(model_path, 'rb') as f: lda = pickle.load(f) else: writer = tf.summary.create_file_writer(output_dir) lda = LatentDirichletAllocation(n_components=cfg.n_topics, doc_topic_prior=0.7, learning_method='online', verbose=True, n_jobs=4, random_state=1) with writer.as_default(): prog = tqdm(train.repeat(-1), desc="Fitting LDA") for n_iter, x in enumerate(prog): lda.partial_fit(x) if n_iter % 500 == 0: text = get_topics_text(lda.components_, vocabulary) perp = lda.perplexity(test) tf.summary.text("topics", text, n_iter) tf.summary.scalar("perplexity", perp, n_iter) prog.write(f"[#{n_iter}]Perplexity: {perp:.2f}") prog.write("\n".join(text)) if n_iter >= 20000: break with open(model_path, 'wb') as f: pickle.dump(lda, f) # final evaluation text = get_topics_text(lda, vocabulary) final_score = lda.perplexity(data['test']) tf.summary.scalar("perplexity", final_score, step=n_iter + 1) print(f"Perplexity:", final_score) print("\n".join(text))