Esempio n. 1
0
def build_ddf_dvf_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    train_config: dict,
    registry: Registry,
) -> tf.keras.Model:
    """
    Build a model which outputs DDF/DVF.

    :param moving_image_size: (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
    :param index_size: int, the number of indices for identifying a sample
    :param labeled: bool, indicating if the data is labeled
    :param batch_size: int, size of mini-batch
    :param train_config: config for the model and loss
    :return: the built tf.keras.Model
    """

    # inputs
    (moving_image, fixed_image, moving_label, fixed_label,
     indices) = build_inputs(
         moving_image_size=moving_image_size,
         fixed_image_size=fixed_image_size,
         index_size=index_size,
         batch_size=batch_size,
         labeled=labeled,
     )

    # backbone
    backbone = build_backbone(
        image_size=fixed_image_size,
        out_channels=3,
        config=train_config["backbone"],
        method_name=train_config["method"],
        registry=registry,
    )

    # forward
    dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = ddf_dvf_forward(
        backbone=backbone,
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        output_dvf=train_config["method"] == "dvf",
    )

    # build model
    inputs = {
        "moving_image": moving_image,
        "fixed_image": fixed_image,
        "indices": indices,
    }
    outputs = {"ddf": ddf}
    if dvf is not None:
        outputs["dvf"] = dvf
    model_name = train_config["method"].upper() + "RegistrationModel"
    if moving_label is None:  # unlabeled
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithoutLabel")
    else:  # labeled
        inputs["moving_label"] = moving_label
        inputs["fixed_label"] = fixed_label
        outputs["pred_fixed_label"] = pred_fixed_label
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithLabel")

    # add loss and metric
    loss_config = train_config["loss"]
    model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config)
    model = add_image_loss(
        model=model,
        fixed_image=fixed_image,
        pred_fixed_image=pred_fixed_image,
        loss_config=loss_config,
    )
    model = add_label_loss(
        model=model,
        grid_fixed=grid_fixed,
        fixed_label=fixed_label,
        pred_fixed_label=pred_fixed_label,
        loss_config=loss_config,
    )

    return model
Esempio n. 2
0
def build_affine_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    model_config: dict,
    loss_config: dict,
):
    """
    Build a model which outputs the parameters for affine transformation.

    :param moving_image_size: (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
    :param index_size: int, the number of indices for identifying a sample
    :param labeled: bool, indicating if the data is labeled
    :param batch_size: int, size of mini-batch
    :param model_config: config for the model
    :param loss_config: config for the loss
    :return: the built tf.keras.Model
    """

    # inputs
    (moving_image, fixed_image, moving_label, fixed_label,
     indices) = build_inputs(
         moving_image_size=moving_image_size,
         fixed_image_size=fixed_image_size,
         index_size=index_size,
         batch_size=batch_size,
         labeled=labeled,
     )

    # backbone
    backbone = build_backbone(
        image_size=fixed_image_size,
        out_channels=3,
        model_config=model_config,
        method_name=model_config["method"],
    )

    # forward
    affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward(
        backbone=backbone,
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )

    # build model
    inputs = {
        "moving_image": moving_image,
        "fixed_image": fixed_image,
        "indices": indices,
    }
    outputs = {"ddf": ddf, "affine": affine}
    model_name = model_config["method"].upper() + "RegistrationModel"
    if moving_label is None:  # unlabeled
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithoutLabel")
    else:  # labeled
        inputs["moving_label"] = moving_label
        inputs["fixed_label"] = fixed_label
        outputs["pred_fixed_label"] = pred_fixed_label
        model = tf.keras.Model(inputs=inputs,
                               outputs=outputs,
                               name=model_name + "WithLabel")

    # add loss and metric
    model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config)
    model = add_image_loss(
        model=model,
        fixed_image=fixed_image,
        pred_fixed_image=pred_fixed_image,
        loss_config=loss_config,
    )
    model = add_label_loss(
        model=model,
        grid_fixed=grid_fixed,
        fixed_label=fixed_label,
        pred_fixed_label=pred_fixed_label,
        loss_config=loss_config,
    )

    return model