def test_global_return():
    """
    Testing that build_backbone func returns an object
    of type GlobalNet from backbone module when initialised
    with the associated GlobalNet config.
    """
    out = util.build_backbone(
        image_size=(1, 2, 3),
        out_channels=1,
        model_config={
            "backbone": "global",
            "global": {
                "num_channel_initial": 4,
                "extract_levels": [1, 2, 3]
            },
        },
        method_name="ddf",
    )
    assert isinstance(
        out,
        type(
            global_net.GlobalNet([1, 2, 3], 4, 4, [1, 2, 3], "he_normal",
                                 "sigmoid")),
    )
示例#2
0
def build_conditional_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    train_config: dict,
    registry: Registry,
) -> tf.keras.Model:
    """
    Build a model which outputs predicted fixed label.

    :param moving_image_size: (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
    :param index_size: int, the number of indices for identifying a sample
    :param labeled: bool, indicating if the data is labeled
    :param batch_size: int, size of mini-batch
    :param train_config: config for the model and loss
    :return: the built tf.keras.Model
    """
    # inputs
    (moving_image, fixed_image, moving_label, fixed_label,
     indices) = build_inputs(
         moving_image_size=moving_image_size,
         fixed_image_size=fixed_image_size,
         index_size=index_size,
         batch_size=batch_size,
         labeled=labeled,
     )

    # backbone
    backbone = build_backbone(
        image_size=fixed_image_size,
        out_channels=1,
        config=train_config["backbone"],
        method_name=train_config["method"],
        registry=registry,
    )

    # prediction
    pred_fixed_label, grid_fixed = conditional_forward(
        backbone=backbone,
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )  # (batch, f_dim1, f_dim2, f_dim3)

    # build model
    inputs = {
        "moving_image": moving_image,
        "fixed_image": fixed_image,
        "moving_label": moving_label,
        "fixed_label": fixed_label,
        "indices": indices,
    }
    outputs = {"pred_fixed_label": pred_fixed_label}
    model = tf.keras.Model(inputs=inputs,
                           outputs=outputs,
                           name="ConditionalRegistrationModel")

    # loss and metric
    model = add_label_loss(
        model=model,
        grid_fixed=grid_fixed,
        fixed_label=fixed_label,
        pred_fixed_label=pred_fixed_label,
        loss_config=train_config["loss"],
    )

    return model
示例#3
0
文件: ddf_dvf.py 项目: CV-IP/DeepReg
def build_ddf_dvf_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    train_config: dict,
    registry: Registry,
) -> tf.keras.Model:
    """
    Build a model which outputs DDF/DVF.

    :param moving_image_size: (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
    :param index_size: int, the number of indices for identifying a sample
    :param labeled: bool, indicating if the data is labeled
    :param batch_size: int, size of mini-batch
    :param train_config: config for the model and loss
    :return: the built tf.keras.Model
    """

    # inputs
    (moving_image, fixed_image, moving_label, fixed_label,
     indices) = build_inputs(
         moving_image_size=moving_image_size,
         fixed_image_size=fixed_image_size,
         index_size=index_size,
         batch_size=batch_size,
         labeled=labeled,
     )

    # backbone
    backbone = build_backbone(
        image_size=fixed_image_size,
        out_channels=3,
        config=train_config["backbone"],
        method_name=train_config["method"],
        registry=registry,
    )

    # forward
    dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = ddf_dvf_forward(
        backbone=backbone,
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        output_dvf=train_config["method"] == "dvf",
    )

    # build model
    inputs = {
        "moving_image": moving_image,
        "fixed_image": fixed_image,
        "indices": indices,
    }
    outputs = {"ddf": ddf}
    if dvf is not None:
        outputs["dvf"] = dvf
    model_name = train_config["method"].upper() + "RegistrationModel"
    if moving_label is None:  # unlabeled
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithoutLabel")
    else:  # labeled
        inputs["moving_label"] = moving_label
        inputs["fixed_label"] = fixed_label
        outputs["pred_fixed_label"] = pred_fixed_label
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithLabel")

    # add loss and metric
    loss_config = train_config["loss"]
    model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config)
    model = add_image_loss(
        model=model,
        fixed_image=fixed_image,
        pred_fixed_image=pred_fixed_image,
        loss_config=loss_config,
    )
    model = add_label_loss(
        model=model,
        grid_fixed=grid_fixed,
        fixed_label=fixed_label,
        pred_fixed_label=pred_fixed_label,
        loss_config=loss_config,
    )

    return model
示例#4
0
文件: affine.py 项目: zcemycl/DeepReg
def build_affine_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    model_config: dict,
    loss_config: dict,
):
    """
    Build a model which outputs the parameters for affine transformation.

    :param moving_image_size: (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
    :param index_size: int, the number of indices for identifying a sample
    :param labeled: bool, indicating if the data is labeled
    :param batch_size: int, size of mini-batch
    :param model_config: config for the model
    :param loss_config: config for the loss
    :return: the built tf.keras.Model
    """

    # inputs
    (moving_image, fixed_image, moving_label, fixed_label,
     indices) = build_inputs(
         moving_image_size=moving_image_size,
         fixed_image_size=fixed_image_size,
         index_size=index_size,
         batch_size=batch_size,
         labeled=labeled,
     )

    # backbone
    backbone = build_backbone(
        image_size=fixed_image_size,
        out_channels=3,
        model_config=model_config,
        method_name=model_config["method"],
    )

    # forward
    affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward(
        backbone=backbone,
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )

    # build model
    inputs = {
        "moving_image": moving_image,
        "fixed_image": fixed_image,
        "indices": indices,
    }
    outputs = {"ddf": ddf, "affine": affine}
    model_name = model_config["method"].upper() + "RegistrationModel"
    if moving_label is None:  # unlabeled
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithoutLabel")
    else:  # labeled
        inputs["moving_label"] = moving_label
        inputs["fixed_label"] = fixed_label
        outputs["pred_fixed_label"] = pred_fixed_label
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithLabel")

    # add loss and metric
    model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config)
    model = add_image_loss(
        model=model,
        fixed_image=fixed_image,
        pred_fixed_image=pred_fixed_image,
        loss_config=loss_config,
    )
    model = add_label_loss(
        model=model,
        grid_fixed=grid_fixed,
        fixed_label=fixed_label,
        pred_fixed_label=pred_fixed_label,
        loss_config=loss_config,
    )

    return model