Пример #1
0
def read_whole_image_test(filename, flags):
    """
    read_whole_image_test reads in the test image, resizes it
    from 10x to 2.5x, and returns it along with its shape

    param: filename
    return: image, dcis_mask, image.shape
    """

    image = KSimage.imread(filename)
    basename = os.path.basename(filename)
    dcis_mask_file = os.path.join('experiment_dcis_segmentation', 'perm1',
                                  'result', basename)

    if os.path.exists(dcis_mask_file):
        dcis_mask = KSimage.imread(dcis_mask_file)
    else:
        dcis_mask = np.ones(shape=(image.shape[0], image.shape[1])) * 255.0
        dcis_mask = dcis_mask.astype(np.uint8)

    #Resizing from 10x to 2.5x
    image = KSimage.imresize(image, 0.25)
    dcis_mask = KSimage.imresize(dcis_mask, 0.25)

    if image.ndim == 2:
        image = np.expand_dims(image, axis=3)

    if dcis_mask.ndim == 2:
        dcis_mask = np.expand_dims(dcis_mask, axis=3)

    return image, dcis_mask, image.shape
Пример #2
0
def read_data_test(filename, flags, dcis_segmentation_result_path):
    stride = flags['stride_test']

    image = KSimage.imread(filename)

    files = [
        f for f in os.listdir(dcis_segmentation_result_path)
        if os.path.isfile(os.path.join(dcis_segmentation_result_path, f))
    ]

    basename = os.path.basename(filename)
    basename = os.path.splitext(basename)[0]
    pos = [m.start() for m in re.finditer('_', basename)]
    # basename = basename[0:pos[3] + 1]

    basename = [x for x in files if basename in x][0]

    dcis_mask_file = os.path.join(dcis_segmentation_result_path, basename)
    if os.path.exists(dcis_mask_file):
        dcis_mask = KSimage.imread(dcis_mask_file)
    else:
        dcis_mask = np.ones(shape=(image.shape[0], image.shape[1], 1)) * 255.0
        dcis_mask = dcis_mask.astype(np.uint8)

    image = KSimage.imresize(image, 2.0)
    dcis_mask = KSimage.imresize(dcis_mask, 2.0)

    if image.ndim == 2:
        image = np.expand_dims(image, axis=3)

    if dcis_mask.ndim == 2:
        dcis_mask = np.expand_dims(dcis_mask, axis=3)

    padrow = flags['size_input_patch'][0]
    padcol = flags['size_input_patch'][1]

    image = np.lib.pad(image, ((padrow, padrow), (padcol, padcol), (0, 0)),
                       'symmetric')
    dcis_mask = np.lib.pad(dcis_mask,
                           ((padrow, padrow), (padcol, padcol), (0, 0)),
                           'symmetric')

    # extract patches
    patches = ExtractPatches_test(flags['size_input_patch'], stride, image)
    patches_mask = ExtractPatches_test(flags['size_input_patch'], stride,
                                       dcis_mask)

    ntimes_row = int(
        np.floor((image.shape[0] - flags['size_input_patch'][0]) /
                 float(stride[0])) + 1)
    ntimes_col = int(
        np.floor((image.shape[1] - flags['size_input_patch'][1]) /
                 float(stride[1])) + 1)
    rowRange = range(0, ntimes_row * stride[0], stride[0])
    colRange = range(0, ntimes_col * stride[1], stride[1])

    nPatches = len(rowRange) * len(colRange)

    return patches, patches_mask, image.shape, nPatches
def read_data_test(filename, flags):
    stride = flags['stride_test']

    image = KSimage.imread(filename)
    ori_dim = image.shape
    image = KSimage.imresize(image, 0.25)  # resize to 1/4 times

    if image.ndim == 2:
        image = np.expand_dims(image, axis=3)

    padrow = flags['size_input_patch'][0]
    padcol = flags['size_input_patch'][1]

    image = np.lib.pad(image, ((padrow, padrow), (padcol, padcol), (0, 0)),
                       'symmetric')

    # extract patches
    patches = ExtractPatches_test(flags['size_input_patch'], stride, image)

    ntimes_row = int(
        np.floor((image.shape[0] - flags['size_input_patch'][0]) /
                 float(stride[0])) + 1)
    ntimes_col = int(
        np.floor((image.shape[1] - flags['size_input_patch'][1]) /
                 float(stride[1])) + 1)
    rowRange = range(0, ntimes_row * stride[0], stride[0])
    colRange = range(0, ntimes_col * stride[1], stride[1])

    nPatches = len(rowRange) * len(colRange)

    return patches, image.shape, nPatches, ori_dim
Пример #4
0
def read_data_test(filename, flags):
    """
    read_data_test reads in the test image, resizes it
    from 10x to 2.5x, and extracts patches from the image

    param: filename
    param: stride_test
    return: patches, patches_mask, image.shape, nPatches
    """

    stride = flags['stride_test']

    image = KSimage.imread(filename)
    image = KSimage.imresize(image, 0.25)  #Resizing from 10x to 2.5x
    basename = os.path.basename(filename)
    dcis_mask_file = os.path.join('experiment_epi_stromal_segmentation',
                                  'perm1', 'result', basename)
    if os.path.exists(dcis_mask_file):
        dcis_mask = KSimage.imread(dcis_mask_file)
    else:
        dcis_mask = np.ones(shape=(image.shape[0], image.shape[1])) * 255.0
        dcis_mask = dcis_mask.astype(np.uint8)

    if image.ndim == 2:
        image = np.expand_dims(image, axis=3)

    if dcis_mask.ndim == 2:
        dcis_mask = np.expand_dims(dcis_mask, axis=3)

    padrow = flags['size_input_patch'][0]
    padcol = flags['size_input_patch'][1]

    image = np.lib.pad(image, ((padrow, padrow), (padcol, padcol), (0, 0)),
                       'symmetric')
    dcis_mask = np.lib.pad(dcis_mask,
                           ((padrow, padrow), (padcol, padcol), (0, 0)),
                           'symmetric')

    print("INPUT SHAPE: " + str(image.shape))

    # extract patches
    patches = ExtractPatches_test(flags['size_input_patch'], stride, image)
    patches_mask = ExtractPatches_test(flags['size_input_patch'], stride,
                                       dcis_mask)

    ntimes_row = int(
        np.floor((image.shape[0] - flags['size_input_patch'][0]) /
                 float(stride[0])) + 1)
    ntimes_col = int(
        np.floor((image.shape[1] - flags['size_input_patch'][1]) /
                 float(stride[1])) + 1)
    rowRange = range(0, ntimes_row * stride[0], stride[0])
    colRange = range(0, ntimes_col * stride[1], stride[1])

    nPatches = len(rowRange) * len(colRange)

    return patches, patches_mask, image.shape, nPatches
Пример #5
0
def batch_processing(filename, sess, logits_test, parameters, images_test,
                     keep_prob, mean_image, variance_image, flags):
    # Read image and extract patches
    patches, image_size, nPatches, ori_dim = tf_model_input_test.read_data_test(
        filename, flags)

    def batches(generator, size):
        source = generator
        while True:
            chunk = [val for _, val in izip(xrange(size), source)]
            if not chunk:
                raise StopIteration
            yield chunk

    # Construct batch indices
    batch_index = range(0, nPatches, flags['test_batch_size'])
    if nPatches not in batch_index:
        batch_index.append(nPatches)

    # Process all_patches
    shape = np.hstack([nPatches, flags['size_output_patch']])
    shape[-1] = logits_test.get_shape()[3].value
    all_patches = np.zeros(shape, dtype=np.float32)

    for ipatch, chunk in enumerate(batches(patches, flags['test_batch_size'])):
        start_idx = batch_index[ipatch]
        end_idx = batch_index[ipatch + 1]

        temp = tf_model_input_test.inputs_test(chunk, mean_image,
                                               variance_image)

        if temp.shape[0] < flags['test_batch_size']:
            rep = np.tile(temp[-1, :, :, :],
                          [flags['test_batch_size'] - temp.shape[0], 1, 1, 1])
            temp = np.vstack([temp, rep])

        pred = sess.run(logits_test,
                        feed_dict={
                            images_test: temp,
                            keep_prob: 1.0
                        })
        all_patches[start_idx:end_idx, :, :, :] = pred[
            range(end_idx - start_idx), :, :, :]

    result = tf_model_input_test.MergePatches_test(
        all_patches, flags['stride_test'], image_size,
        flags['size_input_patch'], flags['size_output_patch'], flags)
    result = result * 255.0
    result = result.astype(np.uint8)
    result = KSimage.imresize(result, 4.0)
    result = np.argmax(result, axis=2)

    # resize may not preserve the original dimensions of the image
    # append with zero or remove excessive pixels in each dimension
    if result.shape[0] < ori_dim[0]:
        result = np.pad(result, ((0, ori_dim[0] - result.shape[0]), (0, 0)),
                        'constant',
                        constant_values=0)
    else:
        result = result[0:ori_dim[0], :]

    if result.shape[1] < ori_dim[1]:
        result = np.pad(result, ((0, 0), (0, ori_dim[1] - result.shape[1])),
                        'constant',
                        constant_values=0)
    else:
        result = result[:, 0:ori_dim[1]]

    mask = result == 1
    mask = mask.astype(np.uint8) * 255
    im, contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE,
                                               cv2.CHAIN_APPROX_SIMPLE)

    temp_mask = np.zeros(mask.shape[:2], dtype='uint8')
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area > 500**2:
            cv2.drawContours(temp_mask, [cnt], -1, 255, -1)

    result = temp_mask

    return result
Пример #6
0
def gen_train_val_data(nth_fold, flags):
    """
    gen_train_val_data generates training and validation data for training the network. It builds
    directories for train and test and extract patches according to the provided 'method', and it
    maintains a log file containing the contents of all the data splits

    param: nth_fold
    param method: sliding_window
    return: void
    """

    ########## check whether 'cv' or 'perm' exists and which one to use ##########
    list_dir = os.listdir(os.path.join(flags['experiment_folder']))
    if ('cv' + str(nth_fold) in list_dir) and ('perm' + str(nth_fold)
                                               in list_dir):
        raise ValueError('Dangerous! You have both cv and perm on the path.')
    elif 'cv' + str(nth_fold) in list_dir:
        object_folder = os.path.join(flags['experiment_folder'],
                                     'cv' + str(nth_fold))
    elif 'perm' + str(nth_fold) in list_dir:
        object_folder = os.path.join(flags['experiment_folder'],
                                     'perm' + str(nth_fold))
    else:
        raise ValueError('No cv or perm folder!')

    ########## create train and val paths ##########
    path_dict = dict()
    path_dict['train_folder'] = os.path.join(object_folder, 'train')
    path_dict['val_folder'] = os.path.join(object_folder, 'val')
    create_dir(path_dict['train_folder'])
    create_dir(path_dict['val_folder'])

    print("Gets to the beginning of an if statement")
    ########## extract patches and put in a designated directory ##########
    if flags['gen_train_val_method'] == 'sliding_window':

        key_list = ['image', 'groundtruth', 'weight']

        for key in key_list:
            path_dict['train_' + key + '_folder'] = os.path.join(
                path_dict['train_folder'], key)
            create_dir(path_dict['train_' + key + '_folder'])
            path_dict['val_' + key + '_folder'] = os.path.join(
                path_dict['val_folder'], key)
            create_dir(path_dict['val_' + key + '_folder'])

        list_dict = dict()
        for key in key_list:
            list_dict['train_' + key + '_list'] = KScsv.read_csv(
                os.path.join(object_folder, 'train_' + key + '_list.csv'))
            list_dict['val_' + key + '_list'] = KScsv.read_csv(
                os.path.join(object_folder, 'val_' + key + '_list.csv'))

        ########## train ##########
        for key in ['train', 'val']:
            if not os.path.isfile(
                    os.path.join(path_dict[key + '_folder'],
                                 key + '_log.csv')):
                log_data = list()

                for i_image in range(len(list_dict[key + '_image_list'])):

                    tic = time.time()

                    path_image = list_dict[key + '_image_list'][i_image][0]
                    path_groundtruth = list_dict[
                        key + '_groundtruth_list'][i_image][0]
                    path_weight = list_dict[key + '_weight_list'][i_image][0]

                    #Resize image, groundtruth, and weight from 10x input size to 2.5x (level at which network operates)
                    image = KSimage.imread(path_image)
                    image = KSimage.imresize(image, 0.25)

                    groundtruth = KSimage.imread(path_groundtruth)
                    groundtruth = KSimage.imresize(groundtruth, 0.25)

                    weight = KSimage.imread(path_weight)
                    weight = KSimage.imresize(weight, 0.25)

                    #make sure that groundtruth images have depth = 1
                    if (len(groundtruth.shape) > 2
                            and groundtruth.shape[2] > 1):
                        groundtruth = groundtruth[:, :, 1]

                    groundtruth[
                        groundtruth ==
                        3] = 2  #remove all intra-stromal epithelium labels and set them simply to stroma
                    groundtruth[
                        groundtruth ==
                        4] = 3  #fat label was originally 4 but is now changed to 3

                    dict_obj = {
                        'image': image,
                        'groundtruth': groundtruth,
                        'weight': weight
                    }

                    extractor = extract_patches.sliding_window(
                        dict_obj, flags['size_input_patch'],
                        flags['size_output_patch'], flags['stride'])

                    for j, (out_obj_dict, coord_dict) in enumerate(extractor):
                        images = out_obj_dict['image']
                        groundtruths = out_obj_dict['groundtruth']
                        weights = out_obj_dict['weight']
                        coord_images = coord_dict['image']

                        #############################################################

                        basename = os.path.basename(path_image)
                        basename = os.path.splitext(basename)[0]

                        image_name = os.path.join(
                            path_dict[key + '_image_folder'], basename +
                            '_idx' + str(j) + '_row' + str(coord_images[0]) +
                            '_col' + str(coord_images[1]) + flags['image_ext'])
                        label_name = os.path.join(
                            path_dict[key + '_groundtruth_folder'],
                            basename + '_idx' + str(j) + '_row' +
                            str(coord_images[0]) + '_col' +
                            str(coord_images[1]) + flags['groundtruth_ext'])
                        weight_name = os.path.join(
                            path_dict[key + '_weight_folder'],
                            basename + '_idx' + str(j) + '_row' +
                            str(coord_images[0]) + '_col' +
                            str(coord_images[1]) + flags['weight_ext'])

                        if not os.path.isfile(image_name):
                            KSimage.imwrite(images, image_name)

                        if not os.path.isfile(label_name):
                            KSimage.imwrite(groundtruths, label_name)

                        if not os.path.isfile(weight_name):
                            KSimage.imwrite(weights, weight_name)

                        log_data.append((image_name, label_name, weight_name))

                    print('finish processing %d image from %d images : %.2f' %
                          (i_image + 1, len(list_dict[key + '_image_list']),
                           time.time() - tic))

                KScsv.write_csv(
                    log_data,
                    os.path.join(path_dict[key + '_folder'], key + '_log.csv'))

    ####################################################################################################################
    else:
        print(
            "ONLY SLIDING WINDOW TRAINING IS SUPPORTED!!!! Training terminated."
        )
        return
Пример #7
0
def batch_processing(filename, sess, logits_test, parameters, images_test,
                     keep_prob, mean_image, variance_image, flags,
                     he_dcis_segmentation_path):
    # Read image and extract patches
    patches, patches_mask, image_size, nPatches, ori_dim = tf_model_input_test.read_data_test(
        filename, flags, he_dcis_segmentation_path)

    def batches(generator, size):
        source = generator
        while True:
            chunk = [val for _, val in izip(xrange(size), source)]
            if not chunk:
                raise StopIteration
            yield chunk

    # Construct batch indices
    batch_index = range(0, nPatches, flags['test_batch_size'])
    if nPatches not in batch_index:
        batch_index.append(nPatches)

    # Process all_patches
    shape = np.hstack([nPatches, flags['size_output_patch']])
    shape[-1] = logits_test.get_shape()[3].value
    all_patches = np.zeros(shape, dtype=np.float32)

    for ipatch, chunk in enumerate(
            zip(batches(patches, flags['test_batch_size']),
                batches(patches_mask, flags['test_batch_size']))):
        # for ipatch in range(len(batch_index) - 1):
        #
        # start_time = time.time()
        start_idx = batch_index[ipatch]
        end_idx = batch_index[ipatch + 1]

        tmp = list()
        for i in range(len(chunk[1])):
            tmp.append(np.sum(chunk[1][i] == 255.0) / float(chunk[1][i].size))

        if np.any(np.array(tmp) > 0.0):
            # temp = tf_model_input_test.inputs_test(patches[start_idx:end_idx, :, :, :], mean_image, variance_image)
            temp = tf_model_input_test.inputs_test(chunk[0], mean_image,
                                                   variance_image)

            if temp.shape[0] < flags['test_batch_size']:
                rep = np.tile(
                    temp[-1, :, :, :],
                    [flags['test_batch_size'] - temp.shape[0], 1, 1, 1])
                temp = np.vstack([temp, rep])

            pred, paras = sess.run([logits_test, parameters],
                                   feed_dict={
                                       images_test: temp,
                                       keep_prob: 1.0
                                   })
            # expand single pixel prediction into patch
            # pred = np.lib.pad(pred, ((0, 0), (int(flags.size_output_patch[0]/2.0), int(flags.size_output_patch[0]/2.0)-1),
            #                           (int(flags.size_output_patch[0]/2.0), int(flags.size_output_patch[0]/2.0) - 1), (0, 0)), 'edge')

        else:
            shape = np.hstack(
                [flags['test_batch_size'], flags['size_output_patch']])
            shape[-1] = logits_test.get_shape()[3].value
            pred = np.zeros(shape, dtype=np.float32)
            for j in range(flags['test_batch_size']):
                x = pred[j, :, :, :]
                x[:, :, 0] = 1.0
                pred[j, :, :, :] = x

        all_patches[start_idx:end_idx, :, :, :] = pred[
            range(end_idx - start_idx), :, :, :]

        # duration = time.time() - start_time
        # print('processing step %d/%d (%.2f sec/step)' % (ipatch + 1, len(batch_index) - 1, duration))

    result = tf_model_input_test.MergePatches_test(
        all_patches, flags['stride_test'], image_size,
        flags['size_input_patch'], flags['size_output_patch'], flags)

    result = result * 255.0
    result = result.astype(np.uint8)
    result = KSimage.imresize(result, 2.0)
    result = np.argmax(result, axis=2)

    # resize may not preserve the original dimensions of the image
    # append with zero or remove excessive pixels in each dimension
    if result.shape[0] < ori_dim[0]:
        result = np.pad(result, ((0, ori_dim[0] - result.shape[0]), (0, 0)),
                        'constant',
                        constant_values=0)
    else:
        result = result[0:ori_dim[0], :]

    if result.shape[1] < ori_dim[1]:
        result = np.pad(result, ((0, 0), (0, ori_dim[1] - result.shape[1])),
                        'constant',
                        constant_values=0)
    else:
        result = result[:, 0:ori_dim[1]]

    im_in = result == 1
    im_in = im_in * 255.0
    im_in = im_in.astype(np.uint8)

    im_out = KSimage.imfill(im_in)
    im_out = im_out * 255.0
    im_out = im_out.astype(np.uint8)

    return im_out
Пример #8
0
def batch_processing(filename, sess, logits_test, parameters, images_test, keep_prob, mean_image, variance_image, flags):
    """
    batch_processing reads in an image, splits it up into patches,
    preprocesses the patches, passes them through the network, and returns
    the results

    param: filename
    param: sess
    param: logits_test
    param: parameters
    param: images_test
    param: keep_prob
    param: mean_image
    param: variance_image
    param: test_batch_size
    param: size_output_patch
    return: result
    """

    # Read image and extract patches
    patches, patches_mask, image_size, nPatches = tf_model_input_test.read_data_test(filename, flags)

    def batches(generator, size):
        source = generator
        while True:
            chunk = [val for _, val in izip(xrange(size), source)]
            if not chunk:
                raise StopIteration
            yield chunk

    # Construct batch indices
    batch_index = range(0, nPatches, flags['test_batch_size'])
    if nPatches not in batch_index:
        batch_index.append(nPatches)

    # Process all_patches
    shape = np.hstack([nPatches, flags['size_output_patch']])
    shape[-1] = logits_test.get_shape()[3].value
    all_patches = np.zeros(shape, dtype=np.float32)

    for ipatch, chunk in enumerate(zip(batches(patches, flags['test_batch_size']),
                                       batches(patches_mask, flags['test_batch_size']))):
        start_time = time.time()
        start_idx = batch_index[ipatch]
        end_idx = batch_index[ipatch + 1]

        tmp = list()
        for i in range(len(chunk[1])):
            tmp.append(np.sum(chunk[1][i]==255.0)/float(chunk[1][i].size))

        # process batch if any patch within it has >=50% uncovered by mask --> make sure to understand. white = uncovered by mask
        if np.any(np.array(tmp) > 0.5):
            temp = tf_model_input_test.inputs_test(chunk[0], mean_image, variance_image)

            if temp.shape[0] < flags['test_batch_size']:
                rep = np.tile(temp[-1, :, :, :], [flags['test_batch_size'] - temp.shape[0], 1, 1, 1])
                temp = np.vstack([temp, rep])

            pred, paras = sess.run([logits_test, parameters], feed_dict={images_test: temp, keep_prob: 1.0})

        else:
            shape = np.hstack([flags['test_batch_size'], flags['size_output_patch']])
            shape[-1] = logits_test.get_shape()[3].value
            pred = np.zeros(shape, dtype=np.float32)
            for j in range(flags['test_batch_size']):
                x = pred[j,:,:,:]
                x[:,:,0] = 1.0
                pred[j,:,:,:] = x

        all_patches[start_idx:end_idx, :, :, :] = pred[range(end_idx - start_idx), :, :, :]

        duration = time.time() - start_time
        print('processing step %d/%d (%.2f sec/step)' % (ipatch + 1, len(batch_index) - 1, duration))

    #this is where all the patches are combined. the issue is --> I NEED A CERTAINTY FOR EVERY INDIVIDUAL PATCH
    result = tf_model_input_test.MergePatches_test(all_patches, flags['stride_test'], image_size, flags['size_input_patch'], flags['size_output_patch'], flags)

    result = tf.squeeze(result)
    result = np.asarray(result.eval())

    result = result * 255.0
    result = result.astype(np.uint8) 
    result = np.argmax(result, axis=2) #TODO: Changed from argmax to max
    result = KSimage.imresize(result, 4.0)
 
    return result