def _multi_view_predict_on(image_pair, image_pair_loader, model, views, hparams, results, per_view_results, out_dir, args): from MultiPlanarUNet.utils.fusion import predict_volume, map_real_space_pred from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space # Set image_pair_loader object with only the given file image_pair_loader.images = [image_pair] n_classes = hparams["build"]["n_classes"] # Load views kwargs = hparams["fit"] kwargs.update(hparams["build"]) seq = image_pair_loader.get_sequencer(views=views, **kwargs) # Get voxel grid in real space voxel_grid_real_space = get_voxel_grid_real_space(image_pair) # Prepare tensor to store combined prediction d = image_pair.image.shape[:-1] combined = np.empty(shape=(len(views), d[0], d[1], d[2], n_classes), dtype=np.float32) print("Predicting on brain hyper-volume of shape:", combined.shape) # Predict for each view for n_view, view in enumerate(views): print("\n[*] (%i/%i) View: %s" % (n_view + 1, len(views), view)) # for each view, predict on all voxels and map the predictions # back into the original coordinate system # Sample planes from the image at grid_real_space grid # in real space (scanner RAS) coordinates. X, y, grid, inv_basis = seq.get_view_from(image_pair.id, view, n_planes="same+20") # Predict on volume using model pred = predict_volume(model, X, axis=2, batch_size=seq.batch_size) # Map the real space coordiante predictions to nearest # real space coordinates defined on voxel grid mapped_pred = map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space, method="nearest") combined[n_view] = mapped_pred if not args.no_eval: _per_view_evaluation(image_id=image_pair.id, pred=pred, true=y, mapped_pred=mapped_pred, mapped_true=image_pair.labels, view=view, n_classes=n_classes, results=results, per_view_results=per_view_results, out_dir=out_dir, args=args) return combined
def predict_and_map(model, seq, image, view, batch_size=None, voxel_grid_real_space=None, targets=None, eval_prob=1.0, n_planes='same+20', torch=False): """ Args: model: seq: image: view: batch_size: voxel_grid_real_space: targets: n_planes: torch: Returns: """ # Sample planes from the image at grid_real_space grid # in real space (scanner RAS) coordinates. X, y, grid, inv_basis = seq.get_view_from(image.id, view, n_planes=n_planes) # Predict on volume using model bs = seq.batch_size if batch_size is None else batch_size if not torch: # Normal prediction from MultiPlanarUNet.utils.fusion import predict_volume pred = predict_volume(model, X, axis=2, batch_size=bs) else: # Predict using a PyTorch model from MultiPlanarUNet.torch.utils.fusion import predict_volume pred = predict_volume(model, X, axis=2, batch_size=bs) # Map the real space coordiante predictions to nearest # real space coordinates defined on voxel grid if voxel_grid_real_space is None: from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space voxel_grid_real_space = get_voxel_grid_real_space(image) # Map the predicted volume to real space mapped = map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space) # Print dice scores if targets is not None and np.random.rand(1)[0] <= eval_prob: print("Computing evaluations...") print("View dice scores: ", dice_all(y, pred.argmax(-1), ignore_zero=False)) print("Mapped dice scores: ", dice_all(targets, mapped.argmax(-1).reshape(-1, 1), ignore_zero=False)) else: print("-- Skipping evaluation") return mapped
def _run_fusion_training(sets, logger, hparams, min_val_images, is_validation, views, n_classes, unet, fusion_model_org, fusion_model, early_stopping, fm_batch_size, epochs, eval_prob, fusion_weights_path): for _round, _set in enumerate(sets): s = "Set %i/%i:\n%s" % (_round + 1, len(sets), _set) logger("\n%s" % highlighted(s)) # Reload data images = ImagePairLoader(**hparams["val_data"]) if len(images) < min_val_images: images.add_images(ImagePairLoader(**hparams["train_data"])) # Get list of ImagePair objects to run on image_set_dict = {m.id: m for m in images if m.id in _set} # Fetch points from the set images points_collection = [] targets_collection = [] N_im = len(image_set_dict) for num_im, image_id in enumerate(list(image_set_dict.keys())): logger("") logger( highlighted("(%i/%i) Running on %s (%s)" % (num_im + 1, N_im, image_id, "val" if is_validation[image_id] else "train"))) # Set the current ImagePair image = image_set_dict[image_id] images.images = [image] # Load views kwargs = hparams["fit"] kwargs.update(hparams["build"]) seq = images.get_sequencer(views=views, **kwargs) # Get voxel grid in real space voxel_grid_real_space = get_voxel_grid_real_space(image) # Get array to store predictions across all views targets = image.labels.reshape(-1, 1) points = np.empty(shape=(len(targets), len(views), n_classes), dtype=np.float32) points.fill(np.nan) # Predict on all views for k, v in enumerate(views): print("\n%s" % "View: %s" % v) points[:, k, :] = predict_and_map( model=unet, seq=seq, image=image, view=v, voxel_grid_real_space=voxel_grid_real_space, n_planes='same+20', targets=targets, eval_prob=eval_prob).reshape(-1, n_classes) # Clean up a bit del image_set_dict[image_id] del image # Should be GC at this point anyway # add to collections points_collection.append(points) targets_collection.append(targets) # Stack points into one matrix logger("Stacking points...") X, y = stack_collections(points_collection, targets_collection) # Shuffle train print("Shuffling points...") X, y = shuffle(X, y) print("Getting validation set...") val_ind = int(0.20 * X.shape[0]) X_val, y_val = X[:val_ind], y[:val_ind] X, y = X[val_ind:], y[val_ind:] # Prepare dice score callback for validation data val_cb = ValDiceScores((X_val, y_val), n_classes, 50000, logger) # Callbacks cbs = [ val_cb, CSVLogger(filename="logs/fusion_training.csv", separator=",", append=True), PrintLayerWeights(fusion_model_org.layers[-1], every=1, first=1000, per_epoch=True, logger=logger) ] es = EarlyStopping(monitor='val_dice', min_delta=0.0, patience=early_stopping, verbose=1, mode='max') cbs.append(es) # Start training try: fusion_model.fit(X, y, batch_size=fm_batch_size, epochs=epochs, callbacks=cbs, verbose=1) except KeyboardInterrupt: pass fusion_model_org.save_weights(fusion_weights_path)
def predict_single(image, model, hparams, verbose=1): """ A generic prediction function that sets up a ImagePairLoader object for the given image, prepares the image and predicts. Note that this function should only be used for convinience in scripts that work on single images at a time anyway, as batch-preparing the entire ImagePairLoader object prior to prediction is faster. NOTE: Only works with iso_live intrp modes at this time """ mode = hparams["fit"]["intrp_style"].lower() assert mode in ("iso_live", "iso_live_3d") # Prepare image for prediction kwargs = hparams["fit"] kwargs.update(hparams["build"]) # Set verbose memory verb_mem = kwargs["verbose"] kwargs["verbose"] = verbose # Create a ImagePairLoader with only the given file from MultiPlanarUNet.image import ImagePairLoader image_pair_loader = ImagePairLoader(predict_mode=True, single_file_mode=True, no_log=bool(verbose)) image_pair_loader.add_image(image) # Get N classes n_classes = kwargs["n_classes"] if mode == "iso_live": # Add views if SMMV model kwargs["views"] = np.load(hparams.project_path + "/views.npz")["arr_0"] # Get sequence object sequence = image_pair_loader.get_sequencer(**kwargs) # Get voxel grid in real space voxel_grid_real_space = get_voxel_grid_real_space(image) # Prepare tensor to store combined prediction d = image.image.shape predicted = np.empty(shape=(len(kwargs["views"]), d[0], d[1], d[2], n_classes), dtype=np.float32) print("Predicting on brain hyper-volume of shape:", predicted.shape) for n_view, v in enumerate(kwargs["views"]): print("\nView %i/%i: %s" % (n_view+1, len(kwargs["views"]), v)) # Sample the volume along the view X, y, grid, inv_basis = sequence.get_view_from(image.id, v, n_planes="same+20") # Predict on volume using model pred = predict_volume(model, X, axis=2) # Map the real space coordiante predictions to nearest # real space coordinates defined on voxel grid predicted[n_view] = map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space, method="nearest") else: predicted = pred_3D_iso(model=model, sequence=image_pair_loader.get_sequencer(**kwargs), image=image, extra_boxes="3x", min_coverage=None) # Revert verbose mem kwargs["verbose"] = verb_mem return predicted
def entry_func(args=None): # Get command line arguments args = vars(get_argparser().parse_args(args)) base_dir = os.path.abspath(args["project_dir"]) analytical = args["analytical"] majority = args["majority"] _file = args["f"] label = args["l"] await_PID = args["wait_for"] eval_prob = args["eval_prob"] _continue = args["continue"] if analytical and majority: raise ValueError("Cannot specify both --analytical and --majority.") # Get settings from YAML file from MultiPlanarUNet.train.hparams import YAMLHParams hparams = YAMLHParams(os.path.join(base_dir, "train_hparams.yaml")) if not _file: try: # Data specified from command line? data_dir = os.path.abspath(args["data_dir"]) # Set with default sub dirs hparams["test_data"] = { "base_dir": data_dir, "img_subdir": "images", "label_subdir": "labels" } except (AttributeError, TypeError): data_dir = hparams["test_data"]["base_dir"] else: data_dir = False out_dir = os.path.abspath(args["out_dir"]) overwrite = args["overwrite"] predict_mode = args["no_eval"] save_input_files = args["save_input_files"] no_argmax = args["no_argmax"] on_val = args["on_val"] # Check if valid dir structures validate_folders(base_dir, out_dir, overwrite, _continue) # Import all needed modules (folder is valid at this point) import numpy as np from MultiPlanarUNet.image import ImagePairLoader, ImagePair from MultiPlanarUNet.models import FusionModel from MultiPlanarUNet.models.model_init import init_model from MultiPlanarUNet.utils import await_and_set_free_gpu, get_best_model, \ create_folders, pred_to_class, set_gpu from MultiPlanarUNet.logging import init_result_dicts, save_all, load_result_dicts from MultiPlanarUNet.evaluate import dice_all from MultiPlanarUNet.utils.fusion import predict_volume, map_real_space_pred from MultiPlanarUNet.interpolation.sample_grid import get_voxel_grid_real_space # Wait for PID? if await_PID: from MultiPlanarUNet.utils import await_PIDs await_PIDs(await_PID) # Set GPU device # Fetch GPU(s) num_GPUs = args["num_GPUs"] force_gpu = args["force_GPU"] # Wait for free GPU if not force_gpu: await_and_set_free_gpu(N=num_GPUs, sleep_seconds=120) num_GPUs = 1 else: set_gpu(force_gpu) num_GPUs = len(force_gpu.split(",")) # Read settings from the project hyperparameter file n_classes = hparams["build"]["n_classes"] # Get views views = np.load("%s/views.npz" % base_dir)["arr_0"] # Force settings hparams["fit"]["max_background"] = 1 hparams["fit"]["test_mode"] = True hparams["fit"]["mix_planes"] = False hparams["fit"]["live_intrp"] = False if "use_bounds" in hparams["fit"]: del hparams["fit"]["use_bounds"] del hparams["fit"]["views"] if hparams["build"]["out_activation"] == "linear": # Trained with logit targets? hparams["build"][ "out_activation"] = "softmax" if n_classes > 1 else "sigmoid" # Set ImagePairLoader object if not _file: data = "test_data" if not on_val else "val_data" image_pair_loader = ImagePairLoader(predict_mode=predict_mode, **hparams[data]) else: predict_mode = not bool(label) image_pair_loader = ImagePairLoader(predict_mode=predict_mode, single_file_mode=True) image_pair_loader.add_image(ImagePair(_file, label)) # Put them into a dict and remove from image_pair_loader to gain more control with # garbage collection all_images = {image.id: image for image in image_pair_loader.images} image_pair_loader.images = None if _continue: all_images = remove_already_predicted(all_images, out_dir) # Evaluate? if not predict_mode: if _continue: csv_dir = os.path.join(out_dir, "csv") results, detailed_res = load_result_dicts(csv_dir=csv_dir, views=views) else: # Prepare dictionary to store results in pd df results, detailed_res = init_result_dicts(views, all_images, n_classes) # Save to check correct format save_all(results, detailed_res, out_dir) # Define result paths nii_res_dir = os.path.join(out_dir, "nii_files") create_folders(nii_res_dir) """ Define UNet model """ model_path = get_best_model(base_dir + "/model") unet = init_model(hparams["build"]) unet.load_weights(model_path, by_name=True) if num_GPUs > 1: from tensorflow.keras.utils import multi_gpu_model n_classes = unet.n_classes unet = multi_gpu_model(unet, gpus=num_GPUs) unet.n_classes = n_classes weights_name = os.path.splitext(os.path.split(model_path)[1])[0] if not analytical and not majority: # Get Fusion model fm = FusionModel(n_inputs=len(views), n_classes=n_classes) weights = base_dir + "/model/fusion_weights/%s_fusion_weights.h5" % weights_name print("\n[*] Loading weights:\n", weights) # Load fusion weights fm.load_weights(weights) print("\nLoaded weights:\n\n%s\n%s\n---" % tuple(fm.layers[-1].get_weights())) # Multi-gpu? if num_GPUs > 1: print("Using multi-GPU model (%i GPUs)" % num_GPUs) fm = multi_gpu_model(fm, gpus=num_GPUs) """ Finally predict on the images """ image_ids = sorted(all_images) N_images = len(image_ids) for n_image, image_id in enumerate(image_ids): print("\n[*] (%i/%s) Running on: %s" % (n_image + 1, N_images, image_id)) # Set image_pair_loader object with only the given file image = all_images[image_id] image_pair_loader.images = [image] # Load views kwargs = hparams["fit"] kwargs.update(hparams["build"]) seq = image_pair_loader.get_sequencer(views=views, **kwargs) # Get voxel grid in real space voxel_grid_real_space = get_voxel_grid_real_space(image) # Prepare tensor to store combined prediction d = image.image.shape[:-1] if not majority: combined = np.empty(shape=(len(views), d[0], d[1], d[2], n_classes), dtype=np.float32) else: combined = np.empty(shape=(d[0], d[1], d[2], n_classes), dtype=np.float32) print("Predicting on brain hyper-volume of shape:", combined.shape) # Predict for each view for n_view, v in enumerate(views): print("\n[*] (%i/%i) View: %s" % (n_view + 1, len(views), v)) # for each view, predict on all voxels and map the predictions # back into the original coordinate system # Sample planes from the image at grid_real_space grid # in real space (scanner RAS) coordinates. X, y, grid, inv_basis = seq.get_view_from(image.id, v, n_planes="same+20") # Predict on volume using model pred = predict_volume(unet, X, axis=2, batch_size=seq.batch_size) # Map the real space coordiante predictions to nearest # real space coordinates defined on voxel grid mapped_pred = map_real_space_pred(pred, grid, inv_basis, voxel_grid_real_space, method="nearest") if not majority: combined[n_view] = mapped_pred else: combined += mapped_pred if n_classes == 1: # Set to background if outside pred domain combined[n_view][np.isnan(combined[n_view])] = 0. if not predict_mode and np.random.rand() <= eval_prob: view_dices = dice_all(y, pred_to_class(pred, img_dims=3, has_batch_dim=False), ignore_zero=False, n_classes=n_classes, skip_if_no_y=False) mapped_dices = dice_all(image.labels, pred_to_class(mapped_pred, img_dims=3, has_batch_dim=False), ignore_zero=False, n_classes=n_classes, skip_if_no_y=False) mean_dice = mapped_dices[~np.isnan(mapped_dices)][1:].mean() # Print dice scores print("View dice scores: ", view_dices) print("Mapped dice scores: ", mapped_dices) print("Mean dice (n=%i): " % (len(mapped_dices) - 1), mean_dice) # Add to results results.loc[image_id, str(v)] = mean_dice detailed_res[str(v)][image_id] = mapped_dices[1:] # Overwrite with so-far results save_all(results, detailed_res, out_dir) else: print("Skipping evaluation for this view... " "(eval_prob=%.3f, predict_mode=%s)" % (eval_prob, predict_mode)) if not analytical and not majority: # Combine predictions across views using Fusion model print("\nFusing views...") combined = np.moveaxis(combined, 0, -2).reshape( (-1, len(views), n_classes)) combined = fm.predict(combined, batch_size=10**4, verbose=1).reshape( (d[0], d[1], d[2], n_classes)) elif analytical: print("\nFusing views (analytical)...") combined = np.sum(combined, axis=0) if not no_argmax: print("\nComputing majority vote...") combined = pred_to_class(combined.squeeze(), img_dims=3).astype(np.uint8) if not predict_mode: if no_argmax: # MAP only for dice calculation c_temp = pred_to_class(combined, img_dims=3).astype(np.uint8) else: c_temp = combined # Calculate combined prediction dice dices = dice_all(image.labels, c_temp, n_classes=n_classes, ignore_zero=True, skip_if_no_y=False) mean_dice = dices[~np.isnan(dices)].mean() detailed_res["MJ"][image_id] = dices print("Combined dices: ", dices) print("Combined mean dice: ", mean_dice) results.loc[image_id, "MJ"] = mean_dice # Overwrite with so-far results save_all(results, detailed_res, out_dir) # Save combined prediction volume as .nii file print("Saving .nii files...") save_nii_files(combined, image, nii_res_dir, save_input_files) # Remove image from dictionary and image_pair_loader to free memory del all_images[image_id] image_pair_loader.images.remove(image) if not predict_mode: # Write final results save_all(results, detailed_res, out_dir)