def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) if self.conf.data_type == '2D': train_reader = H5DataLoader(self.conf.data_dir + self.conf.train_data) valid_reader = H5DataLoader(self.conf.data_dir + self.conf.valid_data) else: train_reader = H53DDataLoader( self.conf.data_dir + self.conf.train_data, self.input_shape) valid_reader = H53DDataLoader( self.conf.data_dir + self.conf.valid_data, self.input_shape) for epoch_num in range(self.conf.max_step): if epoch_num % self.conf.test_interval == 0: inputs, annotations = valid_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, summary, accuracy, dice_accuracy = self.sess.run( [ self.loss_op, self.valid_summary, self.accuracy_op, self.dice_accuracy_op ], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) print('----valid loss', loss) print('----valid accuracy', accuracy) print('----valid dice accuracy', dice_accuracy) elif epoch_num % self.conf.summary_interval == 0: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num + self.conf.reload_step) else: inputs, annotations = train_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, summary, _, accuracy, dice_accuracy= self.sess.run( [self.loss_op, self.train_summary, self.train_op, self.accuracy_op, \ self.dice_accuracy_op], feed_dict=feed_dict) print('----train loss', loss) print('----train accuracy', accuracy) print('----train dice accuracy', dice_accuracy) self.save_summary(summary, epoch_num + self.conf.reload_step) if epoch_num % self.conf.save_interval == 0: self.save(epoch_num + self.conf.reload_step)
def predict(self, test_step, data_index, sub_batch_index, test_type): print('---->predicting ', test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if test_type == 'valid': test_reader = H53DDataLoader(self.conf.data_dir + self.conf.valid_data, self.input_shape, is_train=False) elif test_type == 'predict': test_reader = H53DDataLoader(self.conf.data_dir + self.conf.test_data, self.input_shape, is_train=False) else: print("invalid type") return predict_generator = test_reader.generate_data( data_index, sub_batch_index, [self.conf.depth, self.conf.height, self.conf.width], [self.conf.d_gap, self.conf.w_gap, self.conf.h_gap]) s_predictions, s_labels = self.predict_func(predict_generator) # process label and data # process label expand_annotations = tf.expand_dims(s_labels, -1, name='s_labels/expand_dims') one_hot_annotations = tf.squeeze(expand_annotations, axis=[self.channel_axis], name='s_labels/squeeze') one_hot_annotations = tf.one_hot(one_hot_annotations, depth=self.conf.class_num, axis=self.channel_axis, name='s_labels/one_hot') # process data decoded_predictions = tf.argmax(s_predictions, self.channel_axis, name='accuracy/decode_pred') correct_prediction = tf.equal(tf.cast(s_labels, tf.int64), decoded_predictions, name='accuracy/predict_correct_pred') #accuracy accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32, name='accuracy/cast'), name='accuracy/accuracy_op') #loss loss = tf.reduce_mean( tf.losses.softmax_cross_entropy(one_hot_annotations, s_predictions)) return accuracy_op, loss
def test(self): sig = True print('---->testing ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if self.conf.data_type == '2D': test_reader = H5DataLoader( self.conf.data_dir + self.conf.test_data, False) else: test_reader = H53DDataLoader( self.conf.data_dir + self.conf.test_data, self.input_shape) self.sess.run(tf.local_variables_initializer()) count = 0 losses = [] accuracies = [] m_ious = [] while sig: sig, inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None: break feed_dict = {self.inputs: inputs, self.labels: labels} loss, accuracy, m_iou, _ = self.sess.run( [self.loss_op, self.accuracy_op, self.m_iou, self.miou_op], feed_dict=feed_dict) print('values----->', loss, accuracy, m_iou) count += 1 losses.append(loss) accuracies.append(accuracy) m_ious.append(m_iou) print('Loss: ', np.mean(losses)) print('Accuracy: ', np.mean(accuracies)) print('M_iou: ', m_ious[-1])
def test(self): print('---->testing ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return data_reader = H53DDataLoader(self.conf.data_dir, self.conf.patch_size, self.conf.validation_id, self.conf.overlap_stepsize) self.sess.run(tf.local_variables_initializer()) # count = 0 losses = [] accuracies = [] for i in range(data_reader.num_of_valid_patches): inputs, annotations, _ = data_reader.valid_next_batch() # if inputs.shape[0] < self.conf.batch: # break # pseudo_inputs = np.zeros((1,32,32,32,2), dtype=np.float32) # pseudo_labels = np.zeros((1,32,32,32), dtype=np.float32) # CUT_MEAN = np.array((100.913811861, 121.187003401), dtype=np.float32) # pseudo_inputs -= CUT_MEAN feed_dict = {self.inputs: inputs, self.annotations: annotations} loss, accuracy = self.sess.run([self.loss_op, self.accuracy_op], feed_dict=feed_dict) print('values----->', loss, accuracy) # count += 1 losses.append(loss) accuracies.append(accuracy) print('Loss: ', np.mean(losses)) print('Accuracy: ', np.mean(accuracies))
def predict(self): print('---->predicting ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return data_reader = H53DDataLoader(self.conf.data_dir, self.conf.patch_size, self.conf.validation_id, self.conf.overlap_stepsize) self.sess.run(tf.local_variables_initializer()) predictions = {} for i in range(data_reader.num_of_valid_patches): inputs, annotations, location = data_reader.valid_next_batch() # if inputs.shape[0] < self.conf.batch: # break feed_dict = {self.inputs: inputs, self.annotations: annotations} preds = self.sess.run(self.softmax_predictions, feed_dict=feed_dict) print('--->processing results for: ', location) for j in range(self.conf.patch_size): for k in range(self.conf.patch_size): for l in range(self.conf.patch_size): key = (location[0]+j, location[1]+k, location[2]+l) if key not in predictions.keys(): predictions[key] = [] predictions[key].append(preds[0, j, k, l, :]) print('--->averaging results') results = np.zeros((data_reader.t_d, data_reader.t_h, data_reader.t_w, self.conf.class_num), dtype=np.float32) for key in predictions.keys(): results[key[0], key[1], key[2]] = np.mean(predictions[key], axis=0) print('--->saving results') save_filename = 'results' + str(self.conf.test_step) + '_sub' + str(self.conf.validation_id) + '_overlap' + str(self.conf.overlap_stepsize) +'.npy' save_file = os.path.join(self.conf.savedir, save_filename) np.save(save_file, results)
def predict(self): print('---->predicting ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if self.conf.data_type == '2D': test_reader = H5DataLoader( self.conf.data_dir+self.conf.test_data, False) else: test_reader = H53DDataLoader( self.conf.data_dir+self.conf.test_data, self.input_shape) predictions = [] while True: inputs, labels = test_reader.next_batch(self.conf.batch) if inputs.shape[0] < self.conf.batch: break feed_dict = {self.inputs: inputs, self.labels: labels} predictions.append(self.sess.run( self.decoded_preds, feed_dict=feed_dict)) print('----->saving predictions') for index, prediction in enumerate(predictions): for i in range(prediction.shape[0]): imsave(prediction[i], self.conf.sampledir + str(index*prediction.shape[0]+i)+'.png')
def predict(self,test_step,data_index,sub_batch_index, test_type): print('---->predicting ', self.conf.test_step) if test_type == 'valid': test_reader = H53DDataLoader( self.conf.data_dir+self.conf.valid_data, self.input_shape,is_train=False) elif test_type == 'predict': print("start get predict data") test_reader = H53DDataLoader( self.conf.data_dir+self.conf.test_data, self.input_shape,is_train=False) print("finish get data") else: print("invalid type") return predict_generator = test_reader.generate_data(data_index,sub_batch_index,[self.conf.depth,self.conf.height,self.conf.width],[self.conf.d_gap,self.conf.w_gap,self.conf.h_gap]) s_predictions,s_labels = self.predict_func(predict_generator) print("finish concate ---------+++++") # process label and data # process label expand_annotations = tf.expand_dims( s_labels, -1, name='s_labels/expand_dims') one_hot_annotations = tf.squeeze( expand_annotations, axis=[self.channel_axis], name='s_labels/squeeze') one_hot_annotations = tf.one_hot( one_hot_annotations, depth=self.conf.class_num, axis=self.channel_axis, name='s_labels/one_hot') # process data decoded_predictions = tf.argmax( s_predictions, self.channel_axis, name='accuracy/decode_pred') correct_prediction = tf.equal( tf.cast(s_labels,tf.int64), decoded_predictions, name='accuracy/predict_correct_pred') #accuracy accuracy_op = tf.reduce_mean( tf.cast(correct_prediction, tf.float32, name='accuracy/cast'), name='accuracy/accuracy_op') #loss loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(one_hot_annotations, s_predictions)) #dice ratio DiceRatio, sublist = ops.dice_accuracy(decoded_predictions, s_labels, self.conf.class_num) print("session image labels-------vvvvvvv") accuracy, Loss, diceRatio, Sublist = self.sess.run([accuracy_op, loss,DiceRatio,sublist]) del DiceRatio, sublist gc.collect() return accuracy, Loss, diceRatio, Sublist
def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) if self.conf.data_type == '2D': train_reader = H5DataLoader( self.conf.data_dir+self.conf.train_data) valid_reader = H5DataLoader( self.conf.data_dir+self.conf.valid_data) else: train_reader = H53DDataLoader( self.conf.data_dir+self.conf.train_data, self.input_shape) valid_reader = H53DDataLoader( self.conf.data_dir+self.conf.valid_data, self.input_shape) for epoch_num in range(self.conf.max_step+1): if epoch_num and epoch_num % self.conf.test_interval == 0: inputs, labels = valid_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels} loss, summary = self.sess.run( [self.loss_op, self.valid_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num+self.conf.reload_step) print('----testing loss', loss) if epoch_num and epoch_num % self.conf.summary_interval == 0: inputs, labels = train_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels} loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, epoch_num+self.conf.reload_step) else: inputs, labels = train_reader.next_batch(self.conf.batch) feed_dict = {self.inputs: inputs, self.labels: labels} loss, _ = self.sess.run( [self.loss_op, self.train_op], feed_dict=feed_dict) print('----training loss', loss) if epoch_num and epoch_num % self.conf.save_interval == 0: self.save(epoch_num+self.conf.reload_step)
def predict(self): print('---->predicting ', self.conf.test_step) if self.conf.test_step > 0: self.reload(self.conf.test_step) else: print("please set a reasonable test_step") return if self.conf.data_type == '2D': test_reader = H5DataLoader( self.conf.data_dir + self.conf.test_data, False) else: test_reader = H53DDataLoader( self.conf.data_dir + self.conf.test_data, self.input_shape) predictions = [] sig = True final_inputs = [] final_labels = [] while sig: sig, inputs, labels = test_reader.next_batch(self.conf.batch) if inputs is None: break final_inputs.append(inputs) final_labels.append(labels) feed_dict = {self.inputs: inputs, self.labels: labels} predictions.append( self.sess.run(self.decoded_preds, feed_dict=feed_dict)) final_inputs = np.array(final_inputs, 'uint8') final_labels = np.array(final_labels, 'uint8') print('----->saving predictions') for index, prediction in enumerate(predictions): for i in range(prediction.shape[0]): imsave( prediction[i], self.conf.sampledir + str(index * prediction.shape[0] + i) + '.png') scipy.misc.imsave( self.conf.sampledir + str(index * prediction.shape[0] + i) + '.jpg', final_inputs[index][i]) scipy.misc.imsave( self.conf.sampledir + str(index * prediction.shape[0] + i) + '_label.png', final_labels[index][i] * 45)
def train(self): if self.conf.reload_step > 0: self.reload(self.conf.reload_step) data_reader = H53DDataLoader(self.conf.data_dir, self.conf.patch_size, self.conf.validation_id, self.conf.overlap_stepsize, self.conf.aug_flip, self.conf.aug_rotate) for train_step in range(1, self.conf.max_step + 1): # if train_step % self.conf.test_interval == 0: # inputs, annotations = data_reader.valid_next_batch() # feed_dict = {self.inputs: inputs, # self.annotations: annotations} # loss, summary = self.sess.run( # [self.loss_op, self.valid_summary], feed_dict=feed_dict) # self.save_summary(summary, train_step+self.conf.reload_step) # print('----testing loss', loss) # el if train_step % self.conf.summary_interval == 0: inputs, annotations = data_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _, summary = self.sess.run( [self.loss_op, self.train_op, self.train_summary], feed_dict=feed_dict) self.save_summary(summary, train_step + self.conf.reload_step) else: inputs, annotations = data_reader.next_batch(self.conf.batch) feed_dict = { self.inputs: inputs, self.annotations: annotations } loss, _ = self.sess.run([self.loss_op, self.train_op], feed_dict=feed_dict) print('----training loss', loss) if train_step % self.conf.save_interval == 0: self.save(train_step + self.conf.reload_step)