Esempio n. 1
0
def quick_check(run_makelinks=False):
    if run_makelinks:
        makelinks()
    path = "./config/examples/train_example.yaml"
    with open(path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    train(config)
def load_config(config_path, run_makelinks=False):
    """Reads config file and calculates additional dcm attributes such as
    slice volume. Returns a dictionary used for patient wide calculations
    such as TKV.

    Args:
        config_path (str): config file path
        run_makelinks (bool, optional): Creates symbolic links during the first run. Defaults to False.

    Returns:
        dataloader, model, device, binarize_func, save_dir (str), model_name (str), split (str)
    """

    if run_makelinks:
        makelinks()
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    model_config = config["_MODEL_CONFIG"]
    loader_to_eval = config["_LOADER_TO_EVAL"]
    split = config[loader_to_eval]["dataset"]["splitter_key"].lower()
    dataloader_config = config[loader_to_eval]
    saved_checkpoint = config["_MODEL_CHECKPOINT"]
    checkpoint_format = config["_NEW_CKP_FORMAT"]

    model = get_object_instance(model_config)()
    if saved_checkpoint is not None:
        load_model_data(saved_checkpoint, model, new_format=checkpoint_format)

    dataloader = get_object_instance(dataloader_config)()

    # TODO: support other metrics as needed
    # binarize_func = SigmoidBinarize(thresholds=[0.5])

    pred_process_config = config["_LOSSES_METRICS_CONFIG"]["criterions_dict"][
        "dice_metric"]["pred_process"]
    pred_process = get_object_instance(pred_process_config)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    model_name = Path(config_path).parts[-3]

    save_dir = "./saved_inference"

    return (
        dataloader,
        model,
        device,
        pred_process,
        save_dir,
        model_name,
        split,
    )
Esempio n. 3
0
def calculate_TKVs(config_path, run_makelinks=False, output=None):
    if run_makelinks:
        makelinks()
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # val or test
    split = config["_LOADER_TO_EVAL"].split("_")[1].lower()

    dcm2attrib = evaluate(config)

    patient_MR_TKV = defaultdict(float)
    TKV_data = OrderedDict()

    for key, value in dcm2attrib.items():
        patient_MR = value["patient"] + value["MR"]
        patient_MR_TKV[(patient_MR, "GT")] += value["Vol_GT"]
        patient_MR_TKV[(patient_MR, "Pred")] += value["Vol_Pred"]

    for key, value in dcm2attrib.items():
        patient_MR = value["patient"] + value["MR"]

        if patient_MR not in TKV_data:

            summary = {
                "TKV_GT": patient_MR_TKV[(patient_MR, "GT")],
                "TKV_Pred": patient_MR_TKV[(patient_MR, "Pred")],
                "sequence": value["seq"],
                "split": split,
            }

            TKV_data[patient_MR] = summary

    df = pd.DataFrame(TKV_data).transpose()

    if output is not None:
        df.to_csv(output)

    return TKV_data
def quick_check(config_path, run_makelinks=False):
    if run_makelinks:
        makelinks()
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    evaluate(config)

# %%
def quick_check(config_path, run_makelinks=False):
    if run_makelinks:
        makelinks()
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    evaluate(config)


# %%
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        help="YAML config path",
                        type=str,
                        required=True)
    parser.add_argument("--makelinks",
                        help="Make data links",
                        action="store_true")

    args = parser.parse_args()
    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    if args.makelinks:
        makelinks()

    evaluate(config)