def build_dataset( dataset_config: dict, preprocess_config: dict ) -> [[DataLoader, tf.data.Dataset, int], [DataLoader, tf.data.Dataset, int]]: """ Function to prepare dataset for training and validation. :param dataset_config: configuration for dataset :param preprocess_config: configuration for preprocess :return: - (data_loader_train, dataset_train, steps_per_epoch_train) - (data_loader_val, dataset_val, steps_per_epoch_valid) """ data_loader_train = get_data_loader(dataset_config, "train") if data_loader_train is None: raise ValueError( "Training data loader is None. Probably the data dir path is not defined." ) data_loader_val = get_data_loader(dataset_config, "valid") dataset_train = data_loader_train.get_dataset_and_preprocess( training=True, repeat=True, **preprocess_config) dataset_val = (data_loader_val.get_dataset_and_preprocess( training=False, repeat=True, **preprocess_config) if data_loader_val is not None else None) dataset_size_train = data_loader_train.num_samples dataset_size_val = (data_loader_val.num_samples if data_loader_val is not None else None) steps_per_epoch_train = max( dataset_size_train // preprocess_config["batch_size"], 1) steps_per_epoch_valid = (max( dataset_size_val // preprocess_config["batch_size"], 1) if data_loader_val is not None else None) return ( (data_loader_train, dataset_train, steps_per_epoch_train), (data_loader_val, dataset_val, steps_per_epoch_valid), )
def test_mode_err(self): """Check the error is raised when the split is wrong.""" config = load_yaml("config/test/paired_nifti.yaml") with pytest.raises(ValueError) as err_info: load.get_data_loader(data_config=config["dataset"], split="example") assert "split must be one of ['train', 'valid', 'test']" in str( err_info.value)
def test_dir_err(self, path: Optional[str]): """ Check the error is raised when the path is wrong. :param path: training data path to be used """ config = load_yaml("config/test/paired_nifti.yaml") config["dataset"]["dir"]["train"] = path with pytest.raises(ValueError) as err_info: load.get_data_loader(data_config=config["dataset"], mode="train") assert "is not a directory or does not exist" in str(err_info.value)
def build_dataset( dataset_config: dict, preprocess_config: dict, mode: str, training: bool, repeat: bool, ) -> [(DataLoader, None), (tf.data.Dataset, None), (int, None)]: """ Function to prepare dataset for training and validation. :param dataset_config: configuration for dataset :param preprocess_config: configuration for preprocess :param mode: train or valid or test :param training: bool, if true, data augmentation and shuffling will be added :param repeat: bool, if true, dataset will be repeated, true for train/valid dataset during model.fit :return: - (data_loader_train, dataset_train, steps_per_epoch_train) - (data_loader_val, dataset_val, steps_per_epoch_valid) Cannot move this function into deepreg/dataset/util.py as we need DataLoader to define the output """ assert mode in ["train", "valid", "test"] data_loader = get_data_loader(dataset_config, mode) if data_loader is None: return None, None, None dataset = data_loader.get_dataset_and_preprocess( training=training, repeat=repeat, **preprocess_config ) dataset_size = data_loader.num_samples steps_per_epoch = max(dataset_size // preprocess_config["batch_size"], 1) return data_loader, dataset, steps_per_epoch
def test_empty_config(self, mode: str): """ Test return without data path for the mode. :param mode: train or valid or test """ config = load_yaml("config/test/paired_nifti.yaml") config["dataset"]["dir"].pop(mode) got = load.get_data_loader(data_config=config["dataset"], mode=mode) assert got is None
def test_empty_path(self, path: Optional[str]): """ Test return without data path. :param path: training data path to be used """ config = load_yaml("config/test/paired_nifti.yaml") config["dataset"]["dir"]["train"] = path got = load.get_data_loader(data_config=config["dataset"], mode="train") assert got is None
def test_empty_config(self, split: str): """ Test return without data path for the split. :param split: train or valid or test """ config = load_yaml("config/test/paired_nifti.yaml") config["dataset"].pop(split) got = load.get_data_loader(data_config=config["dataset"], split=split) assert got is None
def test_data_loader(self, data_type: str, format: str): """ Test the data loader can be successfully built. :param data_type: name of data loader for registry :param format: name of file loader for registry """ # single paired data loader config = load_yaml(f"config/test/{data_type}_{format}.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type) assert isinstance(got, expected) # type: ignore
def main(args=None): """Entry point for gen_tfrecord""" parser = argparse.ArgumentParser(description="gen_tfrecord") ## ADD POSITIONAL ARGUMENTS parser.add_argument("--config_path", "-c", help="Path of config", type=str, required=True) parser.add_argument( "--examples_per_tfrecord", "-n", help="Number of examples per tfrecord", type=int, default=64, ) args = parser.parse_args(args) config = config_parser.load_configs(args.config_path) dataset_config = config["dataset"] tfrecord_dir = dataset_config["tfrecord_dir"] dataset_config["tfrecord_dir"] = "" if os.path.exists(tfrecord_dir) and os.path.isdir(tfrecord_dir): remove = input("%s exists. Remove it or not? Y/N\n" % tfrecord_dir) if remove.lower() == "y": shutil.rmtree(tfrecord_dir) for mode in ["train", "valid", "test"]: data_loader = load.get_data_loader(dataset_config, mode) write_tfrecords( data_dir=os.path.join(tfrecord_dir, mode), data_generator=data_loader.data_generator(), examples_per_tfrecord=args.examples_per_tfrecord, )
def test_get_data_loader(): """ Test for get_data_loader to make sure it get correct data loader and raise correct errors """ # single paired data loader config = load_yaml("deepreg/config/test/paired_nifti.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, PairedDataLoader) config = load_yaml("deepreg/config/test/paired_h5.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, PairedDataLoader) # single unpaired data loader config = load_yaml("deepreg/config/test/unpaired_nifti.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, UnpairedDataLoader) config = load_yaml("deepreg/config/test/unpaired_h5.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, UnpairedDataLoader) # single grouped data loader config = load_yaml("deepreg/config/test/grouped_nifti.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, GroupedDataLoader) config = load_yaml("deepreg/config/test/grouped_h5.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, GroupedDataLoader) # empty data loader config = load_yaml("deepreg/config/test/paired_nifti.yaml") config["dataset"]["dir"]["train"] = "" got = load.get_data_loader(data_config=config["dataset"], mode="train") assert got is None config = load_yaml("deepreg/config/test/paired_nifti.yaml") config["dataset"]["dir"]["train"] = None got = load.get_data_loader(data_config=config["dataset"], mode="train") assert got is None # unpaired data loader with multiple dirs config = load_yaml("deepreg/config/test/unpaired_nifti_multi_dirs.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, UnpairedDataLoader) # check not a directory error config = load_yaml("deepreg/config/test/paired_nifti.yaml") config["dataset"]["dir"]["train"] += ".h5" with pytest.raises(ValueError) as err_info: load.get_data_loader(data_config=config["dataset"], mode="train") assert "is not a directory or does not exist" in str(err_info.value) # check directory not existed error config = load_yaml("deepreg/config/test/paired_nifti.yaml") config["dataset"]["dir"]["train"] = "/this_should_not_existed" with pytest.raises(ValueError) as err_info: load.get_data_loader(data_config=config["dataset"], mode="train") assert "is not a directory or does not exist" in str(err_info.value) # check mode config = load_yaml("deepreg/config/test/paired_nifti.yaml") with pytest.raises(AssertionError) as err_info: load.get_data_loader(data_config=config["dataset"], mode="example") assert "mode must be one of train/valid/test" in str(err_info.value)
def predict( gpu, gpu_allow_growth, ckpt_path, mode, batch_size, log_dir, sample_label, config_path, ): """ Function to predict some metrics from the saved model and logging results. :param gpu: str, which env gpu to use. :param gpu_allow_growth: bool, whether to allow gpu growth or not :param ckpt_path: str, where model is stored, should be like log_folder/save/xxx.ckpt :param mode: which mode to load the data ?? :param batch_size: int, batch size to perform predictions in :param log_dir: str, path to store logs :param sample_label: :param config_path: to overwrite the default config """ logging.error("TODO sample_label is not used in predict") # env vars os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" # load config config, log_dir = init(log_dir, ckpt_path, config_path) dataset_config = config["dataset"] preprocess_config = config["train"]["preprocess"] preprocess_config["batch_size"] = batch_size optimizer_config = config["train"]["optimizer"] model_config = config["train"]["model"] loss_config = config["train"]["loss"] # data data_loader = load.get_data_loader(dataset_config, mode) if data_loader is None: raise ValueError( "Data loader for prediction is None. Probably the data dir path is not defined." ) dataset = data_loader.get_dataset_and_preprocess(training=False, repeat=False, **preprocess_config) # optimizer optimizer = opt.build_optimizer(optimizer_config) # model model = build_model( moving_image_size=data_loader.moving_image_shape, fixed_image_size=data_loader.fixed_image_shape, index_size=data_loader.num_indices, labeled=dataset_config["labeled"], batch_size=preprocess_config["batch_size"], model_config=model_config, loss_config=loss_config, ) # metrics model.compile(optimizer=optimizer) # load weights # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec model.load_weights(ckpt_path).expect_partial() # predict fixed_grid_ref = layer_util.get_reference_grid( grid_size=data_loader.fixed_image_shape) predict_on_dataset( dataset=dataset, fixed_grid_ref=fixed_grid_ref, model=model, save_dir=log_dir + "/test", ) data_loader.close()
def test_mode_err(self): """Check the error is raised when the mode is wrong.""" config = load_yaml("config/test/paired_nifti.yaml") with pytest.raises(AssertionError) as err_info: load.get_data_loader(data_config=config["dataset"], mode="example") assert "mode must be one of train/valid/test" in str(err_info.value)
def test_multi_dir_data_loader(self): """unpaired data loader with multiple dirs""" config = load_yaml("config/test/unpaired_nifti_multi_dirs.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") assert isinstance(got, UnpairedDataLoader)
def train(gpu: str, config_path: list, gpu_allow_growth: bool, ckpt_path: str, log_dir: str): """ Function to train a model :param gpu: str, which local gpu to use to train :param config_path: str, path to configuration set up :param gpu_allow_growth: bool, whether or not to allocate whole GPU memory to training :param ckpt_path: str, where to store training ckpts :param log_dir: str, where to store logs in training """ # env vars os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false" # load config config, log_dir = init(config_path, log_dir, ckpt_path) dataset_config = config["dataset"] preprocess_config = config["train"]["preprocess"] optimizer_config = config["train"]["optimizer"] model_config = config["train"]["model"] loss_config = config["train"]["loss"] num_epochs = config["train"]["epochs"] save_period = config["train"]["save_period"] histogram_freq = save_period # data data_loader_train = get_data_loader(dataset_config, "train") if data_loader_train is None: raise ValueError( "Training data loader is None. Probably the data dir path is not defined." ) data_loader_val = get_data_loader(dataset_config, "valid") dataset_train = data_loader_train.get_dataset_and_preprocess( training=True, repeat=True, **preprocess_config) dataset_val = (data_loader_val.get_dataset_and_preprocess( training=False, repeat=True, **preprocess_config) if data_loader_val is not None else None) dataset_size_train = data_loader_train.num_samples dataset_size_val = (data_loader_val.num_samples if data_loader_val is not None else None) steps_per_epoch_train = max( dataset_size_train // preprocess_config["batch_size"], 1) steps_per_epoch_valid = (max( dataset_size_val // preprocess_config["batch_size"], 1) if data_loader_val is not None else None) strategy = tf.distribute.MirroredStrategy() with strategy.scope(): # model model = build_model( moving_image_size=data_loader_train.moving_image_shape, fixed_image_size=data_loader_train.fixed_image_shape, index_size=data_loader_train.num_indices, labeled=dataset_config["labeled"], batch_size=preprocess_config["batch_size"], model_config=model_config, loss_config=loss_config, ) # compile optimizer = opt.get_optimizer(optimizer_config) model.compile(optimizer=optimizer) # load weights if ckpt_path != "": model.load_weights(ckpt_path) # train # callbacks tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=histogram_freq) checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=log_dir + "/save/weights-epoch{epoch:d}.ckpt", save_weights_only=True, period=save_period, ) # it's necessary to define the steps_per_epoch and validation_steps to prevent errors like # BaseCollectiveExecutor::StartAbort Out of range: End of sequence model.fit( x=dataset_train, steps_per_epoch=steps_per_epoch_train, epochs=num_epochs, validation_data=dataset_val, validation_steps=steps_per_epoch_valid, callbacks=[tensorboard_callback, checkpoint_callback], ) data_loader_train.close() if data_loader_val is not None: data_loader_val.close()