コード例 #1
0
def build_backbone(image_size, out_channels, tf_model_config):
    """
    backbone model accepts a single input of shape [batch, dim1, dim2, dim3, ch_in]
               and returns a single output of shape [batch, dim1, dim2, dim3, ch_out]
    :param image_size: [dim1, dim2, dim3]
    :param out_channels: ch_out
    :param tf_model_config:
    :return:
    """

    # no activation
    if tf_model_config["backbone"]["out_activation"] == "":
        tf_model_config["backbone"]["out_activation"] = None

    if tf_model_config["backbone"]["name"] == "local":
        return LocalNet(image_size=image_size, out_channels=out_channels,
                        out_kernel_initializer=tf_model_config["backbone"]["out_kernel_initializer"],
                        out_activation=tf_model_config["backbone"]["out_activation"],
                        **tf_model_config["local"])
    elif tf_model_config["backbone"]["name"] == "unet":
        return UNet(image_size=image_size, out_channels=out_channels,
                    out_kernel_initializer=tf_model_config["backbone"]["out_kernel_initializer"],
                    out_activation=tf_model_config["backbone"]["out_activation"],
                    **tf_model_config["unet"])
    else:
        raise ValueError("Unknown model name")
コード例 #2
0
ファイル: util.py プロジェクト: knvsmadhav/DeepReg
def build_backbone(image_size: tuple, out_channels: int, model_config: dict,
                   method_name: str) -> tf.keras.Model:
    """
    Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in)
    and returns a single output of shape (batch, dim1, dim2, dim3, ch_out)

    :param image_size: tuple, dims of image, (dim1, dim2, dim3)
    :param out_channels: int, number of out channels, ch_out
    :param method_name: str, one of ddf | dvf | conditional
    :param model_config: dict, model configuration, returned from parser.yaml.load
    :return: tf.keras.Model
    """
    if method_name not in ["ddf", "dvf", "conditional", "affine"]:
        raise ValueError(
            "method name has to be one of ddf / dvf / conditional in build_backbone, "
            "got {}".format(method_name))

    if method_name in ["ddf", "dvf"]:
        out_activation = None
        # TODO try random init with smaller number
        out_kernel_initializer = "zeros"  # to ensure small ddf and dvf
    elif method_name in ["conditional"]:
        out_activation = "sigmoid"  # output is probability
        out_kernel_initializer = "glorot_uniform"
    elif method_name in ["affine"]:
        out_activation = None
        out_kernel_initializer = "zeros"
    else:
        raise ValueError("Unknown method name {}".format(method_name))

    if model_config["backbone"] == "local":
        return LocalNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["local"],
        )
    elif model_config["backbone"] == "global":
        return GlobalNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["global"],
        )
    elif model_config["backbone"] == "unet":
        return UNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["unet"],
        )
    else:
        raise ValueError("Unknown model name")
コード例 #3
0
 def test_get_config(self):
     config = dict(
         image_size=(4, 5, 6),
         out_channels=3,
         num_channel_initial=2,
         depth=2,
         extract_levels=(0, 1),
         out_kernel_initializer="he_normal",
         out_activation="softmax",
         pooling=False,
         concat_skip=False,
         use_additive_upsampling=True,
         encode_kernel_sizes=[7, 3, 3],
         decode_kernel_sizes=3,
         encode_num_channels=(2, 4, 8),
         decode_num_channels=(2, 4, 8),
         strides=2,
         padding="same",
         name="Test",
     )
     network = LocalNet(**config)
     got = network.get_config()
     assert got == config
コード例 #4
0
    def test_call(
        self,
        image_size: tuple,
        extract_levels: Tuple[int, ...],
        depth: int,
        use_additive_upsampling: bool,
        pooling: bool,
        concat_skip: bool,
    ):
        """

        :param image_size: (dim1, dim2, dim3), dims of input image.
        :param extract_levels: from which depths the output will be built.
        :param depth: input is at level 0, bottom is at level depth
        :param use_additive_upsampling: whether use additive up-sampling layer
            for decoding.
        :param pooling: for down-sampling, use non-parameterized
                        pooling if true, otherwise use conv3d
        :param concat_skip: if concatenate skip or add it
        """
        out_ch = 3
        network = LocalNet(
            image_size=image_size,
            num_channel_initial=2,
            extract_levels=extract_levels,
            depth=depth,
            out_kernel_initializer="he_normal",
            out_activation="softmax",
            out_channels=out_ch,
            use_additive_upsampling=use_additive_upsampling,
            pooling=pooling,
            concat_skip=concat_skip,
        )
        inputs = tf.ones(shape=(5, *image_size, out_ch))
        output = network.call(inputs)
        assert inputs.shape == output.shape
コード例 #5
0
ファイル: util.py プロジェクト: zcemycl/DeepReg
def build_backbone(image_size: tuple, out_channels: int, model_config: dict,
                   method_name: str) -> tf.keras.Model:
    """
    Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in)
    and returns a single output of shape (batch, dim1, dim2, dim3, ch_out).

    :param image_size: tuple, dims of image, (dim1, dim2, dim3)
    :param out_channels: int, number of out channels, ch_out
    :param method_name: str, one of ddf, dvf and conditional
    :param model_config: dict, model configuration, returned from parser.yaml.load
    :return: tf.keras.Model
    """
    if not ((isinstance(image_size, tuple) or isinstance(image_size, list))
            and len(image_size) == 3):
        raise ValueError(
            f"image_size must be tuple of length 3, got {image_size}")
    if not (isinstance(out_channels, int) and out_channels >= 1):
        raise ValueError(f"out_channels must be int >=1, got {out_channels}")
    if not (isinstance(model_config, dict)
            and "backbone" in model_config.keys()):
        raise ValueError(
            f"model_config must be a dict having key 'backbone', got{model_config}"
        )
    if method_name not in ["ddf", "dvf", "conditional", "affine"]:
        raise ValueError(
            f"method name has to be one of ddf/dvf/conditional/affine in build_backbone, "
            f"got {method_name}")

    if method_name in ["ddf", "dvf"]:
        out_activation = None
        # TODO try random init with smaller number
        out_kernel_initializer = "zeros"  # to ensure small ddf and dvf
    elif method_name in ["conditional"]:
        out_activation = "sigmoid"  # output is probability
        out_kernel_initializer = "glorot_uniform"
    elif method_name in ["affine"]:
        out_activation = None
        out_kernel_initializer = "zeros"
    else:
        raise ValueError("Unknown method name {}".format(method_name))

    if model_config["backbone"] == "local":
        return LocalNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["local"],
        )
    elif model_config["backbone"] == "global":
        return GlobalNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["global"],
        )
    elif model_config["backbone"] == "unet":
        return UNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["unet"],
        )
    else:
        raise ValueError("Unknown model name")