def _full_validation(self, vali_data, sess): tflearn_dev.is_training(True) num_batches_vali = FLAGS.num_eval_images // FLAGS.test_batch_size loss_list = [] accuracy_list = [] start_time = time.time() for step_vali in range(num_batches_vali): loss, accuracy = sess.run([self.loss, self.accuracy], feed_dict={self.am_training: False}) #sess.run([self.batch_data, self.batch_labels], # feed_dict={self.am_training: False}) loss_list.append(loss) accuracy_list.append(accuracy) duration = time.time() - start_time #print(duration) vali_loss_value = np.mean(np.array(loss_list)) vali_accuracy_value = np.mean(np.array(accuracy_list)) return vali_loss_value, vali_accuracy_value
def _full_validation(self, sess): """ Validation. After each epoch training, the loss value and accuracy will be calculated for the validation dataset. The output can be used to implement early stopping. """ tflearn.is_training(False, session=sess) num_batches_vali = FLAGS.num_val_images // FLAGS.batch_size loss_list = [] accuracy_list = [] for step_vali in range(num_batches_vali): _, _, loss, accuracy = sess.run( [self.batch_data, self.batch_labels, self.loss, self.accuracy], feed_dict={ self.am_training: False, self.prob_fc: 1, self.prob_conv: 1 }) #feed_dict={self.am_training: False, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: 1}) #accuracy = 0 loss_list.append(loss) accuracy_list.append(accuracy) vali_loss_value = np.mean(np.array(loss_list)) vali_accuracy_value = np.mean(np.array(accuracy_list)) return vali_loss_value, vali_accuracy_value
def _full_validation(self, vali_data, sess): tflearn.is_training(True) num_batches_vali = FLAGS.num_eval_images // FLAGS.train_batch_size loss_list = [] accuracy_list = [] for step_vali in range(num_batches_vali): loss, accuracy = sess.run([self.loss, self.accuracy], feed_dict={self.am_training: False}) loss_list.append(loss) accuracy_list.append(accuracy) vali_loss_value = np.mean(np.array(loss_list)) vali_accuracy_value = np.mean(np.array(accuracy_list)) return vali_loss_value, vali_accuracy_value
def _full_validation(self, sess): tflearn.is_training(False, session=sess) num_batches_vali = FLAGS.num_val_images // FLAGS.batch_size loss_list = [] accuracy_list = [] for step_vali in range(num_batches_vali): _, _, loss, accuracy = sess.run([self.batch_data, self.batch_labels,self.loss, self.accuracy], feed_dict={self.am_training:False, self.prob_fc: 1, self.prob_conv: 1}) #feed_dict={self.am_training: False, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: 1}) loss_list.append(loss) accuracy_list.append(accuracy) vali_loss_value = np.mean(np.array(loss_list)) vali_accuracy_value = np.mean(np.array(accuracy_list)) return vali_loss_value, vali_accuracy_value
def train(self, **kwargs): ops.reset_default_graph() sess = tf.Session(config=self.tf_config) with sess.as_default(): # Data Reading objects train_data = ReadData('train', self.img_size, self.set_id, self.train_range) vali_data = ReadData('validation', self.img_size, self.set_id, self.vali_range) train_batch_data, train_batch_labels = train_data.read_from_files() vali_batch_data, vali_batch_labels = vali_data.read_from_files() self.am_training = tf.placeholder(dtype=bool, shape=()) self.batch_data = tf.cond(self.am_training, lambda: train_batch_data, lambda: vali_batch_data) self.batch_labels = tf.cond(self.am_training, lambda: train_batch_labels, lambda: vali_batch_labels) self._build_graph() self.saver = tf.train.Saver(tf.global_variables()) # Build an initialization operation to run below init = tf.global_variables_initializer() sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # This summary writer object helps write summaries on tensorboard summary_writer = tf.summary.FileWriter(FLAGS.log_dir + self.run_id) summary_writer.add_graph(sess.graph) train_error_list = [] val_error_list = [] print('Start training...') print('----------------------------------') train_steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size report_freq = train_steps_per_epoch train_steps = FLAGS.train_epoch * train_steps_per_epoch durations = [] train_loss_list = [] train_total_loss_list = [] train_accuracy_list = [] best_accuracy = 0 for step in range(train_steps): tflearn.is_training(True) #print('{} step starts'.format(step)) start_time = time.time() _, summary_str, loss_value, total_loss, accuracy = sess.run( [ self.train_op, self.summary_op, self.loss, self.total_loss, self.accuracy ], feed_dict={self.am_training: True}) #sess.run(self.batch_labels,feed_dict={self.am_training: True}) duration = time.time() - start_time #print('{} step starts {}'.format(step, duration)) durations.append(duration) train_loss_list.append(loss_value) train_total_loss_list.append(total_loss) train_accuracy_list.append(accuracy) assert not np.isnan( loss_value), 'Model diverged with loss = NaN' if step % report_freq == 0: start_time = time.time() summary_writer.add_summary(summary_str, step) sec_per_report = np.sum(np.array(durations)) train_loss = np.mean(np.array(train_loss_list)) train_total_loss = np.mean(np.array(train_total_loss_list)) train_accuracy_value = np.mean( np.array(train_accuracy_list)) train_loss_list = [] train_total_loss_list = [] train_accuracy_list = [] durations = [] train_summ = tf.Summary() train_summ.value.add(tag="train_loss", simple_value=train_loss.astype( np.float)) train_summ.value.add(tag="train_total_loss", simple_value=train_total_loss.astype( np.float)) train_summ.value.add( tag="train_accuracy", simple_value=train_accuracy_value.astype(np.float)) summary_writer.add_summary(train_summ, step) vali_loss_value, vali_accuracy_value = self._full_validation( vali_data, sess) if vali_accuracy_value > best_accuracy: best_accuracy = vali_accuracy_value model_dir = os.path.join(FLAGS.log_dir, self.run_id, 'model') if not os.path.isdir(model_dir): os.mkdir(model_dir) checkpoint_path = os.path.join( model_dir, 'vali_{:.3f}'.format(vali_accuracy_value)) self.saver.save(sess, checkpoint_path, global_step=step) vali_summ = tf.Summary() vali_summ.value.add(tag="vali_loss", simple_value=vali_loss_value.astype( np.float)) vali_summ.value.add( tag="vali_accuracy", simple_value=vali_accuracy_value.astype(np.float)) summary_writer.add_summary(vali_summ, step) summary_writer.flush() vali_duration = time.time() - start_time format_str = ( 'Epoch %d, loss = %.4f, total_loss = %.4f, accuracy = %.4f, vali_loss = %.4f, vali_accuracy = %.4f (%.3f ' 'sec/report)') print( format_str % (step // report_freq, train_loss, train_total_loss, train_accuracy_value, vali_loss_value, vali_accuracy_value, sec_per_report + vali_duration))
def main(model, data_path, output_path, test_range=None): h5f = h5py.File(data_path, 'r') datatype = 'val' #datatype = FLAGS.dataset_test dice_all_epi = [] dice_all_endo = [] dice_all_wall = [] #sn_list = []; sp_list = []; acc_list = []; all_dice = [] #output_all = open(output_path+'/all_patients.txt', 'a') #output_sum = open(output_path+'/summary.txt', 'a') if not os.path.isdir(output_path): os.mkdir(output_path) log_f = open(os.path.join(output_path, 'result_info'), 'a') for dataset in ['val']: #for dataset in ['train']: # Find patient ID in each dataset if datatype == 'Sunny': subjects = list(h5f[dataset + '/location/'].keys()) else: subjects = list(h5f['location/'].keys()) test_range = list(test_range) subjects = [subjects[i] for i in test_range] print(subjects) # Read MRI slices for each patient sess = tf.Session(config=tf_config) save_root = FLAGS.save_root_for_prediction with sess.as_default(): tflearn.is_training(False, session=sess) batch_data = tf.placeholder(tf.float32, shape=[1, 256, 256, 1], name='batch_data') batch_label = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='batch_label') with tf.variable_scope(FLAGS.net_name): logits = getattr(network, FLAGS.net_name)(inputs=batch_data, prob_fc=1, prob_conv=1, wd=0, wd_scale=0, training_phase=False) logits = logits[0] print(logits.shape) label = batch_label[:, :, :, 1:3] label = tf.greater(label, tf.ones(label.shape) * 0.5) label = tf.cast(label, tf.float32) #label = tf.cast(tf.argmax(self.label_pl[:,:,:,0:2],3), tf.float32) #label = tf.expand_dims(label, -1) al_true = getattr(network, 'adversatial_2')(batch_data, label, reuse=False) al_true = tf.nn.softmax(al_true) y_pred = tf.nn.softmax(logits[0]) y_pred = tf.reshape(y_pred, [1, 256, 256, 3]) maxvalue = tf.reduce_max(y_pred, axis=-1) print(y_pred.shape) y_pred = tf.equal( y_pred, tf.stack([maxvalue, maxvalue, maxvalue], axis=-1)) pred = tf.cast(y_pred[:, :, :, 1:], tf.float32) #print(pred) al_false = getattr(network, 'adversatial_2')(batch_data, pred, reuse=True) al_false = tf.nn.softmax(al_false) saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, model_path) print('Model restored from ', model_path) for Pid in range(len(subjects)): path = '/home/spc/Documents/AD/' + str(Pid) + '/' if not os.path.isdir(path): os.mkdir(path) print(Pid) Patient_No = subjects[Pid] # Read input, label and location for each patient if datatype == 'Sunny': Input = h5f[dataset + '/input/%s/' % Patient_No][:, :, :, 0] Input = Input / np.max(Input) Label = h5f[dataset + '/label/%s/' % Patient_No][:] #Location = h5f[dataset+'/location/%s/'%Patient_No][:] else: Input = h5f['/input/%s/' % Patient_No][:, :, :, 0] Input = Input / np.max(Input) Label = h5f['/label/%s/' % Patient_No][:] #Location = h5f['/location/%s/'%Patient_No][:] #slice_idx_small = np.where(Location[:,1].reshape([-1])>10)[0] # Rescale the images #Input = rescale(Input, 0.8) #Label = rescale(Label, 0.8) # Resize the images if necessary # Make the prediction. From the predicted socre to generate final mask Input = Input[..., np.newaxis] num_batches = Input.shape[0] #var = [v for v in tf.global_variables() if v.name == 'prediction'][0] prediction_array = [] true_list = np.zeros(num_batches) false_list = np.zeros(num_batches) for step in range(num_batches): true_label, false_label, Labels, Preds, Images = sess.run( [al_true, al_false, label, pred, batch_data], feed_dict={ batch_data: Input[step:step + 1], batch_label: Label[step:step + 1] }) true_list[step] = np.argmax(true_label) false_list[step] = 1 - np.argmax(false_label) name = '{0}_{1:.03}_{2:.03}_{3:.03}'.format( step, Jaccard(Labels[0, :, :, 0], Preds[0, :, :, 0]), true_label[0, 0], false_label[0, 0]) #name = '{0:.03}_{1:.03}_'.format(-np.mean(dice_list_epi), -np.mean(dice_list_endo)) #draw_output(p[:,:,1]+0.5*p[:,:,2], Input[i,:,:,0], Label[i,:,:,1]+0.5*Label[i,:,:,2], # subject=Patient_No, name=i, dir_name=name) """ plt.imshow(np.hstack((Images[0,:,:,0],Images[0,:,:,0],Images[0,:,:,0])), 'gray', interpolation='none') label_all = np.hstack((np.zeros(Images[0,:,:,0].shape), Labels[0,:,:,0], Preds[0,:,:,0])) mask = label_all.astype(np.int32) masked = np.ma.masked_where(mask == 0, mask) plt.imshow(masked, interpolation='none', alpha=0.4,cmap='hsv') plt.axis('off') plt.savefig(output_path + str(name) + '.png', transparent=True) """ concat_img = np.hstack( (Images[0, :, :, 0], Labels[0, :, :, 0], Preds[0, :, :, 0])) * 255 cv2.imwrite(path + str(name) + '.png', concat_img) log_f.write('{0:.03}\t{1:.03}\t{2:.03}\n'.format( Jaccard(Labels[0, :, :, 0], Preds[0, :, :, 0]), true_label[0, 0], false_label[0, 0])) #print(prediction_array.shape, prediction.shape) #print('{}; True_label: {}; False_Label: {}') print(np.mean(true_list), np.mean(false_list)) log_f.close() if datatype != 'Sunny': break
def train(self, **kwargs): #with tf.Graph().as_default(): ops.reset_default_graph() sess = tf.Session(config=self.tf_config) with sess.as_default(): # Data Reading objects tflearn.is_training(True, session=sess) self.am_training = tf.placeholder(dtype=bool, shape=()) self.prob_fc = tf.placeholder_with_default(0.5, shape=()) self.prob_conv = tf.placeholder_with_default(0.5, shape=()) data_fn = functools.partial(input_fn, data_dir=os.path.join(FLAGS.data_dir, FLAGS.set_id), num_shards=1, batch_size=FLAGS.batch_size, use_distortion_for_training=True) self.batch_data, self.batch_labels = tf.cond(self.am_training, lambda: data_fn(subset='train'), lambda: data_fn(subset='test')) self.batch_data = self.batch_data[0] self.batch_labels = self.batch_labels[0] if len(kwargs)==0: self.dict_widx = None self._build_graph() self.saver = tf.train.Saver(tf.global_variables()) # Build an initialization operation to run below init = tf.global_variables_initializer() sess.run(init) else: self.dict_widx = kwargs['dict_widx'] pruned_model = kwargs['pruned_model_path'] #tflearn.config.init_training_mode() self._build_graph() #tflearn.config.init_training_mode() init = tf.global_variables_initializer() sess.run(init) self.saver = tf.train.Saver(tf.global_variables()) self.saver.restore(sess, pruned_model) print('Pruned model restored from ', pruned_model) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # This summary writer object helps write summaries on tensorboard summary_writer = tf.summary.FileWriter(FLAGS.log_dir+self.run_id) summary_writer.add_graph(sess.graph) train_error_list = [] val_error_list = [] print('Start training...') print('----------------------------------') train_steps_per_epoch = FLAGS.num_train_images//FLAGS.batch_size report_freq = train_steps_per_epoch train_steps = FLAGS.train_epoch * train_steps_per_epoch durations = [] train_loss_list = [] train_total_loss_list = [] train_accuracy_list = [] best_epoch = 0 best_accuracy = 0 best_loss = 1 nparams = calculate_number_of_parameters(tf.trainable_variables()) print(nparams) for step in range(train_steps): #print('{} step starts'.format(step)) start_time = time.time() tflearn.is_training(True, session=sess) _, labels, _, summary_str, loss_value, total_loss, accuracy = sess.run( [self.batch_data, self.batch_labels, self.train_op, self.summary_op, self.loss, self.total_loss, self.accuracy], feed_dict={self.am_training: True, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: FLAGS.keep_prob_conv}) #labels, _, summary_str, loss_value, total_loss, accuracy = sess.run( # [self.batch_labels, self.train_op, self.summary_op, self.loss, self.total_loss, self.accuracy], # feed_dict={self.batch_data: image_batch, self.batch_labels: label_batch, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: FLAGS.keep_prob_conv}) tflearn.is_training(False, session=sess) duration = time.time() - start_time #print('{} step starts {}'.format(step, duration)) durations.append(duration) train_loss_list.append(loss_value) train_total_loss_list.append(total_loss) train_accuracy_list.append(accuracy) assert not np.isnan(loss_value), 'Model diverged with loss = NaN' #if step%(report_freq*10)==0: # print(self.learning_rate) if step%report_freq == 0: start_time = time.time() summary_writer.add_summary(summary_str, step) sec_per_report = np.sum(np.array(durations)) train_loss = np.mean(np.array(train_loss_list)) train_total_loss = np.mean(np.array(train_total_loss_list)) train_accuracy_value = np.mean(np.array(train_accuracy_list)) train_loss_list = [] train_total_loss_list = [] train_accuracy_list = [] durations = [] train_summ = tf.Summary() train_summ.value.add(tag="train_loss", simple_value=train_loss.astype(np.float)) train_summ.value.add(tag="train_total_loss", simple_value=train_total_loss.astype(np.float)) train_summ.value.add(tag="train_accuracy", simple_value=train_accuracy_value.astype(np.float)) summary_writer.add_summary(train_summ, step) vali_loss_value, vali_accuracy_value = self._full_validation(sess) if step%(report_freq*50)==0: epoch = step/(report_freq*10) model_dir = os.path.join(FLAGS.log_dir, self.run_id, 'model') if not os.path.isdir(model_dir): os.mkdir(model_dir) checkpoint_path = os.path.join(model_dir, 'epoch_{}_acc_{:.3f}'.format(epoch, vali_accuracy_value)) self.saver.save(sess, checkpoint_path, global_step=step) vali_summ = tf.Summary() vali_summ.value.add(tag="vali_loss", simple_value=vali_loss_value.astype(np.float)) vali_summ.value.add(tag="vali_accuracy", simple_value=vali_accuracy_value.astype(np.float)) summary_writer.add_summary(vali_summ, step) summary_writer.flush() vali_duration = time.time() - start_time format_str = ('Epoch %d, loss = %.4f, total_loss = %.4f, acc = %.4f, vali_loss = %.4f, val_acc = %.4f (%.3f ' 'sec/report)') print(format_str % (step//report_freq, train_loss, train_total_loss, train_accuracy_value, vali_loss_value, vali_accuracy_value, sec_per_report+vali_duration))
def test(tf_config, test_path, probs): """ Test function. To evaluate the performance of trained models tf_config: tensorflow and gpu configurations test_path: url. the path of trained model for test probs: boolean. True: calculate the probability. False: not calcualte. """ ops.reset_default_graph() sess = tf.Session(config=tf_config) save_root = FLAGS.save_root_for_prediction with sess.as_default(): tflearn.is_training(False, session=sess) batch_data, batch_labels, subject, index = input_fn( data_dir=os.path.join(FLAGS.data_dir, FLAGS.set_id), num_shards=1, batch_size=FLAGS.test_batch_size, data_range=TEST_RANGE, subset='test') with tf.variable_scope(FLAGS.net_name): logits = getattr(network, FLAGS.net_name)(inputs=batch_data[0], prob_fc=1, prob_conv=1, wd=0, wd_scale=0, training_phase=False) preclass = tf.nn.softmax(logits) prediction = tf.argmax(preclass, -1) label = tf.argmax(batch_labels[0], -1) if FLAGS.seg: accuracy = dice(labels=batch_labels[0], logits=logits, am_training=False) else: correct_prediction = tf.equal(prediction, label) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, test_path) print('Model restored from ', test_path) if probs: prediction_array = [] else: prediction_array = np.array([]) label_array = np.array([]) num_batches = FLAGS.num_test_images // FLAGS.test_batch_size accuracy_list = [] prediction_array = [] label_array = [] for step in range(num_batches): #print(step) if probs: # TODO s, batch_prediction_array, batch_accuracy, batch_label_array = sess.run( [batch_data, preclass, accuracy, label]) prediction_array.append(batch_prediction_array) label_array = np.concatenate((label_array, batch_label_array)) accuracy_list.append(batch_accuracy) else: images, annotations, sub_ind, ind, batch_prediction_array, batch_accuracy = sess.run( [batch_data, label, subject, index, prediction, accuracy]) accuracy_list.append(batch_accuracy) if probs: prediction_array = np.concatenate(prediction_array, axis=0) accuracy = np.mean(np.array(accuracy_list, dtype=np.float32)) print('{:.3f}'.format(accuracy)) return prediction_array, label_array, accuracy
def train(self, **kwargs): """ Training body. if the filter prune is used, the input should be: dict_widx: the pruned weight matrix pruned_model_path: the path to the pruned model. """ ops.reset_default_graph() sess = tf.Session(config=self.tf_config) with sess.as_default(): # Data Reading objects tflearn.is_training(True, session=sess) self.am_training = tf.placeholder(dtype=bool, shape=()) self.prob_fc = tf.placeholder_with_default(0.5, shape=()) self.prob_conv = tf.placeholder_with_default(0.5, shape=()) data_fn = functools.partial(input_fn, data_dir=os.path.join( FLAGS.data_dir, FLAGS.set_id), num_shards=FLAGS.num_gpus, batch_size=FLAGS.batch_size, use_distortion_for_training=True) self.batch_data, self.batch_labels, _, _ = tf.cond( self.am_training, lambda: data_fn(data_range=self.train_range, subset='train'), lambda: data_fn(data_range=self.vali_range, subset='test')) if FLAGS.status == 'scratch': self.dict_widx = None self._build_graph() self.saver = tf.train.Saver(tf.global_variables()) # Build an initialization operation to run below init = tf.global_variables_initializer() sess.run(init) elif FLAGS.status == 'prune': self.dict_widx = kwargs['dict_widx'] pruned_model = kwargs['pruned_model_path'] self._build_graph() init = tf.global_variables_initializer() sess.run(init) all_name2var = dict( zip( map(lambda x: x.name.split(':')[0], tf.global_variables()), tf.global_variables())) v_sel = [] for name, var in all_name2var.items(): if 'Adam' not in name: v_sel.append(all_name2var[name]) self.saver = tf.train.Saver(v_sel) self.saver.restore(sess, pruned_model) print('Pruned model restored from ', pruned_model) elif FLAGS.status == 'transfer': self._build_graph() init = tf.global_variables_initializer() sess.run(init) v1, v2 = get_trainable_variables(FLAGS.checkpoint_path) self.saver = tf.train.Saver(tf.global_variables()) saver = tf.train.Saver(v2) saver.restore(sess, FLAGS.checkpoint_path) print('Model restored.') # This summary writer object helps write summaries on tensorboard summary_writer = tf.summary.FileWriter(FLAGS.log_dir + self.run_id) summary_writer.add_graph(sess.graph) train_error_list = [] val_error_list = [] print('Start training...') print('----------------------------------') train_steps_per_epoch = FLAGS.num_train_images // FLAGS.batch_size report_freq = train_steps_per_epoch train_steps = FLAGS.train_epoch * train_steps_per_epoch durations = [] train_loss_list = [] train_total_loss_list = [] train_accuracy_list = [] best_epoch = 0 best_accuracy = 0 best_loss = 1 nparams = op_utils.calculate_number_of_parameters( tf.trainable_variables()) print(nparams) for step in range(train_steps): start_time = time.time() tflearn.is_training(True, session=sess) data, labels, _, summary_str, loss_value, total_loss, accuracy = sess.run( [ self.batch_data, self.batch_labels, self.train_op, self.summary_op, self.loss, self.total_loss, self.accuracy ], feed_dict={ self.am_training: True, self.prob_fc: FLAGS.keep_prob_fc, self.prob_conv: FLAGS.keep_prob_conv }) tflearn.is_training(False, session=sess) duration = time.time() - start_time durations.append(duration) train_loss_list.append(loss_value) train_total_loss_list.append(total_loss) train_accuracy_list.append(accuracy) assert not np.isnan( loss_value), 'Model diverged with loss = NaN' if step % report_freq == 0: start_time = time.time() summary_writer.add_summary(summary_str, step) sec_per_report = np.sum(np.array(durations)) train_loss = np.mean(np.array(train_loss_list)) train_total_loss = np.mean(np.array(train_total_loss_list)) train_accuracy_value = np.mean( np.array(train_accuracy_list)) train_loss_list = [] train_total_loss_list = [] train_accuracy_list = [] durations = [] train_summ = tf.Summary() train_summ.value.add(tag="train_loss", simple_value=train_loss.astype( np.float)) train_summ.value.add(tag="train_total_loss", simple_value=train_total_loss.astype( np.float)) train_summ.value.add( tag="train_accuracy", simple_value=train_accuracy_value.astype(np.float)) summary_writer.add_summary(train_summ, step) vali_loss_value, vali_accuracy_value = self._full_validation( sess) if step % (report_freq * FLAGS.save_epoch) == 0: epoch = step / (report_freq * FLAGS.save_epoch) model_dir = os.path.join(FLAGS.log_dir, self.run_id, 'model') if not os.path.isdir(model_dir): os.mkdir(model_dir) checkpoint_path = os.path.join( model_dir, 'epoch_{}_acc_{:.3f}'.format( epoch, vali_accuracy_value)) self.saver.save(sess, checkpoint_path, global_step=step) vali_summ = tf.Summary() vali_summ.value.add(tag="vali_loss", simple_value=vali_loss_value.astype( np.float)) vali_summ.value.add( tag="vali_accuracy", simple_value=vali_accuracy_value.astype(np.float)) summary_writer.add_summary(vali_summ, step) summary_writer.flush() vali_duration = time.time() - start_time format_str = ( 'Epoch %d, loss = %.4f, total_loss = %.4f, acc = %.4f, vali_loss = %.4f, val_acc = %.4f (%.3f ' 'sec/report)') print( format_str % (step // report_freq, train_loss, train_total_loss, train_accuracy_value, vali_loss_value, vali_accuracy_value, sec_per_report + vali_duration))