def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, recursion=None):

    print("KOD YENİ")
    dices = []
    images, labels = read_data('/scratch/cany/scribble/scribble_data/prostate_divided.h5')
    num_images = images.shape[0]
    print(str(num_images))
    print(str(images.shape))
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    mask_pl, softmax_pl,logits = model.predict_logits(images_pl, exp_config.model_handle, exp_config.nlabels)
    
    
    mask_tensor_shape = [batch_size] + list(exp_config.image_size)
  
#    images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')
    labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels')


    # Add to the Graph the Ops for loss calculation.

#    eval_val_loss = model.evaluation(logits,
#                                     labels_placeholder,
#                                     images_pl,
#                                     nlabels=exp_config.nlabels,
#                                     loss_type=exp_config.loss_type,
#                                     weak_supervision=True,
#                                     cnn_threshold=exp_config.cnn_threshold,
#                                     include_bg=False)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)
        if recursion is None:
            checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        else:
            try:
                checkpoint_path = utils.get_latest_model_checkpoint_path(model_path,
                                                                         'recursion_{}_model_best_dice.ckpt'.format(recursion))
            except:
                checkpoint_path = utils.get_latest_model_checkpoint_path(model_path,
                                                                         'recursion_{}_model.ckpt'.format(recursion))

        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        for k in range(num_images):
            network_input = np.expand_dims(np.expand_dims(images[k,:,:],2),0)
            mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict={images_pl: network_input})
            prediction_cropped = np.squeeze(logits_out[0, ...])
            
            # ASSEMBLE BACK THE SLICES
            prediction_arr = np.uint8(np.argmax(prediction_cropped, axis=-1))
            


#            prediction_arr = np.squeeze(np.transpose(np.asarray(prediction, dtype=np.uint8), (1,2,0)))
#    
                           
            mask = labels[k,:,:]
    
                            # This is the same for 2D and 3D again
            if do_postprocessing:
                print("Entered post processing " + str(True))
                prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)
    
            
            # Save predicted mask
            out_file_name = os.path.join(output_folder, 'prediction', 'patient' + str(k) +'.nii.gz')

            logging.info('saving to: %s' % out_file_name)
            save_nii(out_file_name, prediction_arr)
    
            # Save GT image
            gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + str(k) + '.nii.gz')
            logging.info('saving to: %s' % gt_file_name)
            save_nii(gt_file_name, np.uint8(mask))
    
#            # Save difference mask between predictions and ground truth
#            difference_mask = np.where(np.abs(prediction_arr-mask) > 0, [1], [0])
#            difference_mask = np.asarray(difference_mask, dtype=np.uint8)
#            diff_file_name = os.path.join(output_folder,
#                                          'difference',
#                                          'patient' + str(k) + '.nii.gz')
#            logging.info('saving to: %s' % diff_file_name)
#            save_nii(diff_file_name, difference_mask)
    
            # Save image data to the same folder for convenience
            image_file_name = os.path.join(output_folder, 'image',
                                    'patient' + str(k)  + '.nii.gz')
            logging.info('saving to: %s' % image_file_name)
            save_nii(image_file_name, images[k,:,:])

#            feed_dict = { images_pl: network_input,
#                          labels_placeholder: np.expand_dims(np.squeeze(labels[k,:,:]),0),
#                    }
#
#            closs, cdice = sess.run(eval_val_loss, feed_dict=feed_dict)

#            print(str(prediction_arr.shape))
#            tempp= np.expand_dims(np.squeeze(labels[k,:,:]),0)
#            print(str(tempp.shape))
#            qwe=tf.one_hot(np.uint8(np.squeeze(labels[k,:,:])), depth=4)
#            print(str(sess.run(tf.shape(qwe))))
#            tempp2 = tf.one_hot(prediction_arr, depth=4)
#            print(str(sess.run(tf.shape(tempp2))))
            cdice = sess.run(get_dice(tf.one_hot(np.uint8(prediction_arr), depth=4),np.uint8(np.squeeze(labels[k,:,:])),4))
            print(str(cdice))
#            [val_loss, val_dice] = do_eval(sess,
#                                                       eval_val_loss,
#                                                       images_placeholder,
#                                                       labels_placeholder,
#                                                     network_input,
#                                                       np.expand_dims(np.squeeze(labels[k,:,:]),0),
#                                                       exp_config.batch_size)
            dices.append(cdice)
    print("Average Dice : " + str(np.mean(dices)))
    return init_iteration
예제 #2
0
def postprocess(mask_out, images_train, scribbles_train=None):
    '''
    Postprocesses predictions of CNN to create ground truths for recursion
    :param data: Data of this recursion - e.g. if given data file for recursion n,
                 It will set up ground truths to be random walked for recursion n
    :param images_train: Numpy array of training images
    :param scribbles_train: Numpy array of weakly annotated images
    :return:
'''

    #get labels present
    labels = np.unique(scribbles_train)
    labels = labels[labels != 0]

    # use full segmentation of random walker as upper bound
    if exp_config.rw_intersection:
        rw_segmentation = random_walker.segment(images_train,
                                                scribbles_train,
                                                threshold=0,
                                                beta=exp_config.rw_beta)
        mask = mask_out[:]
        mask_out = np.zeros_like(mask_out)
        for label in labels:
            indices = (rw_segmentation == label)
            indices &= (mask == label)
            mask_out[indices] = label

    #revert to original random walked data for 'bad' prediction
    if exp_config.rw_reversion:
        mask = mask_out[:]
        mask_out = np.zeros_like(mask_out)
        for img_id in range(exp_config.batch_size):
            for label in labels:
                if np.sum(mask[img_id, ...] == label) < np.sum(
                        scribbles_train[img_id, ...] == label):
                    #If the prediction has predicted less than the original scribble, revert to
                    #the scribble
                    mask_out[img_id, scribbles_train[img_id,
                                                     ...] == label] = label
                else:
                    mask_out[img_id, mask[img_id, ...] == label] = label

    #keep only largest cluster for output
    if exp_config.keep_largest_cluster:
        for img_id in range(exp_config.batch_size):
            mask_out[img_id,
                     ...] = image_utils.keep_largest_connected_components(
                         np.squeeze(mask_out[img_id, ...]))

    if exp_config.smooth_edges:
        labels = labels[labels != np.max(labels)]
        for img_id in range(exp_config.batch_size):
            mask = mask_out[img_id, ...]
            new_mask = np.zeros_like(mask)
            for label in labels:
                struct = (mask == label).astype(np.float)
                blurred_struct = gaussian_filter(
                    struct, sigma=exp_config.edge_smoother_sigma)
                # ax = fig.add_subplot(161 + label)
                blurred_struct[
                    blurred_struct >= exp_config.edge_smoother_threshold] = 1
                blurred_struct[
                    blurred_struct < exp_config.edge_smoother_threshold] = 0
                new_mask[blurred_struct != 0] = label
            mask_out[img_id, ...] = new_mask

    return mask_out
예제 #3
0
def main(exp_config, batch_size=3):

    # Load data
    data = h5py.File(sys_config.project_root + exp_config.scribble_data, 'r')

    slices = np.random.randint(low=0,
                               high=data['images_test'].shape[0],
                               size=batch_size)
    slices = np.sort(np.unique(slices))
    slices = [80, 275, 370]
    batch_size = len(slices)
    images = data['images_test'][slices, ...]
    masks = data['masks_test'][slices, ...]
    #masks[masks == 0] = 4

    num_recursions = most_recent_recursion(model_path)

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle,
                                        exp_config.nlabels)
    #mask_fs_pl, softmax_fs_pl = model.predict(images_pl,  unet2D_bn_modified, 4)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    predictions = np.zeros([batch_size] + list(exp_config.image_size) +
                           [num_recursions + 1])

    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }
    path = '/scratch_net/'
    with tf.Session() as sess:
        sess.run(init)
        pred_size = 0
        for recursion in range(num_recursions + 1):
            try:
                try:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        model_path,
                        'recursion_{}_model_best_dice.ckpt'.format(recursion))
                except:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        model_path,
                        'recursion_{}_model.ckpt'.format(recursion))
                saver.restore(sess, checkpoint_path)
                mask_out, _ = sess.run([mask_pl, softmax_pl],
                                       feed_dict=feed_dict)
                for mask in range(batch_size):
                    predictions[
                        mask, ...,
                        pred_size] = image_utils.keep_largest_connected_components(
                            np.squeeze(mask_out[mask, ...]))
                print("Classified for recursion {}".format(recursion))
                pred_size += 1
            except Exception as e:
                print(e)

    num_recursions = pred_size
    fig = plt.figure()
    num_cols = num_recursions + 3
    #RW:
    path = base_path + "/poster/"
    for recursion in range(num_recursions):
        predictions[...,
                    recursion] = segment(images,
                                         np.squeeze(predictions[...,
                                                                recursion]),
                                         beta=exp_config.rw_beta,
                                         threshold=0)

    for r in range(batch_size):
        #Add the image
        # ax = fig.add_subplot(batch_size, num_cols, 1 + r*num_cols)
        # ax.axis('off')
        # ax.imshow(np.squeeze(images[r, ...]), cmap='gray')
        image_utils.print_grayscale(
            np.squeeze(images[r, ...]), path,
            '{}_{}_image'.format(exp_config.experiment_name, slices[r]))
        #Add the mask
        # ax = fig.add_subplot(batch_size, num_cols, 2 + r*num_cols)
        # ax.axis('off')
        # ax.imshow(np.squeeze(masks[r, ...]), vmin=0, vmax=4, cmap='jet')
        image_utils.print_coloured(
            np.squeeze(masks[r, ...]), path,
            '{}_{}_gt'.format(exp_config.experiment_name, slices[r]))

        #predictions[r, ...] = segment(images, np.squeeze(predictions[r, ...]), beta=exp_config.rw_beta, threshold=0)
        for recursion in range(num_recursions):
            #Add each prediction
            image_utils.print_coloured(
                np.squeeze(predictions[r, ..., recursion]), path,
                '{}_{}_pred_r{}'.format(exp_config.experiment_name, slices[r],
                                        recursion))
            #ax = fig.add_subplot(batch_size, num_cols, 3 + recursion + r*num_cols)
            #ax.axis('off')
            #ax.imshow(np.squeeze(predictions[r, ..., recursion]), vmin=0, vmax=4, cmap='jet')

    while True:
        plt.axis('off')
        plt.show()
def score_data(input_folder,
               output_folder,
               model_path,
               exp_config,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False,
               use_iter=None):

    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')

    mask_pl, softmax_pl = model.predict(images_pl, exp_config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    evaluate_test_set = not gt_exists

    with tf.Session() as sess:

        sess.run(init)

        if not use_iter:
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                model_path, 'model_best_dice.ckpt')
        else:
            checkpoint_path = os.path.join(model_path,
                                           'model.ckpt-%d' % use_iter)

        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0

        for folder in os.listdir(input_folder):

            folder_path = os.path.join(input_folder, folder)

            if os.path.isdir(folder_path):

                if evaluate_test_set or evaluate_all:
                    train_test = 'test'  # always test
                else:
                    train_test = 'test' if (int(folder[-3:]) %
                                            5 == 0) else 'train'

                if train_test == 'test':

                    infos = {}
                    for line in open(os.path.join(folder_path, 'Info.cfg')):
                        label, value = line.split(':')
                        infos[label] = value.rstrip('\n').lstrip(' ')

                    patient_id = folder.lstrip('patient')
                    ED_frame = int(infos['ED'])
                    ES_frame = int(infos['ES'])

                    for file in glob.glob(
                            os.path.join(folder_path,
                                         'patient???_frame??.nii.gz')):

                        logging.info(
                            ' ----- Doing image: -------------------------')
                        logging.info('Doing: %s' % file)
                        logging.info(
                            ' --------------------------------------------')

                        file_base = file.split('.nii.gz')[0]

                        frame = int(file_base.split('frame')[-1])
                        img_dat = utils.load_nii(file)
                        img = img_dat[0].copy()
                        img = image_utils.normalise_image(img)

                        if gt_exists:
                            file_mask = file_base + '_gt.nii.gz'
                            mask_dat = utils.load_nii(file_mask)
                            mask = mask_dat[0]

                        start_time = time.time()

                        if exp_config.data_mode == '2D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1],
                                          img_dat[2].structarr['pixdim'][2])
                            scale_vector = (pixel_size[0] /
                                            exp_config.target_resolution[0],
                                            pixel_size[1] /
                                            exp_config.target_resolution[1])

                            predictions = []

                            for zz in range(img.shape[2]):

                                slice_img = np.squeeze(img[:, :, zz])
                                slice_rescaled = transform.rescale(
                                    slice_img,
                                    scale_vector,
                                    order=1,
                                    preserve_range=True,
                                    multichannel=False,
                                    mode='constant')

                                x, y = slice_rescaled.shape

                                x_s = (x - nx) // 2
                                y_s = (y - ny) // 2
                                x_c = (nx - x) // 2
                                y_c = (ny - y) // 2

                                # Crop section of image for prediction
                                if x > nx and y > ny:
                                    slice_cropped = slice_rescaled[x_s:x_s +
                                                                   nx,
                                                                   y_s:y_s +
                                                                   ny]
                                else:
                                    slice_cropped = np.zeros((nx, ny))
                                    if x <= nx and y > ny:
                                        slice_cropped[
                                            x_c:x_c +
                                            x, :] = slice_rescaled[:, y_s:y_s +
                                                                   ny]
                                    elif x > nx and y <= ny:
                                        slice_cropped[:, y_c:y_c +
                                                      y] = slice_rescaled[
                                                          x_s:x_s + nx, :]
                                    else:
                                        slice_cropped[x_c:x_c + x, y_c:y_c +
                                                      y] = slice_rescaled[:, :]

                                # GET PREDICTION
                                network_input = np.float32(
                                    np.tile(
                                        np.reshape(slice_cropped, (nx, ny, 1)),
                                        (batch_size, 1, 1, 1)))
                                mask_out, logits_out = sess.run(
                                    [mask_pl, softmax_pl],
                                    feed_dict={images_pl: network_input})
                                prediction_cropped = np.squeeze(
                                    logits_out[0, ...])

                                # ASSEMBLE BACK THE SLICES
                                slice_predictions = np.zeros(
                                    (x, y, num_channels))
                                # insert cropped region into original image again
                                if x > nx and y > ny:
                                    slice_predictions[
                                        x_s:x_s + nx,
                                        y_s:y_s + ny, :] = prediction_cropped
                                else:
                                    if x <= nx and y > ny:
                                        slice_predictions[:, y_s:y_s +
                                                          ny, :] = prediction_cropped[
                                                              x_c:x_c +
                                                              x, :, :]
                                    elif x > nx and y <= ny:
                                        slice_predictions[
                                            x_s:x_s +
                                            nx, :, :] = prediction_cropped[:,
                                                                           y_c:
                                                                           y_c +
                                                                           y, :]
                                    else:
                                        slice_predictions[:, :, :] = prediction_cropped[
                                            x_c:x_c + x, y_c:y_c + y, :]

                                # RESCALING ON THE LOGITS
                                if gt_exists:
                                    prediction = transform.resize(
                                        slice_predictions,
                                        (mask.shape[0], mask.shape[1],
                                         num_channels),
                                        order=1,
                                        preserve_range=True,
                                        mode='constant')
                                else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                                    # we use the gt mask size for resizing.
                                    prediction = transform.rescale(
                                        slice_predictions,
                                        (1.0 / scale_vector[0],
                                         1.0 / scale_vector[1], 1),
                                        order=1,
                                        preserve_range=True,
                                        multichannel=False,
                                        mode='constant')

                                # prediction = transform.resize(slice_predictions,
                                #                               (mask.shape[0], mask.shape[1], num_channels),
                                #                               order=1,
                                #                               preserve_range=True,
                                #                               mode='constant')

                                prediction = np.uint8(
                                    np.argmax(prediction, axis=-1))
                                predictions.append(prediction)

                            prediction_arr = np.transpose(
                                np.asarray(predictions, dtype=np.uint8),
                                (1, 2, 0))

                        elif exp_config.data_mode == '3D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1],
                                          img_dat[2].structarr['pixdim'][2],
                                          img_dat[2].structarr['pixdim'][3])

                            scale_vector = (pixel_size[0] /
                                            exp_config.target_resolution[0],
                                            pixel_size[1] /
                                            exp_config.target_resolution[1],
                                            pixel_size[2] /
                                            exp_config.target_resolution[2])

                            vol_scaled = transform.rescale(img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                            nz_max = exp_config.image_size[2]
                            slice_vol = np.zeros((nx, ny, nz_max),
                                                 dtype=np.float32)

                            nz_curr = vol_scaled.shape[2]
                            stack_from = (nz_max - nz_curr) // 2
                            stack_counter = stack_from

                            x, y, z = vol_scaled.shape

                            x_s = (x - nx) // 2
                            y_s = (y - ny) // 2
                            x_c = (nx - x) // 2
                            y_c = (ny - y) // 2

                            for zz in range(nz_curr):

                                slice_rescaled = vol_scaled[:, :, zz]

                                if x > nx and y > ny:
                                    slice_cropped = slice_rescaled[x_s:x_s +
                                                                   nx,
                                                                   y_s:y_s +
                                                                   ny]
                                else:
                                    slice_cropped = np.zeros((nx, ny))
                                    if x <= nx and y > ny:
                                        slice_cropped[
                                            x_c:x_c +
                                            x, :] = slice_rescaled[:, y_s:y_s +
                                                                   ny]
                                    elif x > nx and y <= ny:
                                        slice_cropped[:, y_c:y_c +
                                                      y] = slice_rescaled[
                                                          x_s:x_s + nx, :]

                                    else:
                                        slice_cropped[x_c:x_c + x, y_c:y_c +
                                                      y] = slice_rescaled[:, :]

                                slice_vol[:, :, stack_counter] = slice_cropped
                                stack_counter += 1

                            stack_to = stack_counter

                            network_input = np.float32(
                                np.reshape(slice_vol, (1, nx, ny, nz_max, 1)))

                            start_time = time.time()
                            mask_out, logits_out = sess.run(
                                [mask_pl, softmax_pl],
                                feed_dict={images_pl: network_input})

                            logging.info('Classified 3D: %f secs' %
                                         (time.time() - start_time))

                            prediction_nzs = logits_out[0, :, :,
                                                        stack_from:stack_to,
                                                        ...]  # non-zero-slices

                            if not prediction_nzs.shape[2] == nz_curr:
                                raise ValueError('sizes mismatch')

                            # ASSEMBLE BACK THE SLICES
                            prediction_scaled = np.zeros(
                                list(vol_scaled.shape) +
                                [num_channels
                                 ])  # last dim is for logits classes

                            # insert cropped region into original image again
                            if x > nx and y > ny:
                                prediction_scaled[x_s:x_s + nx,
                                                  y_s:y_s + ny, :,
                                                  ...] = prediction_nzs
                            else:
                                if x <= nx and y > ny:
                                    prediction_scaled[:, y_s:y_s + ny, :,
                                                      ...] = prediction_nzs[
                                                          x_c:x_c + x, :, :,
                                                          ...]
                                elif x > nx and y <= ny:
                                    prediction_scaled[
                                        x_s:x_s +
                                        nx, :, :...] = prediction_nzs[:,
                                                                      y_c:y_c +
                                                                      y, :...]
                                else:
                                    prediction_scaled[:, :, :
                                                      ...] = prediction_nzs[
                                                          x_c:x_c + x,
                                                          y_c:y_c + y, :...]

                            logging.info('Prediction_scaled mean %f' %
                                         (np.mean(prediction_scaled)))

                            prediction = transform.resize(
                                prediction_scaled,
                                (mask.shape[0], mask.shape[1], mask.shape[2],
                                 num_channels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                            prediction = np.argmax(prediction, axis=-1)
                            prediction_arr = np.asarray(prediction,
                                                        dtype=np.uint8)

                        # This is the same for 2D and 3D again
                        if do_postprocessing:
                            prediction_arr = image_utils.keep_largest_connected_components(
                                prediction_arr)

                        elapsed_time = time.time() - start_time
                        total_time += elapsed_time
                        total_volumes += 1

                        logging.info('Evaluation of volume took %f secs.' %
                                     elapsed_time)

                        if frame == ED_frame:
                            frame_suffix = '_ED'
                        elif frame == ES_frame:
                            frame_suffix = '_ES'
                        else:
                            raise ValueError(
                                'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                                % (frame, ED_frame, ES_frame))

                        # Save prediced mask
                        out_file_name = os.path.join(
                            output_folder, 'prediction',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        if gt_exists:
                            out_affine = mask_dat[1]
                            out_header = mask_dat[2]
                        else:
                            out_affine = img_dat[1]
                            out_header = img_dat[2]

                        logging.info('saving to: %s' % out_file_name)
                        utils.save_nii(out_file_name, prediction_arr,
                                       out_affine, out_header)

                        # Save image data to the same folder for convenience
                        image_file_name = os.path.join(
                            output_folder, 'image',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % image_file_name)
                        utils.save_nii(image_file_name, img_dat[0], out_affine,
                                       out_header)

                        if gt_exists:

                            # Save GT image
                            gt_file_name = os.path.join(
                                output_folder, 'ground_truth', 'patient' +
                                patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % gt_file_name)
                            utils.save_nii(gt_file_name, mask, out_affine,
                                           out_header)

                            # Save difference mask between predictions and ground truth
                            difference_mask = np.where(
                                np.abs(prediction_arr - mask) > 0, [1], [0])
                            difference_mask = np.asarray(difference_mask,
                                                         dtype=np.uint8)
                            diff_file_name = os.path.join(
                                output_folder, 'difference', 'patient' +
                                patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % diff_file_name)
                            utils.save_nii(diff_file_name, difference_mask,
                                           out_affine, out_header)

        logging.info('Average time per volume: %f' %
                     (total_time / total_volumes))

    return init_iteration
예제 #5
0
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True):

    nx, ny = config.image_size[:2]
    batch_size = 1
    num_channels = config.nlabels

    image_tensor_shape = [batch_size] + list(config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    # According to the experiment config, pick a model and predict the output
    mask_pl, softmax_pl = model.predict(images_pl, config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0
        scale_vector = [config.pixel_size[0] / target_resolution[0], config.pixel_size[1] / target_resolution[1]]

        path_img = os.path.join(input_folder, 'img')
        if gt_exists:
            path_mask = os.path.join(input_folder, 'mask')
        
        for folder in os.listdir(path_img):
            
            logging.info(' ----- Doing image: -------------------------')
            logging.info('Doing: %s' % folder)
            logging.info(' --------------------------------------------')
            folder_path = os.path.join(path_img, folder)   #ciclo su cartelle paz
            
            utils.makefolder(os.path.join(path_pred, folder))
            
            if os.path.isdir(folder_path):
                
                for phase in os.listdir(folder_path):    #ciclo su cartelle ED ES
                    
                    save_path = os.path.join(path_pred, folder, phase)
                    utils.makefolder(save_path)
                    
                    predictions = []
                    mask_arr = []
                    img_arr = []
                    masks = []
                    imgs = []
                    path = os.path.join(folder_path, phase)
                    for file in os.listdir(path):
                        img = plt.imread(os.path.join(path,file))
                        if config.standardize:
                            img = image_utils.standardize_image(img)
                        if config.normalize:
                            img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)
                        img_arr.append(img)
                    if  gt_exists:
                        for file in os.listdir(os.path.join(path_mask,folder,phase)):
                            mask_arr.append(plt.imread(os.path.join(path_mask,folder,phase,file)))
                    
                    img_arr = np.transpose(np.asarray(img_arr),(1,2,0))      # x,y,N
                    if  gt_exists:
                        mask_arr = np.transpose(np.asarray(mask_arr),(1,2,0))
                    
                    start_time = time.time()
                    
                    if config.data_mode == '2D':
                        
                        for zz in range(img_arr.shape[2]):
                            
                            slice_img = np.squeeze(img_arr[:,:,zz])
                            slice_rescaled = transform.rescale(slice_img,
                                                               scale_vector,
                                                               order=1,
                                                               preserve_range=True,
                                                               multichannel=False,
                                                               anti_aliasing=True,
                                                               mode='constant')
                                
                            slice_mask = np.squeeze(mask_arr[:, :, zz])
                            slice_cropped = read_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny)
                            slice_cropped = np.float32(slice_cropped)
                            x = image_utils.reshape_2Dimage_to_tensor(slice_cropped)
                            imgs.append(np.squeeze(x))
                            if gt_exists:
                                mask_rescaled = transform.rescale(slice_mask,
                                                                  scale_vector,
                                                                  order=0,
                                                                  preserve_range=True,
                                                                  multichannel=False,
                                                                  anti_aliasing=True,
                                                                  mode='constant')

                                mask_cropped = read_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny)
                                mask_cropped = np.asarray(mask_cropped, dtype=np.uint8)
                                y = image_utils.reshape_2Dimage_to_tensor(mask_cropped)
                                masks.append(np.squeeze(y))
                                
                            # GET PREDICTION
                            feed_dict = {
                            images_pl: x,
                            }
                                
                            mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)
                            
                            prediction_cropped = np.squeeze(logits_out[0,...])

                            # ASSEMBLE BACK THE SLICES
                            slice_predictions = np.zeros((nx,ny,num_channels))
                            slice_predictions = prediction_cropped
                            # RESCALING ON THE LOGITS
                            if gt_exists:
                                prediction = transform.resize(slice_predictions,
                                                              (nx, ny, num_channels),
                                                              order=1,
                                                              preserve_range=True,
                                                              anti_aliasing=True,
                                                              mode='constant')
                            else:
                                prediction = transform.rescale(slice_predictions,
                                                               (1.0/scale_vector[0], 1.0/scale_vector[1], 1),
                                                               order=1,
                                                               preserve_range=True,
                                                               multichannel=False,
                                                               anti_aliasing=True,
                                                               mode='constant')

                            prediction = np.uint8(np.argmax(prediction, axis=-1))
                            
                            predictions.append(prediction)
                            
       
                        predictions = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0))
                        masks = np.transpose(np.asarray(masks, dtype=np.uint8), (1,2,0))
                        imgs = np.transpose(np.asarray(imgs, dtype=np.float32), (1,2,0))                   

                            
                    # This is the same for 2D and 3D
                    if do_postprocessing:
                        predictions = image_utils.keep_largest_connected_components(predictions)

                    elapsed_time = time.time() - start_time
                    total_time += elapsed_time
                    total_volumes += 1

                    logging.info('Evaluation of volume took %f secs.' % elapsed_time)

                    
                    # Save predicted mask
                    for ii in range(predictions.shape[2]):
                        image_file_name = os.path.join('paz', str(ii).zfill(3) + '.png')
                        cv2.imwrite(os.path.join(save_path , image_file_name), np.squeeze(predictions[:,:,ii]))
                                        
                    if gt_exists:
예제 #6
0
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True, evaluate_all=False, use_iter=None):

    nx, ny = config.image_size[:2]
    batch_size = 1
    num_channels = config.nlabels

    image_tensor_shape = [batch_size] + list(config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images')

    # According to the experiment config, pick a model and predict the output
    # TODO: Implement majority voting using 3 models.
    mask_pl, softmax_pl = model.predict(images_pl, config)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    evaluate_test_set = not gt_exists

    with tf.Session() as sess:

        sess.run(init)

        checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1])

        total_time = 0
        total_volumes = 0

        for folder in os.listdir(input_folder):

            folder_path = os.path.join(input_folder, folder)

            if os.path.isdir(folder_path):

                if evaluate_test_set or evaluate_all:
                    train_test = 'test'  # always test
                else:
                    train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'


                if train_test == 'test':

                    infos = {}
                    for line in open(os.path.join(folder_path, 'Info.cfg')):
                        label, value = line.split(':')
                        infos[label] = value.rstrip('\n').lstrip(' ')

                    patient_id = folder.lstrip('patient')
                    ED_frame = int(infos['ED'])
                    ES_frame = int(infos['ES'])

                    for file in glob.glob(os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                        logging.info(' ----- Doing image: -------------------------')
                        logging.info('Doing: %s' % file)
                        logging.info(' --------------------------------------------')

                        file_base = file.split('.nii.gz')[0]

                        frame = int(file_base.split('frame')[-1])
                        img_dat = utils.load_nii(file)
                        img = img_dat[0].copy()
                        #img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX)
                        #img = image_utils.normalize_image(img)
                        print('img')
                        print(img.shape)
                        print(img.dtype)

                        if gt_exists:
                            file_mask = file_base + '_gt.nii.gz'
                            mask_dat = utils.load_nii(file_mask)
                            mask = mask_dat[0]

                        start_time = time.time()

                        if config.data_mode == '2D':

                            pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2])
                            scale_vector = (pixel_size[0] / config.target_resolution[0],
                                            pixel_size[1] / config.target_resolution[1])
                            print('pixel_size', pixel_size)
                            print('scale_vector', scale_vector)

                            predictions = []
                            mask_arr = []
                            img_arr = []

                            for zz in range(img.shape[2]):

                                slice_img = np.squeeze(img[:,:,zz])
                                slice_rescaled = transform.rescale(slice_img,
                                                                   scale_vector,
                                                                   order=1,
                                                                   preserve_range=True,
                                                                   multichannel=False,
                                                                   anti_aliasing=True,
                                                                   mode='constant')
                                print('slice_img', slice_img.shape)
                                print('slice_rescaled', slice_rescaled.shape)
                                
                                slice_mask = np.squeeze(mask[:, :, zz])
                                mask_rescaled = transform.rescale(slice_mask,
                                                                  scale_vector,
                                                                  order=0,
                                                                  preserve_range=True,
                                                                  multichannel=False,
                                                                  anti_aliasing=True,
                                                                  mode='constant')

                                slice_cropped = acdc_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny)
                                print('slice_cropped', slice_cropped.shape)
                                mask_cropped = acdc_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny)
                                
                                slice_cropped = np.float32(slice_cropped)
                                mask_cropped = np.asarray(mask_cropped, dtype=np.uint8)
                                
                                x = image_utils.reshape_2Dimage_to_tensor(slice_cropped)
                                y = image_utils.reshape_2Dimage_to_tensor(mask_cropped)
                                
                                # GET PREDICTION
                                feed_dict = {
                                images_pl: x,
                                }
                                
                                mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict)

                            
                                prediction_cropped = np.squeeze(logits_out[0,...])

                                # ASSEMBLE BACK THE SLICES
                                slice_predictions = np.zeros((nx,ny,num_channels))
                                slice_predictions = prediction_cropped
                                # RESCALING ON THE LOGITS
                                if gt_exists:
                                    prediction = transform.resize(slice_predictions,
                                                                  (nx, ny, num_channels),
                                                                  order=1,
                                                                  preserve_range=True,
                                                                  anti_aliasing=True,
                                                                  mode='constant')
                                

                                # prediction = transform.resize(slice_predictions,
                                #                               (mask.shape[0], mask.shape[1], num_channels),
                                #                               order=1,
                                #                               preserve_range=True,
                                #                               mode='constant')

                                prediction = np.uint8(np.argmax(prediction, axis=-1))
                                predictions.append(prediction)
                                mask_arr.append(np.squeeze(y))
                                img_arr.append(np.squeeze(x))
                                
                            prediction_arr = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0))
                            mask_arrs = np.transpose(np.asarray(mask_arr, dtype=np.uint8), (1,2,0))
                            img_arrs = np.transpose(np.asarray(img_arr, dtype=np.float32), (1,2,0))                   

                            
                        # This is the same for 2D and 3D again
                        if do_postprocessing:
                            prediction_arr = image_utils.keep_largest_connected_components(prediction_arr)

                        elapsed_time = time.time() - start_time
                        total_time += elapsed_time
                        total_volumes += 1

                        logging.info('Evaluation of volume took %f secs.' % elapsed_time)

                        if frame == ED_frame:
                            frame_suffix = '_ED'
                        elif frame == ES_frame:
                            frame_suffix = '_ES'
                        else:
                            raise ValueError('Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' %
                                             (frame, ED_frame, ES_frame))

                        # Save prediced mask
                        out_file_name = os.path.join(output_folder, 'prediction',
                                                     'patient' + patient_id + frame_suffix + '.nii.gz')
                        if gt_exists:
                            out_affine = mask_dat[1]
                            out_header = mask_dat[2]
                        else:
                            out_affine = img_dat[1]
                            out_header = img_dat[2]

                        logging.info('saving to: %s' % out_file_name)
                        utils.save_nii(out_file_name, prediction_arr, out_affine, out_header)

                        # Save image data to the same folder for convenience
                        image_file_name = os.path.join(output_folder, 'image',
                                                'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % image_file_name)
                        utils.save_nii(image_file_name, img_dat[0], out_affine, out_header)

                        if gt_exists:

                            # Save GT image
                            gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % gt_file_name)
                            utils.save_nii(gt_file_name, mask_arrs, out_affine, out_header)

                            # Save difference mask between predictions and ground truth
                            difference_mask = np.where(np.abs(prediction_arr-mask_arrs) > 0, [1], [0])
                            difference_mask = np.asarray(difference_mask, dtype=np.uint8)
                            
                   #         for zz in range(difference_mask.shape[2]):
                   #       
                   #             fig = plt.figure()
                   #             ax1 = fig.add_subplot(221)
                   #             ax1.set_axis_off()
                   #             ax1.imshow(img_arrs[:,:,zz])
                   #             ax2 = fig.add_subplot(222)
                   #             ax2.set_axis_off()
                   #             ax2.imshow(mask_arrs[:,:,zz])
                   #             ax3 = fig.add_subplot(223)
                   #             ax3.set_axis_off()
                   #             ax3.imshow(prediction_arr[:,:,zz])
                   #             ax1.title.set_text('a')
                   #             ax2.title.set_text('b')
                   #             ax3.title.set_text('c')
                   #             ax4 = fig.add_subplot(224)
                   #             ax4.set_axis_off()
                   #             ax4.imshow(difference_mask[:,:,zz], cmap=plt.cm.gnuplot)
                   #             ax1.title.set_text('a')
                   #             ax2.title.set_text('b')
                   #             ax3.title.set_text('c')
                   #             ax4.title.set_text('d')
                   #             plt.gray()
                   #             plt.show()
                            
                        
                            for zz in range(difference_mask.shape[2]):
                                plt.imshow(img_arrs[:,:,zz])
                                plt.gray()
                                plt.axis('off')
                                plt.show()
                                plt.imshow(mask_arrs[:,:,zz])
                                plt.gray()
                                plt.axis('off')
                                plt.show()
                                plt.imshow(prediction_arr[:,:,zz])
                                plt.gray()
                                plt.axis('off')
                                plt.show()
                                print('...')

                            diff_file_name = os.path.join(output_folder,
                                                          'difference',
                                                          'patient' + patient_id + frame_suffix + '.nii.gz')
                            logging.info('saving to: %s' % diff_file_name)
                            utils.save_nii(diff_file_name, difference_mask, out_affine, out_header)


        logging.info('Average time per volume: %f' % (total_time/total_volumes))

    return init_iteration
예제 #7
0
def main(ws_exp_config, slices, test):
    # Load data
    exp_dir = sys_config.project_root + 'acdc_logdir/' + ws_exp_config.experiment_name + '/'
    base_data = h5py.File(os.path.join(exp_dir, 'base_data.hdf5'), 'r')

    # Get number of recursions
    num_recursions = acdc_data.most_recent_recursion(
        sys_config.project_root + 'acdc_logdir/' +
        ws_exp_config.experiment_name)
    print(num_recursions)

    num_recursions += 1
    # Get images
    batch_size = len(slices)

    if test:
        slices = slices[slices < len(base_data['images_test'])]
        images = base_data['images_test'][slices, ...]
        gt = base_data['masks_test'][slices, ...]
        prefix = 'test'
    else:
        slices = slices[slices < len(base_data['images_train'])]
        images = base_data['images_train'][slices, ...]
        gt = base_data['masks_train'][slices, ...]
        scr = base_data['scribbles_train'][slices, ...]
        prefix = 'train'

    image_tensor_shape = [batch_size] + list(ws_exp_config.image_size) + [1]
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    feed_dict = {
        images_pl: np.expand_dims(images, -1),
    }

    #Get weak supervision predictions
    mask_pl, softmax_pl = model.predict(images_pl, ws_exp_config.model_handle,
                                        ws_exp_config.nlabels)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    predictions = np.zeros([batch_size] + list(ws_exp_config.image_size) +
                           [num_recursions])
    predictions_klc = np.zeros_like(predictions)
    predictions_rw = np.zeros_like(predictions)
    with tf.Session() as sess:
        sess.run(init)
        for recursion in range(num_recursions):
            try:
                try:
                    checkpoint_path = utils.get_latest_model_checkpoint_path(
                        ws_model_path,
                        'recursion_{}_model_best_xent.ckpt'.format(recursion))
                except:
                    try:
                        checkpoint_path = utils.get_latest_model_checkpoint_path(
                            ws_model_path,
                            'recursion_{}_model_best_dice.ckpt'.format(
                                recursion))
                    except:
                        checkpoint_path = utils.get_latest_model_checkpoint_path(
                            ws_model_path,
                            'recursion_{}_model.ckpt'.format(recursion))
                saver.restore(sess, checkpoint_path)
                mask_out, _ = sess.run([mask_pl, softmax_pl],
                                       feed_dict=feed_dict)
                predictions[..., recursion] = mask_out
                for i in range(batch_size):
                    predictions_klc[
                        i, :, :,
                        recursion] = image_utils.keep_largest_connected_components(
                            mask_out[i, ...])

                predictions_rw[..., recursion] = segment(
                    images,
                    np.squeeze(predictions_klc[..., recursion]),
                    beta=ws_exp_config.rw_beta,
                    threshold=0)

                print("Classified for recursion {}".format(recursion))
            except Exception:
                predictions[..., recursion] = -1 * np.zeros_like(
                    predictions[..., recursion])
                print("Could not find checkpoint for recursion {} - skipping".
                      format(recursion))

    for i in range(batch_size):
        pref = '{}{}'.format(prefix, slices[i])

        print_grayscale(images[i, ...],
                        filepath=OUTPUT_FOLDER,
                        filename='{}_image'.format(pref))
        print_coloured(gt[i, ...],
                       filepath=OUTPUT_FOLDER,
                       filename='{}_gt'.format(pref))
        for recursion in range(num_recursions):
            if np.max(predictions[i, :, :, recursion]) >= -0.5:
                print_coloured(predictions[i, :, :, recursion],
                               filepath=OUTPUT_FOLDER,
                               filename="{}_ws_pred_r{}".format(
                                   pref, recursion))
                print_coloured(predictions_klc[i, :, :, recursion],
                               filepath=OUTPUT_FOLDER,
                               filename="{}_ws_pred_klc_r{}".format(
                                   pref, recursion))
                print_coloured(predictions_rw[i, :, :, recursion],
                               filepath=OUTPUT_FOLDER,
                               filename="{}_ws_pred_klc_rw_r{}".format(
                                   pref, recursion))
                print("Dice coefficient for slice {} is {}".format(
                    slices[i],
                    dice(predictions_rw[i, :, :, recursion], gt[i, ...])))
        if not test:
            print_coloured(scr[i, ...],
                           filepath=OUTPUT_FOLDER,
                           filename='{}_scribble'.format(pref))