def test_global_return(): """ Testing that build_backbone func returns an object of type GlobalNet from backbone module when initialised with the associated GlobalNet config. """ out = util.build_backbone( image_size=(1, 2, 3), out_channels=1, model_config={ "backbone": "global", "global": { "num_channel_initial": 4, "extract_levels": [1, 2, 3] }, }, method_name="ddf", ) assert isinstance( out, type( global_net.GlobalNet([1, 2, 3], 4, 4, [1, 2, 3], "he_normal", "sigmoid")), )
def build_conditional_model( moving_image_size: tuple, fixed_image_size: tuple, index_size: int, labeled: bool, batch_size: int, train_config: dict, registry: Registry, ) -> tf.keras.Model: """ Build a model which outputs predicted fixed label. :param moving_image_size: (m_dim1, m_dim2, m_dim3) :param fixed_image_size: (f_dim1, f_dim2, f_dim3) :param index_size: int, the number of indices for identifying a sample :param labeled: bool, indicating if the data is labeled :param batch_size: int, size of mini-batch :param train_config: config for the model and loss :return: the built tf.keras.Model """ # inputs (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, batch_size=batch_size, labeled=labeled, ) # backbone backbone = build_backbone( image_size=fixed_image_size, out_channels=1, config=train_config["backbone"], method_name=train_config["method"], registry=registry, ) # prediction pred_fixed_label, grid_fixed = conditional_forward( backbone=backbone, moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, ) # (batch, f_dim1, f_dim2, f_dim3) # build model inputs = { "moving_image": moving_image, "fixed_image": fixed_image, "moving_label": moving_label, "fixed_label": fixed_label, "indices": indices, } outputs = {"pred_fixed_label": pred_fixed_label} model = tf.keras.Model(inputs=inputs, outputs=outputs, name="ConditionalRegistrationModel") # loss and metric model = add_label_loss( model=model, grid_fixed=grid_fixed, fixed_label=fixed_label, pred_fixed_label=pred_fixed_label, loss_config=train_config["loss"], ) return model
def build_ddf_dvf_model( moving_image_size: tuple, fixed_image_size: tuple, index_size: int, labeled: bool, batch_size: int, train_config: dict, registry: Registry, ) -> tf.keras.Model: """ Build a model which outputs DDF/DVF. :param moving_image_size: (m_dim1, m_dim2, m_dim3) :param fixed_image_size: (f_dim1, f_dim2, f_dim3) :param index_size: int, the number of indices for identifying a sample :param labeled: bool, indicating if the data is labeled :param batch_size: int, size of mini-batch :param train_config: config for the model and loss :return: the built tf.keras.Model """ # inputs (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, batch_size=batch_size, labeled=labeled, ) # backbone backbone = build_backbone( image_size=fixed_image_size, out_channels=3, config=train_config["backbone"], method_name=train_config["method"], registry=registry, ) # forward dvf, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = ddf_dvf_forward( backbone=backbone, moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, output_dvf=train_config["method"] == "dvf", ) # build model inputs = { "moving_image": moving_image, "fixed_image": fixed_image, "indices": indices, } outputs = {"ddf": ddf} if dvf is not None: outputs["dvf"] = dvf model_name = train_config["method"].upper() + "RegistrationModel" if moving_label is None: # unlabeled model = tf.keras.Model(inputs=inputs, outputs=outputs, name=model_name + "WithoutLabel") else: # labeled inputs["moving_label"] = moving_label inputs["fixed_label"] = fixed_label outputs["pred_fixed_label"] = pred_fixed_label model = tf.keras.Model(inputs=inputs, outputs=outputs, name=model_name + "WithLabel") # add loss and metric loss_config = train_config["loss"] model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config) model = add_image_loss( model=model, fixed_image=fixed_image, pred_fixed_image=pred_fixed_image, loss_config=loss_config, ) model = add_label_loss( model=model, grid_fixed=grid_fixed, fixed_label=fixed_label, pred_fixed_label=pred_fixed_label, loss_config=loss_config, ) return model
def build_affine_model( moving_image_size: tuple, fixed_image_size: tuple, index_size: int, labeled: bool, batch_size: int, model_config: dict, loss_config: dict, ): """ Build a model which outputs the parameters for affine transformation. :param moving_image_size: (m_dim1, m_dim2, m_dim3) :param fixed_image_size: (f_dim1, f_dim2, f_dim3) :param index_size: int, the number of indices for identifying a sample :param labeled: bool, indicating if the data is labeled :param batch_size: int, size of mini-batch :param model_config: config for the model :param loss_config: config for the loss :return: the built tf.keras.Model """ # inputs (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, batch_size=batch_size, labeled=labeled, ) # backbone backbone = build_backbone( image_size=fixed_image_size, out_channels=3, model_config=model_config, method_name=model_config["method"], ) # forward affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward( backbone=backbone, moving_image=moving_image, fixed_image=fixed_image, moving_label=moving_label, moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, ) # build model inputs = { "moving_image": moving_image, "fixed_image": fixed_image, "indices": indices, } outputs = {"ddf": ddf, "affine": affine} model_name = model_config["method"].upper() + "RegistrationModel" if moving_label is None: # unlabeled model = tf.keras.Model(inputs=inputs, outputs=outputs, name=model_name + "WithoutLabel") else: # labeled inputs["moving_label"] = moving_label inputs["fixed_label"] = fixed_label outputs["pred_fixed_label"] = pred_fixed_label model = tf.keras.Model(inputs=inputs, outputs=outputs, name=model_name + "WithLabel") # add loss and metric model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config) model = add_image_loss( model=model, fixed_image=fixed_image, pred_fixed_image=pred_fixed_image, loss_config=loss_config, ) model = add_label_loss( model=model, grid_fixed=grid_fixed, fixed_label=fixed_label, pred_fixed_label=pred_fixed_label, loss_config=loss_config, ) return model