Beispiel #1
0
 def test_backbone(self):
     registry = Registry()
     key = "new_backbone"
     value = 0
     registry.register_backbone(key, value)
     got = registry.get_backbone(key)
     assert got == value
Beispiel #2
0
def test_conditional_forward():
    """
    Testing that conditional_forward function returns the tensors with correct shapes
    """

    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)
    batch_size = 1

    local_net = build_backbone(
        image_size=fixed_image_size,
        out_channels=1,
        config={
            "name": "local",
            "num_channel_initial": 4,
            "extract_levels": [1, 2, 3],
        },
        method_name="conditional",
        registry=Registry(),
    )

    # Check conditional mode network output shapes - Pass
    pred_fixed_label, grid_fixed = conditional_forward(
        backbone=local_net,
        moving_image=tf.ones((batch_size, ) + moving_image_size),
        fixed_image=tf.ones((batch_size, ) + fixed_image_size),
        moving_label=tf.ones((batch_size, ) + moving_image_size),
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )
    assert pred_fixed_label.shape == (batch_size, ) + fixed_image_size
    assert grid_fixed.shape == fixed_image_size + (3, )
def build_backbone(
        image_size: tuple,
        out_channels: int,
        config: dict,
        method_name: str,
        registry: Registry = Registry(),
) -> tf.keras.Model:
    """
    Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in)
    and returns a single output of shape (batch, dim1, dim2, dim3, ch_out).

    :param image_size: tuple, dims of image, (dim1, dim2, dim3)
    :param out_channels: int, number of out channels, ch_out
    :param method_name: str, one of ddf, dvf and conditional
    :param config: dict, backbone configuration
    :param registry: the registry object having all backbone classes
    :return: tf.keras.Model
    """
    if not ((isinstance(image_size, tuple) or isinstance(image_size, list))
            and len(image_size) == 3):
        raise ValueError(
            f"image_size must be tuple of length 3, got {image_size}")

    if method_name not in ["ddf", "dvf", "conditional", "affine"]:
        raise ValueError(
            f"method name has to be one of ddf/dvf/conditional/affine in build_backbone, "
            f"got {method_name}")

    if method_name in ["ddf", "dvf"]:
        out_activation = None
        # TODO try random init with smaller number
        out_kernel_initializer = "zeros"  # to ensure small ddf and dvf
    elif method_name in ["conditional"]:
        out_activation = "sigmoid"  # output is probability
        out_kernel_initializer = "glorot_uniform"
    elif method_name in ["affine"]:
        out_activation = None
        out_kernel_initializer = "zeros"

    backbone_cls = registry.get_backbone(key=config["name"])
    return backbone_cls(
        image_size=image_size,
        out_channels=out_channels,
        out_kernel_initializer=out_kernel_initializer,
        out_activation=out_activation,
        **config,
    )
Beispiel #4
0
def test_build_conditional_model():
    """
    Testing that build_conditional_model function returns the tensors with correct shapes
    """
    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)
    batch_size = 1

    model = build_conditional_model(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=1,
        labeled=True,
        batch_size=batch_size,
        train_config={
            "method": "conditional",
            "backbone": {
                "name": "local",
                "num_channel_initial": 4,
                "extract_levels": [1, 2, 3],
            },
            "loss": {
                "dissimilarity": {
                    "image": {
                        "name": "lncc",
                        "weight": 0.0
                    },
                    "label": {
                        "name": "multi_scale",
                        "weight": 1,
                        "multi_scale": {
                            "loss_type": "dice",
                            "loss_scales": [0, 1, 2, 4, 8, 16, 32],
                        },
                    },
                },
                "regularization": {
                    "weight": 0.5,
                    "energy_type": "bending"
                },
            },
        },
        registry=Registry(),
    )

    inputs = {
        "moving_image": tf.ones((batch_size, ) + moving_image_size),
        "fixed_image": tf.ones((batch_size, ) + fixed_image_size),
        "indices": 1,
        "moving_label": tf.ones((batch_size, ) + moving_image_size),
        "fixed_label": tf.ones((batch_size, ) + fixed_image_size),
    }
    outputs = model(inputs)

    expected_outputs_keys = ["pred_fixed_label"]
    assert all(keys in expected_outputs_keys for keys in outputs)
    assert outputs["pred_fixed_label"].shape == (
        batch_size, ) + fixed_image_size
Beispiel #5
0
 def test_wrong_image_size(self):
     with pytest.raises(ValueError) as err_info:
         util.build_backbone(
             image_size=(1, 1, 1, 1),
             out_channels=1,
             config={},
             method_name="ddf",
             registry=Registry(),
         )
     assert "image_size must be tuple of length 3" in str(err_info.value)
Beispiel #6
0
 def test_unet_backbone(self, method_name, out_channels):
     """Only test the function returns successfully"""
     util.build_backbone(
         image_size=(2, 3, 4),
         out_channels=out_channels,
         config={
             "name": "unet",
             "num_channel_initial": 4,
             "depth": 4,
         },
         method_name=method_name,
         registry=Registry(),
     )
Beispiel #7
0
 def test_local_global_backbone(self, method_name, out_channels, backbone_name):
     """Only test the function returns successfully"""
     util.build_backbone(
         image_size=(2, 3, 4),
         out_channels=out_channels,
         config={
             "name": backbone_name,
             "num_channel_initial": 4,
             "extract_levels": [1, 2, 3],
         },
         method_name=method_name,
         registry=Registry(),
     )
Beispiel #8
0
 def test_wrong_method_name(self):
     with pytest.raises(ValueError) as err_info:
         util.build_backbone(
             image_size=(1, 2, 3),
             out_channels=1,
             config={"backbone": "local"},
             method_name="wrong",
             registry=Registry(),
         )
     assert (
         "method name has to be one of ddf/dvf/conditional/affine in build_backbone"
         in str(err_info.value)
     )
Beispiel #9
0
 def test_build(self, method, backbone):
     train_config = self.train_config.copy()
     train_config["method"] = method
     train_config["backbone"]["name"] = backbone
     build_model(
         moving_image_size=self.moving_image_size,
         fixed_image_size=self.fixed_image_size,
         index_size=self.index_size,
         labeled=True,
         batch_size=self.batch_size,
         train_config=train_config,
         registry=Registry(),
     )
Beispiel #10
0
 def test_build_err(self):
     train_config = self.train_config.copy()
     train_config["method"] = "unknown"
     with pytest.raises(ValueError) as err_info:
         build_model(
             moving_image_size=self.moving_image_size,
             fixed_image_size=self.fixed_image_size,
             index_size=self.index_size,
             labeled=True,
             batch_size=self.batch_size,
             train_config=train_config,
             registry=Registry(),
         )
     assert "Unknown method" in str(err_info.value)
Beispiel #11
0
 def test_register_err(self, category, key, err_msg):
     registry = Registry()
     with pytest.raises(ValueError) as err_info:
         registry.register(category, key, 0)
     assert err_msg in str(err_info.value)
Beispiel #12
0
 def test_get_err(self):
     registry = Registry()
     with pytest.raises(ValueError) as err_info:
         registry.get("backbone_class", "wrong_key")
     assert "has not been registered" in str(err_info.value)
Beispiel #13
0
 def test_get(self):
     category, key, value = "backbone_class", "test_key", 0
     registry = Registry()
     registry.register(category, key, value)
     assert registry.get(category, key) == value
Beispiel #14
0
 def test_register(self):
     category, key, value = "backbone_class", "test_key", 0
     registry = Registry()
     registry.register(category, key, value)
     assert registry._dict[(category, key)] == value
Beispiel #15
0
def train(
        gpu: str,
        config_path: (str, list),
        gpu_allow_growth: bool,
        ckpt_path: str,
        log_dir: str,
        log_root: str = "logs",
        max_epochs: int = -1,
        registry: Registry = Registry(),
):
    """
    Function to train a model.

    :param gpu: str, which local gpu to use to train
    :param config_path: str, path to configuration set up
    :param gpu_allow_growth: bool, whether or not to allocate whole GPU memory to training
    :param ckpt_path: str, where to store training checkpoints
    :param log_root: str, root of logs
    :param log_dir: str, where to store logs in training
    :param max_epochs: int, if max_epochs > 0, will use it to overwrite the configuration
    """
    # set env variables
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"

    # load config
    config, log_dir = build_config(
        config_path=config_path,
        log_root=log_root,
        log_dir=log_dir,
        ckpt_path=ckpt_path,
        max_epochs=max_epochs,
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="valid",
        training=False,
        repeat=True,
    )

    # build callbacks
    callbacks = build_callbacks(
        log_dir=log_dir,
        histogram_freq=config["train"]
        ["save_period"],  # use save_period for histogram_freq
        save_period=config["train"]["save_period"],
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_tfkerasmodelfit
    # only model, optimizer and metrics need to be defined inside the strategy
    if len(tf.config.list_physical_devices("GPU")) > 1:
        strategy = tf.distribute.MirroredStrategy()  # pragma: no cover
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model = build_model(
            moving_image_size=data_loader_train.moving_image_shape,
            fixed_image_size=data_loader_train.fixed_image_shape,
            index_size=data_loader_train.num_indices,
            labeled=config["dataset"]["labeled"],
            batch_size=config["train"]["preprocess"]["batch_size"],
            train_config=config["train"],
            registry=registry,
        )
        optimizer = opt.build_optimizer(
            optimizer_config=config["train"]["optimizer"])

    # compile
    model.compile(optimizer=optimizer)

    # load weights
    if ckpt_path != "":
        model.load_weights(ckpt_path)

    # train
    # it's necessary to define the steps_per_epoch and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()
Beispiel #16
0
def predict(
        gpu: str,
        gpu_allow_growth: bool,
        ckpt_path: str,
        mode: str,
        batch_size: int,
        log_dir: str,
        sample_label: str,
        config_path: (str, list),
        save_nifti: bool = True,
        save_png: bool = True,
        log_root: str = "logs",
        registry: Registry = Registry(),
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: str, which env gpu to use.
    :param gpu_allow_growth: bool, whether to allow gpu growth or not
    :param ckpt_path: str, where model is stored, should be like log_folder/save/xxx.ckpt
    :param mode: train / valid / test, to define which split of dataset to be evaluated
    :param batch_size: int, batch size to perform predictions in
    :param log_dir: str, path to store logs
    :param sample_label: sample/all, not used
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    :param config_path: to overwrite the default config
    """
    # TODO support custom sample_label
    logging.warning(
        "sample_label is not used in predict. It is True if and only if mode == 'train'."
    )

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir = build_config(config_path=config_path,
                                   log_root=log_root,
                                   log_dir=log_dir,
                                   ckpt_path=ckpt_path)
    preprocess_config = config["train"]["preprocess"]
    # batch_size corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    preprocess_config["batch_size"] = batch_size * max(len(gpus), 1)

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=preprocess_config,
        mode=mode,
        training=False,
        repeat=False,
    )

    # optimizer
    optimizer = opt.build_optimizer(
        optimizer_config=config["train"]["optimizer"])

    # model
    model = build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=config["dataset"]["labeled"],
        batch_size=preprocess_config["batch_size"],
        train_config=config["train"],
        registry=registry,
    )

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec
    model.load_weights(ckpt_path).expect_partial()

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        model_method=config["train"]["method"],
        save_dir=log_dir + "/test",
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
Beispiel #17
0
def test_build_ddf_dvf_model():
    """
    Testing that build_ddf_dvf_model function returns the tensors with correct shapes
    """
    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)
    batch_size = 1
    train_config = {
        "method": "ddf",
        "backbone": {
            "name": "local",
            "num_channel_initial": 4,
            "extract_levels": [1, 2, 3],
        },
        "loss": {
            "dissimilarity": {
                "image": {"name": "lncc", "weight": 0.1},
                "label": {
                    "name": "multi_scale",
                    "weight": 1,
                    "multi_scale": {
                        "loss_type": "dice",
                        "loss_scales": [0, 1, 2, 4, 8, 16, 32],
                    },
                },
            },
            "regularization": {"weight": 0.0, "energy_type": "bending"},
        },
    }

    # Create DDF model
    model_ddf = build_ddf_dvf_model(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=1,
        labeled=True,
        batch_size=batch_size,
        train_config=train_config,
        registry=Registry(),
    )

    # Create DVF model
    train_config["method"] = "dvf"
    model_dvf = build_ddf_dvf_model(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=1,
        labeled=True,
        batch_size=batch_size,
        train_config=train_config,
        registry=Registry(),
    )
    inputs = {
        "moving_image": tf.ones((batch_size,) + moving_image_size),
        "fixed_image": tf.ones((batch_size,) + fixed_image_size),
        "indices": 1,
        "moving_label": tf.ones((batch_size,) + moving_image_size),
        "fixed_label": tf.ones((batch_size,) + fixed_image_size),
    }
    outputs_ddf = model_ddf(inputs)
    outputs_dvf = model_dvf(inputs)

    expected_outputs_keys = ["dvf", "ddf", "pred_fixed_label"]
    assert all(keys in expected_outputs_keys for keys in outputs_ddf)
    assert outputs_ddf["pred_fixed_label"].shape == (batch_size,) + fixed_image_size
    assert outputs_ddf["ddf"].shape == (batch_size,) + fixed_image_size + (3,)
    with pytest.raises(KeyError):
        outputs_ddf["dvf"]

    assert all(keys in expected_outputs_keys for keys in outputs_dvf)
    assert outputs_dvf["pred_fixed_label"].shape == (batch_size,) + fixed_image_size
    assert outputs_dvf["dvf"].shape == (batch_size,) + fixed_image_size + (3,)
    assert outputs_dvf["ddf"].shape == (batch_size,) + fixed_image_size + (3,)