def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data) valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data) for epoch_num in range(self.conf.max_step + 1): if epoch_num and epoch_num % self.conf.test_interval == 0: inputs, labels = valid_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels} loss, summary = self.sess.run( [self.loss_op, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) print('----testing loss', loss) if epoch_num and epoch_num % self.conf.summary_interval == 0: inputs, labels = train_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels} loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) else: inputs, labels = train_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels} loss, _ = self.sess.run([self.loss_op, self.train_op], feed_dict=feed_dict) print('----training loss', loss) if epoch_num and epoch_num % self.conf.save_interval == 0: self.save(epoch_num + self.conf.reload_step)
def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) if self.conf.data_type == '2D': train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data) valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data) else: train_reader = H53DDataLoader( self.conf.data_dir + self.conf.train_data, self.input_shape) valid_reader = H53DDataLoader( self.conf.data_dir + self.conf.valid_data, self.input_shape) for epoch_num in range(self.conf.max_step): if epoch_num % self.conf.test_interval == 0: inputs, annotations = valid_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, summary, accuracy, dice_accuracy = self.sess.run( [ self.loss_op, self.valid_summary, self.accuracy_op, self.dice_accuracy_op ], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) print('----valid loss', loss) print('----valid accuracy', accuracy) print('----valid dice accuracy', dice_accuracy) elif epoch_num % self.conf.summary_interval == 0: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) else: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, summary, _, accuracy, dice_accuracy= self.sess.run( [self.loss_op, self.train_summary, self.train_op, self.accuracy_op, \ self.dice_accuracy_op], feed_dict=feed_dict) print('----train loss', loss) print('----train accuracy', accuracy) print('----train dice accuracy', dice_accuracy) self.save_summary(summary, epoch_num + self.conf.reload_step) if epoch_num % self.conf.save_interval == 0: self.save(epoch_num + self.conf.reload_step)
def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) train_reader = H5DataLoader( self.conf.data_dir+self.conf.train_data) valid_reader = H5DataLoader( self.conf.data_dir+self.conf.valid_data) iteration = train_reader.iter + self.conf.reload_step pre_iter = iteration epoch_num = 0 while iteration < self.conf.max_step: if pre_iter != iteration: pre_iter = iteration inputs, labels, catgory = valid_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels, self.catgory: catgory} loss, summary = self.sess.run( [self.d_loss_total, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, iteration) print('----testing d loss', loss) loss, summary = self.sess.run( [self.g_loss_total, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, iteration) self.save(iteration) print('----testing g loss', loss) elif epoch_num % self.conf.summary_interval == 0: inputs, labels, catgory = train_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels, self.catgory: catgory} loss, _, summary = self.sess.run( [self.d_loss_total, self.d_train, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num+self.conf.reload_step) loss, _, summary = self.sess.run( [self.g_loss_total, self.g_train, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num+self.conf.reload_step) else: inputs, labels, catgory = train_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels, self.catgory: catgory} loss, _, summary = self.sess.run( [self.d_loss_total, self.d_train, self.train_summary], feed_dict=feed_dict) print('----training d loss', loss) loss, _, summary = self.sess.run( [self.g_loss_total, self.g_train, self.train_summary], feed_dict=feed_dict) print('----training g loss', loss) iteration = train_reader.iter + self.conf.reload_step epoch_num += 1
def test(self): sig = True print('---->testing ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if self.conf.data_type == '2D': test_reader = H5DataLoader( self.conf.data_dir + self.conf.test_data, False) else: test_reader = H53DDataLoader( self.conf.data_dir + self.conf.test_data, self.input_shape) self.sess.run(tf.local_variables_initializer()) count = 0 losses = [] accuracies = [] m_ious = [] while sig: sig, inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None: break feed_dict = {self.inputs: inputs, self.labels: labels} loss, accuracy, m_iou, _ = self.sess.run( [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op], feed_dict=feed_dict) print('values----->', loss, accuracy, m_iou) count += 1 losses.append(loss) accuracies.append(accuracy) m_ious.append(m_iou) print('Loss: ', np.mean(losses)) print('Accuracy: ', np.mean(accuracies)) print('M_iou: ', m_ious[-1])
def store(self): print('---->storing ', self.conf.test_step) test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data, False) images = [] ground_truth = [] while True: inputs, annotations = test_reader.next_batch(self.conf.batch) if inputs.shape[0] < self.conf.batch: break images.append(inputs) ground_truth.append(annotations) print(images) print('----->saving inputs and annotations') for index, image in enumerate(images): print(index) for i in range(image.shape[0]): scipy.misc.imsave( "JPEGImages/" + str(index * image.shape[0] + i) + '.jpg', image[i]) print("Done storing JPEG imeges") for index, annotation in enumerate(ground_truth): print(index) for i in range(annotation.shape[0]): imsave( annotation[i], 'Annotations/' + str(index * annotation.shape[0] + i) + '.png') print("Done storing annotations")
def test(self): print('---->predicting ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data, True) predictions = [] labels = [] while test_reader.iter < 1: inputs, annotations = test_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.annotations: annotations} res, acc = self.sess.run( [self.decoded_predictions, self.accuracy_op], feed_dict=feed_dict) print(acc) #res = np.concatenate(res,axis=0) predictions.append(res) labels.append(annotations) predictions = np.concatenate(predictions, axis=0) labels = np.concatenate(labels, axis=0) print('---', predictions.shape) print(labels.shape) np.savez('temp', predictions, labels) print(predictions.shape) print(ops.dice_ratio(predictions[0], labels[0])) print(ops.dice_ratio(predictions[1], labels[1]))
def predict(self): print('---->predicting ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if self.conf.data_type == '2D': test_reader = H5DataLoader( self.conf.data_dir+self.conf.test_data, False) else: test_reader = H53DDataLoader( self.conf.data_dir+self.conf.test_data, self.input_shape) predictions = [] while True: inputs, labels = test_reader.next_batch(self.conf.batch) if inputs.shape[0] < self.conf.batch: break feed_dict = {self.inputs: inputs, self.labels: labels} predictions.append(self.sess.run( self.decoded_preds, feed_dict=feed_dict)) print('----->saving predictions') for index, prediction in enumerate(predictions): for i in range(prediction.shape[0]): imsave(prediction[i], self.conf.sampledir + str(index*prediction.shape[0]+i)+'.png')
def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data) valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data) iteration = train_reader.iter pre_iter = iteration epoch_num = 0 while iteration < self.conf.max_step: if pre_iter != iteration: pre_iter = iteration inputs, annotations = valid_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, summary = self.sess.run( [self.loss_op, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, iteration) self.save(iteration) print('----testing loss', loss) elif epoch_num % self.conf.summary_interval == 0: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) else: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _ = self.sess.run([self.loss_op, self.train_op], feed_dict=feed_dict) print('----training loss', loss) iteration = train_reader.iter epoch_num += 1
def train(self): self.restore() self.sess.run(tf.local_variables_initializer()) train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data) valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data) start_step = 0 if self.global_step is None else self.global_step + 1 for epoch_num in range(start_step, self.conf.max_step + 1): print(epoch_num) if epoch_num % self.conf.test_interval == 0: inputs, annotations = valid_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, summary = self.sess.run( [self.loss_op, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) print("Step: %d, Test_loss:%g" % (epoch_num, loss)) if epoch_num % self.conf.summary_interval == 0: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) else: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _ = self.sess.run([self.loss_op, self.train_op], feed_dict=feed_dict) print("Step: %d, Train_loss:%g" % (epoch_num, loss)) if epoch_num % self.conf.save_interval == 0: self.save(epoch_num + self.conf.reload_step)
def test(self): super(base_model_metric, self).test() '''test the metric learning part in the end-to-end model''' print('testing metric learning results...') train_reader = GenDataLoader(model_type='unpaired', conf=self.conf, portion=self.conf.portion) test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data, model_type='paired', is_train=False) acc, f1, RI, purity = self.calculate_cost(self.sess, train_reader, test_reader) print( 'knn accuracy f1, and cluster purity, RI &%0.4f &%0.4f &&%0.4f &%0.4f' % (acc, f1, purity, RI))
def predict(self): print('---->predicting ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if self.conf.data_type == '2D': test_reader = H5DataLoader( self.conf.data_dir + self.conf.test_data, False) else: test_reader = H53DDataLoader( self.conf.data_dir + self.conf.test_data, self.input_shape) predictions = [] sig = True final_inputs = [] final_labels = [] while sig: sig, inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None: break final_inputs.append(inputs) final_labels.append(labels) feed_dict = {self.inputs: inputs, self.labels: labels} predictions.append( self.sess.run(self.decoded_preds, feed_dict=feed_dict)) final_inputs = np.array(final_inputs, 'uint8') final_labels = np.array(final_labels, 'uint8') print('----->saving predictions') for index, prediction in enumerate(predictions): for i in range(prediction.shape[0]): imsave( prediction[i], self.conf.sampledir + str(index * prediction.shape[0] + i) + '.png') scipy.misc.imsave( self.conf.sampledir + str(index * prediction.shape[0] + i) + '.jpg', final_inputs[index][i]) scipy.misc.imsave( self.conf.sampledir + str(index * prediction.shape[0] + i) + '_label.png', final_labels[index][i] * 45)
def test(self): print('---->testing ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data, False) self.sess.run(tf.local_variables_initializer()) count = 0 losses = [] accuracies = [] m_ious = [] confusion_matrix_total = tf.zeros( [self.conf.class_num, self.conf.class_num], tf.int32) #start = time.time() while True: inputs, annotations = test_reader.next_batch(self.conf.batch) if inputs.shape[0] < self.conf.batch: break feed_dict = {self.inputs: inputs, self.annotations: annotations} loss, accuracy, m_iou, _, confusion_matrix, decoded_predictions = self.sess.run( [ self.loss_op, self.accuracy_op, self.m_iou, self.miou_op, self.confusion_matrix, self.decoded_predictions ], feed_dict=feed_dict) print('values----->', loss, accuracy, m_iou) count += 1 losses.append(loss) accuracies.append(accuracy) m_ious.append(m_iou) confusion_matrix_total = tf.add(confusion_matrix_total, confusion_matrix) print('Loss: ', np.mean(losses)) print('Accuracy: ', np.mean(accuracies)) print('M_iou: ', m_ious[-1]) print('Confusion Matrix:') print('Dumping confusion matrix') with open('data_best.pickle', 'wb') as f: pickle.dump(confusion_matrix_total.eval(), f, pickle.HIGHEST_PROTOCOL) #end = time.time() print(confusion_matrix_total.eval())
def test(self): print('---->testing ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return test_reader = H5DataLoader(self.conf.data_dir + self.conf.test_data, False) accuracies = [] while True: inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None or inputs.shape[0] < self.conf.batch: break feed_dict = {self.inputs: inputs, self.labels: labels} accur = self.sess.run(self.accuracy_op, feed_dict=feed_dict) accuracies.append(accur) print('accuracy is ', sum(accuracies) / len(accuracies))
def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) train_pair_reader = H5DataLoader(self.conf.data_dir + self.conf.train_pair, model_type='paired_metric', portion=self.conf.portion) train_unpair_reader = H5DataLoader(self.conf.data_dir + self.conf.train_unpair, conf=self.conf, model_type='unpaired_metric') valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data, model_type='paired_metric', is_train=False) train_p2 = GenDataLoader(model_type='unpaired', conf=self.conf, portion=self.conf.portion) valid_p2 = H5DataLoader(self.conf.data_dir + self.conf.valid_data, model_type='paired', is_train=False) iteration = train_pair_reader.iter + self.conf.reload_step pre_iter = iteration epoch_num = self.conf.reload_step start_time = time.time() bestsim = 0 best_acc = 0 epochs_no_performance_gain = 0 while epoch_num < self.conf.max_step: Ap, Bp, p_label = train_pair_reader.next_batch(self.conf.batch) Au, Bu, up_label = train_unpair_reader.next_batch(self.conf.batch) pl_lab = get_pairwiselabel(np.concatenate((p_label, up_label), 0)) feed_dict = { self.a_p: Ap, self.b_p: Bp, self.cat: p_label, self.a_u: Au, self.b_u: Bu, self.pl_total: pl_lab } d_loss, m_loss_d, _ = self.sess.run( [self.d_loss_total, self.m_loss, self.d_train], feed_dict=feed_dict) g_loss, g1, g2, m_loss_g, _, tmp = self.sess.run( [ self.g_loss_total, self.g1_loss, self.g2_loss, self.m_loss, self.g_train, self.tmp ], feed_dict=feed_dict) if epoch_num % 200 == 0: print( 'epoch %d, duration %0.2f, d_loss %0.3f, m_loss_d %0.3f, g_loss %0.3f, g1 %0.3f, g2 %0.3f, m_loss_g %0.3f' % (epoch_num, time.time() - start_time, d_loss, m_loss_d, g_loss, g1, g2, m_loss_g)) start_time = time.time() if epoch_num % 1000 == 0: ######################### validation costs ################################ Ap, Bp, p_label = valid_reader.next_batch( self.conf.batch) #### mri, pet, categ_labels pl_lab = get_pairwiselabel( np.concatenate((p_label, p_label, p_label), 0)) feed_dict = { self.a_p: Ap, self.b_p: Bp, self.cat: p_label, self.a_u: Ap, self.b_u: Bp, ## tensors except for a_p, b_p are used for placeholder only self.pl_total: pl_lab } d_loss = self.sess.run(self.d_loss_total, feed_dict=feed_dict) g_loss, gen_a, gen_b = self.sess.run( [self.g_loss_total, self.fake_ap, self.fake_bp], feed_dict=feed_dict) ps = 0. ss = 0. a_p = np.squeeze(Ap, axis=-1) b_p = np.squeeze(Bp, axis=-1) gen_a = np.squeeze(gen_a, axis=-1) gen_b = np.squeeze(gen_b, axis=-1) for i, (ap, ga, bp, gb) in enumerate(zip(a_p, gen_a, b_p, gen_b)): ap, bp, ga, gb = ops.normalize(ap), ops.normalize( bp), ops.normalize(ga), ops.normalize(gb) ps += (ops.psnr(ap, ga) + ops.psnr(bp, gb)) / 2 ss += (ops.ssim(ap, ga) + ops.ssim(bp, gb)) / 2 acc, f1, RI, purity = self.calculate_cost( self.sess, train_p2, valid_p2) print( 'proposed_model valid d loss %0.3f, g loss %0.3f, PSNR %0.3f, SSIM %0.3f, acc %0.4f, f1 %0.4f, RI %0.4f, purity %0.4f' % (d_loss, g_loss, ps / self.conf.batch, ss / self.conf.batch, acc, f1, RI, purity)) if best_acc < acc: best_acc = acc best_f1, best_purity, best_RI = f1, purity, RI self.save(epoch_num) epochs_no_performance_gain = 0 else: epochs_no_performance_gain += 1 if epochs_no_performance_gain > 5: print('stop since no improvement for %d epochs' % epochs_no_performance_gain) # print('acc, f1, purity, RI &%.4f &%.4f & &%.4f &%.4f'% (best_acc, best_f1, best_purity, best_RI)) print('acc, f1, purity, RI [%.4f, %.4f, %.4f, %.4f]' % (best_acc, best_f1, best_purity, best_RI)) break epoch_num += 1
def test(self): if self.conf.test_all: model_name = os.path.basename(self.conf.modeldir) if not os.path.exists('./test'): os.makedirs('./test') f = open(os.path.join('./test', model_name + '.csv'), "w+") print('testing all') latest_checkpoint_path = tf.train.latest_checkpoint( self.conf.modeldir) latest_checkpoint = int(latest_checkpoint_path.rsplit('-')[1]) checkpoint = 1 + self.conf.save_interval while checkpoint <= latest_checkpoint: self.reload(checkpoint) accuracies = [] test_reader = H5DataLoader( self.conf.data_dir + self.conf.test_data, False) while True: inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None or inputs.shape[0] < self.conf.batch: break feed_dict = {self.inputs: inputs, self.labels: labels} accur = self.sess.run(self.accuracy_op, feed_dict=feed_dict) accuracies.append(accur) f.write('%d, \t %f' % (checkpoint, sum(accuracies) / len(accuracies))) print('%d, \t %f' % (checkpoint, sum(accuracies) / len(accuracies))) checkpoint += self.conf.save_interval f.close() else: print('---->testing ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return test_reader = H5DataLoader( self.conf.data_dir + self.conf.test_data, False) accuracies = [] model_name = os.path.basename(self.conf.modeldir) test_dir = './test/' + model_name if not os.path.exists(test_dir): os.makedirs(test_dir) index = 0 wrong_preds = [] debug_stat = [[0 for i in range(10)] for i in range(10)] while True: inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None or inputs.shape[0] < self.conf.batch: break feed_dict = {self.inputs: inputs, self.labels: labels} accur, preds = self.sess.run( [self.accuracy_op, self.decoded_preds], feed_dict=feed_dict) for i in range(len(preds)): if preds[i] != labels[i]: img = inputs[i] img[:, :, 0] = (img[:, :, 0] * 0.24703233 + 0.49139968) * 255 img[:, :, 1] = (img[:, :, 1] * 0.24348505 + 0.48215827) * 255 img[:, :, 2] = (img[:, :, 2] * 0.26158768 + 0.44653118) * 255 img = np.array(np.uint8(img)) pil_image = Image.fromarray(img) pil_image.save( os.path.join(test_dir, str(index) + '.png')) debug_stat[int(labels[i])][int(preds[i])] += 1 wrong_preds.append({ index: { 'label': int(labels[i]), 'pred': int(preds[i]) } }) index += 1 accuracies.append(accur) label_list = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] print('\t', end='') for item in label_list: print(item + '\t', end=' ') print("total \t percentage") for i in range(10): print(label_list[i], end='\t') for j in range(10): print(debug_stat[i][j], end='\t') print("%d \t %f" % (sum(debug_stat[i]), sum(debug_stat[i]) / sum([sum(i) for i in debug_stat]))) json.dump(wrong_preds, open(os.path.join(test_dir, str(index) + '.json'), 'w+'), sort_keys=True, indent=4) json.dump(debug_stat, open(os.path.join(test_dir, str(index) + '.json'), 'w+'), sort_keys=True, indent=4) print('step: %d, accuracy %f' % (self.conf.test_step, sum(accuracies) / len(accuracies)))
def train(self): def random_flip(image): return random.choice([image, np.fliplr(image)]) def random_crop(image): pad_width = ((4, 4), (4, 4), (0, 0)) image = np.lib.pad(image, pad_width=pad_width, mode='constant', constant_values=0) start_h = random.randint(0, 8) start_w = random.randint(0, 8) return image[start_h:start_h + 32, start_w:start_w + 32] train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data) valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data) for epoch_num in range( tf.train.global_step(self.sess, self.global_step), self.conf.max_step + 1): if epoch_num and epoch_num % self.conf.test_interval == 0: inputs, labels = valid_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.labels: labels, self.learning_rate_placeholder: self.learning_rate_schedule(epoch_num) } loss, summary = self.sess.run( [self.loss_op, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num) print('global step: %d; training loss %f' % (epoch_num, loss)) if epoch_num and epoch_num % self.conf.summary_interval == 0: inputs, labels = train_reader.next_batch(self.conf.batch) inputs = np.array(list(map(random_flip, inputs))) inputs = np.array(list(map(random_crop, inputs))) feed_dict = { self.inputs: inputs, self.labels: labels, self.learning_rate_placeholder: self.learning_rate_schedule(epoch_num) } loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num) else: inputs, labels = train_reader.next_batch(self.conf.batch) inputs = np.array(list(map(random_flip, inputs))) inputs = np.array(list(map(random_crop, inputs))) feed_dict = { self.inputs: inputs, self.labels: labels, self.learning_rate_placeholder: self.learning_rate_schedule(epoch_num) } loss, _ = self.sess.run([self.loss_op, self.train_op], feed_dict=feed_dict) print('global step: %d; training loss %f' % (epoch_num, loss)) if epoch_num and epoch_num % self.conf.save_interval == 0: self.save(epoch_num)