def main(argv): data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE(dae, beta=flags.vae_beta) scan = SCAN(dae, vae, beta=flags.scan_beta, lambd=flags.scan_lambda) scan_recomb = SCANRecombinator(dae, vae, scan) dae_saver = CheckPointSaver(flags.checkpoint_dir, "dae", dae.get_vars()) vae_saver = CheckPointSaver(flags.checkpoint_dir, "vae", vae.get_vars()) scan_saver = CheckPointSaver(flags.checkpoint_dir, "scan", scan.get_vars()) scan_recomb_saver = CheckPointSaver(flags.checkpoint_dir, "scan_recomb", scan_recomb.get_vars()) sess = tf.Session() # Initialze variables init = tf.global_variables_initializer() sess.run(init) # For Tensorboard log summary_writer = tf.summary.FileWriter(flags.log_file, sess.graph) # Load from checkpoint dae_saver.load(sess) vae_saver.load(sess) scan_saver.load(sess) scan_recomb_saver.load(sess) # Train if flags.train_dae: train_dae(sess, dae, data_manager, dae_saver, summary_writer) if flags.train_vae: train_vae(sess, vae, data_manager, vae_saver, summary_writer) disentangle_check(sess, vae, data_manager) if flags.train_scan: train_scan(sess, scan, data_manager, scan_saver, summary_writer) sym2img_check(sess, scan, data_manager) img2sym_check(sess, scan, data_manager) if flags.train_scan_recomb: train_scan_recomb(sess, scan_recomb, data_manager, scan_recomb_saver, summary_writer) recombination_check(sess, scan_recomb, data_manager) sess.close()
def main(argv): data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE(dae) scan = SCAN(dae, vae) dae_saver = CheckPointSaver(CHECKPOINT_DIR, "dae", dae.get_vars()) vae_saver = CheckPointSaver(CHECKPOINT_DIR, "vae", vae.get_vars()) scan_saver = CheckPointSaver(CHECKPOINT_DIR, "scan", scan.get_vars()) sess = tf.Session() # Initialze variables init = tf.global_variables_initializer() sess.run(init) # For Tensorboard log summary_writer = tf.summary.FileWriter(LOG_FILE, sess.graph) # Load from checkpoint dae_saver.load(sess) vae_saver.load(sess) scan_saver.load(sess) # Train train_dae(sess, dae, data_manager, dae_saver, summary_writer) train_vae(sess, vae, data_manager, vae_saver, summary_writer) disentangle_check(sess, vae, data_manager) train_scan(sess, scan, data_manager, scan_saver, summary_writer) sym2img_check(sess, scan, data_manager) img2sym_check(sess, scan, data_manager) sess.close()
def test_dae(self): dae = DAE() vars = dae.get_vars() # Check size of optimizing vars self.assertEqual(len(vars), 10 + 10)