def _save_log(self, tf_summary_value, eval_name, value, step, epoch, key): self.logger.info(log_dict.EvalLog(step, epoch, key, eval_name, value)) tf_summary_value.append( tf.Summary.Value( tag="Evaluation/{}/{}".format(key, eval_name), simple_value=value, ))
def log_b3_eval(logger, step, epoch, data, eval_results): names = ['B3-F1', 'B3-F0.5', 'B3-Recall', 'B3-Precision'] for i in range(4): logger.info( log_dict.EvalLog(step, epoch, data, names[i], eval_results[i]))
def train_model(self, sess: tf.Session, saver, writer): opts = self.options g = self.graph # tensorflow.python.framework.ops.Graph checkpoint_dir = os.path.join(opts.output_dir, "checkpoints/checkpoint") # ----------- Restore Graph Ops ---------------------------- train_op = self.train_op loss_tensor = self.loss global_step = tf.train.get_global_step(g) # Merge training summaries all_train_summary = tf.summary.merge_all(tools.TrainSummaries) total_loss_summary = g.get_tensor_by_name("loss_summary/oie_loss:0") total_loss = 0. step_value = 0 data_iter = self.data.train_iter placeholders = self.placeholders skip_step = max(int(len(data_iter) / opts.batch_size / 10), 1) eval_step_interval = int( len(data_iter) * self.eval_epoch_interval) // opts.batch_size if not hasattr(opts, 'alpha_change') or opts.alpha_change == 'None': alpha_change = False else: alpha_change = True with tf.variable_scope('entropy', reuse=True): alpha_tensor = tf.get_variable('alpha') alpha_change_func = self.get_weight_change_func( opts.alpha_init, opts.alpha_final, opts.alpha_change, opts.epochs * len(data_iter)) for data_batch in data_iter: feed_dict = update_feed_dict(placeholders, data_batch) if alpha_change: alpha = alpha_change_func(data_iter.current_idx) feed_dict[alpha_tensor] = max(alpha, 0) summary_op = all_train_summary if ( step_value + 1) % skip_step == 0 else total_loss_summary _, loss_value, step_value, summary = sess.run([ train_op, loss_tensor, global_step, summary_op, ], feed_dict=feed_dict) # add summary writer.add_summary(summary, global_step=step_value) total_loss += loss_value # if (step_value+1) % SKIP_STEP == 0: if step_value >= eval_step_interval and step_value % eval_step_interval == 0: # self.evaluate_model(sess, self.data.valid_iter) ef = 1. ef = min(ef, self.summary_evaluation(sess, "valid", writer)) ef = min( ef, self.summary_evaluation(sess, "test", writer), ) if ef < 1e-9: self.logger.warning('Early stop') break if self.current_epoch != data_iter.epoch: self.logger.info( log_dict.EvalLog(step_value, self.current_epoch, 'train', 'loss', total_loss)) # saver.save(sess, checkpoint_dir, step_value) self.current_epoch = data_iter.epoch total_loss = 0. self.logger.debug("Training Done")