Beispiel #1
0
def train_step(warper, weights, optimizer, mov, fix) -> tuple:
    """
    Train step function for backprop using gradient tape

    :param warper: warping function returned from layer.Warping
    :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return:
        a tuple:
            - loss: overall loss to optimise
            - loss_image: image dissimilarity
            - loss_deform: deformation regularisation
    """
    with tf.GradientTape() as tape:
        pred = warper(inputs=[weights, mov])
        loss_image = REGISTRY.build_loss(config=image_loss_config)(
            y_true=fix,
            y_pred=pred,
        )
        loss_deform = REGISTRY.build_loss(config=deform_loss_config)(
            inputs=weights, )
        loss = loss_image + weight_deform_loss * loss_deform
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss, loss_image, loss_deform
Beispiel #2
0
def get_data_loader(data_config: dict, split: str) -> Optional[DataLoader]:
    """
    Return the corresponding data loader.

    Can't be placed in the same file of loader interfaces as it causes import cycle.

    :param data_config: a dictionary containing configuration for data
    :param split: must be train/valid/test
    :return: DataLoader or None, returns None if the split or dir is empty.
    """
    if split not in KNOWN_DATA_SPLITS:
        raise ValueError(
            f"split must be one of {KNOWN_DATA_SPLITS}, got {split}")

    if split not in data_config:
        return None
    data_dir_paths = data_config[split].get("dir", None)
    if data_dir_paths is None or data_dir_paths == "":
        return None

    if isinstance(data_dir_paths, str):
        data_dir_paths = [data_dir_paths]
    # replace ~ with user home path
    data_dir_paths = list(map(os.path.expanduser, data_dir_paths))
    for data_dir_path in data_dir_paths:
        if not os.path.isdir(data_dir_path):
            raise ValueError(
                f"Data directory path {data_dir_path} for split {split}"
                f" is not a directory or does not exist")

    # prepare data loader config
    data_loader_config = deepcopy(data_config)
    data_loader_config = {
        k: v
        for k, v in data_loader_config.items() if k not in KNOWN_DATA_SPLITS
    }
    data_loader_config["name"] = data_loader_config.pop("type")

    default_args = dict(
        data_dir_paths=data_dir_paths,
        file_loader=REGISTRY.get(category=FILE_LOADER_CLASS,
                                 key=data_config[split]["format"]),
        labeled=data_config[split]["labeled"],
        sample_label="sample" if split == "train" else "all",
        seed=None if split == "train" else 0,
    )
    data_loader: DataLoader = REGISTRY.build_data_loader(
        config=data_loader_config, default_args=default_args)
    return data_loader
Beispiel #3
0
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
        data_augmentation: Optional[Union[List, Dict]] = None,
    ) -> tf.data.Dataset:
        """
        :param training: bool, indicating if it's training or not
        :param batch_size: int, size of mini batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: int, when shuffling,
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
        :param repeat: bool, indicating if we need to repeat the dataset
        :param data_augmentation: augmentation config, can be a list of dict or dict.
        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if training and data_augmentation is not None:
            if isinstance(data_augmentation, dict):
                data_augmentation = [data_augmentation]
            for config in data_augmentation:
                da_fn = REGISTRY.build_data_augmentation(
                    config=config,
                    default_args={
                        "moving_image_size": self.moving_image_shape,
                        "fixed_image_size": self.fixed_image_shape,
                        "batch_size": batch_size,
                    },
                )
                dataset = dataset.map(
                    da_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
                )

        return dataset
Beispiel #4
0
    def test_image_loss(self, config: dict, option: int, expected: int):
        method = "ddf"
        backbone = "local"
        labeled = True
        copied = deepcopy(config)
        copied["method"] = method
        copied["backbone"]["name"] = backbone
        copied["backbone"] = {
            **backbone_args[backbone],  # type: ignore
            **copied["backbone"],
        }

        if option == 0:
            # remove image loss config, so loss is not used
            copied["loss"].pop("image")
        elif option == 1:
            # set image loss weight to zero, so loss is not used
            copied["loss"]["image"]["weight"] = 0.0
        elif option == 2:
            # remove image loss weight, so loss is used with default weight 1
            copied["loss"]["image"].pop("weight")

        ddf_model = REGISTRY.build_model(config=dict(
            name=method,  # TODO we store method twice
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            config=copied,
        ))

        assert len(ddf_model._model.losses) == expected  # type: ignore
Beispiel #5
0
def model(method: str, labeled: bool, backbone: str) -> RegistrationModel:
    """
    A specific registration model object.

    :param method: name of method
    :param labeled: whether the data is labeled
    :param backbone: name of backbone
    :return: the built object
    """
    copied = deepcopy(config)
    copied["method"] = method
    copied["backbone"]["name"] = backbone  # type: ignore
    if method == "conditional":
        copied["backbone"].pop("control_points", None)  # type: ignore
    copied["backbone"].update(backbone_args[backbone])  # type: ignore
    return REGISTRY.build_model(  # type: ignore
        config=dict(
            name=method,  # TODO we store method twice
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            config=copied,
        ))
Beispiel #6
0
    def build_model(self):
        """Build the model to be saved as self._model."""
        assert self.labeled

        # build inputs
        self._inputs = self.build_inputs()
        moving_image = self._inputs["moving_image"]
        fixed_image = self._inputs["fixed_image"]
        moving_label = self._inputs["moving_label"]

        # build ddf
        backbone_inputs = self.concat_images(moving_image, fixed_image,
                                             moving_label)
        backbone = REGISTRY.build_backbone(
            config=self.config["backbone"],
            default_args=dict(
                image_size=self.fixed_image_size,
                out_channels=1,
                out_kernel_initializer="glorot_uniform",
                out_activation="sigmoid",
            ),
        )
        # (batch, f_dim1, f_dim2, f_dim3)
        pred_fixed_label = backbone(inputs=backbone_inputs)
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)

        self._outputs = dict(pred_fixed_label=pred_fixed_label)
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
Beispiel #7
0
    def _build_loss(self, name: str, inputs_dict: dict):
        """
        Build and add one weighted loss together with the metrics.

        :param name: name of loss, image / label / regularization.
        :param inputs_dict: inputs for loss function
        """

        if name not in self.config["loss"]:
            # loss config is not defined
            logger.warning(
                f"The configuration for loss {name} is not defined. "
                f"Therefore it is not used.")
            return

        loss_configs = self.config["loss"][name]
        if not isinstance(loss_configs, list):
            loss_configs = [loss_configs]

        for loss_config in loss_configs:

            if "weight" not in loss_config:
                # default loss weight 1
                logger.warning(f"The weight for loss {name} is not defined."
                               f"Default weight = 1.0 is used.")
                loss_config["weight"] = 1.0

            # build loss
            weight = loss_config["weight"]

            if weight == 0:
                logger.warning(f"The weight for loss {name} is zero."
                               f"Loss is not used.")
                return

            # do not perform reduction over batch axis for supporting multi-device
            # training, model.fit() will average over global batch size automatically
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
                config=dict_without(d=loss_config, key="weight"),
                default_args={"reduction": tf.keras.losses.Reduction.NONE},
            )
            loss_value = loss_layer(**inputs_dict)
            weighted_loss = loss_value * weight

            # add loss
            self._model.add_loss(weighted_loss)

            # add metric
            self._model.add_metric(loss_value,
                                   name=f"loss/{name}_{loss_layer.name}",
                                   aggregation="mean")
            self._model.add_metric(
                weighted_loss,
                name=f"loss/{name}_{loss_layer.name}_weighted",
                aggregation="mean",
            )
Beispiel #8
0
def get_data_loader(data_config: dict, mode: str) -> Optional[DataLoader]:
    """
    Return the corresponding data loader.
    Can't be placed in the same file of loader interfaces as it causes import cycle.
    :param data_config: a dictionary containing configuration for data
    :param mode: string, must be train/valid/test
    :return: DataLoader or None, returns None if the data_dir_paths is empty
    """
    assert mode in ["train", "valid",
                    "test"], "mode must be one of train/valid/test"

    data_dir_paths = data_config["dir"].get(mode, None)
    if data_dir_paths is None or data_dir_paths == "":
        return None
    if isinstance(data_dir_paths, str):
        data_dir_paths = [data_dir_paths]
    # replace ~ with user home path
    data_dir_paths = list(map(os.path.expanduser, data_dir_paths))
    for data_dir_path in data_dir_paths:
        if not os.path.isdir(data_dir_path):
            raise ValueError(
                f"Data directory path {data_dir_path} for mode {mode}"
                f" is not a directory or does not exist")

    # prepare data loader config
    data_loader_config = deepcopy(data_config)
    data_loader_config.pop("dir")
    data_loader_config.pop("format")
    data_loader_config["name"] = data_loader_config.pop("type")

    default_args = dict(
        data_dir_paths=data_dir_paths,
        file_loader=REGISTRY.get(category=FILE_LOADER_CLASS,
                                 key=data_config["format"]),
        labeled=data_config["labeled"],
        sample_label="sample" if mode == "train" else "all",
        seed=None if mode == "train" else 0,
    )
    data_loader = REGISTRY.build_data_loader(config=data_loader_config,
                                             default_args=default_args)
    return data_loader
Beispiel #9
0
    def test_data_loader(self, data_type: str, format: str):
        """
        Test the data loader can be successfully built.

        :param data_type: name of data loader for registry
        :param format: name of file loader for registry
        """
        # single paired data loader
        config = load_yaml(f"config/test/{data_type}_{format}.yaml")
        got = load.get_data_loader(data_config=config["dataset"], mode="train")
        expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type)
        assert isinstance(got, expected)  # type: ignore
Beispiel #10
0
    def _build_loss(self, name: str, inputs_dict: dict):
        """
        Build and add one weighted loss together with the metrics.

        :param name: name of loss
        :param inputs_dict: inputs for loss function
        """

        if name not in self.config["loss"]:
            # loss config is not defined
            logging.warning(
                f"The configuration for loss {name} is not defined. "
                f"Therefore it is not used.")
            return

        loss_configs = self.config["loss"][name]
        if not isinstance(loss_configs, list):
            loss_configs = [loss_configs]

        for loss_config in loss_configs:

            if "weight" not in loss_config:
                # default loss weight 1
                logging.warning(f"The weight for loss {name} is not defined."
                                f"Default weight = 1.0 is used.")
                loss_config["weight"] = 1.0

            # build loss
            weight = loss_config["weight"]

            if weight == 0:
                logging.warning(f"The weight for loss {name} is zero."
                                f"Loss is not used.")
                return

            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
                config=dict_without(d=loss_config, key="weight"))
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
            weighted_loss = loss_value * weight

            # add loss
            self._model.add_loss(weighted_loss)

            # add metric
            self._model.add_metric(loss_value,
                                   name=f"loss/{name}_{loss_layer.name}",
                                   aggregation="mean")
            self._model.add_metric(
                weighted_loss,
                name=f"loss/{name}_{loss_layer.name}_weighted",
                aggregation="mean",
            )
Beispiel #11
0
    def build_model(self):
        """Build the model to be saved as self._model."""
        # build inputs
        self._inputs = self.build_inputs()
        moving_image = self._inputs[
            "moving_image"]  # (batch, m_dim1, m_dim2, m_dim3)
        fixed_image = self._inputs[
            "fixed_image"]  # (batch, f_dim1, f_dim2, f_dim3)

        # build ddf
        control_points = self.config["backbone"].pop("control_points", False)
        backbone_inputs = self.concat_images(moving_image, fixed_image)
        backbone = REGISTRY.build_backbone(
            config=self.config["backbone"],
            default_args=dict(
                image_size=self.fixed_image_size,
                out_channels=3,
                out_kernel_initializer="zeros",
                out_activation=None,
            ),
        )

        if isinstance(backbone, GlobalNet):
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
            ddf, theta = backbone(inputs=backbone_inputs)
            self._outputs = dict(ddf=ddf, theta=theta)
        else:
            # (f_dim1, f_dim2, f_dim3, 3)
            ddf = backbone(inputs=backbone_inputs)
            ddf = (self._resize_interpolate(ddf, control_points)
                   if control_points else ddf)
            self._outputs = dict(ddf=ddf)

        # build outputs
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
        # (f_dim1, f_dim2, f_dim3)
        pred_fixed_image = warping(inputs=[ddf, moving_image])
        self._outputs["pred_fixed_image"] = pred_fixed_image

        if not self.labeled:
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)

        # (f_dim1, f_dim2, f_dim3)
        moving_label = self._inputs["moving_label"]
        pred_fixed_label = warping(inputs=[ddf, moving_label])

        self._outputs["pred_fixed_label"] = pred_fixed_label
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
Beispiel #12
0
    def build_model(self):
        """Build the model to be saved as self._model."""
        # build inputs
        self._inputs = self.build_inputs()
        moving_image = self._inputs["moving_image"]
        fixed_image = self._inputs["fixed_image"]
        control_points = self.config["backbone"].pop("control_points", False)

        # build ddf
        backbone_inputs = self.concat_images(moving_image, fixed_image)
        backbone = REGISTRY.build_backbone(
            config=self.config["backbone"],
            default_args=dict(
                image_size=self.fixed_image_size,
                out_channels=3,
                out_kernel_initializer="zeros",
                out_activation=None,
            ),
        )
        dvf = backbone(inputs=backbone_inputs)
        dvf = self._resize_interpolate(
            dvf, control_points) if control_points else dvf
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)

        # build outputs
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
        # (f_dim1, f_dim2, f_dim3, 3)
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])

        self._outputs = dict(dvf=dvf,
                             ddf=ddf,
                             pred_fixed_image=pred_fixed_image)

        if not self.labeled:
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)

        # (f_dim1, f_dim2, f_dim3, 3)
        moving_label = self._inputs["moving_label"]
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])

        self._outputs["pred_fixed_label"] = pred_fixed_label
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
Beispiel #13
0
def train_step(grid, weights, optimizer, mov, fix) -> object:
    """
    Train step function for backprop using gradient tape

    :param grid: reference grid return from layer_util.get_reference_grid
    :param weights: trainable affine parameters [1, 4, 3]
    :param optimizer: tf.optimizers
    :param mov: moving image [1, m_dim1, m_dim2, m_dim3]
    :param fix: fixed image [1, f_dim1, f_dim2, f_dim3]
    :return loss: image dissimilarity to minimise
    """
    with tf.GradientTape() as tape:
        pred = layer_util.resample(vol=mov,
                                   loc=layer_util.warp_grid(grid, weights))
        loss = REGISTRY.build_loss(config=image_loss_config)(
            y_true=fix,
            y_pred=pred,
        )
    gradients = tape.gradient(loss, [weights])
    optimizer.apply_gradients(zip(gradients, [weights]))
    return loss
Beispiel #14
0
def predict(
    gpu: str,
    ckpt_path: str,
    split: str,
    batch_size: int,
    exp_name: str,
    config_path: Union[str, List[str]],
    num_workers: int = 1,
    gpu_allow_growth: bool = True,
    save_nifti: bool = True,
    save_png: bool = True,
    log_dir: str = "logs",
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: which env gpu to use.
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
    :param split: train / valid / test, to define the split to be evaluated.
    :param batch_size: int, batch size to perform predictions.
    :param exp_name: name of the experiment.
    :param config_path: to overwrite the default config.
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param save_nifti: if true, outputs will be saved in nifti format.
    :param save_png: if true, outputs will be saved in png format.
    :param log_dir: path of the log directory.
    """

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
    if num_workers <= 0:  # pragma: no cover
        logger.info(
            "Limiting CPU usage by setting environment variables "
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
            "This may slow down the prediction. "
            "Please use --num_workers flag to modify the behavior. "
            "Setting to 0 or negative values will remove the limitation.",
            num_workers,
        )
        # limit CPU usage
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)

    # load config
    config, log_dir, ckpt_path = build_config(config_path=config_path,
                                              log_dir=log_dir,
                                              exp_name=exp_name,
                                              ckpt_path=ckpt_path)
    config["train"]["preprocess"]["batch_size"] = batch_size

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split=split,
        training=False,
        repeat=False,
    )
    assert data_loader is not None

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    if num_devices > 1:  # pragma: no cover
        strategy = tf.distribute.MirroredStrategy()
        if batch_size % num_devices != 0:
            raise ValueError(
                f"batch size {batch_size} can not be divided evenly "
                f"by the number of devices.")
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(config=dict(
            name=config["train"]["method"],
            moving_image_size=data_loader.moving_image_shape,
            fixed_image_size=data_loader.fixed_image_shape,
            index_size=data_loader.num_indices,
            labeled=config["dataset"][split]["labeled"],
            batch_size=batch_size,
            config=config["train"],
        ))
        optimizer = opt.build_optimizer(
            optimizer_config=config["train"]["optimizer"])
        model.compile(optimizer=optimizer)
        model.plot_model(output_dir=log_dir)

    # load weights
    if ckpt_path.endswith(".ckpt"):
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
        # skip warnings because of optimizers
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
    else:
        # for ckpts from ckpt manager callback
        _, _ = build_checkpoint_callback(
            model=model,
            dataset=dataset,
            log_dir=log_dir,
            save_period=config["train"]["save_period"],
            ckpt_path=ckpt_path,
        )

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        save_dir=os.path.join(log_dir, "test"),
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
Beispiel #15
0
 def test_register_err(self, category, key, err_msg):
     with pytest.raises(ValueError) as err_info:
         REGISTRY.register(category=category, name=key, cls=0)
     assert err_msg in str(err_info.value)
Beispiel #16
0
 def test_register(self):
     category, key, value = "backbone_class", "test_key", 0
     REGISTRY.register(category=category, name=key, cls=value)
     assert REGISTRY._dict[(category, key)] == value
     assert REGISTRY.get(category, key) == value
Beispiel #17
0
 def test_get_err(self):
     with pytest.raises(ValueError) as err_info:
         REGISTRY.get("backbone_class", "wrong_key")
     assert "has not been registered" in str(err_info.value)
Beispiel #18
0
 def test_get_backbone(self):
     # no error means the unet has been registered
     _ = REGISTRY.get("backbone_class", "unet")
Beispiel #19
0
 def reg(self):
     return REGISTRY.copy()
Beispiel #20
0
 def test_register_backbone(self):
     key = "new_backbone"
     value = 0
     REGISTRY.register_backbone(name=key, cls=value)
     got = REGISTRY.get_backbone(key)
     assert got == value
Beispiel #21
0
def predict(
    gpu: str,
    gpu_allow_growth: bool,
    ckpt_path: str,
    mode: str,
    batch_size: int,
    exp_name: str,
    config_path: Union[str, List[str]],
    save_nifti: bool = True,
    save_png: bool = True,
    log_dir: str = "logs",
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: which env gpu to use.
    :param gpu_allow_growth: whether to allow gpu growth or not
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x
    :param mode: train / valid / test, to define which split of dataset to be evaluated
    :param batch_size: int, batch size to perform predictions in
    :param exp_name: name of the experiment
    :param log_dir: path of the log directory
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    :param config_path: to overwrite the default config
    """
    # TODO support custom sample_label
    logging.warning("sample_label is not used in predict. "
                    "It is True if and only if mode == 'train'.")

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir, ckpt_path = build_config(config_path=config_path,
                                              log_dir=log_dir,
                                              exp_name=exp_name,
                                              ckpt_path=ckpt_path)
    preprocess_config = config["train"]["preprocess"]
    # batch_size corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    preprocess_config["batch_size"] = batch_size * max(len(gpus), 1)

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=preprocess_config,
        mode=mode,
        training=False,
        repeat=False,
    )
    assert data_loader is not None

    # optimizer
    optimizer = opt.build_optimizer(
        optimizer_config=config["train"]["optimizer"])

    # model
    model: tf.keras.Model = REGISTRY.build_model(config=dict(
        name=config["train"]["method"],
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=config["dataset"]["labeled"],
        batch_size=config["train"]["preprocess"]["batch_size"],
        config=config["train"],
    ))

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    if ckpt_path.endswith(".ckpt"):
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
        # skip warnings because of optimizers
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
    else:
        # for ckpts from ckpt manager callback
        _, _ = build_checkpoint_callback(
            model=model,
            dataset=dataset,
            log_dir=log_dir,
            save_period=config["train"]["save_period"],
            ckpt_path=ckpt_path,
        )

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        model_method=config["train"]["method"],
        save_dir=os.path.join(log_dir, "test"),
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
Beispiel #22
0
def train(
    gpu: str,
    config_path: Union[str, List[str]],
    gpu_allow_growth: bool,
    ckpt_path: str,
    exp_name: str = "",
    log_dir: str = "logs",
    max_epochs: int = -1,
):
    """
    Function to train a model.

    :param gpu: which local gpu to use to train.
    :param config_path: path to configuration set up.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param ckpt_path: where to store training checkpoints.
    :param log_dir: path of the log directory.
    :param exp_name: experiment name.
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
    """
    # set env variables
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"

    # load config
    config, log_dir, ckpt_path = build_config(
        config_path=config_path,
        log_dir=log_dir,
        exp_name=exp_name,
        ckpt_path=ckpt_path,
        max_epochs=max_epochs,
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        mode="valid",
        training=False,
        repeat=True,
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    if num_devices > 1:
        strategy = tf.distribute.MirroredStrategy()  # pragma: no cover
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(
            config=dict(
                name=config["train"]["method"],
                moving_image_size=data_loader_train.moving_image_shape,
                fixed_image_size=data_loader_train.fixed_image_shape,
                index_size=data_loader_train.num_indices,
                labeled=config["dataset"]["labeled"],
                batch_size=config["train"]["preprocess"]["batch_size"],
                config=config["train"],
                num_devices=num_devices,
            )
        )
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])

    # compile
    model.compile(optimizer=optimizer)
    model.plot_model(output_dir=log_dir)

    # build callbacks
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir, histogram_freq=config["train"]["save_period"]
    )
    ckpt_callback, initial_epoch = build_checkpoint_callback(
        model=model,
        dataset=dataset_train,
        log_dir=log_dir,
        save_period=config["train"]["save_period"],
        ckpt_path=ckpt_path,
    )
    callbacks = [tensorboard_callback, ckpt_callback]

    # train
    # it's necessary to define the steps_per_epoch
    # and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        initial_epoch=initial_epoch,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()
Beispiel #23
0
def train(
    gpu: str,
    config_path: Union[str, List[str]],
    ckpt_path: str,
    num_workers: int = 1,
    gpu_allow_growth: bool = True,
    exp_name: str = "",
    log_dir: str = "logs",
    max_epochs: int = -1,
):
    """
    Function to train a model.

    :param gpu: which local gpu to use to train.
    :param config_path: path to configuration set up.
    :param ckpt_path: where to store training checkpoints.
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param log_dir: path of the log directory.
    :param exp_name: experiment name.
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
    """
    # set env variables
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"
    if num_workers <= 0:  # pragma: no cover
        logger.info(
            "Limiting CPU usage by setting environment variables "
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
            "This may slow down the training. "
            "Please use --num_workers flag to modify the behavior. "
            "Setting to 0 or negative values will remove the limitation.",
            num_workers,
        )
        # limit CPU usage
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)

    # load config
    config, log_dir, ckpt_path = build_config(
        config_path=config_path,
        log_dir=log_dir,
        exp_name=exp_name,
        ckpt_path=ckpt_path,
        max_epochs=max_epochs,
    )

    # build dataset
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="train",
        training=True,
        repeat=True,
    )
    assert data_loader_train is not None  # train data should not be None
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split="valid",
        training=False,
        repeat=True,
    )

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    batch_size = config["train"]["preprocess"]["batch_size"]
    if num_devices > 1:  # pragma: no cover
        strategy = tf.distribute.MirroredStrategy()
        if batch_size % num_devices != 0:
            raise ValueError(
                f"batch size {batch_size} can not be divided evenly "
                f"by the number of devices.")
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(config=dict(
            name=config["train"]["method"],
            moving_image_size=data_loader_train.moving_image_shape,
            fixed_image_size=data_loader_train.fixed_image_shape,
            index_size=data_loader_train.num_indices,
            labeled=config["dataset"]["train"]["labeled"],
            batch_size=batch_size,
            config=config["train"],
        ))
        optimizer = opt.build_optimizer(
            optimizer_config=config["train"]["optimizer"])
        model.compile(optimizer=optimizer)
        model.plot_model(output_dir=log_dir)

    # build callbacks
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=config["train"]["save_period"],
        update_freq=config["train"].get("update_freq", "epoch"),
    )
    ckpt_callback, initial_epoch = build_checkpoint_callback(
        model=model,
        dataset=dataset_train,
        log_dir=log_dir,
        save_period=config["train"]["save_period"],
        ckpt_path=ckpt_path,
    )
    callbacks = [tensorboard_callback, ckpt_callback]

    # train
    # it's necessary to define the steps_per_epoch
    # and validation_steps to prevent errors like
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
    model.fit(
        x=dataset_train,
        steps_per_epoch=steps_per_epoch_train,
        initial_epoch=initial_epoch,
        epochs=config["train"]["epochs"],
        validation_data=dataset_val,
        validation_steps=steps_per_epoch_val,
        callbacks=callbacks,
    )

    # close file loaders in data loaders after training
    data_loader_train.close()
    if data_loader_val is not None:
        data_loader_val.close()
Beispiel #24
0
    def get_dataset_and_preprocess(
        self,
        training: bool,
        batch_size: int,
        repeat: bool,
        shuffle_buffer_num_batch: int,
        data_augmentation: Optional[Union[List, Dict]] = None,
        num_parallel_calls: int = tf.data.experimental.AUTOTUNE,
    ) -> tf.data.Dataset:
        """
        Generate tf.data.dataset.

        Reference:

            - https://www.tensorflow.org/guide/data_performance#parallelizing_data_transformation
            - https://www.tensorflow.org/api_docs/python/tf/data/Dataset

        :param training: indicating if it's training or not
        :param batch_size: size of mini batch
        :param repeat: indicating if we need to repeat the dataset
        :param shuffle_buffer_num_batch: when shuffling,
            the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch
        :param repeat: indicating if we need to repeat the dataset
        :param data_augmentation: augmentation config, can be a list of dict or dict.
        :param num_parallel_calls: number elements to process asynchronously in parallel
            during preprocessing, -1 means unlimited, heuristically it should be set to
            the number of CPU cores available. AUTOTUNE=-1 means not limited.
        :returns dataset:
        """

        dataset = self.get_dataset()

        # resize
        dataset = dataset.map(
            lambda x: resize_inputs(
                inputs=x,
                moving_image_size=self.moving_image_shape,
                fixed_image_size=self.fixed_image_shape,
            ),
            num_parallel_calls=num_parallel_calls,
        )

        # shuffle / repeat / batch / preprocess
        if training and shuffle_buffer_num_batch > 0:
            dataset = dataset.shuffle(
                buffer_size=batch_size * shuffle_buffer_num_batch,
                reshuffle_each_iteration=True,
            )
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.batch(batch_size=batch_size, drop_remainder=training)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if training and data_augmentation is not None:
            if isinstance(data_augmentation, dict):
                data_augmentation = [data_augmentation]
            for config in data_augmentation:
                da_fn = REGISTRY.build_data_augmentation(
                    config=config,
                    default_args={
                        "moving_image_size": self.moving_image_shape,
                        "fixed_image_size": self.fixed_image_shape,
                        "batch_size": batch_size,
                    },
                )
                dataset = dataset.map(da_fn,
                                      num_parallel_calls=num_parallel_calls)

        return dataset