コード例 #1
0
def generate_filenames(config, name, system_config, skip_targets=False, raise_if_not_exists=False):
    if name not in config:
        load_subject_ids(config, name)
    if "generate_filenames" not in config or config["generate_filenames"] == "classic":
        return generate_hcp_filenames(in_config('directory', system_config, ""),
                                      config['surface_basename_template']
                                      if "surface_basename_template" in config else None,
                                      config['target_basenames'],
                                      config['feature_basenames'],
                                      config[name],
                                      config['hemispheres'] if 'hemispheres' in config else None)
    elif config["generate_filenames"] == "paired":
        return generate_paired_filenames(in_config('directory', system_config, ""),
                                         config[name],
                                         name,
                                         raise_if_not_exists=raise_if_not_exists,
                                         **config["generate_filenames_kwargs"])
    elif config["generate_filenames"] == "multisource_templates":
        return generate_filenames_from_multisource_templates(config[name],
                                                             raise_if_not_exists=raise_if_not_exists,
                                                             **config["generate_filenames_kwargs"])
    elif config["generate_filenames"] == "templates":
        return generate_filenames_from_templates(config[name],
                                                 raise_if_not_exists=raise_if_not_exists,
                                                 **config["generate_filenames_kwargs"],
                                                 skip_targets=skip_targets)
コード例 #2
0
ファイル: train.py プロジェクト: YanglanOu/3DUnetCNN
def check_hierarchy(config):
    if in_config("labels", config["sequence_kwargs"]) and in_config("use_label_hierarchy", config["sequence_kwargs"]):
        config["sequence_kwargs"].pop("use_label_hierarchy")
        labels = config["sequence_kwargs"].pop("labels")
        new_labels = list()
        while len(labels):
            new_labels.append(labels)
            labels = labels[1:]
        config["sequence_kwargs"]["labels"] = new_labels
コード例 #3
0
ファイル: train.py プロジェクト: YanglanOu/3DUnetCNN
def main():
    import nibabel as nib
    nib.imageglobals.logger.level = 40

    namespace = parse_args()

    print("Config: ", namespace.config_filename)
    config = load_json(namespace.config_filename)

    if "package" in config:
        package = config["package"]
    else:
        package = "keras"

    if "metric_names" in config and not config["n_outputs"] == len(config["metric_names"]):
        raise ValueError("n_outputs set to {}, but number of metrics is {}.".format(config["n_outputs"],
                                                                                    len(config["metric_names"])))

    print("Model: ", namespace.model_filename)
    print("Log: ", namespace.training_log_filename)
    system_config = get_machine_config(namespace)

    if namespace.fit_gpu_mem and namespace.fit_gpu_mem > 0:
        update_config_to_fit_gpu_memory(config=config, n_gpus=system_config["n_gpus"], gpu_memory=namespace.fit_gpu_mem,
                                        output_filename=namespace.config_filename.replace(".json", "_auto.json"))

    if namespace.group_average_filenames is not None:
        group_average = get_metric_data_from_config(namespace.group_average_filenames, namespace.config_filename)
        model_metrics = [wrapped_partial(compare_scores, comparison=group_average)]
        metric_to_monitor = "compare_scores"
    else:
        model_metrics = []
        if config['skip_validation']:
            metric_to_monitor = "loss"
        else:
            metric_to_monitor = "val_loss"

    if config["skip_validation"]:
        groups = ("training",)
    else:
        groups = ("training", "validation")

    for name in groups:
        key = name + "_filenames"
        if key not in config:
            config[key] = generate_filenames(config, name, system_config)
    if "directory" in system_config:
        directory = system_config.pop("directory")
    else:
        directory = "."

    if "sequence" in config:
        sequence_class = load_sequence(config["sequence"])
    elif "_wb_" in os.path.basename(namespace.config_filename):
        if "package" in config and config["package"] == "pytorch":
            if config["sequence"] == "AEDataset":
                sequence_class = AEDataset
            elif config["sequence"] == "WholeVolumeSegmentationDataset":
                sequence_class = WholeVolumeSegmentationDataset
            else:
                sequence_class = WholeBrainCIFTI2DenseScalarDataset
        else:
            sequence_class = WholeVolumeToSurfaceSequence
    elif config["sequence"] == "WindowedAutoEncoderSequence":
        sequence_class = WindowedAutoEncoderSequence
    elif config["sequence"] == "WindowedAEDataset":
        sequence_class = WindowedAEDataset
    elif "_pb_" in os.path.basename(namespace.config_filename):
        sequence_class = ParcelBasedSequence
        config["sequence_kwargs"]["parcellation_template"] = os.path.join(
            directory, config["sequence_kwargs"]["parcellation_template"])
    else:
        if config["package"] == "pytorch":
            sequence_class = HCPRegressionDataset
        else:
            sequence_class = HCPRegressionSequence

    if "bias_filename" in config and config["bias_filename"] is not None:
        bias = load_bias(config["bias_filename"])
    else:
        bias = None

    check_hierarchy(config)

    if in_config("add_contours", config["sequence_kwargs"], False):
        config["n_outputs"] = config["n_outputs"] * 2

    if sequence_class == ParcelBasedSequence:
        target_parcels = config["sequence_kwargs"].pop("target_parcels")
        for target_parcel in target_parcels:
            config["sequence_kwargs"]["target_parcel"] = target_parcel
            print("Training on parcel: {}".format(target_parcel))
            if type(target_parcel) == list:
                parcel_id = "-".join([str(i) for i in target_parcel])
            else:
                parcel_id = str(target_parcel)
            _training_log_filename = namespace.training_log_filename.replace(".csv", "_{}.csv".format(parcel_id))
            if os.path.exists(_training_log_filename):
                _training_log = pd.read_csv(_training_log_filename)
                if (_training_log[metric_to_monitor].values.argmin()
                        <= len(_training_log) - int(config["early_stopping_patience"])):
                    print("Already trained")
                    continue
            run_training(package,
                         config,
                         namespace.model_filename.replace(".h5", "_{}.h5".format(parcel_id)),
                         _training_log_filename,
                         sequence_class=sequence_class,
                         model_metrics=model_metrics,
                         metric_to_monitor=metric_to_monitor,
                         **system_config)

    else:
        run_training(package, config, namespace.model_filename, namespace.training_log_filename,
                     sequence_class=sequence_class,
                     model_metrics=model_metrics, metric_to_monitor=metric_to_monitor, bias=bias, **system_config)

    if namespace.sub_command == "predict":
        run_inference(namespace)
コード例 #4
0
ファイル: predict.py プロジェクト: ellisdg/3DUnetCNN
def run_inference(namespace):
    print("Config: ", namespace.config_filename)
    config = load_json(namespace.config_filename)
    key = namespace.group + "_filenames"

    system_config = get_machine_config(namespace)

    if namespace.filenames:
        filenames = list()
        for filename in namespace.filenames:
            filenames.append([
                filename, namespace.sub_volumes, None, None,
                os.path.basename(filename).split(".")[0]
            ])
    elif key not in config:
        if namespace.replace is not None:
            for _key in ("directory", "feature_templates", "target_templates"):
                if _key in config["generate_filenames_kwargs"]:
                    if type(config["generate_filenames_kwargs"][_key]) == str:

                        for i in range(0, len(namespace.replace), 2):
                            config["generate_filenames_kwargs"][_key] = config[
                                "generate_filenames_kwargs"][_key].replace(
                                    namespace.replace[i],
                                    namespace.replace[i + 1])
                    else:
                        config["generate_filenames_kwargs"][_key] = [
                            template.replace(namespace.replace[0],
                                             namespace.replace[1]) for template
                            in config["generate_filenames_kwargs"][_key]
                        ]
        if namespace.directory_template is not None:
            directory = namespace.directory_template
        elif "directory" in system_config and system_config["directory"]:
            directory = system_config["directory"]
        elif "directory" in config:
            directory = config["directory"]
        else:
            directory = ""
        if namespace.subjects_config_filename:
            config[namespace.group] = load_json(
                namespace.subjects_config_filename)[namespace.group]
        else:
            load_subject_ids(config, namespace.group)
        filenames = generate_filenames(config,
                                       namespace.group,
                                       directory,
                                       skip_targets=(not namespace.eval))

    else:
        filenames = config[key]

    print("Model: ", namespace.model_filename)

    print("Output Directory:", namespace.output_directory)

    if not os.path.exists(namespace.output_directory):
        os.makedirs(namespace.output_directory)

    if "evaluation_metric" in config and config[
            "evaluation_metric"] is not None:
        criterion_name = config['evaluation_metric']
    else:
        criterion_name = config['loss']

    if "model_kwargs" in config:
        model_kwargs = config["model_kwargs"]
    else:
        model_kwargs = dict()

    if namespace.activation:
        model_kwargs["activation"] = namespace.activation

    if "sequence_kwargs" in config:
        sequence_kwargs = config["sequence_kwargs"]
        # make sure any augmentations are set to None
        for key in ["augment_scale_std", "additive_noise_std"]:
            if key in sequence_kwargs:
                sequence_kwargs[key] = None
    else:
        sequence_kwargs = dict()

    if "reorder" not in sequence_kwargs:
        sequence_kwargs["reorder"] = in_config("reorder", config, False)

    if "generate_filenames" in config and config[
            "generate_filenames"] == "multisource_templates":
        if namespace.filenames is not None:
            sequence_kwargs["inputs_per_epoch"] = None
        else:
            # set which source(s) to use for prediction filenames
            if "inputs_per_epoch" not in sequence_kwargs:
                sequence_kwargs["inputs_per_epoch"] = dict()
            if namespace.source is not None:
                # just use the named source
                for dataset in filenames:
                    sequence_kwargs["inputs_per_epoch"][dataset] = 0
                sequence_kwargs["inputs_per_epoch"][namespace.source] = "all"
            else:
                # use all sources
                for dataset in filenames:
                    sequence_kwargs["inputs_per_epoch"][dataset] = "all"
    if namespace.sub_volumes is not None:
        sequence_kwargs["extract_sub_volumes"] = True

    if "sequence" in config:
        sequence = load_sequence(config["sequence"])
    else:
        sequence = None

    labels = sequence_kwargs["labels"] if namespace.segment else None
    if "use_label_hierarchy" in sequence_kwargs:
        label_hierarchy = sequence_kwargs.pop("use_label_hierarchy")
    else:
        label_hierarchy = False

    if label_hierarchy and (namespace.threshold != 0.5 or namespace.sum):
        # TODO: put a warning here instead of a print statement
        print(
            "Using label hierarchy. Resetting threshold to 0.5 and turning the summation off."
        )
        namespace.threshold = 0.5
        namespace.sum = False
    if in_config("add_contours", sequence_kwargs, False):
        config["n_outputs"] = config["n_outputs"] * 2
        if namespace.use_contours:
            # this sets the labels for the contours
            if label_hierarchy:
                raise RuntimeError(
                    "Cannot use contours for segmentation while a label hierarchy is specified."
                )
            labels = list(labels) + list(labels)

    if namespace.alternate_prediction_func:
        from unet3d import predict
        func = getattr(predict, namespace.alternate_prediction_func)
    else:
        func = volumetric_predictions

    return func(model_filename=namespace.model_filename,
                filenames=filenames,
                prediction_dir=namespace.output_directory,
                model_name=config["model_name"],
                n_features=config["n_features"],
                window=config["window"],
                criterion_name=criterion_name,
                package=config['package'],
                n_gpus=system_config['n_gpus'],
                batch_size=config['validation_batch_size'],
                n_workers=system_config["n_workers"],
                model_kwargs=model_kwargs,
                sequence_kwargs=sequence_kwargs,
                sequence=sequence,
                n_outputs=config["n_outputs"],
                metric_names=in_config("metric_names", config, None),
                evaluate_predictions=namespace.eval,
                resample_predictions=(not namespace.no_resample),
                interpolation=namespace.interpolation,
                output_template=namespace.output_template,
                segmentation=namespace.segment,
                segmentation_labels=labels,
                threshold=namespace.threshold,
                sum_then_threshold=namespace.sum,
                label_hierarchy=label_hierarchy,
                write_input_images=namespace.write_input_images)