def generate_filenames(config, name, system_config, skip_targets=False, raise_if_not_exists=False): if name not in config: load_subject_ids(config, name) if "generate_filenames" not in config or config["generate_filenames"] == "classic": return generate_hcp_filenames(in_config('directory', system_config, ""), config['surface_basename_template'] if "surface_basename_template" in config else None, config['target_basenames'], config['feature_basenames'], config[name], config['hemispheres'] if 'hemispheres' in config else None) elif config["generate_filenames"] == "paired": return generate_paired_filenames(in_config('directory', system_config, ""), config[name], name, raise_if_not_exists=raise_if_not_exists, **config["generate_filenames_kwargs"]) elif config["generate_filenames"] == "multisource_templates": return generate_filenames_from_multisource_templates(config[name], raise_if_not_exists=raise_if_not_exists, **config["generate_filenames_kwargs"]) elif config["generate_filenames"] == "templates": return generate_filenames_from_templates(config[name], raise_if_not_exists=raise_if_not_exists, **config["generate_filenames_kwargs"], skip_targets=skip_targets)
def check_hierarchy(config): if in_config("labels", config["sequence_kwargs"]) and in_config("use_label_hierarchy", config["sequence_kwargs"]): config["sequence_kwargs"].pop("use_label_hierarchy") labels = config["sequence_kwargs"].pop("labels") new_labels = list() while len(labels): new_labels.append(labels) labels = labels[1:] config["sequence_kwargs"]["labels"] = new_labels
def main(): import nibabel as nib nib.imageglobals.logger.level = 40 namespace = parse_args() print("Config: ", namespace.config_filename) config = load_json(namespace.config_filename) if "package" in config: package = config["package"] else: package = "keras" if "metric_names" in config and not config["n_outputs"] == len(config["metric_names"]): raise ValueError("n_outputs set to {}, but number of metrics is {}.".format(config["n_outputs"], len(config["metric_names"]))) print("Model: ", namespace.model_filename) print("Log: ", namespace.training_log_filename) system_config = get_machine_config(namespace) if namespace.fit_gpu_mem and namespace.fit_gpu_mem > 0: update_config_to_fit_gpu_memory(config=config, n_gpus=system_config["n_gpus"], gpu_memory=namespace.fit_gpu_mem, output_filename=namespace.config_filename.replace(".json", "_auto.json")) if namespace.group_average_filenames is not None: group_average = get_metric_data_from_config(namespace.group_average_filenames, namespace.config_filename) model_metrics = [wrapped_partial(compare_scores, comparison=group_average)] metric_to_monitor = "compare_scores" else: model_metrics = [] if config['skip_validation']: metric_to_monitor = "loss" else: metric_to_monitor = "val_loss" if config["skip_validation"]: groups = ("training",) else: groups = ("training", "validation") for name in groups: key = name + "_filenames" if key not in config: config[key] = generate_filenames(config, name, system_config) if "directory" in system_config: directory = system_config.pop("directory") else: directory = "." if "sequence" in config: sequence_class = load_sequence(config["sequence"]) elif "_wb_" in os.path.basename(namespace.config_filename): if "package" in config and config["package"] == "pytorch": if config["sequence"] == "AEDataset": sequence_class = AEDataset elif config["sequence"] == "WholeVolumeSegmentationDataset": sequence_class = WholeVolumeSegmentationDataset else: sequence_class = WholeBrainCIFTI2DenseScalarDataset else: sequence_class = WholeVolumeToSurfaceSequence elif config["sequence"] == "WindowedAutoEncoderSequence": sequence_class = WindowedAutoEncoderSequence elif config["sequence"] == "WindowedAEDataset": sequence_class = WindowedAEDataset elif "_pb_" in os.path.basename(namespace.config_filename): sequence_class = ParcelBasedSequence config["sequence_kwargs"]["parcellation_template"] = os.path.join( directory, config["sequence_kwargs"]["parcellation_template"]) else: if config["package"] == "pytorch": sequence_class = HCPRegressionDataset else: sequence_class = HCPRegressionSequence if "bias_filename" in config and config["bias_filename"] is not None: bias = load_bias(config["bias_filename"]) else: bias = None check_hierarchy(config) if in_config("add_contours", config["sequence_kwargs"], False): config["n_outputs"] = config["n_outputs"] * 2 if sequence_class == ParcelBasedSequence: target_parcels = config["sequence_kwargs"].pop("target_parcels") for target_parcel in target_parcels: config["sequence_kwargs"]["target_parcel"] = target_parcel print("Training on parcel: {}".format(target_parcel)) if type(target_parcel) == list: parcel_id = "-".join([str(i) for i in target_parcel]) else: parcel_id = str(target_parcel) _training_log_filename = namespace.training_log_filename.replace(".csv", "_{}.csv".format(parcel_id)) if os.path.exists(_training_log_filename): _training_log = pd.read_csv(_training_log_filename) if (_training_log[metric_to_monitor].values.argmin() <= len(_training_log) - int(config["early_stopping_patience"])): print("Already trained") continue run_training(package, config, namespace.model_filename.replace(".h5", "_{}.h5".format(parcel_id)), _training_log_filename, sequence_class=sequence_class, model_metrics=model_metrics, metric_to_monitor=metric_to_monitor, **system_config) else: run_training(package, config, namespace.model_filename, namespace.training_log_filename, sequence_class=sequence_class, model_metrics=model_metrics, metric_to_monitor=metric_to_monitor, bias=bias, **system_config) if namespace.sub_command == "predict": run_inference(namespace)
def run_inference(namespace): print("Config: ", namespace.config_filename) config = load_json(namespace.config_filename) key = namespace.group + "_filenames" system_config = get_machine_config(namespace) if namespace.filenames: filenames = list() for filename in namespace.filenames: filenames.append([ filename, namespace.sub_volumes, None, None, os.path.basename(filename).split(".")[0] ]) elif key not in config: if namespace.replace is not None: for _key in ("directory", "feature_templates", "target_templates"): if _key in config["generate_filenames_kwargs"]: if type(config["generate_filenames_kwargs"][_key]) == str: for i in range(0, len(namespace.replace), 2): config["generate_filenames_kwargs"][_key] = config[ "generate_filenames_kwargs"][_key].replace( namespace.replace[i], namespace.replace[i + 1]) else: config["generate_filenames_kwargs"][_key] = [ template.replace(namespace.replace[0], namespace.replace[1]) for template in config["generate_filenames_kwargs"][_key] ] if namespace.directory_template is not None: directory = namespace.directory_template elif "directory" in system_config and system_config["directory"]: directory = system_config["directory"] elif "directory" in config: directory = config["directory"] else: directory = "" if namespace.subjects_config_filename: config[namespace.group] = load_json( namespace.subjects_config_filename)[namespace.group] else: load_subject_ids(config, namespace.group) filenames = generate_filenames(config, namespace.group, directory, skip_targets=(not namespace.eval)) else: filenames = config[key] print("Model: ", namespace.model_filename) print("Output Directory:", namespace.output_directory) if not os.path.exists(namespace.output_directory): os.makedirs(namespace.output_directory) if "evaluation_metric" in config and config[ "evaluation_metric"] is not None: criterion_name = config['evaluation_metric'] else: criterion_name = config['loss'] if "model_kwargs" in config: model_kwargs = config["model_kwargs"] else: model_kwargs = dict() if namespace.activation: model_kwargs["activation"] = namespace.activation if "sequence_kwargs" in config: sequence_kwargs = config["sequence_kwargs"] # make sure any augmentations are set to None for key in ["augment_scale_std", "additive_noise_std"]: if key in sequence_kwargs: sequence_kwargs[key] = None else: sequence_kwargs = dict() if "reorder" not in sequence_kwargs: sequence_kwargs["reorder"] = in_config("reorder", config, False) if "generate_filenames" in config and config[ "generate_filenames"] == "multisource_templates": if namespace.filenames is not None: sequence_kwargs["inputs_per_epoch"] = None else: # set which source(s) to use for prediction filenames if "inputs_per_epoch" not in sequence_kwargs: sequence_kwargs["inputs_per_epoch"] = dict() if namespace.source is not None: # just use the named source for dataset in filenames: sequence_kwargs["inputs_per_epoch"][dataset] = 0 sequence_kwargs["inputs_per_epoch"][namespace.source] = "all" else: # use all sources for dataset in filenames: sequence_kwargs["inputs_per_epoch"][dataset] = "all" if namespace.sub_volumes is not None: sequence_kwargs["extract_sub_volumes"] = True if "sequence" in config: sequence = load_sequence(config["sequence"]) else: sequence = None labels = sequence_kwargs["labels"] if namespace.segment else None if "use_label_hierarchy" in sequence_kwargs: label_hierarchy = sequence_kwargs.pop("use_label_hierarchy") else: label_hierarchy = False if label_hierarchy and (namespace.threshold != 0.5 or namespace.sum): # TODO: put a warning here instead of a print statement print( "Using label hierarchy. Resetting threshold to 0.5 and turning the summation off." ) namespace.threshold = 0.5 namespace.sum = False if in_config("add_contours", sequence_kwargs, False): config["n_outputs"] = config["n_outputs"] * 2 if namespace.use_contours: # this sets the labels for the contours if label_hierarchy: raise RuntimeError( "Cannot use contours for segmentation while a label hierarchy is specified." ) labels = list(labels) + list(labels) if namespace.alternate_prediction_func: from unet3d import predict func = getattr(predict, namespace.alternate_prediction_func) else: func = volumetric_predictions return func(model_filename=namespace.model_filename, filenames=filenames, prediction_dir=namespace.output_directory, model_name=config["model_name"], n_features=config["n_features"], window=config["window"], criterion_name=criterion_name, package=config['package'], n_gpus=system_config['n_gpus'], batch_size=config['validation_batch_size'], n_workers=system_config["n_workers"], model_kwargs=model_kwargs, sequence_kwargs=sequence_kwargs, sequence=sequence, n_outputs=config["n_outputs"], metric_names=in_config("metric_names", config, None), evaluate_predictions=namespace.eval, resample_predictions=(not namespace.no_resample), interpolation=namespace.interpolation, output_template=namespace.output_template, segmentation=namespace.segment, segmentation_labels=labels, threshold=namespace.threshold, sum_then_threshold=namespace.sum, label_hierarchy=label_hierarchy, write_input_images=namespace.write_input_images)