def test_build_affine_model(): """ Testing that build_affine_model function returns the tensors with correct shapes """ moving_image_size = (1, 3, 5) fixed_image_size = (2, 4, 6) batch_size = 1 model = build_affine_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=1, labeled=True, batch_size=batch_size, train_config={ "method": "affine", "backbone": { "name": "global", "num_channel_initial": 4, "extract_levels": [1, 2, 3], }, "loss": { "dissimilarity": { "image": { "name": "lncc", "weight": 0.1 }, "label": { "name": "multi_scale", "weight": 1, "multi_scale": { "loss_type": "dice", "loss_scales": [0, 1, 2, 4, 8, 16, 32], }, }, }, "regularization": { "weight": 0.0, "energy_type": "bending" }, }, }, ) inputs = { "moving_image": tf.ones((batch_size, ) + moving_image_size), "fixed_image": tf.ones((batch_size, ) + fixed_image_size), "indices": 1, "moving_label": tf.ones((batch_size, ) + moving_image_size), "fixed_label": tf.ones((batch_size, ) + fixed_image_size), } outputs = model(inputs) expected_outputs_keys = ["affine", "ddf", "pred_fixed_label"] assert all(keys in expected_outputs_keys for keys in outputs) assert outputs["pred_fixed_label"].shape == ( batch_size, ) + fixed_image_size assert outputs["affine"].shape == (batch_size, ) + (4, ) + (3, ) assert outputs["ddf"].shape == (batch_size, ) + fixed_image_size + (3, )
def build_model( moving_image_size: tuple, fixed_image_size: tuple, index_size: int, labeled: bool, batch_size: int, model_config: dict, loss_config: dict, ): """ Parsing algorithm types to model building functions :param moving_image_size: [m_dim1, m_dim2, m_dim3] :param fixed_image_size: [f_dim1, f_dim2, f_dim3] :param index_size: dataset size :param labeled: true if the label of moving/fixed images are provided :param batch_size: mini-batch size :param model_config: model configuration, e.g. dictionary return from parser.yaml.load :param loss_config: loss configuration, e.g. dictionary return from parser.yaml.load :return: the built tf.keras.Model """ if model_config["method"] in ["ddf", "dvf"]: return build_ddf_dvf_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, model_config=model_config, loss_config=loss_config, ) elif model_config["method"] == "conditional": return build_conditional_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, model_config=model_config, loss_config=loss_config, ) elif model_config["method"] == "affine": return build_affine_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, model_config=model_config, loss_config=loss_config, ) else: raise ValueError("Unknown model method")
def build_model( moving_image_size: tuple, fixed_image_size: tuple, index_size: int, labeled: bool, batch_size: int, train_config: dict, registry: Registry, ): """ Parsing algorithm types to model building functions. :param moving_image_size: [m_dim1, m_dim2, m_dim3] :param fixed_image_size: [f_dim1, f_dim2, f_dim3] :param index_size: dataset size :param labeled: true if the label of moving/fixed images are provided :param batch_size: mini-batch size :param train_config: train configuration :return: the built tf.keras.Model """ if train_config["method"] in ["ddf", "dvf"]: return build_ddf_dvf_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, train_config=train_config, registry=registry, ) elif train_config["method"] == "conditional": return build_conditional_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, train_config=train_config, registry=registry, ) elif train_config["method"] == "affine": return build_affine_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, train_config=train_config, registry=registry, ) else: raise ValueError(f"Unknown method {train_config['method']}")