def train(): global max_f1 with tf.Session(config=util.gpu_config()) as session: session.run(tf.global_variables_initializer()) model.start_enqueue_thread(session) accumulated_loss = 0.0 ckpt = tf.train.get_checkpoint_state(log_dir) if ckpt and ckpt.model_checkpoint_path: print("Restoring from: {}".format(ckpt.model_checkpoint_path)) saver.restore(session, ckpt.model_checkpoint_path) initial_time = time.time() while True: tf_loss, tf_global_step, _ = session.run( [model.loss, model.global_step, model.train_op]) accumulated_loss += tf_loss if tf_global_step % report_frequency == 0: total_time = time.time() - initial_time steps_per_second = tf_global_step / total_time average_loss = accumulated_loss / report_frequency print("[{}] loss={:.2f}, steps/s={:.2f}".format( tf_global_step, average_loss, steps_per_second)) writer.add_summary( util.make_summary({"loss": average_loss}), tf_global_step) accumulated_loss = 0.0 if tf_global_step % save_frequency == 0: saver.save(session, os.path.join(log_dir, "model"), global_step=tf_global_step) if tf_global_step % eval_frequency == 0: eval_summary, eval_f1 = model.evaluate(session) if eval_f1 > max_f1: max_f1 = eval_f1 util.copy_checkpoint( os.path.join(log_dir, "model-{}".format(tf_global_step)), os.path.join(log_dir, "model.max.ckpt")) writer.add_summary(eval_summary, tf_global_step) writer.add_summary( util.make_summary({"max_eval_f1": max_f1}), tf_global_step) print("[{}] evaL_f1={:.2f}, max_f1={:.2f}".format( tf_global_step, eval_f1, max_f1))
) writer.add_summary(util.make_summary({"loss": average_loss}), tf_global_step) accumulated_loss = 0.0 if tf_global_step == 1 or tf_global_step % eval_frequency == 0: eval_summary, eval_f1 = model.evaluate(session) _ = session.run(model.update_max_f1) saver.save(session, os.path.join(log_dir, "model"), global_step=tf_global_step) if eval_f1 > max_f1: max_f1 = eval_f1 util.copy_checkpoint( os.path.join(log_dir, "model-{}".format(tf_global_step)), os.path.join(log_dir, "model.max.ckpt")) writer.add_summary(eval_summary, tf_global_step) writer.add_summary(util.make_summary({"max_eval_f1": max_f1}), tf_global_step) print( f"[{tf_global_step}] evaL_f1={eval_f1:.2f}, max_f1={max_f1:.2f}" ) if tf_global_step >= config['max_step']: print('Training finishes due to reaching max steps') break
def main(): config = util.initialize_from_env() report_frequency = config["report_frequency"] eval_frequency = config["eval_frequency"] model = util.get_model(config) saver = tf.train.Saver() log_dir = config["log_dir"] max_steps = config['num_epochs'] * config['num_docs'] writer = tf.summary.FileWriter(log_dir, flush_secs=20) max_f1 = 0 mode = 'w' with tf.Session() as session: session.run(tf.global_variables_initializer()) model.start_enqueue_thread(session) accumulated_loss = 0.0 initial_step = 0 ckpt = tf.train.get_checkpoint_state(log_dir) if ckpt and ckpt.model_checkpoint_path: print("Restoring from: {}".format(ckpt.model_checkpoint_path)) saver.restore(session, ckpt.model_checkpoint_path) mode = 'a' initial_step = int( os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) fh = logging.FileHandler(os.path.join(log_dir, 'stdout.log'), mode=mode) fh.setFormatter(logging.Formatter(format)) logger.addHandler(fh) initial_time = time.time() while True: tf_loss, tf_global_step, _ = session.run( [model.loss, model.global_step, model.train_op]) accumulated_loss += tf_loss # print('tf global_step', tf_global_step) if tf_global_step % report_frequency == 0: steps_per_second = (tf_global_step - initial_step) / ( time.time() - initial_time) average_loss = accumulated_loss / report_frequency logger.info("[{}] loss={:.2f}, steps/s={:.2f}".format( tf_global_step, average_loss, steps_per_second)) writer.add_summary(util.make_summary({"loss": average_loss}), tf_global_step) accumulated_loss = 0.0 if tf_global_step % eval_frequency == 0: eval_summary, eval_f1 = model.evaluate(session) if eval_f1 > max_f1: max_f1 = eval_f1 saver.save(session, os.path.join(log_dir, "model"), global_step=tf_global_step) util.copy_checkpoint( os.path.join(log_dir, "model-{}".format(tf_global_step)), os.path.join(log_dir, "model.max.ckpt")) writer.add_summary(eval_summary, tf_global_step) writer.add_summary(util.make_summary({"max_eval_f1": max_f1}), tf_global_step) logger.info("[{}] evaL_f1={:.4f}, max_f1={:.4f}".format( tf_global_step, eval_f1, max_f1)) if tf_global_step > max_steps: break