Esempio n. 1
0
 def test_backbone(self):
     registry = Registry()
     key = "new_backbone"
     value = 0
     registry.register_backbone(key, value)
     got = registry.get_backbone(key)
     assert got == value
def build_backbone(
        image_size: tuple,
        out_channels: int,
        config: dict,
        method_name: str,
        registry: Registry = Registry(),
) -> tf.keras.Model:
    """
    Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in)
    and returns a single output of shape (batch, dim1, dim2, dim3, ch_out).

    :param image_size: tuple, dims of image, (dim1, dim2, dim3)
    :param out_channels: int, number of out channels, ch_out
    :param method_name: str, one of ddf, dvf and conditional
    :param config: dict, backbone configuration
    :param registry: the registry object having all backbone classes
    :return: tf.keras.Model
    """
    if not ((isinstance(image_size, tuple) or isinstance(image_size, list))
            and len(image_size) == 3):
        raise ValueError(
            f"image_size must be tuple of length 3, got {image_size}")

    if method_name not in ["ddf", "dvf", "conditional", "affine"]:
        raise ValueError(
            f"method name has to be one of ddf/dvf/conditional/affine in build_backbone, "
            f"got {method_name}")

    if method_name in ["ddf", "dvf"]:
        out_activation = None
        # TODO try random init with smaller number
        out_kernel_initializer = "zeros"  # to ensure small ddf and dvf
    elif method_name in ["conditional"]:
        out_activation = "sigmoid"  # output is probability
        out_kernel_initializer = "glorot_uniform"
    elif method_name in ["affine"]:
        out_activation = None
        out_kernel_initializer = "zeros"

    backbone_cls = registry.get_backbone(key=config["name"])
    return backbone_cls(
        image_size=image_size,
        out_channels=out_channels,
        out_kernel_initializer=out_kernel_initializer,
        out_activation=out_activation,
        **config,
    )