def infer(self, image_dir, batch_size, ckpt, output_folder): """ Uses a trained model file to get predictions on the specified images. Args: image_dir: Directory where the images are located. batch_size: Batch size to use while inferring (relevant if batch norm is used) ckpt: Name of the checkpoint file to use. output_folder: Folder where the predictions on the images shoudl be saved. """ image_paths = [ os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png') or x.endswith('.jpg') ] infer_data, infer_queue_init = utility.data_batch( image_paths, None, batch_size) image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3]) training = tf.placeholder(tf.bool, shape=[]) if not self.logits: self.logits = self.model(image_ph, training) mask = tf.squeeze(tf.argmax(self.logits, axis=3)) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, ckpt) sess.run(infer_queue_init) for _ in range(len(image_paths) // batch_size): image = sess.run(infer_data) feed_dict = {image_ph: image, training: True} prediction = sess.run(mask, feed_dict) for j in range(prediction.shape[0]): cv2.imwrite( os.path.join(output_folder, '{}.png'.format(j)), 255 * prediction[j, :, :])
def infer(self, image_paths, batch_size): infer_data, infer_queue_init = utility.data_batch( image_paths, None, batch_size) image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3]) training = tf.placeholder(tf.bool, shape=[]) tiramisu = DenseTiramisu(16, [2, 3, 3], 2) logits = tiramisu.model(image_ph, training) mask = tf.squeeze(tf.argmax(logits, axis=3)) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, 'trained_tiramisu/model.ckpt-18') sess.run(infer_queue_init) for i in range(1): image = sess.run(infer_data) feed_dict = {image_ph: image, training: True} prediction = sess.run(mask, feed_dict) for j in range(prediction.shape[0]): cv2.imwrite('predictions/' + str(j) + '.png', 255 * prediction[j, :, :])
def train(train_image_paths, train_mask_paths, val_image_path, val_mask_path, lr=1e-4): tf.reset_default_graph() with tf.Graph().as_default(): data, init_op = utility.data_batch(train_image_paths, train_mask_paths, augment=True, batch_size=BATCH_SIZE) val_data, val_init_op = utility.data_batch(val_image_path, val_mask_path, augment=False, batch_size=3) image_tensor, mask_tensor = data val_image_tensor, val_mask_tensor = val_data image_placeholder = tf.placeholder(tf.float32, shape=[None, H, W, 3]) mask_placeholder = tf.placeholder(tf.int32, shape=[None, H, W, 1]) training_flag = tf.placeholder(tf.bool) logits = network(image_placeholder, training_flag) cost = xentropy_loss(mask_placeholder, logits) optimizer = tf.train.AdamOptimizer(learning_rate=lr) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train = optimizer.minimize(cost) iou_metric, iou_update = calculate_iou(mask_placeholder, logits, "iou_metric", BATCH_SIZE, NUM_CLASSES) running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="iou_metric") running_vars_init = tf.variables_initializer(var_list=running_vars) seg_image = tf.argmax(logits, axis=3) saver = tf.train.Saver() with tf.Session() as sess: print('Training with learning rate: ', lr) if FINETUNE: saver.restore(sess, tf.train.latest_checkpoint(MODEL)) else: sess.run(tf.global_variables_initializer()) for epoch in range(1, EPOCHS + 1): try: total_loss, total_iou = 0, 0 for step in range(0, TRAIN_ITERATIONS): print('Step:', step) sess.run([init_op, val_init_op]) sess.run(running_vars_init) train_image, train_mask = sess.run( [image_tensor, mask_tensor]) train_feed_dict = { image_placeholder: train_image, mask_placeholder: train_mask, training_flag: True } _, loss, update_iou, pred_mask = sess.run( [train, cost, iou_update, seg_image], feed_dict=train_feed_dict) iou = sess.run(iou_metric) total_loss += loss total_iou += iou except tf.errors.OutOfRangeError: pass finally: sess.run(running_vars_init) val_iou, val_loss = 0, 0 val_image, val_mask = sess.run( [val_image_tensor, val_mask_tensor]) val_feed_dict = { image_placeholder: val_image, mask_placeholder: val_mask, training_flag: False } update_iou, pred_mask, val_loss = sess.run( [iou_update, seg_image, cost], feed_dict=val_feed_dict) val_iou = sess.run(iou_metric) print('Train loss: ', total_loss / TRAIN_ITERATIONS, 'Train iou: ', total_iou / TRAIN_ITERATIONS) print('Val. loss: ', val_loss, 'Val. iou: ', val_iou) saver.save(sess, MODEL + 'model.ckpt', global_step=epoch) print('Starting epoch: ', epoch)
def infer(self, image_dir, batch_size, ckpt, output_folder): """ Uses a trained model file to get predictions on the specified images. Args: image_dir: Directory where the images are located. batch_size: Batch size to use while inferring (relevant if batch norm is used) ckpt: Name of the checkpoint file to use. output_folder: Folder where the predictions on the images shoudl be saved. """ print(image_dir) image_paths = [ os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png') or x.endswith('.jpg') ] infer_data, infer_queue_init = utility.data_batch(image_paths, None, batch_size, shuffle=False) image_ph = None mask = None #training = None if not self.restored: image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1]) #training = tf.placeholder(tf.bool, shape=[]) if self.logits == None: self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=tf.AUTO_REUSE) #self.logits = self.model(image_ph, training) if not self.restored: mask = tf.squeeze(tf.argmax(self.logits, axis=3)) j = 0 saver = tf.train.Saver() with tf.Session() as sess: if not self.restored: saver.restore(sess, ckpt) restored = True sess.run(infer_queue_init) for _ in range(len(image_paths) // batch_size): image = sess.run(infer_data) #print(np.max(image), np.min(image), image.shape, image.dtype) feed_dict = { image_ph: image #training: False } prediction = sess.run(mask, feed_dict) """ if len(prediction.shape) >= 3: for j in range(prediction.shape[0]): print(prediction[j].shape) cv2.imwrite(os.path.join(output_folder, '{}.png'.format(j)), 255 * prediction[j, :, :]) else: """ #print(prediction.shape) prediction = binary_fill_holes( resize((255 * prediction[:, :]).astype(np.uint8), (512, 512))) * 255 prediction[:, 0:256] = 0 cv2.imwrite( os.path.join(output_folder, os.path.basename(image_paths[_])), prediction) print( os.path.join(output_folder, os.path.basename(image_paths[_]))) remove_small_objects(prediction, min_size=50, connectivity=1, in_place=True) image = resize((255 * image[0, :, :, 0]).astype(np.uint8), (512, 512)) * 255 mean3_image = convolve2d(image, np.ones((3, 5)) * 1 / 9, mode='same', boundary='symm') mean5_image = convolve2d(image, np.ones((5, 5)) * 1 / 25, mode='same', boundary='symm') segstd = np.std(image[prediction == 255]) segmean = np.mean(image[prediction == 255]) """ Visualization Skip Purpose if np.mean(prediction) < 1: continue """ mean3_segstd = np.std(mean3_image[prediction == 255]) mean3_segmean = np.mean(mean3_image[prediction == 255]) mean5_segstd = np.std(mean5_image[prediction == 255]) mean5_segmean = np.mean(mean5_image[prediction == 255]) print(segstd, segmean) queue_visit = [] visited = copy.deepcopy(prediction) visit_depth = 5 """ # Region Growing 'queue_visit initialization' for x in range(1, 511): for y in range(1, 511): if visited[x, y] == 255: if visited[x + 1, y] == 0: visited[x + 1, y] = 1 queue_visit.append((x + 1, y)) if visited[x, y + 1] == 0: visited[x, y + 1] = 1 queue_visit.append((x, y + 1)) if visited[x - 1, y] == 0: visited[x - 1, y] = 1 queue_visit.append((x - 1, y)) if visited[x, y - 1] == 0: visited[x, y - 1] = 1 queue_visit.append((x, y - 1)) while(len(queue_visit) and visit_depth): x, y = queue_visit.pop(0) # Check the tissue value is eligible if mean5_image[x, y] > mean5_segmean + 1.5 * mean5_segstd or mean5_image[x, y] < mean5_segmean - 1.5 * mean5_segstd: visited[x, y] = -1 else: visited[x, y] = 255 # Traverse neighborhood pixels if visited[x + 1, y] == 0: visited[x + 1, y] = 1 queue_visit.append((x + 1, y)) if visited[x, y + 1] == 0: visited[x, y + 1] = 1 queue_visit.append((x, y + 1)) if visited[x - 1, y] == 0: visited[x - 1, y] = 1 queue_visit.append((x - 1, y)) if visited[x, y - 1] == 0: visited[x, y - 1] = 1 queue_visit.append((x, y - 1)) visited[visited < 0] = 0 prediction = visited """ """ fig = plt.figure() ax_o = fig.add_subplot(133) ax_o.set_title("Original") #im = ax.imshow(prediction, cmap='gray') im_5mean = ax_5mean.imshow(prediction, cmap='gray') im_pred = ax_pred.imshow(prediction, cmap='gray') im_o = ax_o.imshow(image, cmap='gray') overlay_tmp = skimage.color.grey2rgb(prediction, alpha=True) overlay_tmp[:,:,3] = 0.7 im_o_ol = ax_o.imshow(overlay_tmp) axstd = fig.add_axes([0.1,0.08,0.65,0.03], facecolor='lightgoldenrodyellow') sstd = Slider(axstd, 'STD', 0.0, 10.0, valinit=1, valstep=0.05) axtp_factor = fig.add_axes([0.1,0.03,0.65,0.03], facecolor='lightgoldenrodyellow') stp_factor = Slider(axtp_factor, 'TP_FACTOR', 0.0, 50.0, valinit=1, valstep=0.05) def update(val): tmp_img = copy.deepcopy(prediction) tmp_img_mean3 = copy.deepcopy(prediction) tmp_img_mean5 = copy.deepcopy(prediction) tmp_img_mean5_growing = copy.deepcopy(prediction) alpha = sstd.val tp_factor=stp_factor.val min_targ = np.minimum.reduce(np.where(tmp_img > 0)) max_targ = np.maximum.reduce(np.where(tmp_img > 0)) for _ in range(4): for x in range(min_targ[0] - 3, max_targ[0] + 2): for y in range(min_targ[1] - 3, max_targ[1] + 2): if image[x, y] < segmean + alpha * segstd and image[x,y] > segmean - alpha * segstd and tmp_img[x, y] == 0 and (tmp_img[x - 1, y] != 0 or tmp_img[x + 1, y] != 0 or tmp_img[x, y - 1] != 0 or tmp_img[x, y + 1] != 0): print("Find!", _, x, y) tmp_img[x,y] = 255 #im_rg.set_data(tmp_img) tmp_img[image > segmean + alpha * segstd] = 0 tmp_img[image < segmean - alpha * segstd] = 0 #im.set_data(tmp_img) tmp_img_mean3[mean3_image > mean3_segmean + alpha * mean3_segstd] = 0 tmp_img_mean3[mean3_image < mean3_segmean - alpha * mean3_segstd] = 0 #im_3mean.set_data(tmp_img_mean3) tmp_img_mean5[mean5_image > mean5_segmean + alpha * mean5_segstd] = 0 tmp_img_mean5[mean5_image < mean5_segmean - alpha * mean5_segstd] = 0 im_5mean.set_data(tmp_img_mean5) im_o.set_data(image + tmp_img_mean5/tp_factor) sstd.on_changed(update) stp_factor.on_changed(update) plt.show() """ tmp_img_mean5 = copy.deepcopy(prediction) tmp_img_mean5[mean5_image > mean5_segmean + 1.5 * mean5_segstd] = 0 tmp_img_mean5[mean5_image < mean5_segmean - 1.5 * mean5_segstd] = 0 prediction = tmp_img_mean5 cv2.imwrite( os.path.join(output_folder, os.path.basename(image_paths[_])), prediction) j = j + 1
def train(self, train_path, val_path, save_dir, batch_size, epochs, learning_rate): """ Trains the Tiramisu on the specified training data and periodically validates on the validation data. Args: train_path: Directory where the training data is present. val_path: Directory where the validation data is present. save_dir: Directory where to save the model and training summaries. batch_size: Batch size to use for training. epochs: Number of epochs (complete passes over one dataset) to train for. learning_rate: Learning rate for the optimizer. Returns: None """ tf.logging.set_verbosity(tf.logging.INFO) train_image_path = os.path.join(train_path, 'images') train_mask_path = os.path.join(train_path, 'masks_spleen') val_image_path = os.path.join(val_path, 'images') val_mask_path = os.path.join(val_path, 'masks_spleen') assert os.path.exists( train_image_path), "No training image folder found" assert os.path.exists(train_mask_path), "No training mask folder found" assert os.path.exists( val_image_path), "No validation image folder found" assert os.path.exists(val_mask_path), "No validation mask folder found" train_image_paths, train_mask_paths = get_data_paths_list( train_image_path, train_mask_path) val_image_paths, val_mask_paths = get_data_paths_list( val_image_path, val_mask_path) assert len(train_image_paths) == len( train_mask_paths ), "Number of images and masks dont match in train folder" assert len(val_image_paths) == len( val_mask_paths ), "Number of images and masks dont match in validation folder" self.num_train_images = len(train_image_paths) self.num_val_images = len(val_image_paths) print("Loading Data") train_data, train_queue_init = utility.data_batch( train_image_paths, train_mask_paths, batch_size) train_image_tensor, train_mask_tensor = train_data eval_data, eval_queue_init = utility.data_batch( val_image_paths, val_mask_paths, batch_size) eval_image_tensor, eval_mask_tensor = eval_data print("Loading Data Finished") image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1]) mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1]) training = tf.placeholder(tf.bool, shape=[]) if self.logits == None: self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=False) #slef.logits = self.model(image_ph, training) loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph)) with tf.variable_scope("mean_iou_train"): iou, iou_update = self.calculate_iou(mask_ph, self.logits) merged = tf.summary.merge_all() optimizer = tf.train.AdamOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt = optimizer.minimize(loss) running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train") reset_iou = tf.variables_initializer(var_list=running_vars) saver = tf.train.Saver(max_to_keep=30) config = tf.ConfigProto() config.gpu_options.allow_growth = True print(self.num_train_images) with tf.Session(config=config) as sess: print("Initializing Variables") sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) print("Initializing Variables Finished") saver.restore(sess, tf.train.latest_checkpoint(save_dir)) print("Checkpoint restored") for epoch in range(epochs): print("Epoch: ", epoch) writer = tf.summary.FileWriter(os.path.dirname(save_dir), sess.graph) print("Epoch queue init start") sess.run([train_queue_init, eval_queue_init]) print("Epoch queue init ends") total_train_cost, total_val_cost = 0, 0 total_train_iou, total_val_iou = 0, 0 for train_step in range((13 * self.num_train_images) // batch_size - 1): image_batch, mask_batch, _ = sess.run( [train_image_tensor, train_mask_tensor, reset_iou]) #print(np.max(image_batch), np.min(image_batch)) feed_dict = { image_ph: image_batch, mask_ph: mask_batch, training: True } cost, _, _, summary = sess.run( [loss, opt, iou_update, merged], feed_dict=feed_dict) train_iou = sess.run(iou, feed_dict=feed_dict) total_train_cost += cost total_train_iou += train_iou writer.add_summary(summary, train_step) if train_step % 50 == 0: print("Step: ", train_step, "Cost: ", cost, "IoU:", train_iou) for val_step in range(self.num_val_images // batch_size): image_batch, mask_batch, _ = sess.run( [eval_image_tensor, eval_mask_tensor, reset_iou]) feed_dict = { image_ph: image_batch, mask_ph: mask_batch, training: True } eval_cost, _ = sess.run([loss, iou_update], feed_dict=feed_dict) eval_iou = sess.run(iou, feed_dict=feed_dict) total_val_cost += eval_cost total_val_iou += eval_iou print("Epoch: {0}, training loss: {1}, validation loss: {2}". format(epoch, total_train_cost / train_step, total_val_cost / val_step)) print("Epoch: {0}, training iou: {1}, val iou: {2}".format( epoch, total_train_iou / train_step, total_val_iou / val_step)) print("Saving model...") saver.save(sess, save_dir, global_step=epoch)
def train_eval(self, batch_size, growth_k, layers_per_block, epochs, learning_rate=1e-3): """Trains the model on the dataset, and does periodic validations.""" train_data, train_queue_init = utility.data_batch(self.train_image_paths, self.train_mask_paths, batch_size) train_image_tensor, train_mask_tensor = train_data eval_data, eval_queue_init = utility.data_batch(self.eval_image_paths, self.eval_mask_paths, batch_size) eval_image_tensor, eval_mask_tensor = eval_data image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3]) mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1]) training = tf.placeholder(tf.bool, shape=[]) tiramisu = DenseTiramisu(growth_k, layers_per_block, self.num_classes) logits = tiramisu.model(image_ph, training) loss = tf.reduce_mean(tiramisu.xentropy_loss(logits, mask_ph)) with tf.variable_scope("mean_iou_train"): iou, iou_update = tiramisu.calculate_iou(mask_ph, logits) optimizer = tf.train.AdamOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt = optimizer.minimize(loss) running_vars = tf.get_collection( tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train") reset_iou = tf.variables_initializer(var_list=running_vars) saver = tf.train.Saver(max_to_keep=20) with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) for epoch in range(epochs): writer = tf.summary.FileWriter(self.model_save_dir, sess.graph) sess.run([train_queue_init, eval_queue_init]) total_train_cost, total_eval_cost = 0, 0 total_train_iou, total_eval_iou = 0, 0 for train_step in range(self.num_train_images // batch_size): image_batch, mask_batch, _ = sess.run([train_image_tensor, train_mask_tensor, reset_iou]) #print("Mask batch shape:", mask_batch.shape) feed_dict = {image_ph: image_batch, mask_ph: mask_batch, training: True} cost, _, _ = sess.run([loss, opt, iou_update], feed_dict=feed_dict) train_iou = sess.run(iou, feed_dict=feed_dict) total_train_cost += cost total_train_iou += train_iou if train_step % 50 == 0: print("Step: ", train_step, "Cost: ", cost, "IoU:", train_iou) for eval_step in range(self.num_eval_images // batch_size): image_batch, mask_batch, _ = sess.run([eval_image_tensor, eval_mask_tensor, reset_iou]) feed_dict = {image_ph: image_batch, mask_ph: mask_batch, training: True} eval_cost, _ = sess.run([loss, iou_update], feed_dict=feed_dict) eval_iou = sess.run(iou, feed_dict=feed_dict) total_eval_cost += eval_cost total_eval_iou += eval_iou print("Epoch: ", epoch, "train loss: ", total_train_cost / train_step, "eval loss: ", total_eval_cost) print("Epoch: ", epoch, "train eval: ", total_train_iou / train_step, "eval iou: ", total_eval_iou) print("Saving model...") saver.save(sess, self.model_save_dir, global_step=epoch)
def train(self, train_path, val_path, save_dir, batch_size, epochs, learning_rate, prior_model): """ Trains the Tiramisu on the specified training data and periodically validates on the validation data. Args: train_path: Directory where the training data is present. val_path: Directory where the validation data is present. save_dir: Directory where to save the model and training summaries. batch_size: Batch size to use for training. epochs: Number of epochs (complete passes over one dataset) to train for. learning_rate: Learning rate for the optimizer. Returns: None """ train_image_path = os.path.join(train_path, 'images') train_mask_path = os.path.join(train_path, 'masks') val_image_path = os.path.join(val_path, 'images') val_mask_path = os.path.join(val_path, 'masks') assert os.path.exists( train_image_path), "No training image folder found" assert os.path.exists(train_mask_path), "No training mask folder found" assert os.path.exists( val_image_path), "No validation image folder found" assert os.path.exists(val_mask_path), "No validation mask folder found" train_image_paths, train_mask_paths = get_data_paths_list( train_image_path, train_mask_path) val_image_paths, val_mask_paths = get_data_paths_list( val_image_path, val_mask_path) assert len(train_image_paths) == len( train_mask_paths ), "Number of images and masks dont match in train folder" assert len(val_image_paths) == len( val_mask_paths ), "Number of images and masks dont match in validation folder" self.num_train_images = len(train_image_paths) self.num_val_images = len(val_image_paths) train_data, train_queue_init = utility.data_batch( train_image_paths, train_mask_paths, batch_size) train_image_tensor, train_mask_tensor = train_data eval_data, eval_queue_init = utility.data_batch( val_image_paths, val_mask_paths, batch_size) eval_image_tensor, eval_mask_tensor = eval_data image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3]) mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1]) training = tf.placeholder(tf.bool, shape=[]) if not self.logits: self.logits = self.model(image_ph, training) loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph)) with tf.variable_scope("mean_iou_train"): iou, iou_update = self.calculate_iou(mask_ph, self.logits) optimizer = tf.train.AdamOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt = optimizer.minimize(loss) running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train") reset_iou = tf.variables_initializer(var_list=running_vars) with tf.Session() as sess: saver = tf.train.Saver(max_to_keep=20) if prior_model != "": saver.restore(sess, prior_model) sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) for epoch in range(epochs): writer = tf.summary.FileWriter(os.path.dirname(save_dir), sess.graph) sess.run([train_queue_init, eval_queue_init]) total_train_cost, total_val_cost = 0, 0 total_train_iou, total_val_iou = 0, 0 for train_step in range(self.num_train_images // batch_size): image_batch, mask_batch, _ = sess.run( [train_image_tensor, train_mask_tensor, reset_iou]) feed_dict = { image_ph: image_batch, mask_ph: mask_batch, training: True } cost, _, _ = sess.run([loss, opt, iou_update], feed_dict=feed_dict) train_iou = sess.run(iou, feed_dict=feed_dict) total_train_cost += cost total_train_iou += train_iou if train_step % 50 == 0: print("Step: ", train_step, "Cost: ", cost, "IoU:", train_iou) for val_step in range(self.num_val_images // batch_size): image_batch, mask_batch, _ = sess.run( [eval_image_tensor, eval_mask_tensor, reset_iou]) feed_dict = { image_ph: image_batch, mask_ph: mask_batch, training: True } eval_cost, _ = sess.run([loss, iou_update], feed_dict=feed_dict) eval_iou = sess.run(iou, feed_dict=feed_dict) total_val_cost += eval_cost total_val_iou += eval_iou print("Epoch: {0}, training loss: {1}, validation loss: {2}". format(epoch, total_train_cost / train_step, total_val_cost / val_step)) print("Epoch: {0}, training iou: {1}, val iou: {2}".format( epoch, total_train_iou / train_step, total_val_iou / val_step)) print("Saving model...") saver.save(sess, save_dir, global_step=epoch)
def train(self, train_path, val_path, save_dir, batch_size, epochs, learning_rate,learning_policy): """ Trains the Tiramisu on the specified training data and periodically validates on the validation data. Args: train_path: Directory where the training data is present. val_path: Directory where the validation data is present. save_dir: Directory where to save the model and training summaries. batch_size: Batch size to use for training. epochs: Number of epochs (complete passes over one dataset) to train for. learning_rate: Learning rate for the optimizer. Returns: None """ train_image_path = os.path.join(train_path, 'images') train_mask_path = os.path.join(train_path, 'masks') val_image_path = os.path.join(val_path, 'images') val_mask_path = os.path.join(val_path, 'masks') assert os.path.exists(train_image_path), "No training image folder found" assert os.path.exists(train_mask_path), "No training mask folder found" assert os.path.exists(val_image_path), "No validation image folder found" assert os.path.exists(val_mask_path), "No validation mask folder found" train_image_paths, train_mask_paths = get_data_paths_list(train_image_path, train_mask_path) val_image_paths, val_mask_paths = get_data_paths_list(val_image_path, val_mask_path) assert len(train_image_paths) == len(train_mask_paths), "Number of images and masks dont match in train folder" assert len(val_image_paths) == len(val_mask_paths), "Number of images and masks dont match in validation folder" self.num_train_images = len(train_image_paths) self.num_val_images = len(val_image_paths) train_data, train_queue_init = utility.data_batch( train_image_paths, train_mask_paths, batch_size,train_flag=True) train_image_tensor, train_mask_tensor = train_data eval_data, eval_queue_init = utility.data_batch( val_image_paths, val_mask_paths, batch_size,train_flag=True) eval_image_tensor, eval_mask_tensor = eval_data image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3]) mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1]) training = tf.placeholder(tf.bool, shape=[]) global_step = tf.Variable(0, trainable=False) if not self.logits: self.logits = self.model(image_ph, training) regularizer = tf.contrib.layers.l2_regularizer(scale=0.001) #reg_term = tf.contrib.layers.apply_regularization(regularizer) loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph)) with tf.variable_scope("mean_iou_train",regularizer=regularizer): #with tf.variable_scope("mean_iou_train"): iou, iou_update = self.calculate_iou(mask_ph, self.logits) initial_learning_rate =learning_rate; if learning_policy == 'step': learning_rate = tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=10,decay_rate=0.9) elif learning_policy == 'poly': learning_rate =tf.train.polynomial_decay(initial_learning_rate,global_step=global_step, decay_steps=10,end_learning_rate=0.0001,power=1.0, cycle=False) else: learning_rate = initial_learning_rate #RMSProp #optimizer = tf.train.RMSPropOptimizer(0.001, 0.9) #SGD optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) #Momentum #optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) #AdaGrad #optimizer = tf.train.AdamOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt = optimizer.minimize(loss) running_vars = tf.get_collection( tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train") reset_iou = tf.variables_initializer(var_list=running_vars) saver = tf.train.Saver(max_to_keep=20) with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) total_train_losses =[] total_train_ious = [] total_val_losses = [] total_val_ious = [] steps = [] for epoch in range(epochs): writer = tf.summary.FileWriter(os.path.dirname(save_dir), sess.graph) sess.run([train_queue_init, eval_queue_init]) total_train_cost, total_val_cost = 0, 0 total_train_iou, total_val_iou = 0, 0 for train_step in range(self.num_train_images // batch_size): image_batch, mask_batch, _ = sess.run( [train_image_tensor, train_mask_tensor, reset_iou]) feed_dict = {image_ph: image_batch, mask_ph: mask_batch, training: True} cost, _, _ = sess.run( [loss, opt, iou_update], feed_dict=feed_dict) train_iou = sess.run(iou, feed_dict=feed_dict) total_train_cost += cost total_train_iou += train_iou if train_step % 50 == 0: print("Step: ", train_step, "Cost: ", cost, "IoU:", train_iou) for val_step in range(self.num_val_images // batch_size): image_batch, mask_batch, _ = sess.run( [eval_image_tensor, eval_mask_tensor, reset_iou]) feed_dict = {image_ph: image_batch, mask_ph: mask_batch, training: True} eval_cost, _ = sess.run( [loss, iou_update], feed_dict=feed_dict) eval_iou = sess.run(iou, feed_dict=feed_dict) total_val_cost += eval_cost total_val_iou += eval_iou print("Epoch: {0}, training loss: {1}, validation loss: {2}".format(epoch, total_train_cost / train_step, total_val_cost / val_step)) print("Epoch: {0}, training iou: {1}, val iou: {2}".format(epoch, total_train_iou / train_step, total_val_iou / val_step)) total_train_losses.append(total_train_cost / train_step) total_val_losses.append(total_val_cost / val_step) total_train_ious.append(total_train_iou / train_step) total_val_ious.append(total_val_iou / val_step) steps.append(epoch) print("Saving model...") saver.save(sess, save_dir, global_step=epoch) plot(steps,total_train_losses,color='r',label='train_loss') plot(steps,total_val_losses,color='g',label='val_loss') plot(steps,total_train_ious,color='k',label='train_iou') plot(steps,total_val_ious,color = 'b',label='val_iou') xlabel('epoch') ylabel('value') title('Loss-Iou') legend(loc='best') savefig('./result.jpg')
def infer(self, image_dir, batch_size, ckpt, output_folder, organ): """ Uses a trained model file to get predictions on the specified images. Args: image_dir: Directory where the images are located. batch_size: Batch size to use while inferring (relevant if batch norm is used) ckpt: Name of the checkpoint file to use. output_folder: Folder where the predictions on the images shoudl be saved. """ print(image_dir) cvt_dcm_png(image_dir, "CT") image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png') or x.endswith('.jpg')] infer_data, infer_queue_init = utility.data_batch( image_paths, None, batch_size, shuffle=False) image_ph = None mask = None #training = None if not self.restored: image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1]) #training = tf.placeholder(tf.bool, shape=[]) if self.logits == None: self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=tf.AUTO_REUSE) #self.logits = self.model(image_ph, training) if not self.restored: mask = tf.squeeze(tf.argmax(self.logits, axis=3)) j = 0 saver = tf.train.Saver() with tf.Session() as sess: if not self.restored: saver.restore(sess, ckpt) restored = True sess.run(infer_queue_init) for _ in range(len(image_paths) // batch_size): image = sess.run(infer_data) #print(np.max(image), np.min(image), image.shape, image.dtype) feed_dict = { image_ph: image #training: False } prediction = sess.run(mask, feed_dict) prediction = remove_small_objects(prediction.astype(np.bool), min_size=450, connectivity=2) * 255 #cv2.imwrite(os.path.join(output_folder, "raw_" + os.path.basename(image_paths[_])), prediction) prediction = binary_fill_holes(resize((255 * prediction[:, :]).astype(np.uint8), (512,512)))* 255 if organ=="spleen": prediction[:, 0:250] = 0 print(os.path.join(output_folder, os.path.basename(image_paths[_]))) image = resize((255 * image[0,:, :,0]).astype(np.uint8), (512,512))* 255 mean3_image = convolve2d(image, np.ones((3,3)) * 1/9, mode='same', boundary='symm') mean5_image = convolve2d(image, np.ones((5,5)) * 1/25, mode='same', boundary='symm') segstd = np.std(image[prediction == 255]) segmean = np.mean(image[prediction == 255]) """ Visualization Skip Purpose if np.mean(prediction) < 1: continue """ mean3_segstd = np.std(mean3_image[prediction == 255]) mean3_segmean = np.mean(mean3_image[prediction == 255]) mean5_segstd = np.std(mean5_image[prediction == 255]) mean5_segmean = np.mean(mean5_image[prediction == 255]) print(segstd, segmean) segment = 1800 compact = 100 std = 1.2 # 1.2 std_rm = 1.85 #1.85 std_slic = 1.0 # 1.0 queue_visit = [] visited = copy.deepcopy(prediction) visit_depth = 150 image_slic = slic(image, n_segments=segment, compactness=compact, sigma=1) liver_slic = defaultdict(lambda: 0) # Region Growing 'queue_visit initialization' for x in range(1, 510): for y in range(1, 511): if visited[x, y] == 255: liver_slic[image_slic[x, y]] += 1 if visited[x + 1, y] == 0: visited[x + 1, y] = 1 queue_visit.append((x + 1, y)) if visited[x, y + 1] == 0: visited[x, y + 1] = 1 queue_visit.append((x, y + 1)) if visited[x - 1, y] == 0: visited[x - 1, y] = 1 queue_visit.append((x - 1, y)) if visited[x, y - 1] == 0: visited[x, y - 1] = 1 queue_visit.append((x, y - 1)) while(len(queue_visit)): if visited[x, y] == visit_depth: break x, y = queue_visit.pop(0) #print("Target Pixel", (x, y), image[x, y], liver_slic[image_slic[x, y]]) # Check the tissue value is eligible and liver_slic[image_slic[x,y]] < 300 m = np.mean(image[image_slic == image_slic[x,y]]) if image[x, y] > segmean + std * segstd or image[x, y] < segmean - std * segstd or m < segmean: visited[x, y] = 1 else: # Traverse neighborhood pixels if x + 1 < 512 and visited[x + 1, y] == 0: visited[x + 1, y] = visited[x, y] + 1 queue_visit.append((x + 1, y)) if y + 1 < 512 and visited[x, y + 1] == 0: visited[x, y + 1] = visited[x, y] + 1 queue_visit.append((x, y + 1)) if x - 1 > 0 and visited[x - 1, y] == 0: visited[x - 1, y] = visited[x, y] + 1 queue_visit.append((x - 1, y)) if y - 1 > 0 and visited[x, y - 1] == 0: visited[x, y - 1] = visited[x, y] + 1 queue_visit.append((x, y - 1)) visited[x, y] = 255 tmp_img_mean5 = copy.deepcopy(visited) tmp_img_mean5[mean5_image > mean5_segmean + std_rm * mean5_segstd] = 0 tmp_img_mean5[mean5_image < mean5_segmean - std_rm * mean5_segstd] = 0 tmp_img_mean5 = binary_opening(tmp_img_mean5, selem=disk(5)).astype(np.uint8) * 255 remove_small_objects(tmp_img_mean5, min_size=400, connectivity=1, in_place=True) cv2.imwrite(os.path.join(output_folder, os.path.basename(image_paths[_])), tmp_img_mean5) j = j + 1