Beispiel #1
0
def main():
    namespace = parse_args()
    if not os.path.exists(namespace.output_dir):
        os.makedirs(namespace.output_dir)
    config = load_json(namespace.config_filename)
    labels1, labels2 = config["sequence_kwargs"]["labels"]
    for fn in glob.glob(os.path.join(namespace.prediction_dir, "*")):
        print(fn)
        bn = os.path.basename(fn)
        ofn = os.path.join(namespace.output_dir, bn)
        image = nib.load(fn)
        _image = reorder_img(image)
        data = _image.get_fdata()
        data1 = data[..., :len(labels1)]
        data2 = data[..., len(labels1):]
        for i, (l, d) in enumerate(((labels1, data1), (labels2, data2))):
            volumes = list()
            labels = list()
            for ii, label in enumerate(l):
                if type(label) == list and len(label) == 2:
                    volumes.extend(split_left_right(d[..., ii]))
                    labels.extend(label)
                else:
                    volumes.append(d[..., ii])
                    labels.append(label)
            fixed_data = resample_to_img(_image.__class__(
                dataobj=np.stack(volumes, axis=-1), affine=_image.affine),
                                         target_img=image).get_fdata()
            label_map = convert_one_hot_to_single_label_map_volume(
                fixed_data, labels, dtype=np.uint8)
            out_image = image.__class__(dataobj=label_map, affine=image.affine)
            out_image.to_filename(ofn.replace(".", "_pred{}.".format(i + 1),
                                              1))
Beispiel #2
0
def main():
    namespace = parse_args()
    config = load_json(namespace.config_filename)
    labels = config["labels"]
    if labels is None:
        labels = [1]
    filenames = get_filenames(namespace)
    target_filenames = list()
    subject_ids = list()
    for filename in filenames:
        subject_id = os.path.basename(filename).split("_")[0]
        subject_ids.append(subject_id)
        target_filenames.append(
            os.path.join(
                config["generate_filenames_kwargs"]["directory"],
                config["generate_filenames_kwargs"]["target_templates"]
                [0].format(subject=subject_id)))

    func = partial(_evaluate_filenames, labels=labels)

    with Pool(namespace.n_threads) as pool:
        scores = pool.map(func, zip(filenames, target_filenames))

    df = pd.DataFrame(scores, columns=labels, index=subject_ids)
    df.to_csv(namespace.output_filename)
Beispiel #3
0
def get_metric_data_from_config(metric_filenames,
                                config_filename,
                                subject_id=100206):
    config = load_json(config_filename)
    if type(metric_filenames) == str:
        metrics = nib_load_files([metric_filenames])
    else:
        metrics = nib_load_files(metric_filenames)
    metric_data = get_metric_data(metrics, config["metric_names"],
                                  config["surface_names"],
                                  subject_id).T.ravel()
    return metric_data
Beispiel #4
0
def get_machine_config(namespace):
    if namespace.machine_config_filename:
        print("MP Config: ", namespace.machine_config_filename)
        return load_json(namespace.machine_config_filename)
    else:
        return {
            "n_workers": namespace.nthreads,
            "n_gpus": namespace.ngpus,
            "use_multiprocessing": namespace.nthreads > 1,
            "pin_memory": namespace.pin_memory,
            "directory": namespace.directory
        }
Beispiel #5
0
def load_subject_ids(config):
    if "subjects_filename" in config:
        subjects = load_json(os.path.join(unet3d_path, config["subjects_filename"]))
        for key, value in subjects.items():
            config[key] = value
Beispiel #6
0
def main():
    import nibabel as nib
    nib.imageglobals.logger.level = 40

    namespace = parse_args()

    print("Config: ", namespace.config_filename)
    config = load_json(namespace.config_filename)

    if "package" in config:
        package = config["package"]
    else:
        package = "keras"

    if "metric_names" in config and not config["n_outputs"] == len(config["metric_names"]):
        raise ValueError("n_outputs set to {}, but number of metrics is {}.".format(config["n_outputs"],
                                                                                    len(config["metric_names"])))

    print("Model: ", namespace.model_filename)
    print("Log: ", namespace.training_log_filename)
    system_config = get_machine_config(namespace)

    if namespace.fit_gpu_mem and namespace.fit_gpu_mem > 0:
        update_config_to_fit_gpu_memory(config=config, n_gpus=system_config["n_gpus"], gpu_memory=namespace.fit_gpu_mem,
                                        output_filename=namespace.config_filename.replace(".json", "_auto.json"))

    if namespace.group_average_filenames is not None:
        group_average = get_metric_data_from_config(namespace.group_average_filenames, namespace.config_filename)
        model_metrics = [wrapped_partial(compare_scores, comparison=group_average)]
        metric_to_monitor = "compare_scores"
    else:
        model_metrics = []
        if config['skip_validation']:
            metric_to_monitor = "loss"
        else:
            metric_to_monitor = "val_loss"

    if config["skip_validation"]:
        groups = ("training",)
    else:
        groups = ("training", "validation")

    for name in groups:
        key = name + "_filenames"
        if key not in config:
            config[key] = generate_filenames(config, name, system_config)
    if "directory" in system_config:
        directory = system_config.pop("directory")
    else:
        directory = "."

    if "sequence" in config:
        sequence_class = load_sequence(config["sequence"])
    elif "_wb_" in os.path.basename(namespace.config_filename):
        if "package" in config and config["package"] == "pytorch":
            if config["sequence"] == "AEDataset":
                sequence_class = AEDataset
            elif config["sequence"] == "WholeVolumeSegmentationDataset":
                sequence_class = WholeVolumeSegmentationDataset
            else:
                sequence_class = WholeBrainCIFTI2DenseScalarDataset
        else:
            sequence_class = WholeVolumeToSurfaceSequence
    elif config["sequence"] == "WindowedAutoEncoderSequence":
        sequence_class = WindowedAutoEncoderSequence
    elif config["sequence"] == "WindowedAEDataset":
        sequence_class = WindowedAEDataset
    elif "_pb_" in os.path.basename(namespace.config_filename):
        sequence_class = ParcelBasedSequence
        config["sequence_kwargs"]["parcellation_template"] = os.path.join(
            directory, config["sequence_kwargs"]["parcellation_template"])
    else:
        if config["package"] == "pytorch":
            sequence_class = HCPRegressionDataset
        else:
            sequence_class = HCPRegressionSequence

    if "bias_filename" in config and config["bias_filename"] is not None:
        bias = load_bias(config["bias_filename"])
    else:
        bias = None

    check_hierarchy(config)

    if in_config("add_contours", config["sequence_kwargs"], False):
        config["n_outputs"] = config["n_outputs"] * 2

    if sequence_class == ParcelBasedSequence:
        target_parcels = config["sequence_kwargs"].pop("target_parcels")
        for target_parcel in target_parcels:
            config["sequence_kwargs"]["target_parcel"] = target_parcel
            print("Training on parcel: {}".format(target_parcel))
            if type(target_parcel) == list:
                parcel_id = "-".join([str(i) for i in target_parcel])
            else:
                parcel_id = str(target_parcel)
            _training_log_filename = namespace.training_log_filename.replace(".csv", "_{}.csv".format(parcel_id))
            if os.path.exists(_training_log_filename):
                _training_log = pd.read_csv(_training_log_filename)
                if (_training_log[metric_to_monitor].values.argmin()
                        <= len(_training_log) - int(config["early_stopping_patience"])):
                    print("Already trained")
                    continue
            run_training(package,
                         config,
                         namespace.model_filename.replace(".h5", "_{}.h5".format(parcel_id)),
                         _training_log_filename,
                         sequence_class=sequence_class,
                         model_metrics=model_metrics,
                         metric_to_monitor=metric_to_monitor,
                         **system_config)

    else:
        run_training(package, config, namespace.model_filename, namespace.training_log_filename,
                     sequence_class=sequence_class,
                     model_metrics=model_metrics, metric_to_monitor=metric_to_monitor, bias=bias, **system_config)

    if namespace.sub_command == "predict":
        run_inference(namespace)
Beispiel #7
0
def main():
    config_filename = sys.argv[1]
    print("Config: ", config_filename)
    config = load_json(config_filename)
    model_filename = sys.argv[2]
    print("Model: ", model_filename)

    machine_config_filename = sys.argv[3]
    print("Machine config: ", machine_config_filename)
    machine_config = load_json(machine_config_filename)

    output_directory = os.path.abspath(sys.argv[4])
    print("Output Directory:", output_directory)

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    try:
        reference_filename = sys.argv[5]
        reference_subject_id = sys.argv[6]
        reference_cifti = nib_load_files([reference_filename])
        reference_array = get_metric_data(reference_cifti,
                                          config["metric_names"],
                                          config["surface_names"],
                                          reference_subject_id)
    except IndexError:
        reference_array = None

    load_subject_ids(config)

    if "evaluation_metric" in config:
        criterion_name = config['evaluation_metric']
    else:
        criterion_name = config['loss']

    if "model_kwargs" in config:
        model_kwargs = config["model_kwargs"]
    else:
        model_kwargs = dict()

    return whole_brain_scalar_predictions(
        model_filename=model_filename,
        subject_ids=config['validation'],
        hcp_dir=machine_config["directory"],
        output_dir=output_directory,
        hemispheres=config["hemispheres"],
        feature_basenames=config["feature_basenames"],
        surface_basename_template=config["surface_basename_template"],
        target_basenames=config["target_basenames"],
        model_name=config["model_name"],
        n_outputs=config["n_outputs"],
        n_features=config["n_features"],
        window=config["window"],
        criterion_name=criterion_name,
        metric_names=config["metric_names"],
        surface_names=config["surface_names"],
        reference=reference_array,
        package=config['package'],
        n_gpus=machine_config['n_gpus'],
        batch_size=config['validation_batch_size'],
        n_workers=machine_config["n_workers"],
        model_kwargs=model_kwargs)
Beispiel #8
0
def make_predictions(config_filename,
                     model_filename,
                     output_directory='./',
                     n_subjects=None,
                     shuffle=False,
                     key='validation_filenames',
                     use_multiprocessing=False,
                     n_workers=1,
                     max_queue_size=5,
                     batch_size=50,
                     overwrite=True,
                     single_subject=None,
                     output_task_name=None,
                     package="keras",
                     directory="./",
                     n_gpus=1):
    output_directory = os.path.abspath(output_directory)
    config = load_json(config_filename)

    if key not in config:
        name = key.split("_")[0]
        if name not in config:
            load_subject_ids(config)
        config[key] = generate_hcp_filenames(
            directory, config['surface_basename_template'],
            config['target_basenames'], config['feature_basenames'],
            config[name], config['hemispheres'])

    filenames = config[key]

    model_basename = os.path.basename(model_filename).replace(".h5", "")

    if "package" in config and config["package"] == "pytorch":
        generator = HCPSubjectDataset
        package = "pytorch"
    else:
        generator = SubjectPredictionSequence

    if "model_kwargs" in config:
        model_kwargs = config["model_kwargs"]
    else:
        model_kwargs = dict()

    if "batch_size" in config:
        batch_size = config["batch_size"]

    if single_subject is None:
        if package == "pytorch":
            from unet3d.models.pytorch.build import build_or_load_model

            model = build_or_load_model(model_filename=model_filename,
                                        model_name=config["model_name"],
                                        n_features=config["n_features"],
                                        n_outputs=config["n_outputs"],
                                        n_gpus=n_gpus,
                                        **model_kwargs)
        else:
            from keras.models import load_model
            model = load_model(model_filename)
    else:
        model = None

    if n_subjects is not None:
        if shuffle:
            np.random.shuffle(filenames)
        filenames = filenames[:n_subjects]

    for feature_filename, surface_filenames, metric_filenames, subject_id in filenames:
        if single_subject is None or subject_id == single_subject:
            if model is None:
                if package == "pytorch":
                    from unet3d.models.pytorch.build import build_or_load_model

                    model = build_or_load_model(
                        model_filename=model_filename,
                        model_name=config["model_name"],
                        n_features=config["n_features"],
                        n_outputs=config["n_outputs"],
                        n_gpus=n_gpus,
                        **model_kwargs)
                else:
                    model = load_model(model_filename)
            if output_task_name is None:
                _output_task_name = os.path.basename(
                    metric_filenames[0]).split(".")[0]
                if len(metric_filenames) > 1:
                    _output_task_name = "_".join(
                        _output_task_name.split("_")[:2] + ["ALL47"] +
                        _output_task_name.split("_")[3:])
            else:
                _output_task_name = output_task_name

            output_basename = "{task}-{model}_prediction.dscalar.nii".format(
                model=model_basename, task=_output_task_name)
            output_filename = os.path.join(output_directory, output_basename)
            subject_metric_names = list()
            for metric_list in config["metric_names"]:
                for metric_name in metric_list:
                    subject_metric_names.append(metric_name.format(subject_id))
            predict_subject(model,
                            feature_filename,
                            surface_filenames,
                            config['surface_names'],
                            subject_metric_names,
                            output_filename=output_filename,
                            batch_size=batch_size,
                            window=np.asarray(config['window']),
                            spacing=np.asarray(config['spacing']),
                            flip=False,
                            overwrite=overwrite,
                            use_multiprocessing=use_multiprocessing,
                            workers=n_workers,
                            max_queue_size=max_queue_size,
                            reference_filename=metric_filenames[0],
                            package=package,
                            generator=generator)
Beispiel #9
0
def run_inference(namespace):
    print("Config: ", namespace.config_filename)
    config = load_json(namespace.config_filename)
    key = namespace.group + "_filenames"

    system_config = get_machine_config(namespace)

    if namespace.filenames:
        filenames = list()
        for filename in namespace.filenames:
            filenames.append([
                filename, namespace.sub_volumes, None, None,
                os.path.basename(filename).split(".")[0]
            ])
    elif key not in config:
        if namespace.replace is not None:
            for _key in ("directory", "feature_templates", "target_templates"):
                if _key in config["generate_filenames_kwargs"]:
                    if type(config["generate_filenames_kwargs"][_key]) == str:

                        for i in range(0, len(namespace.replace), 2):
                            config["generate_filenames_kwargs"][_key] = config[
                                "generate_filenames_kwargs"][_key].replace(
                                    namespace.replace[i],
                                    namespace.replace[i + 1])
                    else:
                        config["generate_filenames_kwargs"][_key] = [
                            template.replace(namespace.replace[0],
                                             namespace.replace[1]) for template
                            in config["generate_filenames_kwargs"][_key]
                        ]
        if namespace.directory_template is not None:
            directory = namespace.directory_template
        elif "directory" in system_config and system_config["directory"]:
            directory = system_config["directory"]
        elif "directory" in config:
            directory = config["directory"]
        else:
            directory = ""
        if namespace.subjects_config_filename:
            config[namespace.group] = load_json(
                namespace.subjects_config_filename)[namespace.group]
        else:
            load_subject_ids(config, namespace.group)
        filenames = generate_filenames(config,
                                       namespace.group,
                                       directory,
                                       skip_targets=(not namespace.eval))

    else:
        filenames = config[key]

    print("Model: ", namespace.model_filename)

    print("Output Directory:", namespace.output_directory)

    if not os.path.exists(namespace.output_directory):
        os.makedirs(namespace.output_directory)

    if "evaluation_metric" in config and config[
            "evaluation_metric"] is not None:
        criterion_name = config['evaluation_metric']
    else:
        criterion_name = config['loss']

    if "model_kwargs" in config:
        model_kwargs = config["model_kwargs"]
    else:
        model_kwargs = dict()

    if namespace.activation:
        model_kwargs["activation"] = namespace.activation

    if "sequence_kwargs" in config:
        sequence_kwargs = config["sequence_kwargs"]
        # make sure any augmentations are set to None
        for key in ["augment_scale_std", "additive_noise_std"]:
            if key in sequence_kwargs:
                sequence_kwargs[key] = None
    else:
        sequence_kwargs = dict()

    if "reorder" not in sequence_kwargs:
        sequence_kwargs["reorder"] = in_config("reorder", config, False)

    if "generate_filenames" in config and config[
            "generate_filenames"] == "multisource_templates":
        if namespace.filenames is not None:
            sequence_kwargs["inputs_per_epoch"] = None
        else:
            # set which source(s) to use for prediction filenames
            if "inputs_per_epoch" not in sequence_kwargs:
                sequence_kwargs["inputs_per_epoch"] = dict()
            if namespace.source is not None:
                # just use the named source
                for dataset in filenames:
                    sequence_kwargs["inputs_per_epoch"][dataset] = 0
                sequence_kwargs["inputs_per_epoch"][namespace.source] = "all"
            else:
                # use all sources
                for dataset in filenames:
                    sequence_kwargs["inputs_per_epoch"][dataset] = "all"
    if namespace.sub_volumes is not None:
        sequence_kwargs["extract_sub_volumes"] = True

    if "sequence" in config:
        sequence = load_sequence(config["sequence"])
    else:
        sequence = None

    labels = sequence_kwargs["labels"] if namespace.segment else None
    if "use_label_hierarchy" in sequence_kwargs:
        label_hierarchy = sequence_kwargs.pop("use_label_hierarchy")
    else:
        label_hierarchy = False

    if label_hierarchy and (namespace.threshold != 0.5 or namespace.sum):
        # TODO: put a warning here instead of a print statement
        print(
            "Using label hierarchy. Resetting threshold to 0.5 and turning the summation off."
        )
        namespace.threshold = 0.5
        namespace.sum = False
    if in_config("add_contours", sequence_kwargs, False):
        config["n_outputs"] = config["n_outputs"] * 2
        if namespace.use_contours:
            # this sets the labels for the contours
            if label_hierarchy:
                raise RuntimeError(
                    "Cannot use contours for segmentation while a label hierarchy is specified."
                )
            labels = list(labels) + list(labels)

    if namespace.alternate_prediction_func:
        from unet3d import predict
        func = getattr(predict, namespace.alternate_prediction_func)
    else:
        func = volumetric_predictions

    return func(model_filename=namespace.model_filename,
                filenames=filenames,
                prediction_dir=namespace.output_directory,
                model_name=config["model_name"],
                n_features=config["n_features"],
                window=config["window"],
                criterion_name=criterion_name,
                package=config['package'],
                n_gpus=system_config['n_gpus'],
                batch_size=config['validation_batch_size'],
                n_workers=system_config["n_workers"],
                model_kwargs=model_kwargs,
                sequence_kwargs=sequence_kwargs,
                sequence=sequence,
                n_outputs=config["n_outputs"],
                metric_names=in_config("metric_names", config, None),
                evaluate_predictions=namespace.eval,
                resample_predictions=(not namespace.no_resample),
                interpolation=namespace.interpolation,
                output_template=namespace.output_template,
                segmentation=namespace.segment,
                segmentation_labels=labels,
                threshold=namespace.threshold,
                sum_then_threshold=namespace.sum,
                label_hierarchy=label_hierarchy,
                write_input_images=namespace.write_input_images)
Beispiel #10
0
def load_subject_ids(config, name):
    if "subjects_filename" in config:
        subjects = load_json(
            os.path.join(unet3d_path, config["subjects_filename"]))
        config[name] = subjects[name]