def test_backbone(self): registry = Registry() key = "new_backbone" value = 0 registry.register_backbone(key, value) got = registry.get_backbone(key) assert got == value
def test_conditional_forward(): """ Testing that conditional_forward function returns the tensors with correct shapes """ moving_image_size = (1, 3, 5) fixed_image_size = (2, 4, 6) batch_size = 1 local_net = build_backbone( image_size=fixed_image_size, out_channels=1, config={ "name": "local", "num_channel_initial": 4, "extract_levels": [1, 2, 3], }, method_name="conditional", registry=Registry(), ) # Check conditional mode network output shapes - Pass pred_fixed_label, grid_fixed = conditional_forward( backbone=local_net, moving_image=tf.ones((batch_size, ) + moving_image_size), fixed_image=tf.ones((batch_size, ) + fixed_image_size), moving_label=tf.ones((batch_size, ) + moving_image_size), moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, ) assert pred_fixed_label.shape == (batch_size, ) + fixed_image_size assert grid_fixed.shape == fixed_image_size + (3, )
def build_backbone( image_size: tuple, out_channels: int, config: dict, method_name: str, registry: Registry = Registry(), ) -> tf.keras.Model: """ Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in) and returns a single output of shape (batch, dim1, dim2, dim3, ch_out). :param image_size: tuple, dims of image, (dim1, dim2, dim3) :param out_channels: int, number of out channels, ch_out :param method_name: str, one of ddf, dvf and conditional :param config: dict, backbone configuration :param registry: the registry object having all backbone classes :return: tf.keras.Model """ if not ((isinstance(image_size, tuple) or isinstance(image_size, list)) and len(image_size) == 3): raise ValueError( f"image_size must be tuple of length 3, got {image_size}") if method_name not in ["ddf", "dvf", "conditional", "affine"]: raise ValueError( f"method name has to be one of ddf/dvf/conditional/affine in build_backbone, " f"got {method_name}") if method_name in ["ddf", "dvf"]: out_activation = None # TODO try random init with smaller number out_kernel_initializer = "zeros" # to ensure small ddf and dvf elif method_name in ["conditional"]: out_activation = "sigmoid" # output is probability out_kernel_initializer = "glorot_uniform" elif method_name in ["affine"]: out_activation = None out_kernel_initializer = "zeros" backbone_cls = registry.get_backbone(key=config["name"]) return backbone_cls( image_size=image_size, out_channels=out_channels, out_kernel_initializer=out_kernel_initializer, out_activation=out_activation, **config, )
def test_build_conditional_model(): """ Testing that build_conditional_model function returns the tensors with correct shapes """ moving_image_size = (1, 3, 5) fixed_image_size = (2, 4, 6) batch_size = 1 model = build_conditional_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=1, labeled=True, batch_size=batch_size, train_config={ "method": "conditional", "backbone": { "name": "local", "num_channel_initial": 4, "extract_levels": [1, 2, 3], }, "loss": { "dissimilarity": { "image": { "name": "lncc", "weight": 0.0 }, "label": { "name": "multi_scale", "weight": 1, "multi_scale": { "loss_type": "dice", "loss_scales": [0, 1, 2, 4, 8, 16, 32], }, }, }, "regularization": { "weight": 0.5, "energy_type": "bending" }, }, }, registry=Registry(), ) inputs = { "moving_image": tf.ones((batch_size, ) + moving_image_size), "fixed_image": tf.ones((batch_size, ) + fixed_image_size), "indices": 1, "moving_label": tf.ones((batch_size, ) + moving_image_size), "fixed_label": tf.ones((batch_size, ) + fixed_image_size), } outputs = model(inputs) expected_outputs_keys = ["pred_fixed_label"] assert all(keys in expected_outputs_keys for keys in outputs) assert outputs["pred_fixed_label"].shape == ( batch_size, ) + fixed_image_size
def test_wrong_image_size(self): with pytest.raises(ValueError) as err_info: util.build_backbone( image_size=(1, 1, 1, 1), out_channels=1, config={}, method_name="ddf", registry=Registry(), ) assert "image_size must be tuple of length 3" in str(err_info.value)
def test_unet_backbone(self, method_name, out_channels): """Only test the function returns successfully""" util.build_backbone( image_size=(2, 3, 4), out_channels=out_channels, config={ "name": "unet", "num_channel_initial": 4, "depth": 4, }, method_name=method_name, registry=Registry(), )
def test_local_global_backbone(self, method_name, out_channels, backbone_name): """Only test the function returns successfully""" util.build_backbone( image_size=(2, 3, 4), out_channels=out_channels, config={ "name": backbone_name, "num_channel_initial": 4, "extract_levels": [1, 2, 3], }, method_name=method_name, registry=Registry(), )
def test_wrong_method_name(self): with pytest.raises(ValueError) as err_info: util.build_backbone( image_size=(1, 2, 3), out_channels=1, config={"backbone": "local"}, method_name="wrong", registry=Registry(), ) assert ( "method name has to be one of ddf/dvf/conditional/affine in build_backbone" in str(err_info.value) )
def test_build(self, method, backbone): train_config = self.train_config.copy() train_config["method"] = method train_config["backbone"]["name"] = backbone build_model( moving_image_size=self.moving_image_size, fixed_image_size=self.fixed_image_size, index_size=self.index_size, labeled=True, batch_size=self.batch_size, train_config=train_config, registry=Registry(), )
def test_build_err(self): train_config = self.train_config.copy() train_config["method"] = "unknown" with pytest.raises(ValueError) as err_info: build_model( moving_image_size=self.moving_image_size, fixed_image_size=self.fixed_image_size, index_size=self.index_size, labeled=True, batch_size=self.batch_size, train_config=train_config, registry=Registry(), ) assert "Unknown method" in str(err_info.value)
def test_register_err(self, category, key, err_msg): registry = Registry() with pytest.raises(ValueError) as err_info: registry.register(category, key, 0) assert err_msg in str(err_info.value)
def test_get_err(self): registry = Registry() with pytest.raises(ValueError) as err_info: registry.get("backbone_class", "wrong_key") assert "has not been registered" in str(err_info.value)
def test_get(self): category, key, value = "backbone_class", "test_key", 0 registry = Registry() registry.register(category, key, value) assert registry.get(category, key) == value
def test_register(self): category, key, value = "backbone_class", "test_key", 0 registry = Registry() registry.register(category, key, value) assert registry._dict[(category, key)] == value
def train( gpu: str, config_path: (str, list), gpu_allow_growth: bool, ckpt_path: str, log_dir: str, log_root: str = "logs", max_epochs: int = -1, registry: Registry = Registry(), ): """ Function to train a model. :param gpu: str, which local gpu to use to train :param config_path: str, path to configuration set up :param gpu_allow_growth: bool, whether or not to allocate whole GPU memory to training :param ckpt_path: str, where to store training checkpoints :param log_root: str, root of logs :param log_dir: str, where to store logs in training :param max_epochs: int, if max_epochs > 0, will use it to overwrite the configuration """ # set env variables os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false" # load config config, log_dir = build_config( config_path=config_path, log_root=log_root, log_dir=log_dir, ckpt_path=ckpt_path, max_epochs=max_epochs, ) # build dataset data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], mode="train", training=True, repeat=True, ) assert data_loader_train is not None # train data should not be None data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], mode="valid", training=False, repeat=True, ) # build callbacks callbacks = build_callbacks( log_dir=log_dir, histogram_freq=config["train"] ["save_period"], # use save_period for histogram_freq save_period=config["train"]["save_period"], ) # use strategy to support multiple GPUs # the network is mirrored in each GPU so that we can use larger batch size # https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_tfkerasmodelfit # only model, optimizer and metrics need to be defined inside the strategy if len(tf.config.list_physical_devices("GPU")) > 1: strategy = tf.distribute.MirroredStrategy() # pragma: no cover else: strategy = tf.distribute.get_strategy() with strategy.scope(): model = build_model( moving_image_size=data_loader_train.moving_image_shape, fixed_image_size=data_loader_train.fixed_image_shape, index_size=data_loader_train.num_indices, labeled=config["dataset"]["labeled"], batch_size=config["train"]["preprocess"]["batch_size"], train_config=config["train"], registry=registry, ) optimizer = opt.build_optimizer( optimizer_config=config["train"]["optimizer"]) # compile model.compile(optimizer=optimizer) # load weights if ckpt_path != "": model.load_weights(ckpt_path) # train # it's necessary to define the steps_per_epoch and validation_steps to prevent errors like # BaseCollectiveExecutor::StartAbort Out of range: End of sequence model.fit( x=dataset_train, steps_per_epoch=steps_per_epoch_train, epochs=config["train"]["epochs"], validation_data=dataset_val, validation_steps=steps_per_epoch_val, callbacks=callbacks, ) # close file loaders in data loaders after training data_loader_train.close() if data_loader_val is not None: data_loader_val.close()
def predict( gpu: str, gpu_allow_growth: bool, ckpt_path: str, mode: str, batch_size: int, log_dir: str, sample_label: str, config_path: (str, list), save_nifti: bool = True, save_png: bool = True, log_root: str = "logs", registry: Registry = Registry(), ): """ Function to predict some metrics from the saved model and logging results. :param gpu: str, which env gpu to use. :param gpu_allow_growth: bool, whether to allow gpu growth or not :param ckpt_path: str, where model is stored, should be like log_folder/save/xxx.ckpt :param mode: train / valid / test, to define which split of dataset to be evaluated :param batch_size: int, batch size to perform predictions in :param log_dir: str, path to store logs :param sample_label: sample/all, not used :param save_nifti: if true, outputs will be saved in nifti format :param save_png: if true, outputs will be saved in png format :param config_path: to overwrite the default config """ # TODO support custom sample_label logging.warning( "sample_label is not used in predict. It is True if and only if mode == 'train'." ) # env vars os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" # load config config, log_dir = build_config(config_path=config_path, log_root=log_root, log_dir=log_dir, ckpt_path=ckpt_path) preprocess_config = config["train"]["preprocess"] # batch_size corresponds to batch_size per GPU gpus = tf.config.experimental.list_physical_devices("GPU") preprocess_config["batch_size"] = batch_size * max(len(gpus), 1) # data data_loader, dataset, _ = build_dataset( dataset_config=config["dataset"], preprocess_config=preprocess_config, mode=mode, training=False, repeat=False, ) # optimizer optimizer = opt.build_optimizer( optimizer_config=config["train"]["optimizer"]) # model model = build_model( moving_image_size=data_loader.moving_image_shape, fixed_image_size=data_loader.fixed_image_shape, index_size=data_loader.num_indices, labeled=config["dataset"]["labeled"], batch_size=preprocess_config["batch_size"], train_config=config["train"], registry=registry, ) # metrics model.compile(optimizer=optimizer) # load weights # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec model.load_weights(ckpt_path).expect_partial() # predict fixed_grid_ref = tf.expand_dims( layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0) # shape = (1, f_dim1, f_dim2, f_dim3, 3) predict_on_dataset( dataset=dataset, fixed_grid_ref=fixed_grid_ref, model=model, model_method=config["train"]["method"], save_dir=log_dir + "/test", save_nifti=save_nifti, save_png=save_png, ) # close the opened files in data loaders data_loader.close()
def test_build_ddf_dvf_model(): """ Testing that build_ddf_dvf_model function returns the tensors with correct shapes """ moving_image_size = (1, 3, 5) fixed_image_size = (2, 4, 6) batch_size = 1 train_config = { "method": "ddf", "backbone": { "name": "local", "num_channel_initial": 4, "extract_levels": [1, 2, 3], }, "loss": { "dissimilarity": { "image": {"name": "lncc", "weight": 0.1}, "label": { "name": "multi_scale", "weight": 1, "multi_scale": { "loss_type": "dice", "loss_scales": [0, 1, 2, 4, 8, 16, 32], }, }, }, "regularization": {"weight": 0.0, "energy_type": "bending"}, }, } # Create DDF model model_ddf = build_ddf_dvf_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=1, labeled=True, batch_size=batch_size, train_config=train_config, registry=Registry(), ) # Create DVF model train_config["method"] = "dvf" model_dvf = build_ddf_dvf_model( moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=1, labeled=True, batch_size=batch_size, train_config=train_config, registry=Registry(), ) inputs = { "moving_image": tf.ones((batch_size,) + moving_image_size), "fixed_image": tf.ones((batch_size,) + fixed_image_size), "indices": 1, "moving_label": tf.ones((batch_size,) + moving_image_size), "fixed_label": tf.ones((batch_size,) + fixed_image_size), } outputs_ddf = model_ddf(inputs) outputs_dvf = model_dvf(inputs) expected_outputs_keys = ["dvf", "ddf", "pred_fixed_label"] assert all(keys in expected_outputs_keys for keys in outputs_ddf) assert outputs_ddf["pred_fixed_label"].shape == (batch_size,) + fixed_image_size assert outputs_ddf["ddf"].shape == (batch_size,) + fixed_image_size + (3,) with pytest.raises(KeyError): outputs_ddf["dvf"] assert all(keys in expected_outputs_keys for keys in outputs_dvf) assert outputs_dvf["pred_fixed_label"].shape == (batch_size,) + fixed_image_size assert outputs_dvf["dvf"].shape == (batch_size,) + fixed_image_size + (3,) assert outputs_dvf["ddf"].shape == (batch_size,) + fixed_image_size + (3,)