def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, recursion=None): print("KOD YENİ") dices = [] images, labels = read_data('/scratch/cany/scribble/scribble_data/prostate_divided.h5') num_images = images.shape[0] print(str(num_images)) print(str(images.shape)) nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = exp_config.nlabels image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl,logits = model.predict_logits(images_pl, exp_config.model_handle, exp_config.nlabels) mask_tensor_shape = [batch_size] + list(exp_config.image_size) # images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_placeholder = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') # Add to the Graph the Ops for loss calculation. # eval_val_loss = model.evaluation(logits, # labels_placeholder, # images_pl, # nlabels=exp_config.nlabels, # loss_type=exp_config.loss_type, # weak_supervision=True, # cnn_threshold=exp_config.cnn_threshold, # include_bg=False) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) if recursion is None: checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') else: try: checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'recursion_{}_model_best_dice.ckpt'.format(recursion)) except: checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'recursion_{}_model.ckpt'.format(recursion)) saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) for k in range(num_images): network_input = np.expand_dims(np.expand_dims(images[k,:,:],2),0) mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict={images_pl: network_input}) prediction_cropped = np.squeeze(logits_out[0, ...]) # ASSEMBLE BACK THE SLICES prediction_arr = np.uint8(np.argmax(prediction_cropped, axis=-1)) # prediction_arr = np.squeeze(np.transpose(np.asarray(prediction, dtype=np.uint8), (1,2,0))) # mask = labels[k,:,:] # This is the same for 2D and 3D again if do_postprocessing: print("Entered post processing " + str(True)) prediction_arr = image_utils.keep_largest_connected_components(prediction_arr) # Save predicted mask out_file_name = os.path.join(output_folder, 'prediction', 'patient' + str(k) +'.nii.gz') logging.info('saving to: %s' % out_file_name) save_nii(out_file_name, prediction_arr) # Save GT image gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + str(k) + '.nii.gz') logging.info('saving to: %s' % gt_file_name) save_nii(gt_file_name, np.uint8(mask)) # # Save difference mask between predictions and ground truth # difference_mask = np.where(np.abs(prediction_arr-mask) > 0, [1], [0]) # difference_mask = np.asarray(difference_mask, dtype=np.uint8) # diff_file_name = os.path.join(output_folder, # 'difference', # 'patient' + str(k) + '.nii.gz') # logging.info('saving to: %s' % diff_file_name) # save_nii(diff_file_name, difference_mask) # Save image data to the same folder for convenience image_file_name = os.path.join(output_folder, 'image', 'patient' + str(k) + '.nii.gz') logging.info('saving to: %s' % image_file_name) save_nii(image_file_name, images[k,:,:]) # feed_dict = { images_pl: network_input, # labels_placeholder: np.expand_dims(np.squeeze(labels[k,:,:]),0), # } # # closs, cdice = sess.run(eval_val_loss, feed_dict=feed_dict) # print(str(prediction_arr.shape)) # tempp= np.expand_dims(np.squeeze(labels[k,:,:]),0) # print(str(tempp.shape)) # qwe=tf.one_hot(np.uint8(np.squeeze(labels[k,:,:])), depth=4) # print(str(sess.run(tf.shape(qwe)))) # tempp2 = tf.one_hot(prediction_arr, depth=4) # print(str(sess.run(tf.shape(tempp2)))) cdice = sess.run(get_dice(tf.one_hot(np.uint8(prediction_arr), depth=4),np.uint8(np.squeeze(labels[k,:,:])),4)) print(str(cdice)) # [val_loss, val_dice] = do_eval(sess, # eval_val_loss, # images_placeholder, # labels_placeholder, # network_input, # np.expand_dims(np.squeeze(labels[k,:,:]),0), # exp_config.batch_size) dices.append(cdice) print("Average Dice : " + str(np.mean(dices))) return init_iteration
def postprocess(mask_out, images_train, scribbles_train=None): ''' Postprocesses predictions of CNN to create ground truths for recursion :param data: Data of this recursion - e.g. if given data file for recursion n, It will set up ground truths to be random walked for recursion n :param images_train: Numpy array of training images :param scribbles_train: Numpy array of weakly annotated images :return: ''' #get labels present labels = np.unique(scribbles_train) labels = labels[labels != 0] # use full segmentation of random walker as upper bound if exp_config.rw_intersection: rw_segmentation = random_walker.segment(images_train, scribbles_train, threshold=0, beta=exp_config.rw_beta) mask = mask_out[:] mask_out = np.zeros_like(mask_out) for label in labels: indices = (rw_segmentation == label) indices &= (mask == label) mask_out[indices] = label #revert to original random walked data for 'bad' prediction if exp_config.rw_reversion: mask = mask_out[:] mask_out = np.zeros_like(mask_out) for img_id in range(exp_config.batch_size): for label in labels: if np.sum(mask[img_id, ...] == label) < np.sum( scribbles_train[img_id, ...] == label): #If the prediction has predicted less than the original scribble, revert to #the scribble mask_out[img_id, scribbles_train[img_id, ...] == label] = label else: mask_out[img_id, mask[img_id, ...] == label] = label #keep only largest cluster for output if exp_config.keep_largest_cluster: for img_id in range(exp_config.batch_size): mask_out[img_id, ...] = image_utils.keep_largest_connected_components( np.squeeze(mask_out[img_id, ...])) if exp_config.smooth_edges: labels = labels[labels != np.max(labels)] for img_id in range(exp_config.batch_size): mask = mask_out[img_id, ...] new_mask = np.zeros_like(mask) for label in labels: struct = (mask == label).astype(np.float) blurred_struct = gaussian_filter( struct, sigma=exp_config.edge_smoother_sigma) # ax = fig.add_subplot(161 + label) blurred_struct[ blurred_struct >= exp_config.edge_smoother_threshold] = 1 blurred_struct[ blurred_struct < exp_config.edge_smoother_threshold] = 0 new_mask[blurred_struct != 0] = label mask_out[img_id, ...] = new_mask return mask_out
def main(exp_config, batch_size=3): # Load data data = h5py.File(sys_config.project_root + exp_config.scribble_data, 'r') slices = np.random.randint(low=0, high=data['images_test'].shape[0], size=batch_size) slices = np.sort(np.unique(slices)) slices = [80, 275, 370] batch_size = len(slices) images = data['images_test'][slices, ...] masks = data['masks_test'][slices, ...] #masks[masks == 0] = 4 num_recursions = most_recent_recursion(model_path) image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, exp_config.model_handle, exp_config.nlabels) #mask_fs_pl, softmax_fs_pl = model.predict(images_pl, unet2D_bn_modified, 4) saver = tf.train.Saver() init = tf.global_variables_initializer() predictions = np.zeros([batch_size] + list(exp_config.image_size) + [num_recursions + 1]) feed_dict = { images_pl: np.expand_dims(images, -1), } path = '/scratch_net/' with tf.Session() as sess: sess.run(init) pred_size = 0 for recursion in range(num_recursions + 1): try: try: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model_best_dice.ckpt'.format(recursion)) except: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'recursion_{}_model.ckpt'.format(recursion)) saver.restore(sess, checkpoint_path) mask_out, _ = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) for mask in range(batch_size): predictions[ mask, ..., pred_size] = image_utils.keep_largest_connected_components( np.squeeze(mask_out[mask, ...])) print("Classified for recursion {}".format(recursion)) pred_size += 1 except Exception as e: print(e) num_recursions = pred_size fig = plt.figure() num_cols = num_recursions + 3 #RW: path = base_path + "/poster/" for recursion in range(num_recursions): predictions[..., recursion] = segment(images, np.squeeze(predictions[..., recursion]), beta=exp_config.rw_beta, threshold=0) for r in range(batch_size): #Add the image # ax = fig.add_subplot(batch_size, num_cols, 1 + r*num_cols) # ax.axis('off') # ax.imshow(np.squeeze(images[r, ...]), cmap='gray') image_utils.print_grayscale( np.squeeze(images[r, ...]), path, '{}_{}_image'.format(exp_config.experiment_name, slices[r])) #Add the mask # ax = fig.add_subplot(batch_size, num_cols, 2 + r*num_cols) # ax.axis('off') # ax.imshow(np.squeeze(masks[r, ...]), vmin=0, vmax=4, cmap='jet') image_utils.print_coloured( np.squeeze(masks[r, ...]), path, '{}_{}_gt'.format(exp_config.experiment_name, slices[r])) #predictions[r, ...] = segment(images, np.squeeze(predictions[r, ...]), beta=exp_config.rw_beta, threshold=0) for recursion in range(num_recursions): #Add each prediction image_utils.print_coloured( np.squeeze(predictions[r, ..., recursion]), path, '{}_{}_pred_r{}'.format(exp_config.experiment_name, slices[r], recursion)) #ax = fig.add_subplot(batch_size, num_cols, 3 + recursion + r*num_cols) #ax.axis('off') #ax.imshow(np.squeeze(predictions[r, ..., recursion]), vmin=0, vmax=4, cmap='jet') while True: plt.axis('off') plt.show()
def score_data(input_folder, output_folder, model_path, exp_config, do_postprocessing=False, gt_exists=True, evaluate_all=False, use_iter=None): nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = exp_config.nlabels image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') mask_pl, softmax_pl = model.predict(images_pl, exp_config) saver = tf.train.Saver() init = tf.global_variables_initializer() evaluate_test_set = not gt_exists with tf.Session() as sess: sess.run(init) if not use_iter: checkpoint_path = utils.get_latest_model_checkpoint_path( model_path, 'model_best_dice.ckpt') else: checkpoint_path = os.path.join(model_path, 'model.ckpt-%d' % use_iter) saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) total_time = 0 total_volumes = 0 for folder in os.listdir(input_folder): folder_path = os.path.join(input_folder, folder) if os.path.isdir(folder_path): if evaluate_test_set or evaluate_all: train_test = 'test' # always test else: train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train' if train_test == 'test': infos = {} for line in open(os.path.join(folder_path, 'Info.cfg')): label, value = line.split(':') infos[label] = value.rstrip('\n').lstrip(' ') patient_id = folder.lstrip('patient') ED_frame = int(infos['ED']) ES_frame = int(infos['ES']) for file in glob.glob( os.path.join(folder_path, 'patient???_frame??.nii.gz')): logging.info( ' ----- Doing image: -------------------------') logging.info('Doing: %s' % file) logging.info( ' --------------------------------------------') file_base = file.split('.nii.gz')[0] frame = int(file_base.split('frame')[-1]) img_dat = utils.load_nii(file) img = img_dat[0].copy() img = image_utils.normalise_image(img) if gt_exists: file_mask = file_base + '_gt.nii.gz' mask_dat = utils.load_nii(file_mask) mask = mask_dat[0] start_time = time.time() if exp_config.data_mode == '2D': pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2]) scale_vector = (pixel_size[0] / exp_config.target_resolution[0], pixel_size[1] / exp_config.target_resolution[1]) predictions = [] for zz in range(img.shape[2]): slice_img = np.squeeze(img[:, :, zz]) slice_rescaled = transform.rescale( slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, mode='constant') x, y = slice_rescaled.shape x_s = (x - nx) // 2 y_s = (y - ny) // 2 x_c = (nx - x) // 2 y_c = (ny - y) // 2 # Crop section of image for prediction if x > nx and y > ny: slice_cropped = slice_rescaled[x_s:x_s + nx, y_s:y_s + ny] else: slice_cropped = np.zeros((nx, ny)) if x <= nx and y > ny: slice_cropped[ x_c:x_c + x, :] = slice_rescaled[:, y_s:y_s + ny] elif x > nx and y <= ny: slice_cropped[:, y_c:y_c + y] = slice_rescaled[ x_s:x_s + nx, :] else: slice_cropped[x_c:x_c + x, y_c:y_c + y] = slice_rescaled[:, :] # GET PREDICTION network_input = np.float32( np.tile( np.reshape(slice_cropped, (nx, ny, 1)), (batch_size, 1, 1, 1))) mask_out, logits_out = sess.run( [mask_pl, softmax_pl], feed_dict={images_pl: network_input}) prediction_cropped = np.squeeze( logits_out[0, ...]) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros( (x, y, num_channels)) # insert cropped region into original image again if x > nx and y > ny: slice_predictions[ x_s:x_s + nx, y_s:y_s + ny, :] = prediction_cropped else: if x <= nx and y > ny: slice_predictions[:, y_s:y_s + ny, :] = prediction_cropped[ x_c:x_c + x, :, :] elif x > nx and y <= ny: slice_predictions[ x_s:x_s + nx, :, :] = prediction_cropped[:, y_c: y_c + y, :] else: slice_predictions[:, :, :] = prediction_cropped[ x_c:x_c + x, y_c:y_c + y, :] # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize( slice_predictions, (mask.shape[0], mask.shape[1], num_channels), order=1, preserve_range=True, mode='constant') else: # This can occasionally lead to wrong volume size, therefore if gt_exists # we use the gt mask size for resizing. prediction = transform.rescale( slice_predictions, (1.0 / scale_vector[0], 1.0 / scale_vector[1], 1), order=1, preserve_range=True, multichannel=False, mode='constant') # prediction = transform.resize(slice_predictions, # (mask.shape[0], mask.shape[1], num_channels), # order=1, # preserve_range=True, # mode='constant') prediction = np.uint8( np.argmax(prediction, axis=-1)) predictions.append(prediction) prediction_arr = np.transpose( np.asarray(predictions, dtype=np.uint8), (1, 2, 0)) elif exp_config.data_mode == '3D': pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2], img_dat[2].structarr['pixdim'][3]) scale_vector = (pixel_size[0] / exp_config.target_resolution[0], pixel_size[1] / exp_config.target_resolution[1], pixel_size[2] / exp_config.target_resolution[2]) vol_scaled = transform.rescale(img, scale_vector, order=1, preserve_range=True, multichannel=False, mode='constant') nz_max = exp_config.image_size[2] slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32) nz_curr = vol_scaled.shape[2] stack_from = (nz_max - nz_curr) // 2 stack_counter = stack_from x, y, z = vol_scaled.shape x_s = (x - nx) // 2 y_s = (y - ny) // 2 x_c = (nx - x) // 2 y_c = (ny - y) // 2 for zz in range(nz_curr): slice_rescaled = vol_scaled[:, :, zz] if x > nx and y > ny: slice_cropped = slice_rescaled[x_s:x_s + nx, y_s:y_s + ny] else: slice_cropped = np.zeros((nx, ny)) if x <= nx and y > ny: slice_cropped[ x_c:x_c + x, :] = slice_rescaled[:, y_s:y_s + ny] elif x > nx and y <= ny: slice_cropped[:, y_c:y_c + y] = slice_rescaled[ x_s:x_s + nx, :] else: slice_cropped[x_c:x_c + x, y_c:y_c + y] = slice_rescaled[:, :] slice_vol[:, :, stack_counter] = slice_cropped stack_counter += 1 stack_to = stack_counter network_input = np.float32( np.reshape(slice_vol, (1, nx, ny, nz_max, 1))) start_time = time.time() mask_out, logits_out = sess.run( [mask_pl, softmax_pl], feed_dict={images_pl: network_input}) logging.info('Classified 3D: %f secs' % (time.time() - start_time)) prediction_nzs = logits_out[0, :, :, stack_from:stack_to, ...] # non-zero-slices if not prediction_nzs.shape[2] == nz_curr: raise ValueError('sizes mismatch') # ASSEMBLE BACK THE SLICES prediction_scaled = np.zeros( list(vol_scaled.shape) + [num_channels ]) # last dim is for logits classes # insert cropped region into original image again if x > nx and y > ny: prediction_scaled[x_s:x_s + nx, y_s:y_s + ny, :, ...] = prediction_nzs else: if x <= nx and y > ny: prediction_scaled[:, y_s:y_s + ny, :, ...] = prediction_nzs[ x_c:x_c + x, :, :, ...] elif x > nx and y <= ny: prediction_scaled[ x_s:x_s + nx, :, :...] = prediction_nzs[:, y_c:y_c + y, :...] else: prediction_scaled[:, :, : ...] = prediction_nzs[ x_c:x_c + x, y_c:y_c + y, :...] logging.info('Prediction_scaled mean %f' % (np.mean(prediction_scaled))) prediction = transform.resize( prediction_scaled, (mask.shape[0], mask.shape[1], mask.shape[2], num_channels), order=1, preserve_range=True, mode='constant') prediction = np.argmax(prediction, axis=-1) prediction_arr = np.asarray(prediction, dtype=np.uint8) # This is the same for 2D and 3D again if do_postprocessing: prediction_arr = image_utils.keep_largest_connected_components( prediction_arr) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) if frame == ED_frame: frame_suffix = '_ED' elif frame == ES_frame: frame_suffix = '_ES' else: raise ValueError( 'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' % (frame, ED_frame, ES_frame)) # Save prediced mask out_file_name = os.path.join( output_folder, 'prediction', 'patient' + patient_id + frame_suffix + '.nii.gz') if gt_exists: out_affine = mask_dat[1] out_header = mask_dat[2] else: out_affine = img_dat[1] out_header = img_dat[2] logging.info('saving to: %s' % out_file_name) utils.save_nii(out_file_name, prediction_arr, out_affine, out_header) # Save image data to the same folder for convenience image_file_name = os.path.join( output_folder, 'image', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % image_file_name) utils.save_nii(image_file_name, img_dat[0], out_affine, out_header) if gt_exists: # Save GT image gt_file_name = os.path.join( output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % gt_file_name) utils.save_nii(gt_file_name, mask, out_affine, out_header) # Save difference mask between predictions and ground truth difference_mask = np.where( np.abs(prediction_arr - mask) > 0, [1], [0]) difference_mask = np.asarray(difference_mask, dtype=np.uint8) diff_file_name = os.path.join( output_folder, 'difference', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % diff_file_name) utils.save_nii(diff_file_name, difference_mask, out_affine, out_header) logging.info('Average time per volume: %f' % (total_time / total_volumes)) return init_iteration
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True): nx, ny = config.image_size[:2] batch_size = 1 num_channels = config.nlabels image_tensor_shape = [batch_size] + list(config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') # According to the experiment config, pick a model and predict the output mask_pl, softmax_pl = model.predict(images_pl, config) saver = tf.train.Saver() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) total_time = 0 total_volumes = 0 scale_vector = [config.pixel_size[0] / target_resolution[0], config.pixel_size[1] / target_resolution[1]] path_img = os.path.join(input_folder, 'img') if gt_exists: path_mask = os.path.join(input_folder, 'mask') for folder in os.listdir(path_img): logging.info(' ----- Doing image: -------------------------') logging.info('Doing: %s' % folder) logging.info(' --------------------------------------------') folder_path = os.path.join(path_img, folder) #ciclo su cartelle paz utils.makefolder(os.path.join(path_pred, folder)) if os.path.isdir(folder_path): for phase in os.listdir(folder_path): #ciclo su cartelle ED ES save_path = os.path.join(path_pred, folder, phase) utils.makefolder(save_path) predictions = [] mask_arr = [] img_arr = [] masks = [] imgs = [] path = os.path.join(folder_path, phase) for file in os.listdir(path): img = plt.imread(os.path.join(path,file)) if config.standardize: img = image_utils.standardize_image(img) if config.normalize: img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX) img_arr.append(img) if gt_exists: for file in os.listdir(os.path.join(path_mask,folder,phase)): mask_arr.append(plt.imread(os.path.join(path_mask,folder,phase,file))) img_arr = np.transpose(np.asarray(img_arr),(1,2,0)) # x,y,N if gt_exists: mask_arr = np.transpose(np.asarray(mask_arr),(1,2,0)) start_time = time.time() if config.data_mode == '2D': for zz in range(img_arr.shape[2]): slice_img = np.squeeze(img_arr[:,:,zz]) slice_rescaled = transform.rescale(slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') slice_mask = np.squeeze(mask_arr[:, :, zz]) slice_cropped = read_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny) slice_cropped = np.float32(slice_cropped) x = image_utils.reshape_2Dimage_to_tensor(slice_cropped) imgs.append(np.squeeze(x)) if gt_exists: mask_rescaled = transform.rescale(slice_mask, scale_vector, order=0, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') mask_cropped = read_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny) mask_cropped = np.asarray(mask_cropped, dtype=np.uint8) y = image_utils.reshape_2Dimage_to_tensor(mask_cropped) masks.append(np.squeeze(y)) # GET PREDICTION feed_dict = { images_pl: x, } mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) prediction_cropped = np.squeeze(logits_out[0,...]) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros((nx,ny,num_channels)) slice_predictions = prediction_cropped # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize(slice_predictions, (nx, ny, num_channels), order=1, preserve_range=True, anti_aliasing=True, mode='constant') else: prediction = transform.rescale(slice_predictions, (1.0/scale_vector[0], 1.0/scale_vector[1], 1), order=1, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') prediction = np.uint8(np.argmax(prediction, axis=-1)) predictions.append(prediction) predictions = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0)) masks = np.transpose(np.asarray(masks, dtype=np.uint8), (1,2,0)) imgs = np.transpose(np.asarray(imgs, dtype=np.float32), (1,2,0)) # This is the same for 2D and 3D if do_postprocessing: predictions = image_utils.keep_largest_connected_components(predictions) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) # Save predicted mask for ii in range(predictions.shape[2]): image_file_name = os.path.join('paz', str(ii).zfill(3) + '.png') cv2.imwrite(os.path.join(save_path , image_file_name), np.squeeze(predictions[:,:,ii])) if gt_exists:
def score_data(input_folder, output_folder, model_path, config, do_postprocessing=False, gt_exists=True, evaluate_all=False, use_iter=None): nx, ny = config.image_size[:2] batch_size = 1 num_channels = config.nlabels image_tensor_shape = [batch_size] + list(config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') # According to the experiment config, pick a model and predict the output # TODO: Implement majority voting using 3 models. mask_pl, softmax_pl = model.predict(images_pl, config) saver = tf.train.Saver() init = tf.global_variables_initializer() evaluate_test_set = not gt_exists with tf.Session() as sess: sess.run(init) checkpoint_path = utils.get_latest_model_checkpoint_path(model_path, 'model_best_dice.ckpt') saver.restore(sess, checkpoint_path) init_iteration = int(checkpoint_path.split('/')[-1].split('-')[-1]) total_time = 0 total_volumes = 0 for folder in os.listdir(input_folder): folder_path = os.path.join(input_folder, folder) if os.path.isdir(folder_path): if evaluate_test_set or evaluate_all: train_test = 'test' # always test else: train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train' if train_test == 'test': infos = {} for line in open(os.path.join(folder_path, 'Info.cfg')): label, value = line.split(':') infos[label] = value.rstrip('\n').lstrip(' ') patient_id = folder.lstrip('patient') ED_frame = int(infos['ED']) ES_frame = int(infos['ES']) for file in glob.glob(os.path.join(folder_path, 'patient???_frame??.nii.gz')): logging.info(' ----- Doing image: -------------------------') logging.info('Doing: %s' % file) logging.info(' --------------------------------------------') file_base = file.split('.nii.gz')[0] frame = int(file_base.split('frame')[-1]) img_dat = utils.load_nii(file) img = img_dat[0].copy() #img = cv2.normalize(img, dst=None, alpha=config.min, beta=config.max, norm_type=cv2.NORM_MINMAX) #img = image_utils.normalize_image(img) print('img') print(img.shape) print(img.dtype) if gt_exists: file_mask = file_base + '_gt.nii.gz' mask_dat = utils.load_nii(file_mask) mask = mask_dat[0] start_time = time.time() if config.data_mode == '2D': pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2]) scale_vector = (pixel_size[0] / config.target_resolution[0], pixel_size[1] / config.target_resolution[1]) print('pixel_size', pixel_size) print('scale_vector', scale_vector) predictions = [] mask_arr = [] img_arr = [] for zz in range(img.shape[2]): slice_img = np.squeeze(img[:,:,zz]) slice_rescaled = transform.rescale(slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') print('slice_img', slice_img.shape) print('slice_rescaled', slice_rescaled.shape) slice_mask = np.squeeze(mask[:, :, zz]) mask_rescaled = transform.rescale(slice_mask, scale_vector, order=0, preserve_range=True, multichannel=False, anti_aliasing=True, mode='constant') slice_cropped = acdc_data.crop_or_pad_slice_to_size(slice_rescaled, nx, ny) print('slice_cropped', slice_cropped.shape) mask_cropped = acdc_data.crop_or_pad_slice_to_size(mask_rescaled, nx, ny) slice_cropped = np.float32(slice_cropped) mask_cropped = np.asarray(mask_cropped, dtype=np.uint8) x = image_utils.reshape_2Dimage_to_tensor(slice_cropped) y = image_utils.reshape_2Dimage_to_tensor(mask_cropped) # GET PREDICTION feed_dict = { images_pl: x, } mask_out, logits_out = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) prediction_cropped = np.squeeze(logits_out[0,...]) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros((nx,ny,num_channels)) slice_predictions = prediction_cropped # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize(slice_predictions, (nx, ny, num_channels), order=1, preserve_range=True, anti_aliasing=True, mode='constant') # prediction = transform.resize(slice_predictions, # (mask.shape[0], mask.shape[1], num_channels), # order=1, # preserve_range=True, # mode='constant') prediction = np.uint8(np.argmax(prediction, axis=-1)) predictions.append(prediction) mask_arr.append(np.squeeze(y)) img_arr.append(np.squeeze(x)) prediction_arr = np.transpose(np.asarray(predictions, dtype=np.uint8), (1,2,0)) mask_arrs = np.transpose(np.asarray(mask_arr, dtype=np.uint8), (1,2,0)) img_arrs = np.transpose(np.asarray(img_arr, dtype=np.float32), (1,2,0)) # This is the same for 2D and 3D again if do_postprocessing: prediction_arr = image_utils.keep_largest_connected_components(prediction_arr) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) if frame == ED_frame: frame_suffix = '_ED' elif frame == ES_frame: frame_suffix = '_ES' else: raise ValueError('Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' % (frame, ED_frame, ES_frame)) # Save prediced mask out_file_name = os.path.join(output_folder, 'prediction', 'patient' + patient_id + frame_suffix + '.nii.gz') if gt_exists: out_affine = mask_dat[1] out_header = mask_dat[2] else: out_affine = img_dat[1] out_header = img_dat[2] logging.info('saving to: %s' % out_file_name) utils.save_nii(out_file_name, prediction_arr, out_affine, out_header) # Save image data to the same folder for convenience image_file_name = os.path.join(output_folder, 'image', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % image_file_name) utils.save_nii(image_file_name, img_dat[0], out_affine, out_header) if gt_exists: # Save GT image gt_file_name = os.path.join(output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % gt_file_name) utils.save_nii(gt_file_name, mask_arrs, out_affine, out_header) # Save difference mask between predictions and ground truth difference_mask = np.where(np.abs(prediction_arr-mask_arrs) > 0, [1], [0]) difference_mask = np.asarray(difference_mask, dtype=np.uint8) # for zz in range(difference_mask.shape[2]): # # fig = plt.figure() # ax1 = fig.add_subplot(221) # ax1.set_axis_off() # ax1.imshow(img_arrs[:,:,zz]) # ax2 = fig.add_subplot(222) # ax2.set_axis_off() # ax2.imshow(mask_arrs[:,:,zz]) # ax3 = fig.add_subplot(223) # ax3.set_axis_off() # ax3.imshow(prediction_arr[:,:,zz]) # ax1.title.set_text('a') # ax2.title.set_text('b') # ax3.title.set_text('c') # ax4 = fig.add_subplot(224) # ax4.set_axis_off() # ax4.imshow(difference_mask[:,:,zz], cmap=plt.cm.gnuplot) # ax1.title.set_text('a') # ax2.title.set_text('b') # ax3.title.set_text('c') # ax4.title.set_text('d') # plt.gray() # plt.show() for zz in range(difference_mask.shape[2]): plt.imshow(img_arrs[:,:,zz]) plt.gray() plt.axis('off') plt.show() plt.imshow(mask_arrs[:,:,zz]) plt.gray() plt.axis('off') plt.show() plt.imshow(prediction_arr[:,:,zz]) plt.gray() plt.axis('off') plt.show() print('...') diff_file_name = os.path.join(output_folder, 'difference', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % diff_file_name) utils.save_nii(diff_file_name, difference_mask, out_affine, out_header) logging.info('Average time per volume: %f' % (total_time/total_volumes)) return init_iteration
def main(ws_exp_config, slices, test): # Load data exp_dir = sys_config.project_root + 'acdc_logdir/' + ws_exp_config.experiment_name + '/' base_data = h5py.File(os.path.join(exp_dir, 'base_data.hdf5'), 'r') # Get number of recursions num_recursions = acdc_data.most_recent_recursion( sys_config.project_root + 'acdc_logdir/' + ws_exp_config.experiment_name) print(num_recursions) num_recursions += 1 # Get images batch_size = len(slices) if test: slices = slices[slices < len(base_data['images_test'])] images = base_data['images_test'][slices, ...] gt = base_data['masks_test'][slices, ...] prefix = 'test' else: slices = slices[slices < len(base_data['images_train'])] images = base_data['images_train'][slices, ...] gt = base_data['masks_train'][slices, ...] scr = base_data['scribbles_train'][slices, ...] prefix = 'train' image_tensor_shape = [batch_size] + list(ws_exp_config.image_size) + [1] images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') feed_dict = { images_pl: np.expand_dims(images, -1), } #Get weak supervision predictions mask_pl, softmax_pl = model.predict(images_pl, ws_exp_config.model_handle, ws_exp_config.nlabels) saver = tf.train.Saver() init = tf.global_variables_initializer() predictions = np.zeros([batch_size] + list(ws_exp_config.image_size) + [num_recursions]) predictions_klc = np.zeros_like(predictions) predictions_rw = np.zeros_like(predictions) with tf.Session() as sess: sess.run(init) for recursion in range(num_recursions): try: try: checkpoint_path = utils.get_latest_model_checkpoint_path( ws_model_path, 'recursion_{}_model_best_xent.ckpt'.format(recursion)) except: try: checkpoint_path = utils.get_latest_model_checkpoint_path( ws_model_path, 'recursion_{}_model_best_dice.ckpt'.format( recursion)) except: checkpoint_path = utils.get_latest_model_checkpoint_path( ws_model_path, 'recursion_{}_model.ckpt'.format(recursion)) saver.restore(sess, checkpoint_path) mask_out, _ = sess.run([mask_pl, softmax_pl], feed_dict=feed_dict) predictions[..., recursion] = mask_out for i in range(batch_size): predictions_klc[ i, :, :, recursion] = image_utils.keep_largest_connected_components( mask_out[i, ...]) predictions_rw[..., recursion] = segment( images, np.squeeze(predictions_klc[..., recursion]), beta=ws_exp_config.rw_beta, threshold=0) print("Classified for recursion {}".format(recursion)) except Exception: predictions[..., recursion] = -1 * np.zeros_like( predictions[..., recursion]) print("Could not find checkpoint for recursion {} - skipping". format(recursion)) for i in range(batch_size): pref = '{}{}'.format(prefix, slices[i]) print_grayscale(images[i, ...], filepath=OUTPUT_FOLDER, filename='{}_image'.format(pref)) print_coloured(gt[i, ...], filepath=OUTPUT_FOLDER, filename='{}_gt'.format(pref)) for recursion in range(num_recursions): if np.max(predictions[i, :, :, recursion]) >= -0.5: print_coloured(predictions[i, :, :, recursion], filepath=OUTPUT_FOLDER, filename="{}_ws_pred_r{}".format( pref, recursion)) print_coloured(predictions_klc[i, :, :, recursion], filepath=OUTPUT_FOLDER, filename="{}_ws_pred_klc_r{}".format( pref, recursion)) print_coloured(predictions_rw[i, :, :, recursion], filepath=OUTPUT_FOLDER, filename="{}_ws_pred_klc_rw_r{}".format( pref, recursion)) print("Dice coefficient for slice {} is {}".format( slices[i], dice(predictions_rw[i, :, :, recursion], gt[i, ...]))) if not test: print_coloured(scr[i, ...], filepath=OUTPUT_FOLDER, filename='{}_scribble'.format(pref))