Exemplo n.º 1
0
    def read_file(self):
        #print(self.dir_name)
        #print(self.file_index)
        #print(len(self.image_fname))
        print(self.image_fname[self.random_batch[self.file_index]])
        print(self.dir_name + '/' + self.dir_image + '/' + self.image_fname[self.random_batch[self.file_index]])

        image = m2n.loadimage(self.dir_name + '/' + self.dir_image + '/' + self.image_fname[self.random_batch[self.file_index]])
        _, mask, str_name = m2n.parsemask(self.dir_name + '/' + self.dir_mask + '/' + self.mask_fname[self.random_batch[self.file_index]])
        mask = mask[self.mask_index[self.random_batch[self.file_index]]]

        # Broadcasting from int32 to float64/32(?) is mandatory for resizing images without pixel value change!
        image = np.float64(image)
        mask = np.float64(mask)

        print(str_name[self.mask_index[self.random_batch[self.file_index]]])

        self.file_index += 1

        if self.opt_resize == True and self.opt_crop == True:
            print("Both resize and crop are selected. Please choose either one")

        # resize 2-D N number of image (N x width x height)
        if self.opt_resize == True:
            print("image is resized")
            image = dpp.resize_3d(dcmimage=image, resize_shape=self.resize_shape)
            mask = dpp.resize_3d(dcmimage=mask, resize_shape=self.resize_shape)

        # crop 2-D N number of image (N x width x height)
        if self.opt_crop == True:
            print("image is cropped")
            image = dpp.crop_3d(dcmimage=image, crop_shape=self.crop_shape)
            mask = dpp.crop_3d(dcmimage=mask, crop_shape=self.crop_shape)

        return image, mask
Exemplo n.º 2
0
    def _read_py_function(self, img_fname, pix_upper_lim, pix_lower_lim):
        img_fname = img_fname.decode('utf-8')

        # read matlab file and turn it into numpy
        print("img_fname : ", img_fname)
        a = sio.loadmat(img_fname)
        image = a['img']

        image = dpp.resize_3d(image, self.img_resize)

        # Set the range of pixel values to be 0 - 3000
        image[
            image <
            pix_lower_lim] = pix_lower_lim  # Value less than pixel lower limit -> pixel lower limit
        image[
            image >
            pix_upper_lim] = pix_upper_lim  # Value more than 2000 -> pixel upper limit

        image = image - np.min(image.flatten(
        ))  # range from (lower_lim, upper_lim) -> (0,upper_lim - lower_lim)

        # Normalize pixel values to (0,255) range
        pix_range = pix_upper_lim - pix_lower_lim
        bin_size = pix_range / 255
        image = np.matrix.round(
            image /
            bin_size)  # range from (0,upper_lim - lower_lim) -> (0,255)

        return image
Exemplo n.º 3
0
    def read_file(self):
        #print(self.dir_name)
        #print(self.file_index)
        #print(len(self.image_fname))

        image = m2n.loadimage(self.dir_name + '/image/' +
                              self.image_fname[self.file_index])
        _, mask, str_name = m2n.parsemask(self.dir_name + '/mask/' +
                                          self.mask_fname[self.file_index])
        mask = mask[self.mask_index[self.file_index]]

        # Broadcasting from int32 to float64/32(?) is mandatory for resizing images without pixel value change!
        image = np.float64(image)
        mask = np.float64(mask)

        print(str_name[self.mask_index[self.file_index]])

        self.file_index += 1

        # resize 2-D N number of image (N x width x height)
        image = dpp.resize_3d(dcmimage=image, resize_shape=self.resize_shape)
        mask = dpp.resize_3d(dcmimage=mask, resize_shape=self.resize_shape)

        return image, mask
Exemplo n.º 4
0
def main(argv=None):
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="input_image")

    if FLAGS.optimization == "cross_entropy":
        annotation = tf.placeholder(tf.int32, shape=[None, LOGITS_SIZE, LOGITS_SIZE, 1], name="annotation")   # For cross entropy
        logits = u_net(x=image,keep_prob=0.75,channels=1,n_class=2)

        label = tf.squeeze(annotation, squeeze_dims=[3])
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label,name="entropy")) # For softmax

        #loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=tf.squeeze(annotation, squeeze_dims=[3]),name="entropy"))  # For softmax

    elif FLAGS.optimization == "dice":
        annotation = tf.placeholder(tf.int32, shape=[None, LOGITS_SIZE, LOGITS_SIZE, 1], name="annotation")  # For DICE
        logits = u_net(x=image,keep_prob=0.75,channels=1,n_class=2)

        # pred_annotation (argmax) is not differentiable so it cannot be optimized. So in loss, we need to use logits instead of pred_annotation!
        logits = tf.nn.softmax(logits) # axis = -1 default
        logits2 = tf.slice(logits, [0,0,0,1],[-1,LOGITS_SIZE,LOGITS_SIZE,1])
        loss = 1 - tl.cost.dice_coe(logits2, tf.cast(annotation, dtype=tf.float32))


    total_var = tf.trainable_variables()
    # ========================================
    # To limit the training range
    # scope_name = 'inference'
    # trainable_var = [var for var in total_var if scope_name in var.name]
    # ========================================

    # Train all model
    trainable_var = total_var

    train_op = train(loss, trainable_var)


    #print("Setting up summary op...")
    #summary_op = tf.summary.merge_all()

#    for variable in trainable_var:
#        print(variable)


    #Way to count the number of variables + print variable names
    """
    total_parameters = 0
    for variable in trainable_var:
        # shape is an array of tf.Dimension
        print(variable)
        shape = variable.get_shape()
        print(shape)
        print(len(shape))
        variable_parameters = 1
        for dim in shape:
            print(dim)
            variable_parameters *= dim.value
        print(variable_parameters)
        total_parameters += variable_parameters
    print("Total # of parameters : ", total_parameters)
    """
    # All the variables defined HERE -------------------------------
    #dir_name = 'DICOM_data/mandible/'
    #contour_name = 'brainstem'

    dir_name = 'AQA/'
    contour_name = 'external'

    batch_size = 3

    opt_crop = False
    crop_shape = (224, 224)
    opt_resize = False
    resize_shape = (224, 224)
    rotation = True
    rotation_angle = [-5, 5]
    bitsampling = False
    bitsampling_bit = [4, 8]
    # --------------------------------------------------------------


    #sess = tf.Session()
    sess = tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) # CPU ONLY

    print("Setting up Saver...")
    saver = tf.train.Saver()
    #summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    if FLAGS.mode == "train":
        print("Setting up training data...")
        dicom_records = dicom_batch.read_DICOM(dir_name=dir_name + 'training_set', dir_image=dir_image, dir_mask=dir_mask,
                                               contour_name=contour_name, opt_resize=opt_resize, resize_shape=resize_shape,
                                               opt_crop=opt_crop, crop_shape=crop_shape, rotation=rotation,
                                               rotation_angle=rotation_angle, bitsampling=bitsampling,
                                               bitsampling_bit=bitsampling_bit)

        print("Setting up validation data...")
        validation_records = dicom_batch.read_DICOM(dir_name=dir_name + 'validation_set', dir_image=dir_image, dir_mask=dir_mask,
                                                    contour_name=contour_name, opt_resize=opt_resize, resize_shape=resize_shape,
                                                    opt_crop=opt_crop, crop_shape=crop_shape, rotation=False,
                                                    rotation_angle=rotation_angle, bitsampling=False,
                                                    bitsampling_bit=bitsampling_bit)

        print("Start training")
        start = time.time()
        train_loss_list = []
        x_train = []
        validation_loss_list = []
        x_validation = []
        # for itr in xrange(MAX_ITERATION):
        for itr in xrange(MAX_ITERATION): # about 12 hours of work / 2000
            train_images, train_annotations = dicom_records.next_batch(batch_size=batch_size)

            # Reshape the annotation as the output (mask) has different dimension with the input
            print(type(train_annotations), train_annotations.shape)
            train_annotations = dpp.resize_3d(np.squeeze(train_annotations, axis=3),resize_shape=(LOGITS_SIZE,LOGITS_SIZE))
            train_annotations = train_annotations[:, :, :, np.newaxis]
            print(type(train_annotations),train_annotations.shape)

            feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.75}
            sess.run(train_op, feed_dict=feed_dict)

            if (itr+1) % 20 == 0:
                #train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
                train_loss = sess.run(loss, feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (itr, train_loss))
                train_loss_list.append(train_loss)
                x_train.append(itr+1)
                #summary_writer.add_summary(summary_str, itr)

            if (itr+1) % 50 == 0:
                valid_images, valid_annotations = validation_records.next_batch(batch_size=batch_size)
                valid_annotations = dpp.resize_3d(np.squeeze(valid_annotations, axis=3), resize_shape=(LOGITS_SIZE, LOGITS_SIZE))
                valid_annotations = valid_annotations[:, :, :, np.newaxis]

                valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
                                                       keep_probability: 1.0})
                print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
                validation_loss_list.append(valid_loss)
                x_validation.append(itr+1)

            if (itr+1) % 2000 == 0:
                saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr+1)

            end = time.time()
            print("Iteration #", itr+1, ",", np.int32(end - start), "s")

        saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr+1)

        # Draw loss functions
        plt.plot(x_train,train_loss_list,label='train')
        plt.plot(x_validation,validation_loss_list,label='validation')
        plt.title("loss functions")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.ylim(ymin=min(train_loss_list))
        plt.ylim(ymax=max(train_loss_list)*1.1)
        plt.legend()
        plt.savefig("loss_functions.png")

    # Need to add another mode to draw the contour based on image only.
    elif FLAGS.mode == "test":
        print("Setting up test data...")
        img_dir_name = '..\H&N_CTONLY'
        test_batch_size = 10
        test_index = 5
        ind = 0
        test_records = dicom_batchImage.read_DICOMbatchImage(dir_name=img_dir_name, opt_resize=opt_resize,
                                                             resize_shape=resize_shape, opt_crop=opt_crop, crop_shape=crop_shape)

        test_annotations = np.zeros([test_batch_size,224,224,1]) # fake input

        for index in range(test_index):
            print("Start creating data")
            test_images = test_records.next_batch(batch_size=test_batch_size)
            pred = sess.run(pred_annotation, feed_dict={image: test_images, annotation: test_annotations, keep_probability: 1.0})
            pred = np.squeeze(pred, axis=3)

            print("Start saving data")
            for itr in range(test_batch_size):
                plt.subplot(121)
                plt.imshow(test_images[itr, :, :, 0], cmap='gray')
                plt.title("image")
                plt.subplot(122)
                plt.imshow(pred[itr], cmap='gray')
                plt.title("pred mask")
                plt.savefig(FLAGS.logs_dir + "/Prediction_test" + str(ind) + ".png")
                print("Test iteration : ", ind)
                ind += 1

    elif FLAGS.mode == "visualize":
        print("Setting up validation data...")
        validation_records = dicom_batch.read_DICOM(dir_name=dir_name + 'validation_set', dir_image=dir_image, dir_mask=dir_mask,
                                                    contour_name=contour_name, opt_resize=opt_resize, resize_shape=resize_shape,
                                                    opt_crop=opt_crop, crop_shape=crop_shape, rotation=False,
                                                    rotation_angle=rotation_angle, bitsampling=False,
                                                    bitsampling_bit=bitsampling_bit)

        dice_array = []
        bins = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        # Save the image for display. Use matplotlib to draw this.
        for itr in range(20):
            valid_images, valid_annotations = validation_records.next_batch(batch_size=1)

            pred = sess.run(pred_annotation, feed_dict={image: valid_images, annotation: valid_annotations, keep_probability: 1.0})
            pred = np.squeeze(pred, axis=3)


            print(valid_images.shape, valid_annotations.shape, pred.shape)
            valid_annotations = np.squeeze(valid_annotations, axis=3)
            dice_coeff = dice(valid_annotations[0], pred[0])

            dice_array.append(dice_coeff)
            print("min max of prediction : ", pred.flatten().min(), pred.flatten().max())
            print("min max of validation : ", valid_annotations.flatten().min(), valid_annotations.flatten().max())
            print("DICE : ", dice_coeff)
            print(valid_annotations.shape)


            # Save images
            plt.subplot(131)
            plt.imshow(valid_images[0, :, :, 0], cmap='gray')
            plt.title("image")
            plt.subplot(132)
            plt.imshow(valid_annotations[0,:,:], cmap='gray')
            plt.title("mask original")
            plt.subplot(133)
            plt.imshow(pred[0], cmap='gray')
            plt.title("mask predicted")
            plt.suptitle("DICE : " + str(dice_coeff))

            plt.savefig(FLAGS.logs_dir + "/Prediction_validation" + str(itr) + ".png")
            # plt.show()

        plt.figure()
        plt.hist(dice_array,bins)
        plt.xlabel('Dice')
        plt.ylabel('frequency')
        plt.title('Dice coefficient distribution of validation dataset')
        plt.savefig(FLAGS.logs_dir + "/dice histogram" + ".png")