示例#1
0
 def __init__(self, moving_image_size, fixed_image_size, batch_size, scale=0.1):
     self._batch_size = batch_size
     self._scale = scale
     self._moving_grid_ref = layer_util.get_reference_grid(
         grid_size=moving_image_size
     )
     self._fixed_grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
示例#2
0
 def test_get_reference_grid(self):
     want = tf.constant(
         np.array([[[[0, 0, 0], [0, 0, 1], [0, 0, 2]],
                    [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
                  dtype=np.float32))
     get = layer_util.get_reference_grid(grid_size=[1, 2, 3])
     self.check_equal(want, get)
示例#3
0
    def __init__(
        self,
        image_size: tuple,
        out_channels: int,
        num_channel_initial: int,
        extract_levels: List[int],
        out_kernel_initializer: str,
        out_activation: str,
        name: str = "GlobalNet",
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E.
        Then, a densely-connected layer outputs an affine
        transformation.

        :param image_size: tuple, such as (dim1, dim2, dim3)
        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param extract_levels: list, which levels from net to extract
        :param out_kernel_initializer: not used
        :param out_activation: not used
        :param name: name of the backbone.
        :param kwargs: additional arguments.
        """
        super().__init__(
            image_size=image_size,
            out_channels=out_channels,
            num_channel_initial=num_channel_initial,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            name=name,
            **kwargs,
        )

        # save parameters
        assert out_channels == 3
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self.reference_grid = layer_util.get_reference_grid(image_size)
        self.transform_initial = tf.constant_initializer(
            value=list(np.eye(4, 3).reshape((-1))))
        # init layer variables
        num_channels = [
            num_channel_initial * (2**level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(filters=num_channels[i],
                                        kernel_size=7 if i == 0 else 3)
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(
            filters=num_channels[-1])  # level E
        self._dense_layer = layer.Dense(
            units=12, bias_initializer=self.transform_initial)
示例#4
0
 def __init__(self, fixed_image_size, **kwargs):
     """
     :param fixed_image_size: shape = [f_dim1, f_dim2, f_dim3]
                              or [f_dim1, f_dim2, f_dim3, ch] with the last channel for features
     :param kwargs:
     """
     super(Warping, self).__init__(**kwargs)
     self.grid_ref = tf.expand_dims(
         layer_util.get_reference_grid(grid_size=fixed_image_size), axis=0
     )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
示例#5
0
def test_get_reference_grid():
    """
    Test get_reference_grid by confirming that it generates
    a sample grid test case to is_equal_tf's tolerance level.
    """
    want = tf.constant(
        np.array(
            [[[[0, 0, 0], [0, 0, 1], [0, 0, 2]],
              [[0, 1, 0], [0, 1, 1], [0, 1, 2]]]],
            dtype=np.float32,
        ))
    get = layer_util.get_reference_grid(grid_size=[1, 2, 3])
    assert is_equal_tf(want, get)
示例#6
0
    def __init__(
        self,
        moving_image_size: tuple,
        fixed_image_size: tuple,
        batch_size: int,
        name: str = "RandomTransformation3D",
        trainable: bool = False,
    ):
        """
        Abstract class for image transformation.

        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
        :param batch_size: size of mini-batch
        :param name: name of layer
        :param trainable: if this layer is trainable
        """
        super().__init__(trainable=trainable, name=name)
        self.moving_image_size = moving_image_size
        self.fixed_image_size = fixed_image_size
        self.batch_size = batch_size
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
示例#7
0
    def __init__(
        self,
        moving_image_size: Tuple[int, ...],
        fixed_image_size: Tuple[int, ...],
        batch_size: int,
        name: str = "RandomTransformation3D",
        trainable: bool = False,
    ):
        """
        Abstract class for image transformation.

        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
        :param batch_size: total number of samples consumed per step, over all devices.
        :param name: name of layer
        :param trainable: if this layer is trainable
        """
        super().__init__(trainable=trainable, name=name)
        self.moving_image_size = moving_image_size
        self.fixed_image_size = fixed_image_size
        self.batch_size = batch_size
        self.moving_grid_ref = get_reference_grid(grid_size=moving_image_size)
        self.fixed_grid_ref = get_reference_grid(grid_size=fixed_image_size)
示例#8
0
    def __init__(
        self,
        image_size,
        out_channels,
        num_channel_initial,
        extract_levels,
        out_kernel_initializer,
        out_activation,
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E.
        Then, a densely-connected layer outputs an affine
        transformation.

        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param extract_levels: list, which levels from net to extract
        :param out_activation: str, activation at last layer
        :param out_kernel_initializer: str, which kernel to use as initialiser
        :param kwargs:
        """
        super(GlobalNet, self).__init__(**kwargs)

        # save parameters
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self.reference_grid = layer_util.get_reference_grid(image_size)
        self.transform_initial = tf.constant_initializer(
            value=[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
        )

        # init layer variables
        num_channels = [
            num_channel_initial * (2 ** level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(
                filters=num_channels[i], kernel_size=7 if i == 0 else 3
            )
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(filters=num_channels[-1])  # level E
        self._dense_layer = layer.Dense(
            units=12, bias_initializer=self.transform_initial
        )
        self._reshape = tf.keras.layers.Reshape(target_shape=(4, 3))
示例#9
0
    def __init__(self,
                 fixed_image_size: tuple,
                 name: str = "warping",
                 **kwargs):
        """
        Init.

        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
        :param name: name of the layer
        :param kwargs: additional arguments.
        """
        super().__init__(name=name, **kwargs)
        self._fixed_image_size = fixed_image_size
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
        self.grid_ref = layer_util.get_reference_grid(
            grid_size=fixed_image_size)[None, ...]
示例#10
0
    def __init__(
        self,
        image_size: tuple,
        name: str = "AffineHead",
    ):
        """
        Init.

        :param image_size: such as (dim1, dim2, dim3)
        :param name: name of the layer
        """
        super().__init__(name=name)
        self.reference_grid = layer_util.get_reference_grid(image_size)
        self.transform_initial = tf.constant_initializer(
            value=list(np.eye(4, 3).reshape((-1))))
        self._flatten = tfkl.Flatten()
        self._dense = tfkl.Dense(units=12,
                                 bias_initializer=self.transform_initial)
示例#11
0
    def __init__(self, fixed_image_size: tuple, **kwargs):
        """
        A layer warps an image using DDF.

        Reference:

        - transform of neuron
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py

          where vol = image, loc_shift = ddf

        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
        :param kwargs: additional arguments.
        """
        super().__init__(**kwargs)
        self.grid_ref = tf.expand_dims(
            layer_util.get_reference_grid(grid_size=fixed_image_size), axis=0
        )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
示例#12
0
    def __init__(
        self,
        moving_image_size: Tuple,
        fixed_image_size: Tuple,
        index_size: int,
        labeled: bool,
        batch_size: int,
        config: dict,
        name: str = "RegistrationModel",
    ):
        """
        Init.

        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
        :param index_size: number of indices for identify each sample
        :param labeled: if the data is labeled
        :param batch_size: total number of samples consumed per step, over all devices.
            When using multiple devices, TensorFlow automatically split the tensors.
            Therefore, input shapes should be defined over batch_size.
        :param config: config for method, backbone, and loss.
        :param name: name of the model
        """
        super().__init__(name=name)
        self.moving_image_size = moving_image_size
        self.fixed_image_size = fixed_image_size
        self.index_size = index_size
        self.labeled = labeled
        self.config = config
        self.batch_size = batch_size

        self._inputs = None  # save inputs of self._model as dict
        self._outputs = None  # save outputs of self._model as dict

        self.grid_ref = layer_util.get_reference_grid(
            grid_size=fixed_image_size)[None, ...]
        self._model: tf.keras.Model = self.build_model()
        self.build_loss()
示例#13
0
def predict(
    gpu: str,
    gpu_allow_growth: bool,
    ckpt_path: str,
    mode: str,
    batch_size: int,
    log_dir: str,
    sample_label: str,
    config_path: (str, list),
    save_nifti: bool = True,
    save_png: bool = True,
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: str, which env gpu to use.
    :param gpu_allow_growth: bool, whether to allow gpu growth or not
    :param ckpt_path: str, where model is stored, should be like log_folder/save/xxx.ckpt
    :param mode: train / valid / test, to define which split of dataset to be evaluated
    :param batch_size: int, batch size to perform predictions in
    :param log_dir: str, path to store logs
    :param sample_label: sample/all, not used
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    :param config_path: to overwrite the default config
    """
    # TODO support custom sample_label
    logging.warning(
        "sample_label is not used in predict. It is True if and only if mode == 'train'."
    )

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir = build_config(config_path=config_path,
                                   log_dir=log_dir,
                                   ckpt_path=ckpt_path)
    preprocess_config = config["train"]["preprocess"]
    preprocess_config["batch_size"] = batch_size

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=preprocess_config,
        mode=mode,
        training=False,
        repeat=False,
    )

    # optimizer
    optimizer = opt.build_optimizer(
        optimizer_config=config["train"]["optimizer"])

    # model
    model = build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=config["dataset"]["labeled"],
        batch_size=preprocess_config["batch_size"],
        model_config=config["train"]["model"],
        loss_config=config["train"]["loss"],
    )

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec
    model.load_weights(ckpt_path).expect_partial()

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        model_method=config["train"]["model"]["method"],
        save_dir=log_dir + "/test",
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
示例#14
0
## load image
if not os.path.exists(DATA_PATH):
    raise ("Download the data using demo_data.py script")
if not os.path.exists(FILE_PATH):
    raise ("Download the data using demo_data.py script")

fid = h5py.File(FILE_PATH, "r")
fixed_image = tf.cast(tf.expand_dims(fid["image"], axis=0), dtype=tf.float32)
fixed_image = (fixed_image - tf.reduce_min(fixed_image)) / (
    tf.reduce_max(fixed_image) - tf.reduce_min(fixed_image)
)  # normalisation to [0,1]

# generate a radomly-affine-transformed moving image
fixed_image_size = fixed_image.shape
transform_random = layer_util.random_transform_generator(batch_size=1, scale=0.2)
grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size[1:4])
grid_random = layer_util.warp_grid(grid_ref, transform_random)
moving_image = layer_util.resample(vol=fixed_image, loc=grid_random)
# warp the labels to get ground-truth using the same random affine, for validation
fixed_labels = tf.cast(tf.expand_dims(fid["label"], axis=0), dtype=tf.float32)
moving_labels = tf.stack(
    [
        layer_util.resample(vol=fixed_labels[..., idx], loc=grid_random)
        for idx in range(fixed_labels.shape[4])
    ],
    axis=4,
)


## optimisation
@tf.function
示例#15
0
def main(gpu, gpu_allow_growth, ckpt_path, mode, batch_size, log):
    # sanity check
    if not ckpt_path.endswith(
            ".ckpt"):  # should be like log_folder/save/xxx.ckpt
        raise ValueError("checkpoint path should end with .ckpt")

    # env vars
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config = config_parser.load("/".join(ckpt_path.split("/")[:-2]) +
                                "/config.yaml")
    data_config = config["data"]
    if data_config["name"] == "mr_us":
        data_config["sample_label"]["train"] = "all"
        data_config["sample_label"]["test"] = "all"
    tf_data_config = config["tf"]["data"]
    tf_data_config["batch_size"] = batch_size
    tf_opt_config = config["tf"]["opt"]
    tf_model_config = config["tf"]["model"]
    tf_loss_config = config["tf"]["loss"]
    log_folder_name = log if log != "" else datetime.now().strftime(
        "%Y%m%d-%H%M%S")
    log_dir = config["log_dir"][:-1] if config["log_dir"][
        -1] == "/" else config["log_dir"]
    log_dir = log_dir + "/" + log_folder_name

    # data
    data_loader = load.get_data_loader(data_config, mode)
    dataset = data_loader.get_dataset_and_preprocess(training=False,
                                                     repeat=False,
                                                     **tf_data_config)

    # optimizer
    optimizer = opt.get_optimizer(tf_opt_config)

    # model
    model = network.build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        batch_size=tf_data_config["batch_size"],
        tf_model_config=tf_model_config,
        tf_loss_config=tf_loss_config)

    # metrics
    model.compile(optimizer=optimizer,
                  loss=label_loss.get_similarity_fn(
                      config=tf_loss_config["similarity"]["label"]),
                  metrics=[
                      metric.MeanDiceScore(),
                      metric.MeanCentroidDistance(
                          grid_size=data_loader.fixed_image_shape)
                  ])

    # load weights
    model.load_weights(ckpt_path)

    # predict
    fixed_grid_ref = layer_util.get_reference_grid(
        grid_size=data_loader.fixed_image_shape)
    predict(data_loader=data_loader,
            dataset=dataset,
            fixed_grid_ref=fixed_grid_ref,
            model=model,
            save_dir=log_dir + "/test")
示例#16
0
def predict(
    gpu,
    gpu_allow_growth,
    ckpt_path,
    mode,
    batch_size,
    log_dir,
    sample_label,
    config_path,
):
    """
    Function to predict some metrics from the saved model and logging results.
    :param gpu: str, which env gpu to use.
    :param gpu_allow_growth: bool, whether to allow gpu growth or not
    :param ckpt_path: str, where model is stored, should be like
                      log_folder/save/xxx.ckpt
    :param mode: which mode to load the data ??
    :param batch_size: int, batch size to perform predictions in
    :param log_dir: str, path to store logs
    :param sample_label:
    :param config_path: to overwrite the default config
    """
    logging.error("TODO sample_label is not used in predict")

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir = init(log_dir, ckpt_path, config_path)
    dataset_config = config["dataset"]
    preprocess_config = config["train"]["preprocess"]
    preprocess_config["batch_size"] = batch_size
    optimizer_config = config["train"]["optimizer"]
    model_config = config["train"]["model"]
    loss_config = config["train"]["loss"]

    # data
    data_loader = load.get_data_loader(dataset_config, mode)
    if data_loader is None:
        raise ValueError(
            "Data loader for prediction is None. Probably the data dir path is not defined."
        )
    dataset = data_loader.get_dataset_and_preprocess(training=False,
                                                     repeat=False,
                                                     **preprocess_config)

    # optimizer
    optimizer = opt.build_optimizer(optimizer_config)

    # model
    model = build_model(
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=dataset_config["labeled"],
        batch_size=preprocess_config["batch_size"],
        model_config=model_config,
        loss_config=loss_config,
    )

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-objec
    model.load_weights(ckpt_path).expect_partial()

    # predict
    fixed_grid_ref = layer_util.get_reference_grid(
        grid_size=data_loader.fixed_image_shape)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        save_dir=log_dir + "/test",
    )

    data_loader.close()
示例#17
0
def predict(
    gpu: str,
    gpu_allow_growth: bool,
    ckpt_path: str,
    mode: str,
    batch_size: int,
    exp_name: str,
    config_path: Union[str, List[str]],
    save_nifti: bool = True,
    save_png: bool = True,
    log_dir: str = "logs",
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: which env gpu to use.
    :param gpu_allow_growth: whether to allow gpu growth or not
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x
    :param mode: train / valid / test, to define which split of dataset to be evaluated
    :param batch_size: int, batch size to perform predictions in
    :param exp_name: name of the experiment
    :param log_dir: path of the log directory
    :param save_nifti: if true, outputs will be saved in nifti format
    :param save_png: if true, outputs will be saved in png format
    :param config_path: to overwrite the default config
    """
    # TODO support custom sample_label
    logging.warning("sample_label is not used in predict. "
                    "It is True if and only if mode == 'train'.")

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"

    # load config
    config, log_dir, ckpt_path = build_config(config_path=config_path,
                                              log_dir=log_dir,
                                              exp_name=exp_name,
                                              ckpt_path=ckpt_path)
    preprocess_config = config["train"]["preprocess"]
    # batch_size corresponds to batch_size per GPU
    gpus = tf.config.experimental.list_physical_devices("GPU")
    preprocess_config["batch_size"] = batch_size * max(len(gpus), 1)

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=preprocess_config,
        mode=mode,
        training=False,
        repeat=False,
    )
    assert data_loader is not None

    # optimizer
    optimizer = opt.build_optimizer(
        optimizer_config=config["train"]["optimizer"])

    # model
    model: tf.keras.Model = REGISTRY.build_model(config=dict(
        name=config["train"]["method"],
        moving_image_size=data_loader.moving_image_shape,
        fixed_image_size=data_loader.fixed_image_shape,
        index_size=data_loader.num_indices,
        labeled=config["dataset"]["labeled"],
        batch_size=config["train"]["preprocess"]["batch_size"],
        config=config["train"],
    ))

    # metrics
    model.compile(optimizer=optimizer)

    # load weights
    if ckpt_path.endswith(".ckpt"):
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
        # skip warnings because of optimizers
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
    else:
        # for ckpts from ckpt manager callback
        _, _ = build_checkpoint_callback(
            model=model,
            dataset=dataset,
            log_dir=log_dir,
            save_period=config["train"]["save_period"],
            ckpt_path=ckpt_path,
        )

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        model_method=config["train"]["method"],
        save_dir=os.path.join(log_dir, "test"),
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
示例#18
0
def predict(
    gpu: str,
    ckpt_path: str,
    split: str,
    batch_size: int,
    exp_name: str,
    config_path: Union[str, List[str]],
    num_workers: int = 1,
    gpu_allow_growth: bool = True,
    save_nifti: bool = True,
    save_png: bool = True,
    log_dir: str = "logs",
):
    """
    Function to predict some metrics from the saved model and logging results.

    :param gpu: which env gpu to use.
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
    :param split: train / valid / test, to define the split to be evaluated.
    :param batch_size: int, batch size to perform predictions.
    :param exp_name: name of the experiment.
    :param config_path: to overwrite the default config.
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
    :param save_nifti: if true, outputs will be saved in nifti format.
    :param save_png: if true, outputs will be saved in png format.
    :param log_dir: path of the log directory.
    """

    # env vars
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    os.environ[
        "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
    if num_workers <= 0:  # pragma: no cover
        logger.info(
            "Limiting CPU usage by setting environment variables "
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
            "This may slow down the prediction. "
            "Please use --num_workers flag to modify the behavior. "
            "Setting to 0 or negative values will remove the limitation.",
            num_workers,
        )
        # limit CPU usage
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)

    # load config
    config, log_dir, ckpt_path = build_config(config_path=config_path,
                                              log_dir=log_dir,
                                              exp_name=exp_name,
                                              ckpt_path=ckpt_path)
    config["train"]["preprocess"]["batch_size"] = batch_size

    # data
    data_loader, dataset, _ = build_dataset(
        dataset_config=config["dataset"],
        preprocess_config=config["train"]["preprocess"],
        split=split,
        training=False,
        repeat=False,
    )
    assert data_loader is not None

    # use strategy to support multiple GPUs
    # the network is mirrored in each GPU so that we can use larger batch size
    # https://www.tensorflow.org/guide/distributed_training
    # only model, optimizer and metrics need to be defined inside the strategy
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
    if num_devices > 1:  # pragma: no cover
        strategy = tf.distribute.MirroredStrategy()
        if batch_size % num_devices != 0:
            raise ValueError(
                f"batch size {batch_size} can not be divided evenly "
                f"by the number of devices.")
    else:
        strategy = tf.distribute.get_strategy()
    with strategy.scope():
        model: tf.keras.Model = REGISTRY.build_model(config=dict(
            name=config["train"]["method"],
            moving_image_size=data_loader.moving_image_shape,
            fixed_image_size=data_loader.fixed_image_shape,
            index_size=data_loader.num_indices,
            labeled=config["dataset"][split]["labeled"],
            batch_size=batch_size,
            config=config["train"],
        ))
        optimizer = opt.build_optimizer(
            optimizer_config=config["train"]["optimizer"])
        model.compile(optimizer=optimizer)
        model.plot_model(output_dir=log_dir)

    # load weights
    if ckpt_path.endswith(".ckpt"):
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
        # skip warnings because of optimizers
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
    else:
        # for ckpts from ckpt manager callback
        _, _ = build_checkpoint_callback(
            model=model,
            dataset=dataset,
            log_dir=log_dir,
            save_period=config["train"]["save_period"],
            ckpt_path=ckpt_path,
        )

    # predict
    fixed_grid_ref = tf.expand_dims(
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape),
        axis=0)  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
    predict_on_dataset(
        dataset=dataset,
        fixed_grid_ref=fixed_grid_ref,
        model=model,
        save_dir=os.path.join(log_dir, "test"),
        save_nifti=save_nifti,
        save_png=save_png,
    )

    # close the opened files in data loaders
    data_loader.close()
示例#19
0
 def __init__(self, grid_size, name="metric/centroid_distance_mean", **kwargs):
     super(MeanCentroidDistance, self).__init__(name=name, **kwargs)
     self.grid = layer_util.get_reference_grid(grid_size)