Example #1
0
def test_build_optimizer_not_dict():
    """
    Test assertion error raised if
    config passed not dict.
    """
    with pytest.raises(AssertionError):
        optimizer.build_optimizer(["name"])
Example #2
0
def test_build_optimizer_error():
    """
    Assert value_error raised if
    unknown optimizer type is passed
    to build_optimizer func,
    """
    with pytest.raises(ValueError):
        optimizer.build_optimizer({"name": "random"})
Example #3
0
def test_build_optimizer_adam():
    """Assert that correct keras optimizer
    is returned when passing the adam string
    into build_optimizer function
    """
    dict_config = {"name": "adam", "adam": {}}
    opt_get = optimizer.build_optimizer(dict_config)
    assert isinstance(opt_get, tf.keras.optimizers.Adam)
Example #4
0
def test_build_optimizer_rms():
    """
    Assert that correct keras optimizer
    is returned when passing the rms string
    into build_optimizer function
    """
    dict_config = {"name": "rms", "rms": {}}
    opt_get = optimizer.build_optimizer(dict_config)
    assert isinstance(opt_get, tf.keras.optimizers.RMSprop)
Example #5
0
def predict(
    gpu: str,
    gpu_allow_growth: bool,
    ckpt_path: str,
    mode: str,
    batch_size: int,
    log_dir: str,
    sample_label: str,
    config_path: (str, list),
    save_nifti: bool = True,
    save_png: bool = True,
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: str, which env gpu to use.
    :param gpu_allow_growth: bool, whether to allow gpu growth or not
    :param ckpt_path: str, where model is stored, should be like log_folder/save/xxx.ckpt
    :param mode: train / valid / test, to define which split of dataset to be evaluated
    :param batch_size: int, batch size to perform predictions in
    :param log_dir: str, path to store logs
    :param sample_label: sample/all, not used
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    :param config_path: to overwrite the default config
    """
    # TODO support custom sample_label
    logging.warning(
        "sample_label is not used in predict. It is True if and only if mode == 'train'."
    )

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir = build_config(config_path=config_path,
                                   log_dir=log_dir,
                                   ckpt_path=ckpt_path)
    preprocess_config = config["train"]["preprocess"]
    preprocess_config["batch_size"] = batch_size

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=preprocess_config,
        mode=mode,
        training=False,
        repeat=False,
    )

    # optimizer
    optimizer = opt.build_optimizer(
        optimizer_config=config["train"]["optimizer"])

    # model
    model = build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=config["dataset"]["labeled"],
        batch_size=preprocess_config["batch_size"],
        model_config=config["train"]["model"],
        loss_config=config["train"]["loss"],
    )

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec
    model.load_weights(ckpt_path).expect_partial()

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        model_method=config["train"]["model"]["method"],
        save_dir=log_dir + "/test",
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
Example #6
0
def train(
    gpu: str,
    config_path: (str, list),
    gpu_allow_growth: bool,
    ckpt_path: str,
    log_dir: str,
):
    """
    Function to train a model

    :param gpu: str, which local gpu to use to train
    :param config_path: str, path to configuration set up
    :param gpu_allow_growth: bool, whether or not to allocate whole GPU memory to training
    :param ckpt_path: str, where to store training checkpoints
    :param log_dir: str, where to store logs in training
    """
    # set env variables
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"

    # load config
    config, log_dir = build_config(
        config_path=config_path, log_dir=log_dir, ckpt_path=ckpt_path
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="valid",
        training=False,
        repeat=True,
    )

    # build callbacks
    callbacks = build_callbacks(
        log_dir=log_dir,
        histogram_freq=config["train"][
            "save_period"
        ],  # use save_period for histogram_freq
        save_period=config["train"]["save_period"],
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_tfkerasmodelfit
    # only model, optimizer and metrics need to be defined inside the strategy
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = build_model(
            moving_image_size=data_loader_train.moving_image_shape,
            fixed_image_size=data_loader_train.fixed_image_shape,
            index_size=data_loader_train.num_indices,
            labeled=config["dataset"]["labeled"],
            batch_size=config["train"]["preprocess"]["batch_size"],
            model_config=config["train"]["model"],
            loss_config=config["train"]["loss"],
        )
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])

    # compile
    model.compile(optimizer=optimizer)

    # load weights
    if ckpt_path != "":
        model.load_weights(ckpt_path)

    # train
    # it's necessary to define the steps_per_epoch and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()
Example #7
0
def predict(
    gpu,
    gpu_allow_growth,
    ckpt_path,
    mode,
    batch_size,
    log_dir,
    sample_label,
    config_path,
):
    """
    Function to predict some metrics from the saved model and logging results.
    :param gpu: str, which env gpu to use.
    :param gpu_allow_growth: bool, whether to allow gpu growth or not
    :param ckpt_path: str, where model is stored, should be like
                      log_folder/save/xxx.ckpt
    :param mode: which mode to load the data ??
    :param batch_size: int, batch size to perform predictions in
    :param log_dir: str, path to store logs
    :param sample_label:
    :param config_path: to overwrite the default config
    """
    logging.error("TODO sample_label is not used in predict")

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir = init(log_dir, ckpt_path, config_path)
    dataset_config = config["dataset"]
    preprocess_config = config["train"]["preprocess"]
    preprocess_config["batch_size"] = batch_size
    optimizer_config = config["train"]["optimizer"]
    model_config = config["train"]["model"]
    loss_config = config["train"]["loss"]

    # data
    data_loader = load.get_data_loader(dataset_config, mode)
    if data_loader is None:
        raise ValueError(
            "Data loader for prediction is None. Probably the data dir path is not defined."
        )
    dataset = data_loader.get_dataset_and_preprocess(training=False,
                                                     repeat=False,
                                                     **preprocess_config)

    # optimizer
    optimizer = opt.build_optimizer(optimizer_config)

    # model
    model = build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=dataset_config["labeled"],
        batch_size=preprocess_config["batch_size"],
        model_config=model_config,
        loss_config=loss_config,
    )

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec
    model.load_weights(ckpt_path).expect_partial()

    # predict
    fixed_grid_ref = layer_util.get_reference_grid(
        grid_size=data_loader.fixed_image_shape)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        save_dir=log_dir + "/test",
    )

    data_loader.close()
Example #8
0
def predict(
    gpu: str,
    gpu_allow_growth: bool,
    ckpt_path: str,
    mode: str,
    batch_size: int,
    exp_name: str,
    config_path: Union[str, List[str]],
    save_nifti: bool = True,
    save_png: bool = True,
    log_dir: str = "logs",
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: which env gpu to use.
    :param gpu_allow_growth: whether to allow gpu growth or not
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x
    :param mode: train / valid / test, to define which split of dataset to be evaluated
    :param batch_size: int, batch size to perform predictions in
    :param exp_name: name of the experiment
    :param log_dir: path of the log directory
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    :param config_path: to overwrite the default config
    """
    # TODO support custom sample_label
    logging.warning("sample_label is not used in predict. "
                    "It is True if and only if mode == 'train'.")

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir, ckpt_path = build_config(config_path=config_path,
                                              log_dir=log_dir,
                                              exp_name=exp_name,
                                              ckpt_path=ckpt_path)
    preprocess_config = config["train"]["preprocess"]
    # batch_size corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    preprocess_config["batch_size"] = batch_size * max(len(gpus), 1)

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=preprocess_config,
        mode=mode,
        training=False,
        repeat=False,
    )
    assert data_loader is not None

    # optimizer
    optimizer = opt.build_optimizer(
        optimizer_config=config["train"]["optimizer"])

    # model
    model: tf.keras.Model = REGISTRY.build_model(config=dict(
        name=config["train"]["method"],
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=config["dataset"]["labeled"],
        batch_size=config["train"]["preprocess"]["batch_size"],
        config=config["train"],
    ))

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    if ckpt_path.endswith(".ckpt"):
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
        # skip warnings because of optimizers
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
    else:
        # for ckpts from ckpt manager callback
        _, _ = build_checkpoint_callback(
            model=model,
            dataset=dataset,
            log_dir=log_dir,
            save_period=config["train"]["save_period"],
            ckpt_path=ckpt_path,
        )

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        model_method=config["train"]["method"],
        save_dir=os.path.join(log_dir, "test"),
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
Example #9
0
def train(
    gpu: str,
    config_path: Union[str, List[str]],
    gpu_allow_growth: bool,
    ckpt_path: str,
    exp_name: str = "",
    log_dir: str = "logs",
    max_epochs: int = -1,
):
    """
    Function to train a model.

    :param gpu: which local gpu to use to train.
    :param config_path: path to configuration set up.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param ckpt_path: where to store training checkpoints.
    :param log_dir: path of the log directory.
    :param exp_name: experiment name.
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
    """
    # set env variables
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"

    # load config
    config, log_dir, ckpt_path = build_config(
        config_path=config_path,
        log_dir=log_dir,
        exp_name=exp_name,
        ckpt_path=ckpt_path,
        max_epochs=max_epochs,
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="valid",
        training=False,
        repeat=True,
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    if num_devices > 1:
        strategy = tf.distribute.MirroredStrategy()  # pragma: no cover
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(
            config=dict(
                name=config["train"]["method"],
                moving_image_size=data_loader_train.moving_image_shape,
                fixed_image_size=data_loader_train.fixed_image_shape,
                index_size=data_loader_train.num_indices,
                labeled=config["dataset"]["labeled"],
                batch_size=config["train"]["preprocess"]["batch_size"],
                config=config["train"],
                num_devices=num_devices,
            )
        )
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])

    # compile
    model.compile(optimizer=optimizer)
    model.plot_model(output_dir=log_dir)

    # build callbacks
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir, histogram_freq=config["train"]["save_period"]
    )
    ckpt_callback, initial_epoch = build_checkpoint_callback(
        model=model,
        dataset=dataset_train,
        log_dir=log_dir,
        save_period=config["train"]["save_period"],
        ckpt_path=ckpt_path,
    )
    callbacks = [tensorboard_callback, ckpt_callback]

    # train
    # it's necessary to define the steps_per_epoch
    # and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        initial_epoch=initial_epoch,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()
Example #10
0
def train(
    gpu: str,
    config_path: Union[str, List[str]],
    ckpt_path: str,
    num_workers: int = 1,
    gpu_allow_growth: bool = True,
    exp_name: str = "",
    log_dir: str = "logs",
    max_epochs: int = -1,
):
    """
    Function to train a model.

    :param gpu: which local gpu to use to train.
    :param config_path: path to configuration set up.
    :param ckpt_path: where to store training checkpoints.
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param log_dir: path of the log directory.
    :param exp_name: experiment name.
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
    """
    # set env variables
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"
    if num_workers <= 0:  # pragma: no cover
        logger.info(
            "Limiting CPU usage by setting environment variables "
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
            "This may slow down the training. "
            "Please use --num_workers flag to modify the behavior. "
            "Setting to 0 or negative values will remove the limitation.",
            num_workers,
        )
        # limit CPU usage
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)

    # load config
    config, log_dir, ckpt_path = build_config(
        config_path=config_path,
        log_dir=log_dir,
        exp_name=exp_name,
        ckpt_path=ckpt_path,
        max_epochs=max_epochs,
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="valid",
        training=False,
        repeat=True,
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    batch_size = config["train"]["preprocess"]["batch_size"]
    if num_devices > 1:  # pragma: no cover
        strategy = tf.distribute.MirroredStrategy()
        if batch_size % num_devices != 0:
            raise ValueError(
                f"batch size {batch_size} can not be divided evenly "
                f"by the number of devices.")
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(config=dict(
            name=config["train"]["method"],
            moving_image_size=data_loader_train.moving_image_shape,
            fixed_image_size=data_loader_train.fixed_image_shape,
            index_size=data_loader_train.num_indices,
            labeled=config["dataset"]["train"]["labeled"],
            batch_size=batch_size,
            config=config["train"],
        ))
        optimizer = opt.build_optimizer(
            optimizer_config=config["train"]["optimizer"])
        model.compile(optimizer=optimizer)
        model.plot_model(output_dir=log_dir)

    # build callbacks
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=config["train"]["save_period"],
        update_freq=config["train"].get("update_freq", "epoch"),
    )
    ckpt_callback, initial_epoch = build_checkpoint_callback(
        model=model,
        dataset=dataset_train,
        log_dir=log_dir,
        save_period=config["train"]["save_period"],
        ckpt_path=ckpt_path,
    )
    callbacks = [tensorboard_callback, ckpt_callback]

    # train
    # it's necessary to define the steps_per_epoch
    # and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        initial_epoch=initial_epoch,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()
Example #11
0
 def test_build_optimizer_sgd(self):
     """Build an SGD optimizer"""
     opt_config = {"name": "SGD"}
     opt_get = optimizer.build_optimizer(opt_config)
     assert isinstance(opt_get, tf.keras.optimizers.SGD)
Example #12
0
 def test_build_optimizer_adam(self):
     """Build an Adam optimizer"""
     opt_config = {"name": "Adam", "learning_rate": 1.0e-5}
     opt_get = optimizer.build_optimizer(opt_config)
     assert isinstance(opt_get, tf.keras.optimizers.Adam)
Example #13
0
def predict(
    gpu: str,
    ckpt_path: str,
    split: str,
    batch_size: int,
    exp_name: str,
    config_path: Union[str, List[str]],
    num_workers: int = 1,
    gpu_allow_growth: bool = True,
    save_nifti: bool = True,
    save_png: bool = True,
    log_dir: str = "logs",
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: which env gpu to use.
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
    :param split: train / valid / test, to define the split to be evaluated.
    :param batch_size: int, batch size to perform predictions.
    :param exp_name: name of the experiment.
    :param config_path: to overwrite the default config.
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param save_nifti: if true, outputs will be saved in nifti format.
    :param save_png: if true, outputs will be saved in png format.
    :param log_dir: path of the log directory.
    """

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
    if num_workers <= 0:  # pragma: no cover
        logger.info(
            "Limiting CPU usage by setting environment variables "
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
            "This may slow down the prediction. "
            "Please use --num_workers flag to modify the behavior. "
            "Setting to 0 or negative values will remove the limitation.",
            num_workers,
        )
        # limit CPU usage
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)

    # load config
    config, log_dir, ckpt_path = build_config(config_path=config_path,
                                              log_dir=log_dir,
                                              exp_name=exp_name,
                                              ckpt_path=ckpt_path)
    config["train"]["preprocess"]["batch_size"] = batch_size

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split=split,
        training=False,
        repeat=False,
    )
    assert data_loader is not None

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    if num_devices > 1:  # pragma: no cover
        strategy = tf.distribute.MirroredStrategy()
        if batch_size % num_devices != 0:
            raise ValueError(
                f"batch size {batch_size} can not be divided evenly "
                f"by the number of devices.")
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(config=dict(
            name=config["train"]["method"],
            moving_image_size=data_loader.moving_image_shape,
            fixed_image_size=data_loader.fixed_image_shape,
            index_size=data_loader.num_indices,
            labeled=config["dataset"][split]["labeled"],
            batch_size=batch_size,
            config=config["train"],
        ))
        optimizer = opt.build_optimizer(
            optimizer_config=config["train"]["optimizer"])
        model.compile(optimizer=optimizer)
        model.plot_model(output_dir=log_dir)

    # load weights
    if ckpt_path.endswith(".ckpt"):
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
        # skip warnings because of optimizers
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
    else:
        # for ckpts from ckpt manager callback
        _, _ = build_checkpoint_callback(
            model=model,
            dataset=dataset,
            log_dir=log_dir,
            save_period=config["train"]["save_period"],
            ckpt_path=ckpt_path,
        )

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        save_dir=os.path.join(log_dir, "test"),
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()