示例#1
0
def build_dataset(
    dataset_config: dict, preprocess_config: dict
) -> [[DataLoader, tf.data.Dataset, int], [DataLoader, tf.data.Dataset, int]]:
    """
    Function to prepare dataset for training and validation.
    :param dataset_config: configuration for dataset
    :param preprocess_config: configuration for preprocess
    :return:
    - (data_loader_train, dataset_train, steps_per_epoch_train)
    - (data_loader_val, dataset_val, steps_per_epoch_valid)
    """
    data_loader_train = get_data_loader(dataset_config, "train")
    if data_loader_train is None:
        raise ValueError(
            "Training data loader is None. Probably the data dir path is not defined."
        )
    data_loader_val = get_data_loader(dataset_config, "valid")
    dataset_train = data_loader_train.get_dataset_and_preprocess(
        training=True, repeat=True, **preprocess_config)
    dataset_val = (data_loader_val.get_dataset_and_preprocess(
        training=False, repeat=True, **preprocess_config)
                   if data_loader_val is not None else None)
    dataset_size_train = data_loader_train.num_samples
    dataset_size_val = (data_loader_val.num_samples
                        if data_loader_val is not None else None)
    steps_per_epoch_train = max(
        dataset_size_train // preprocess_config["batch_size"], 1)
    steps_per_epoch_valid = (max(
        dataset_size_val // preprocess_config["batch_size"], 1)
                             if data_loader_val is not None else None)

    return (
        (data_loader_train, dataset_train, steps_per_epoch_train),
        (data_loader_val, dataset_val, steps_per_epoch_valid),
    )
示例#2
0
 def test_mode_err(self):
     """Check the error is raised when the split is wrong."""
     config = load_yaml("config/test/paired_nifti.yaml")
     with pytest.raises(ValueError) as err_info:
         load.get_data_loader(data_config=config["dataset"],
                              split="example")
     assert "split must be one of ['train', 'valid', 'test']" in str(
         err_info.value)
示例#3
0
    def test_dir_err(self, path: Optional[str]):
        """
        Check the error is raised when the path is wrong.

        :param path: training data path to be used
        """
        config = load_yaml("config/test/paired_nifti.yaml")
        config["dataset"]["dir"]["train"] = path
        with pytest.raises(ValueError) as err_info:
            load.get_data_loader(data_config=config["dataset"], mode="train")
        assert "is not a directory or does not exist" in str(err_info.value)
示例#4
0
def build_dataset(
    dataset_config: dict,
    preprocess_config: dict,
    mode: str,
    training: bool,
    repeat: bool,
) -> [(DataLoader, None), (tf.data.Dataset, None), (int, None)]:
    """
    Function to prepare dataset for training and validation.
    :param dataset_config: configuration for dataset
    :param preprocess_config: configuration for preprocess
    :param mode: train or valid or test
    :param training: bool, if true, data augmentation and shuffling will be added
    :param repeat: bool, if true, dataset will be repeated, true for train/valid dataset during model.fit
    :return:
    - (data_loader_train, dataset_train, steps_per_epoch_train)
    - (data_loader_val, dataset_val, steps_per_epoch_valid)

    Cannot move this function into deepreg/dataset/util.py
    as we need DataLoader to define the output
    """
    assert mode in ["train", "valid", "test"]
    data_loader = get_data_loader(dataset_config, mode)
    if data_loader is None:
        return None, None, None
    dataset = data_loader.get_dataset_and_preprocess(
        training=training, repeat=repeat, **preprocess_config
    )
    dataset_size = data_loader.num_samples
    steps_per_epoch = max(dataset_size // preprocess_config["batch_size"], 1)
    return data_loader, dataset, steps_per_epoch
示例#5
0
    def test_empty_config(self, mode: str):
        """
        Test return without data path for the mode.

        :param mode: train or valid or test
        """
        config = load_yaml("config/test/paired_nifti.yaml")
        config["dataset"]["dir"].pop(mode)
        got = load.get_data_loader(data_config=config["dataset"], mode=mode)
        assert got is None
示例#6
0
    def test_empty_path(self, path: Optional[str]):
        """
        Test return without data path.

        :param path: training data path to be used
        """
        config = load_yaml("config/test/paired_nifti.yaml")
        config["dataset"]["dir"]["train"] = path
        got = load.get_data_loader(data_config=config["dataset"], mode="train")
        assert got is None
示例#7
0
    def test_empty_config(self, split: str):
        """
        Test return without data path for the split.

        :param split: train or valid or test
        """
        config = load_yaml("config/test/paired_nifti.yaml")
        config["dataset"].pop(split)
        got = load.get_data_loader(data_config=config["dataset"], split=split)
        assert got is None
示例#8
0
    def test_data_loader(self, data_type: str, format: str):
        """
        Test the data loader can be successfully built.

        :param data_type: name of data loader for registry
        :param format: name of file loader for registry
        """
        # single paired data loader
        config = load_yaml(f"config/test/{data_type}_{format}.yaml")
        got = load.get_data_loader(data_config=config["dataset"], mode="train")
        expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type)
        assert isinstance(got, expected)  # type: ignore
示例#9
0
def main(args=None):
    """Entry point for gen_tfrecord"""

    parser = argparse.ArgumentParser(description="gen_tfrecord")

    ## ADD POSITIONAL ARGUMENTS
    parser.add_argument("--config_path",
                        "-c",
                        help="Path of config",
                        type=str,
                        required=True)

    parser.add_argument(
        "--examples_per_tfrecord",
        "-n",
        help="Number of examples per tfrecord",
        type=int,
        default=64,
    )

    args = parser.parse_args(args)

    config = config_parser.load_configs(args.config_path)
    dataset_config = config["dataset"]
    tfrecord_dir = dataset_config["tfrecord_dir"]
    dataset_config["tfrecord_dir"] = ""

    if os.path.exists(tfrecord_dir) and os.path.isdir(tfrecord_dir):
        remove = input("%s exists. Remove it or not? Y/N\n" % tfrecord_dir)
        if remove.lower() == "y":
            shutil.rmtree(tfrecord_dir)
    for mode in ["train", "valid", "test"]:
        data_loader = load.get_data_loader(dataset_config, mode)
        write_tfrecords(
            data_dir=os.path.join(tfrecord_dir, mode),
            data_generator=data_loader.data_generator(),
            examples_per_tfrecord=args.examples_per_tfrecord,
        )
示例#10
0
def test_get_data_loader():
    """
    Test for get_data_loader to make sure it get correct data loader and raise correct errors
    """

    # single paired data loader
    config = load_yaml("deepreg/config/test/paired_nifti.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, PairedDataLoader)

    config = load_yaml("deepreg/config/test/paired_h5.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, PairedDataLoader)

    # single unpaired data loader
    config = load_yaml("deepreg/config/test/unpaired_nifti.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, UnpairedDataLoader)

    config = load_yaml("deepreg/config/test/unpaired_h5.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, UnpairedDataLoader)

    # single grouped data loader
    config = load_yaml("deepreg/config/test/grouped_nifti.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, GroupedDataLoader)

    config = load_yaml("deepreg/config/test/grouped_h5.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, GroupedDataLoader)

    # empty data loader
    config = load_yaml("deepreg/config/test/paired_nifti.yaml")
    config["dataset"]["dir"]["train"] = ""
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert got is None

    config = load_yaml("deepreg/config/test/paired_nifti.yaml")
    config["dataset"]["dir"]["train"] = None
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert got is None

    # unpaired data loader with multiple dirs
    config = load_yaml("deepreg/config/test/unpaired_nifti_multi_dirs.yaml")
    got = load.get_data_loader(data_config=config["dataset"], mode="train")
    assert isinstance(got, UnpairedDataLoader)

    # check not a directory error
    config = load_yaml("deepreg/config/test/paired_nifti.yaml")
    config["dataset"]["dir"]["train"] += ".h5"
    with pytest.raises(ValueError) as err_info:
        load.get_data_loader(data_config=config["dataset"], mode="train")
    assert "is not a directory or does not exist" in str(err_info.value)

    # check directory not existed error
    config = load_yaml("deepreg/config/test/paired_nifti.yaml")
    config["dataset"]["dir"]["train"] = "/this_should_not_existed"
    with pytest.raises(ValueError) as err_info:
        load.get_data_loader(data_config=config["dataset"], mode="train")
    assert "is not a directory or does not exist" in str(err_info.value)

    # check mode
    config = load_yaml("deepreg/config/test/paired_nifti.yaml")
    with pytest.raises(AssertionError) as err_info:
        load.get_data_loader(data_config=config["dataset"], mode="example")
    assert "mode must be one of train/valid/test" in str(err_info.value)
示例#11
0
def predict(
    gpu,
    gpu_allow_growth,
    ckpt_path,
    mode,
    batch_size,
    log_dir,
    sample_label,
    config_path,
):
    """
    Function to predict some metrics from the saved model and logging results.
    :param gpu: str, which env gpu to use.
    :param gpu_allow_growth: bool, whether to allow gpu growth or not
    :param ckpt_path: str, where model is stored, should be like
                      log_folder/save/xxx.ckpt
    :param mode: which mode to load the data ??
    :param batch_size: int, batch size to perform predictions in
    :param log_dir: str, path to store logs
    :param sample_label:
    :param config_path: to overwrite the default config
    """
    logging.error("TODO sample_label is not used in predict")

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir = init(log_dir, ckpt_path, config_path)
    dataset_config = config["dataset"]
    preprocess_config = config["train"]["preprocess"]
    preprocess_config["batch_size"] = batch_size
    optimizer_config = config["train"]["optimizer"]
    model_config = config["train"]["model"]
    loss_config = config["train"]["loss"]

    # data
    data_loader = load.get_data_loader(dataset_config, mode)
    if data_loader is None:
        raise ValueError(
            "Data loader for prediction is None. Probably the data dir path is not defined."
        )
    dataset = data_loader.get_dataset_and_preprocess(training=False,
                                                     repeat=False,
                                                     **preprocess_config)

    # optimizer
    optimizer = opt.build_optimizer(optimizer_config)

    # model
    model = build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=dataset_config["labeled"],
        batch_size=preprocess_config["batch_size"],
        model_config=model_config,
        loss_config=loss_config,
    )

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec
    model.load_weights(ckpt_path).expect_partial()

    # predict
    fixed_grid_ref = layer_util.get_reference_grid(
        grid_size=data_loader.fixed_image_shape)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        save_dir=log_dir + "/test",
    )

    data_loader.close()
示例#12
0
 def test_mode_err(self):
     """Check the error is raised when the mode is wrong."""
     config = load_yaml("config/test/paired_nifti.yaml")
     with pytest.raises(AssertionError) as err_info:
         load.get_data_loader(data_config=config["dataset"], mode="example")
     assert "mode must be one of train/valid/test" in str(err_info.value)
示例#13
0
 def test_multi_dir_data_loader(self):
     """unpaired data loader with multiple dirs"""
     config = load_yaml("config/test/unpaired_nifti_multi_dirs.yaml")
     got = load.get_data_loader(data_config=config["dataset"], mode="train")
     assert isinstance(got, UnpairedDataLoader)
示例#14
0
def train(gpu: str, config_path: list, gpu_allow_growth: bool, ckpt_path: str,
          log_dir: str):
    """
    Function to train a model
    :param gpu: str, which local gpu to use to train
    :param config_path: str, path to configuration set up
    :param gpu_allow_growth: bool, whether or not to allocate
                             whole GPU memory to training
    :param ckpt_path: str, where to store training ckpts
    :param log_dir: str, where to store logs in training
    """
    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"

    # load config
    config, log_dir = init(config_path, log_dir, ckpt_path)
    dataset_config = config["dataset"]
    preprocess_config = config["train"]["preprocess"]
    optimizer_config = config["train"]["optimizer"]
    model_config = config["train"]["model"]
    loss_config = config["train"]["loss"]
    num_epochs = config["train"]["epochs"]
    save_period = config["train"]["save_period"]
    histogram_freq = save_period

    # data
    data_loader_train = get_data_loader(dataset_config, "train")
    if data_loader_train is None:
        raise ValueError(
            "Training data loader is None. Probably the data dir path is not defined."
        )
    data_loader_val = get_data_loader(dataset_config, "valid")
    dataset_train = data_loader_train.get_dataset_and_preprocess(
        training=True, repeat=True, **preprocess_config)
    dataset_val = (data_loader_val.get_dataset_and_preprocess(
        training=False, repeat=True, **preprocess_config)
                   if data_loader_val is not None else None)
    dataset_size_train = data_loader_train.num_samples
    dataset_size_val = (data_loader_val.num_samples
                        if data_loader_val is not None else None)
    steps_per_epoch_train = max(
        dataset_size_train // preprocess_config["batch_size"], 1)
    steps_per_epoch_valid = (max(
        dataset_size_val // preprocess_config["batch_size"], 1)
                             if data_loader_val is not None else None)

    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        # model
        model = build_model(
            moving_image_size=data_loader_train.moving_image_shape,
            fixed_image_size=data_loader_train.fixed_image_shape,
            index_size=data_loader_train.num_indices,
            labeled=dataset_config["labeled"],
            batch_size=preprocess_config["batch_size"],
            model_config=model_config,
            loss_config=loss_config,
        )

        # compile
        optimizer = opt.get_optimizer(optimizer_config)

        model.compile(optimizer=optimizer)

        # load weights
        if ckpt_path != "":
            model.load_weights(ckpt_path)

        # train
        # callbacks
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=log_dir, histogram_freq=histogram_freq)
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=log_dir + "/save/weights-epoch{epoch:d}.ckpt",
            save_weights_only=True,
            period=save_period,
        )
        # it's necessary to define the steps_per_epoch and validation_steps to prevent errors like
        # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
        model.fit(
            x=dataset_train,
            steps_per_epoch=steps_per_epoch_train,
            epochs=num_epochs,
            validation_data=dataset_val,
            validation_steps=steps_per_epoch_valid,
            callbacks=[tensorboard_callback, checkpoint_callback],
        )

    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()