Ejemplo n.º 1
0
    def infer(self, image_dir, batch_size, ckpt, output_folder):
        """
        Uses a trained model file to get predictions on the specified images.
        Args:
            image_dir: Directory where the images are located.
            batch_size: Batch size to use while inferring (relevant if batch norm is used)
            ckpt: Name of the checkpoint file to use.
            output_folder: Folder where the predictions on the images shoudl be saved.
        """
        image_paths = [
            os.path.join(image_dir, x) for x in os.listdir(image_dir)
            if x.endswith('.png') or x.endswith('.jpg')
        ]
        infer_data, infer_queue_init = utility.data_batch(
            image_paths, None, batch_size)
        image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
        training = tf.placeholder(tf.bool, shape=[])

        if not self.logits:
            self.logits = self.model(image_ph, training)

        mask = tf.squeeze(tf.argmax(self.logits, axis=3))

        saver = tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess, ckpt)
            sess.run(infer_queue_init)
            for _ in range(len(image_paths) // batch_size):
                image = sess.run(infer_data)
                feed_dict = {image_ph: image, training: True}
                prediction = sess.run(mask, feed_dict)
                for j in range(prediction.shape[0]):
                    cv2.imwrite(
                        os.path.join(output_folder, '{}.png'.format(j)),
                        255 * prediction[j, :, :])
Ejemplo n.º 2
0
    def infer(self, image_paths, batch_size):
        infer_data, infer_queue_init = utility.data_batch(
            image_paths, None, batch_size)
        image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
        training = tf.placeholder(tf.bool, shape=[])
        tiramisu = DenseTiramisu(16, [2, 3, 3], 2)
        logits = tiramisu.model(image_ph, training)
        mask = tf.squeeze(tf.argmax(logits, axis=3))

        saver = tf.train.Saver()
        with tf.Session() as sess:
            saver.restore(sess, 'trained_tiramisu/model.ckpt-18')
            sess.run(infer_queue_init)
            for i in range(1):
                image = sess.run(infer_data)
                feed_dict = {image_ph: image, training: True}
                prediction = sess.run(mask, feed_dict)
                for j in range(prediction.shape[0]):
                    cv2.imwrite('predictions/' + str(j) + '.png',
                                255 * prediction[j, :, :])
Ejemplo n.º 3
0
def train(train_image_paths,
          train_mask_paths,
          val_image_path,
          val_mask_path,
          lr=1e-4):
    tf.reset_default_graph()
    with tf.Graph().as_default():
        data, init_op = utility.data_batch(train_image_paths,
                                           train_mask_paths,
                                           augment=True,
                                           batch_size=BATCH_SIZE)

        val_data, val_init_op = utility.data_batch(val_image_path,
                                                   val_mask_path,
                                                   augment=False,
                                                   batch_size=3)

        image_tensor, mask_tensor = data
        val_image_tensor, val_mask_tensor = val_data

        image_placeholder = tf.placeholder(tf.float32, shape=[None, H, W, 3])
        mask_placeholder = tf.placeholder(tf.int32, shape=[None, H, W, 1])

        training_flag = tf.placeholder(tf.bool)

        logits = network(image_placeholder, training_flag)
        cost = xentropy_loss(mask_placeholder, logits)

        optimizer = tf.train.AdamOptimizer(learning_rate=lr)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train = optimizer.minimize(cost)

        iou_metric, iou_update = calculate_iou(mask_placeholder, logits,
                                               "iou_metric", BATCH_SIZE,
                                               NUM_CLASSES)

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="iou_metric")

        running_vars_init = tf.variables_initializer(var_list=running_vars)

        seg_image = tf.argmax(logits, axis=3)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            print('Training with learning rate: ', lr)
            if FINETUNE:
                saver.restore(sess, tf.train.latest_checkpoint(MODEL))
            else:
                sess.run(tf.global_variables_initializer())
            for epoch in range(1, EPOCHS + 1):
                try:

                    total_loss, total_iou = 0, 0
                    for step in range(0, TRAIN_ITERATIONS):
                        print('Step:', step)
                        sess.run([init_op, val_init_op])
                        sess.run(running_vars_init)
                        train_image, train_mask = sess.run(
                            [image_tensor, mask_tensor])
                        train_feed_dict = {
                            image_placeholder: train_image,
                            mask_placeholder: train_mask,
                            training_flag: True
                        }

                        _, loss, update_iou, pred_mask = sess.run(
                            [train, cost, iou_update, seg_image],
                            feed_dict=train_feed_dict)

                        iou = sess.run(iou_metric)
                        total_loss += loss
                        total_iou += iou
                except tf.errors.OutOfRangeError:
                    pass
                finally:
                    sess.run(running_vars_init)
                    val_iou, val_loss = 0, 0
                    val_image, val_mask = sess.run(
                        [val_image_tensor, val_mask_tensor])
                    val_feed_dict = {
                        image_placeholder: val_image,
                        mask_placeholder: val_mask,
                        training_flag: False
                    }
                    update_iou, pred_mask, val_loss = sess.run(
                        [iou_update, seg_image, cost], feed_dict=val_feed_dict)
                    val_iou = sess.run(iou_metric)

                    print('Train loss: ', total_loss / TRAIN_ITERATIONS,
                          'Train iou: ', total_iou / TRAIN_ITERATIONS)
                    print('Val. loss: ', val_loss, 'Val. iou: ', val_iou)
                    saver.save(sess, MODEL + 'model.ckpt', global_step=epoch)
                    print('Starting epoch: ', epoch)
Ejemplo n.º 4
0
    def infer(self, image_dir, batch_size, ckpt, output_folder):
        """
        Uses a trained model file to get predictions on the specified images.
        Args:
            image_dir: Directory where the images are located.
            batch_size: Batch size to use while inferring (relevant if batch norm is used)
            ckpt: Name of the checkpoint file to use.
            output_folder: Folder where the predictions on the images shoudl be saved.
        """
        print(image_dir)
        image_paths = [
            os.path.join(image_dir, x) for x in os.listdir(image_dir)
            if x.endswith('.png') or x.endswith('.jpg')
        ]
        infer_data, infer_queue_init = utility.data_batch(image_paths,
                                                          None,
                                                          batch_size,
                                                          shuffle=False)

        image_ph = None
        mask = None
        #training = None
        if not self.restored:
            image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1])
            #training = tf.placeholder(tf.bool, shape=[])

        if self.logits == None:
            self.logits = network.deeplab_v3(image_ph,
                                             is_training=True,
                                             reuse=tf.AUTO_REUSE)
            #self.logits = self.model(image_ph, training)
        if not self.restored:
            mask = tf.squeeze(tf.argmax(self.logits, axis=3))
        j = 0
        saver = tf.train.Saver()
        with tf.Session() as sess:
            if not self.restored:
                saver.restore(sess, ckpt)
                restored = True
            sess.run(infer_queue_init)
            for _ in range(len(image_paths) // batch_size):
                image = sess.run(infer_data)
                #print(np.max(image), np.min(image), image.shape, image.dtype)

                feed_dict = {
                    image_ph: image
                    #training: False
                }
                prediction = sess.run(mask, feed_dict)
                """
                if len(prediction.shape) >= 3:
                    for j in range(prediction.shape[0]):
                        print(prediction[j].shape)
                        cv2.imwrite(os.path.join(output_folder, '{}.png'.format(j)), 255 * prediction[j, :, :])
                else:
                """
                #print(prediction.shape)
                prediction = binary_fill_holes(
                    resize((255 * prediction[:, :]).astype(np.uint8),
                           (512, 512))) * 255
                prediction[:, 0:256] = 0
                cv2.imwrite(
                    os.path.join(output_folder,
                                 os.path.basename(image_paths[_])), prediction)
                print(
                    os.path.join(output_folder,
                                 os.path.basename(image_paths[_])))

                remove_small_objects(prediction,
                                     min_size=50,
                                     connectivity=1,
                                     in_place=True)

                image = resize((255 * image[0, :, :, 0]).astype(np.uint8),
                               (512, 512)) * 255

                mean3_image = convolve2d(image,
                                         np.ones((3, 5)) * 1 / 9,
                                         mode='same',
                                         boundary='symm')
                mean5_image = convolve2d(image,
                                         np.ones((5, 5)) * 1 / 25,
                                         mode='same',
                                         boundary='symm')

                segstd = np.std(image[prediction == 255])
                segmean = np.mean(image[prediction == 255])
                """
                Visualization Skip Purpose

                if np.mean(prediction) < 1:
                    continue
                """

                mean3_segstd = np.std(mean3_image[prediction == 255])
                mean3_segmean = np.mean(mean3_image[prediction == 255])
                mean5_segstd = np.std(mean5_image[prediction == 255])
                mean5_segmean = np.mean(mean5_image[prediction == 255])
                print(segstd, segmean)

                queue_visit = []
                visited = copy.deepcopy(prediction)
                visit_depth = 5
                """
                # Region Growing 'queue_visit initialization' 
                for x in range(1, 511):
                    for y in range(1, 511):
                        if visited[x, y] == 255:
                            if visited[x + 1, y] == 0:
                                visited[x + 1, y] = 1
                                queue_visit.append((x + 1, y))
                            if visited[x, y + 1] == 0:
                                visited[x, y + 1] = 1
                                queue_visit.append((x, y + 1))
                            if visited[x - 1, y] == 0:
                                visited[x - 1, y] = 1
                                queue_visit.append((x - 1, y))
                            if visited[x, y - 1] == 0:
                                visited[x, y - 1] = 1
                                queue_visit.append((x, y - 1))

                while(len(queue_visit) and visit_depth):
                    x, y = queue_visit.pop(0)
                    
                    # Check the tissue value is eligible
                    if mean5_image[x, y] > mean5_segmean + 1.5 * mean5_segstd or mean5_image[x, y] < mean5_segmean - 1.5 * mean5_segstd:
                        visited[x, y] = -1
                    else:
                        visited[x, y] = 255
                        # Traverse neighborhood pixels
                        if visited[x + 1, y] == 0:
                            visited[x + 1, y] = 1
                            queue_visit.append((x + 1, y))
                        if visited[x, y + 1] == 0:
                            visited[x, y + 1] = 1
                            queue_visit.append((x, y + 1))
                        if visited[x - 1, y] == 0:
                            visited[x - 1, y] = 1
                            queue_visit.append((x - 1, y))
                        if visited[x, y - 1] == 0:
                            visited[x, y - 1] = 1
                            queue_visit.append((x, y - 1))

                visited[visited < 0] = 0
                prediction = visited                       
                """
                """
                fig = plt.figure()
                

                ax_o = fig.add_subplot(133)
                ax_o.set_title("Original")

                #im = ax.imshow(prediction, cmap='gray')

                

                im_5mean = ax_5mean.imshow(prediction, cmap='gray')
               
                im_pred = ax_pred.imshow(prediction, cmap='gray')

                im_o = ax_o.imshow(image, cmap='gray')

                overlay_tmp = skimage.color.grey2rgb(prediction, alpha=True)
                overlay_tmp[:,:,3] = 0.7
                im_o_ol = ax_o.imshow(overlay_tmp)
                axstd = fig.add_axes([0.1,0.08,0.65,0.03], facecolor='lightgoldenrodyellow')
                sstd = Slider(axstd, 'STD', 0.0, 10.0, valinit=1, valstep=0.05)
                axtp_factor = fig.add_axes([0.1,0.03,0.65,0.03], facecolor='lightgoldenrodyellow')
                stp_factor = Slider(axtp_factor, 'TP_FACTOR', 0.0, 50.0, valinit=1, valstep=0.05)
                def update(val):
                    tmp_img = copy.deepcopy(prediction)
                    tmp_img_mean3 = copy.deepcopy(prediction)
                    tmp_img_mean5 = copy.deepcopy(prediction)
                    tmp_img_mean5_growing = copy.deepcopy(prediction)
                    alpha = sstd.val
                    tp_factor=stp_factor.val
                    min_targ = np.minimum.reduce(np.where(tmp_img > 0))
                    max_targ = np.maximum.reduce(np.where(tmp_img > 0))

                    for _ in range(4):
                        for x in range(min_targ[0] - 3, max_targ[0] + 2):
                            for y in range(min_targ[1] - 3, max_targ[1] + 2):
                                if image[x, y] < segmean + alpha * segstd and image[x,y] > segmean - alpha * segstd  and tmp_img[x, y] == 0 and (tmp_img[x - 1, y] != 0 or tmp_img[x + 1, y] != 0 or tmp_img[x, y - 1] != 0 or tmp_img[x, y + 1] != 0):
                                    print("Find!", _, x, y)
                                    tmp_img[x,y] = 255

                    #im_rg.set_data(tmp_img)

                    tmp_img[image > segmean + alpha * segstd] = 0
                    tmp_img[image < segmean - alpha * segstd] = 0
                    
                    #im.set_data(tmp_img)

                    tmp_img_mean3[mean3_image > mean3_segmean + alpha * mean3_segstd] = 0
                    tmp_img_mean3[mean3_image < mean3_segmean - alpha * mean3_segstd] = 0
                    #im_3mean.set_data(tmp_img_mean3)

                    tmp_img_mean5[mean5_image > mean5_segmean + alpha * mean5_segstd] = 0
                    tmp_img_mean5[mean5_image < mean5_segmean - alpha * mean5_segstd] = 0
                    im_5mean.set_data(tmp_img_mean5)
                    
                    im_o.set_data(image + tmp_img_mean5/tp_factor)


                
                sstd.on_changed(update)
                stp_factor.on_changed(update)
                plt.show()
                """

                tmp_img_mean5 = copy.deepcopy(prediction)
                tmp_img_mean5[mean5_image > mean5_segmean +
                              1.5 * mean5_segstd] = 0
                tmp_img_mean5[mean5_image < mean5_segmean -
                              1.5 * mean5_segstd] = 0
                prediction = tmp_img_mean5

                cv2.imwrite(
                    os.path.join(output_folder,
                                 os.path.basename(image_paths[_])), prediction)
                j = j + 1
Ejemplo n.º 5
0
    def train(self, train_path, val_path, save_dir, batch_size, epochs,
              learning_rate):
        """
        Trains the Tiramisu on the specified training data and periodically validates
        on the validation data.

        Args:
            train_path: Directory where the training data is present.
            val_path: Directory where the validation data is present.
            save_dir: Directory where to save the model and training summaries.
            batch_size: Batch size to use for training.
            epochs: Number of epochs (complete passes over one dataset) to train for.
            learning_rate: Learning rate for the optimizer.
        Returns:
            None
        """
        tf.logging.set_verbosity(tf.logging.INFO)
        train_image_path = os.path.join(train_path, 'images')
        train_mask_path = os.path.join(train_path, 'masks_spleen')
        val_image_path = os.path.join(val_path, 'images')
        val_mask_path = os.path.join(val_path, 'masks_spleen')

        assert os.path.exists(
            train_image_path), "No training image folder found"
        assert os.path.exists(train_mask_path), "No training mask folder found"
        assert os.path.exists(
            val_image_path), "No validation image folder found"
        assert os.path.exists(val_mask_path), "No validation mask folder found"

        train_image_paths, train_mask_paths = get_data_paths_list(
            train_image_path, train_mask_path)
        val_image_paths, val_mask_paths = get_data_paths_list(
            val_image_path, val_mask_path)

        assert len(train_image_paths) == len(
            train_mask_paths
        ), "Number of images and masks dont match in train folder"
        assert len(val_image_paths) == len(
            val_mask_paths
        ), "Number of images and masks dont match in validation folder"

        self.num_train_images = len(train_image_paths)
        self.num_val_images = len(val_image_paths)

        print("Loading Data")
        train_data, train_queue_init = utility.data_batch(
            train_image_paths, train_mask_paths, batch_size)
        train_image_tensor, train_mask_tensor = train_data

        eval_data, eval_queue_init = utility.data_batch(
            val_image_paths, val_mask_paths, batch_size)
        eval_image_tensor, eval_mask_tensor = eval_data
        print("Loading Data Finished")
        image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1])
        mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1])
        training = tf.placeholder(tf.bool, shape=[])

        if self.logits == None:
            self.logits = network.deeplab_v3(image_ph,
                                             is_training=True,
                                             reuse=False)
            #slef.logits = self.model(image_ph, training)

        loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph))

        with tf.variable_scope("mean_iou_train"):
            iou, iou_update = self.calculate_iou(mask_ph, self.logits)
        merged = tf.summary.merge_all()
        optimizer = tf.train.AdamOptimizer(learning_rate)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = optimizer.minimize(loss)

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="mean_iou_train")

        reset_iou = tf.variables_initializer(var_list=running_vars)

        saver = tf.train.Saver(max_to_keep=30)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        print(self.num_train_images)
        with tf.Session(config=config) as sess:
            print("Initializing Variables")
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            print("Initializing Variables Finished")
            saver.restore(sess, tf.train.latest_checkpoint(save_dir))
            print("Checkpoint restored")
            for epoch in range(epochs):
                print("Epoch: ", epoch)

                writer = tf.summary.FileWriter(os.path.dirname(save_dir),
                                               sess.graph)

                print("Epoch queue init start")
                sess.run([train_queue_init, eval_queue_init])
                print("Epoch queue init ends")

                total_train_cost, total_val_cost = 0, 0
                total_train_iou, total_val_iou = 0, 0
                for train_step in range((13 * self.num_train_images) //
                                        batch_size - 1):

                    image_batch, mask_batch, _ = sess.run(
                        [train_image_tensor, train_mask_tensor, reset_iou])
                    #print(np.max(image_batch), np.min(image_batch))
                    feed_dict = {
                        image_ph: image_batch,
                        mask_ph: mask_batch,
                        training: True
                    }

                    cost, _, _, summary = sess.run(
                        [loss, opt, iou_update, merged], feed_dict=feed_dict)
                    train_iou = sess.run(iou, feed_dict=feed_dict)
                    total_train_cost += cost
                    total_train_iou += train_iou

                    writer.add_summary(summary, train_step)
                    if train_step % 50 == 0:
                        print("Step: ", train_step, "Cost: ", cost, "IoU:",
                              train_iou)

                for val_step in range(self.num_val_images // batch_size):
                    image_batch, mask_batch, _ = sess.run(
                        [eval_image_tensor, eval_mask_tensor, reset_iou])
                    feed_dict = {
                        image_ph: image_batch,
                        mask_ph: mask_batch,
                        training: True
                    }
                    eval_cost, _ = sess.run([loss, iou_update],
                                            feed_dict=feed_dict)
                    eval_iou = sess.run(iou, feed_dict=feed_dict)
                    total_val_cost += eval_cost
                    total_val_iou += eval_iou

                print("Epoch: {0}, training loss: {1}, validation loss: {2}".
                      format(epoch, total_train_cost / train_step,
                             total_val_cost / val_step))
                print("Epoch: {0}, training iou: {1}, val iou: {2}".format(
                    epoch, total_train_iou / train_step,
                    total_val_iou / val_step))

                print("Saving model...")
                saver.save(sess, save_dir, global_step=epoch)
Ejemplo n.º 6
0
    def train_eval(self, batch_size, growth_k, layers_per_block, epochs, learning_rate=1e-3):
        """Trains the model on the dataset, and does periodic validations."""
        train_data, train_queue_init = utility.data_batch(self.train_image_paths, self.train_mask_paths, batch_size)
        train_image_tensor, train_mask_tensor = train_data

        eval_data, eval_queue_init = utility.data_batch(self.eval_image_paths, self.eval_mask_paths, batch_size)
        eval_image_tensor, eval_mask_tensor = eval_data

        image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
        mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1])
        training = tf.placeholder(tf.bool, shape=[])

        tiramisu = DenseTiramisu(growth_k, layers_per_block, self.num_classes)

        logits = tiramisu.model(image_ph, training)

        loss = tf.reduce_mean(tiramisu.xentropy_loss(logits, mask_ph))

        with tf.variable_scope("mean_iou_train"):
            iou, iou_update = tiramisu.calculate_iou(mask_ph, logits)

        optimizer = tf.train.AdamOptimizer(learning_rate)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = optimizer.minimize(loss)

        running_vars = tf.get_collection(
            tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train")

        reset_iou = tf.variables_initializer(var_list=running_vars)

        saver = tf.train.Saver(max_to_keep=20)

        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            for epoch in range(epochs):
                writer = tf.summary.FileWriter(self.model_save_dir, sess.graph)
                sess.run([train_queue_init, eval_queue_init])
                total_train_cost, total_eval_cost = 0, 0
                total_train_iou, total_eval_iou = 0, 0
                for train_step in range(self.num_train_images // batch_size):
                    image_batch, mask_batch, _ = sess.run([train_image_tensor, train_mask_tensor, reset_iou])
                    #print("Mask batch shape:", mask_batch.shape)
                    feed_dict = {image_ph: image_batch,
                                 mask_ph: mask_batch,
                                 training: True}
                    cost, _, _ = sess.run([loss, opt, iou_update], feed_dict=feed_dict)
                    train_iou = sess.run(iou, feed_dict=feed_dict)
                    total_train_cost += cost
                    total_train_iou += train_iou
                    if train_step % 50 == 0:
                        print("Step: ", train_step, "Cost: ", cost, "IoU:", train_iou)

                for eval_step in range(self.num_eval_images // batch_size):
                    image_batch, mask_batch, _ = sess.run([eval_image_tensor, eval_mask_tensor, reset_iou])
                    feed_dict = {image_ph: image_batch,
                                 mask_ph: mask_batch,
                                 training: True}
                    eval_cost, _ = sess.run([loss, iou_update], feed_dict=feed_dict)
                    eval_iou = sess.run(iou, feed_dict=feed_dict)
                    total_eval_cost += eval_cost
                    total_eval_iou += eval_iou

                print("Epoch: ", epoch, "train loss: ", total_train_cost / train_step, "eval loss: ",
                      total_eval_cost)
                print("Epoch: ", epoch, "train eval: ", total_train_iou / train_step, "eval iou: ",
                      total_eval_iou)

                print("Saving model...")
                saver.save(sess, self.model_save_dir, global_step=epoch)
Ejemplo n.º 7
0
    def train(self, train_path, val_path, save_dir, batch_size, epochs,
              learning_rate, prior_model):
        """
        Trains the Tiramisu on the specified training data and periodically validates
        on the validation data.

        Args:
            train_path: Directory where the training data is present.
            val_path: Directory where the validation data is present.
            save_dir: Directory where to save the model and training summaries.
            batch_size: Batch size to use for training.
            epochs: Number of epochs (complete passes over one dataset) to train for.
            learning_rate: Learning rate for the optimizer.
        Returns:
            None
        """

        train_image_path = os.path.join(train_path, 'images')
        train_mask_path = os.path.join(train_path, 'masks')
        val_image_path = os.path.join(val_path, 'images')
        val_mask_path = os.path.join(val_path, 'masks')

        assert os.path.exists(
            train_image_path), "No training image folder found"
        assert os.path.exists(train_mask_path), "No training mask folder found"
        assert os.path.exists(
            val_image_path), "No validation image folder found"
        assert os.path.exists(val_mask_path), "No validation mask folder found"

        train_image_paths, train_mask_paths = get_data_paths_list(
            train_image_path, train_mask_path)
        val_image_paths, val_mask_paths = get_data_paths_list(
            val_image_path, val_mask_path)

        assert len(train_image_paths) == len(
            train_mask_paths
        ), "Number of images and masks dont match in train folder"
        assert len(val_image_paths) == len(
            val_mask_paths
        ), "Number of images and masks dont match in validation folder"

        self.num_train_images = len(train_image_paths)
        self.num_val_images = len(val_image_paths)

        train_data, train_queue_init = utility.data_batch(
            train_image_paths, train_mask_paths, batch_size)
        train_image_tensor, train_mask_tensor = train_data

        eval_data, eval_queue_init = utility.data_batch(
            val_image_paths, val_mask_paths, batch_size)
        eval_image_tensor, eval_mask_tensor = eval_data

        image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
        mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1])
        training = tf.placeholder(tf.bool, shape=[])

        if not self.logits:
            self.logits = self.model(image_ph, training)

        loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph))

        with tf.variable_scope("mean_iou_train"):
            iou, iou_update = self.calculate_iou(mask_ph, self.logits)

        optimizer = tf.train.AdamOptimizer(learning_rate)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = optimizer.minimize(loss)

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="mean_iou_train")

        reset_iou = tf.variables_initializer(var_list=running_vars)

        with tf.Session() as sess:
            saver = tf.train.Saver(max_to_keep=20)
            if prior_model != "":
                saver.restore(sess, prior_model)
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            for epoch in range(epochs):
                writer = tf.summary.FileWriter(os.path.dirname(save_dir),
                                               sess.graph)
                sess.run([train_queue_init, eval_queue_init])
                total_train_cost, total_val_cost = 0, 0
                total_train_iou, total_val_iou = 0, 0
                for train_step in range(self.num_train_images // batch_size):
                    image_batch, mask_batch, _ = sess.run(
                        [train_image_tensor, train_mask_tensor, reset_iou])
                    feed_dict = {
                        image_ph: image_batch,
                        mask_ph: mask_batch,
                        training: True
                    }
                    cost, _, _ = sess.run([loss, opt, iou_update],
                                          feed_dict=feed_dict)
                    train_iou = sess.run(iou, feed_dict=feed_dict)
                    total_train_cost += cost
                    total_train_iou += train_iou
                    if train_step % 50 == 0:
                        print("Step: ", train_step, "Cost: ", cost, "IoU:",
                              train_iou)

                for val_step in range(self.num_val_images // batch_size):
                    image_batch, mask_batch, _ = sess.run(
                        [eval_image_tensor, eval_mask_tensor, reset_iou])
                    feed_dict = {
                        image_ph: image_batch,
                        mask_ph: mask_batch,
                        training: True
                    }
                    eval_cost, _ = sess.run([loss, iou_update],
                                            feed_dict=feed_dict)
                    eval_iou = sess.run(iou, feed_dict=feed_dict)
                    total_val_cost += eval_cost
                    total_val_iou += eval_iou

                print("Epoch: {0}, training loss: {1}, validation loss: {2}".
                      format(epoch, total_train_cost / train_step,
                             total_val_cost / val_step))
                print("Epoch: {0}, training iou: {1}, val iou: {2}".format(
                    epoch, total_train_iou / train_step,
                    total_val_iou / val_step))

                print("Saving model...")
                saver.save(sess, save_dir, global_step=epoch)
Ejemplo n.º 8
0
    def train(self, train_path, val_path, save_dir, batch_size, epochs, learning_rate,learning_policy):
        """
        Trains the Tiramisu on the specified training data and periodically validates
        on the validation data.

        Args:
            train_path: Directory where the training data is present.
            val_path: Directory where the validation data is present.
            save_dir: Directory where to save the model and training summaries.
            batch_size: Batch size to use for training.
            epochs: Number of epochs (complete passes over one dataset) to train for.
            learning_rate: Learning rate for the optimizer.
        Returns:
            None
        """

        train_image_path = os.path.join(train_path, 'images')
        train_mask_path = os.path.join(train_path, 'masks')
        val_image_path = os.path.join(val_path, 'images')
        val_mask_path = os.path.join(val_path, 'masks')

        assert os.path.exists(train_image_path), "No training image folder found"
        assert os.path.exists(train_mask_path), "No training mask folder found"
        assert os.path.exists(val_image_path), "No validation image folder found"
        assert os.path.exists(val_mask_path), "No validation mask folder found"

        train_image_paths, train_mask_paths = get_data_paths_list(train_image_path, train_mask_path)
        val_image_paths, val_mask_paths = get_data_paths_list(val_image_path, val_mask_path)

        assert len(train_image_paths) == len(train_mask_paths), "Number of images and masks dont match in train folder"
        assert len(val_image_paths) == len(val_mask_paths), "Number of images and masks dont match in validation folder"

        self.num_train_images = len(train_image_paths)
        self.num_val_images = len(val_image_paths)

        train_data, train_queue_init = utility.data_batch(
            train_image_paths, train_mask_paths, batch_size,train_flag=True)
        train_image_tensor, train_mask_tensor = train_data

        eval_data, eval_queue_init = utility.data_batch(
            val_image_paths, val_mask_paths, batch_size,train_flag=True)
        eval_image_tensor, eval_mask_tensor = eval_data

        image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
        mask_ph = tf.placeholder(tf.int32, shape=[None, 256, 256, 1])
        training = tf.placeholder(tf.bool, shape=[])
        global_step = tf.Variable(0, trainable=False)
        if not self.logits:
            self.logits = self.model(image_ph, training)
        
        regularizer = tf.contrib.layers.l2_regularizer(scale=0.001)
        #reg_term = tf.contrib.layers.apply_regularization(regularizer)

        loss = tf.reduce_mean(self.xentropy_loss(self.logits, mask_ph))

        with tf.variable_scope("mean_iou_train",regularizer=regularizer):
        #with tf.variable_scope("mean_iou_train"):
            iou, iou_update = self.calculate_iou(mask_ph, self.logits)
        
        initial_learning_rate =learning_rate;
        if learning_policy == 'step':
            learning_rate = tf.train.exponential_decay(initial_learning_rate,
                                           global_step=global_step,
                                           decay_steps=10,decay_rate=0.9)
        elif learning_policy == 'poly':
            learning_rate  =tf.train.polynomial_decay(initial_learning_rate,global_step=global_step,
                                           decay_steps=10,end_learning_rate=0.0001,power=1.0, cycle=False)
        else:
            learning_rate = initial_learning_rate
        #RMSProp
        #optimizer = tf.train.RMSPropOptimizer(0.001, 0.9)
        #SGD
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        #Momentum
        #optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
        #AdaGrad
        #optimizer = tf.train.AdamOptimizer(learning_rate)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = optimizer.minimize(loss)

        running_vars = tf.get_collection(
            tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou_train")

        reset_iou = tf.variables_initializer(var_list=running_vars)

        saver = tf.train.Saver(max_to_keep=20)
        with tf.Session() as sess:
            sess.run([tf.global_variables_initializer(),
                    tf.local_variables_initializer()])
            
            total_train_losses =[]
            total_train_ious = []
            total_val_losses = []
            total_val_ious = []
            steps = []
            
            for epoch in range(epochs):
                writer = tf.summary.FileWriter(os.path.dirname(save_dir), sess.graph)
                sess.run([train_queue_init, eval_queue_init])
                total_train_cost, total_val_cost = 0, 0
                total_train_iou, total_val_iou = 0, 0
                for train_step in range(self.num_train_images // batch_size):
                    image_batch, mask_batch, _ = sess.run(
                        [train_image_tensor, train_mask_tensor, reset_iou])
                    feed_dict = {image_ph: image_batch,
                                mask_ph: mask_batch,
                                training: True}
                    cost, _, _ = sess.run(
                        [loss, opt, iou_update], feed_dict=feed_dict)
                    train_iou = sess.run(iou, feed_dict=feed_dict)
                    total_train_cost += cost
                    total_train_iou += train_iou
                    if train_step % 50 == 0:
                        print("Step: ", train_step, "Cost: ",
                            cost, "IoU:", train_iou)

                for val_step in range(self.num_val_images // batch_size):
                    image_batch, mask_batch, _ = sess.run(
                        [eval_image_tensor, eval_mask_tensor, reset_iou])
                    feed_dict = {image_ph: image_batch,
                                mask_ph: mask_batch,
                                training: True}
                    eval_cost, _ = sess.run(
                        [loss, iou_update], feed_dict=feed_dict)
                    eval_iou = sess.run(iou, feed_dict=feed_dict)
                    total_val_cost += eval_cost
                    total_val_iou += eval_iou

                print("Epoch: {0}, training loss: {1}, validation loss: {2}".format(epoch, 
                                    total_train_cost / train_step, total_val_cost / val_step))
                print("Epoch: {0}, training iou: {1}, val iou: {2}".format(epoch, 
                                    total_train_iou / train_step, total_val_iou / val_step))
                total_train_losses.append(total_train_cost / train_step)
                total_val_losses.append(total_val_cost / val_step)
                total_train_ious.append(total_train_iou / train_step)
                total_val_ious.append(total_val_iou / val_step)
                steps.append(epoch)                                  
                print("Saving model...")
                saver.save(sess, save_dir, global_step=epoch)
            plot(steps,total_train_losses,color='r',label='train_loss')
            plot(steps,total_val_losses,color='g',label='val_loss')
            plot(steps,total_train_ious,color='k',label='train_iou')
            plot(steps,total_val_ious,color = 'b',label='val_iou')
            xlabel('epoch')
            ylabel('value')
            title('Loss-Iou')
            legend(loc='best')
            savefig('./result.jpg')
Ejemplo n.º 9
0
    def infer(self, image_dir, batch_size, ckpt, output_folder, organ):
        """
        Uses a trained model file to get predictions on the specified images.
        Args:
            image_dir: Directory where the images are located.
            batch_size: Batch size to use while inferring (relevant if batch norm is used)
            ckpt: Name of the checkpoint file to use.
            output_folder: Folder where the predictions on the images shoudl be saved.
        """
        print(image_dir)
        cvt_dcm_png(image_dir, "CT")
        image_paths = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png') or x.endswith('.jpg')]
        infer_data, infer_queue_init = utility.data_batch(
            image_paths, None, batch_size, shuffle=False)

        image_ph = None
        mask = None
        #training = None
        if not self.restored:
            image_ph = tf.placeholder(tf.float32, shape=[None, 256, 256, 1])
            #training = tf.placeholder(tf.bool, shape=[])

        if self.logits == None:
            self.logits = network.deeplab_v3(image_ph, is_training=True, reuse=tf.AUTO_REUSE) 
            #self.logits = self.model(image_ph, training)
        if not self.restored:
            mask = tf.squeeze(tf.argmax(self.logits, axis=3))
        j = 0
        saver = tf.train.Saver()
        with tf.Session() as sess:
            if not self.restored:
                saver.restore(sess, ckpt)
                restored = True
            sess.run(infer_queue_init)
            for _ in range(len(image_paths) // batch_size):
                image = sess.run(infer_data)
                #print(np.max(image), np.min(image), image.shape, image.dtype)


                feed_dict = {
                    image_ph: image
                    #training: False 
                }
                prediction = sess.run(mask, feed_dict)

                prediction = remove_small_objects(prediction.astype(np.bool), min_size=450, connectivity=2) * 255
                #cv2.imwrite(os.path.join(output_folder, "raw_" + os.path.basename(image_paths[_])), prediction)
                prediction = binary_fill_holes(resize((255 * prediction[:, :]).astype(np.uint8), (512,512)))* 255 
                if organ=="spleen":                
                    prediction[:, 0:250] = 0

                print(os.path.join(output_folder, os.path.basename(image_paths[_])))


                
                image = resize((255 * image[0,:, :,0]).astype(np.uint8), (512,512))* 255 

                mean3_image = convolve2d(image, np.ones((3,3)) * 1/9, mode='same', boundary='symm')
                mean5_image = convolve2d(image, np.ones((5,5)) * 1/25, mode='same', boundary='symm')

                segstd = np.std(image[prediction == 255])
                segmean = np.mean(image[prediction == 255])


                """
                Visualization Skip Purpose

                if np.mean(prediction) < 1:
                    continue
                """



                mean3_segstd = np.std(mean3_image[prediction == 255])
                mean3_segmean = np.mean(mean3_image[prediction == 255])
                mean5_segstd = np.std(mean5_image[prediction == 255])
                mean5_segmean = np.mean(mean5_image[prediction == 255])
                print(segstd, segmean)


                segment = 1800
                compact = 100

                std = 1.2 # 1.2
                std_rm = 1.85 #1.85
                std_slic = 1.0 # 1.0
                queue_visit = []
                visited = copy.deepcopy(prediction)
                visit_depth = 150
                
                image_slic = slic(image, n_segments=segment, compactness=compact, sigma=1)
                liver_slic = defaultdict(lambda: 0)
                # Region Growing 'queue_visit initialization' 
                for x in range(1, 510):
                    for y in range(1, 511):
                        if visited[x, y] == 255:
                            liver_slic[image_slic[x, y]] += 1
                            if visited[x + 1, y] == 0:
                                visited[x + 1, y] = 1
                                queue_visit.append((x + 1, y))

                            if visited[x, y + 1] == 0:
                                visited[x, y + 1] = 1
                                queue_visit.append((x, y + 1))

                            if visited[x - 1, y] == 0:
                                visited[x - 1, y] = 1
                                queue_visit.append((x - 1, y))

                            if visited[x, y - 1] == 0:
                                visited[x, y - 1] = 1
                                queue_visit.append((x, y - 1))


                while(len(queue_visit)):
                    if visited[x, y] == visit_depth:
                        break


                    x, y = queue_visit.pop(0)
                    #print("Target Pixel", (x, y), image[x, y], liver_slic[image_slic[x, y]]) 
                    # Check the tissue value is eligible  and liver_slic[image_slic[x,y]] < 300
                    m = np.mean(image[image_slic == image_slic[x,y]])
                    if image[x, y] > segmean + std * segstd or image[x, y] < segmean - std * segstd  or  m < segmean:
                        visited[x, y] = 1
                    else:
                        # Traverse neighborhood pixels
                        if x + 1 < 512 and visited[x + 1, y] == 0:
                            visited[x + 1, y] = visited[x, y] + 1
                            queue_visit.append((x + 1, y))
                        if y + 1 < 512 and visited[x, y + 1] == 0:
                            visited[x, y + 1] = visited[x, y] + 1
                            queue_visit.append((x, y + 1))
                        if x - 1 > 0 and visited[x - 1, y] == 0:
                            visited[x - 1, y] = visited[x, y] + 1
                            queue_visit.append((x - 1, y))
                        if y  - 1 > 0 and visited[x, y - 1] == 0:
                            visited[x, y - 1] = visited[x, y] + 1
                            queue_visit.append((x, y - 1))

                        visited[x, y] = 255

                tmp_img_mean5 = copy.deepcopy(visited)
                tmp_img_mean5[mean5_image > mean5_segmean + std_rm * mean5_segstd] = 0
                tmp_img_mean5[mean5_image < mean5_segmean - std_rm * mean5_segstd] = 0

                tmp_img_mean5 = binary_opening(tmp_img_mean5, selem=disk(5)).astype(np.uint8) * 255

                remove_small_objects(tmp_img_mean5, min_size=400, connectivity=1, in_place=True)


                cv2.imwrite(os.path.join(output_folder, os.path.basename(image_paths[_])), tmp_img_mean5)
                j = j + 1