Exemple #1
0
def build_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    model_config: dict,
    loss_config: dict,
):
    """
    Parsing algorithm types to model building functions

    :param moving_image_size: [m_dim1, m_dim2, m_dim3]
    :param fixed_image_size: [f_dim1, f_dim2, f_dim3]
    :param index_size: dataset size
    :param labeled: true if the label of moving/fixed images are provided
    :param batch_size: mini-batch size
    :param model_config: model configuration, e.g. dictionary return from parser.yaml.load
    :param loss_config: loss configuration, e.g. dictionary return from parser.yaml.load
    :return: the built tf.keras.Model
    """
    if model_config["method"] in ["ddf", "dvf"]:
        return build_ddf_dvf_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            model_config=model_config,
            loss_config=loss_config,
        )
    elif model_config["method"] == "conditional":
        return build_conditional_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            model_config=model_config,
            loss_config=loss_config,
        )
    elif model_config["method"] == "affine":
        return build_affine_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            model_config=model_config,
            loss_config=loss_config,
        )
    else:
        raise ValueError("Unknown model method")
Exemple #2
0
def build_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    train_config: dict,
    registry: Registry,
):
    """
    Parsing algorithm types to model building functions.

    :param moving_image_size: [m_dim1, m_dim2, m_dim3]
    :param fixed_image_size: [f_dim1, f_dim2, f_dim3]
    :param index_size: dataset size
    :param labeled: true if the label of moving/fixed images are provided
    :param batch_size: mini-batch size
    :param train_config: train configuration
    :return: the built tf.keras.Model
    """
    if train_config["method"] in ["ddf", "dvf"]:
        return build_ddf_dvf_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            train_config=train_config,
            registry=registry,
        )
    elif train_config["method"] == "conditional":
        return build_conditional_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            train_config=train_config,
            registry=registry,
        )
    elif train_config["method"] == "affine":
        return build_affine_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            train_config=train_config,
            registry=registry,
        )
    else:
        raise ValueError(f"Unknown method {train_config['method']}")
def test_build_ddf_dvf_model():
    """
    Testing that build_ddf_dvf_model function returns the tensors with correct shapes
    """
    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)
    batch_size = 1
    train_config = {
        "method": "ddf",
        "backbone": {
            "name": "local",
            "num_channel_initial": 4,
            "extract_levels": [1, 2, 3],
        },
        "loss": {
            "dissimilarity": {
                "image": {
                    "name": "lncc",
                    "weight": 0.1
                },
                "label": {
                    "name": "multi_scale",
                    "weight": 1,
                    "multi_scale": {
                        "loss_type": "dice",
                        "loss_scales": [0, 1, 2, 4, 8, 16, 32],
                    },
                },
            },
            "regularization": {
                "weight": 0.0,
                "energy_type": "bending"
            },
        },
    }

    # Create DDF model
    model_ddf = build_ddf_dvf_model(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=1,
        labeled=True,
        batch_size=batch_size,
        train_config=train_config,
        registry=REGISTRY,
    )

    # Create DVF model
    train_config["method"] = "dvf"
    model_dvf = build_ddf_dvf_model(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=1,
        labeled=True,
        batch_size=batch_size,
        train_config=train_config,
        registry=REGISTRY,
    )
    inputs = {
        "moving_image": tf.ones((batch_size, ) + moving_image_size),
        "fixed_image": tf.ones((batch_size, ) + fixed_image_size),
        "indices": 1,
        "moving_label": tf.ones((batch_size, ) + moving_image_size),
        "fixed_label": tf.ones((batch_size, ) + fixed_image_size),
    }
    outputs_ddf = model_ddf(inputs)
    outputs_dvf = model_dvf(inputs)

    expected_outputs_keys = ["dvf", "ddf", "pred_fixed_label"]
    assert all(keys in expected_outputs_keys for keys in outputs_ddf)
    assert outputs_ddf["pred_fixed_label"].shape == (
        batch_size, ) + fixed_image_size
    assert outputs_ddf["ddf"].shape == (batch_size, ) + fixed_image_size + (
        3, )
    with pytest.raises(KeyError):
        outputs_ddf["dvf"]

    assert all(keys in expected_outputs_keys for keys in outputs_dvf)
    assert outputs_dvf["pred_fixed_label"].shape == (
        batch_size, ) + fixed_image_size
    assert outputs_dvf["dvf"].shape == (batch_size, ) + fixed_image_size + (
        3, )
    assert outputs_dvf["ddf"].shape == (batch_size, ) + fixed_image_size + (
        3, )