def train(FLAGS): # learner graph = Train_Graph(FLAGS) graph.build() summary_op, latent_summary_op = Summary.collect_globalVAE_summary(graph, FLAGS) # train #define model saver with tf.name_scope("parameter_count"): total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \ for v in tf.trainable_variables()]) save_vars = tf.global_variables() saver = tf.train.Saver(save_vars, max_to_keep=100) latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "latent"+str(m))) \ for m in range(FLAGS.tex_dim)] sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "globalVAE_Sum"), saver=None, save_summaries_secs=0) #not saved automatically for flexibility with sv.managed_session() as sess: myprint ("Number of total params: {0} \n".format( \ sess.run(total_parameter_count))) if FLAGS.resume_fullmodel: assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index') saver.restore(sess, FLAGS.fullmodel_ckpt) myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt)) myprint ("Start from step {}".format(sess.run(graph.global_step))) myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir)) if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir: print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m") #myprint ("Please make sure that the checkpoint will be saved in the same dir with the resumed model") else: myprint ("Train from scratch") myinput('Press enter to continue') start_time = time.time() step = sess.run(graph.global_step) progbar = Progbar(target=FLAGS.ckpt_steps) #100k while (time.time()-start_time)<FLAGS.max_training_hrs*3600: if sv.should_stop(): break fetches = {"global_step_inc": graph.incr_global_step, "step": graph.global_step, "train_op": graph.train_ops} if step % FLAGS.summaries_steps == 0: fetches["Loss"] = graph.loss fetches["kl_dim"] = graph.latent_loss_dim #dim, fetches['summary'] = summary_op results = sess.run(fetches) progbar.update(step%FLAGS.ckpt_steps) if step % FLAGS.summaries_steps == 0 : print (" Step:%3dk time:%4.4fmin Loss%4.2f " \ %(step/1000, (time.time()-start_time)/60, results['Loss'])) sv.summary_writer.add_summary(results['summary'], step) for m in range(FLAGS.tex_dim): kl_summary = sess.run(latent_summary_op, feed_dict={graph.kl_var: results['kl_dim'][m]}) latent_writers[m].add_summary(kl_summary, step) if step % FLAGS.ckpt_steps == 0: saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'model'), global_step=step) progbar = Progbar(target=FLAGS.ckpt_steps) step = results['step'] myprint("Training completed")