def test_backbone(self): registry = Registry() key = "new_backbone" value = 0 registry.register_backbone(key, value) got = registry.get_backbone(key) assert got == value
def build_backbone( image_size: tuple, out_channels: int, config: dict, method_name: str, registry: Registry = Registry(), ) -> tf.keras.Model: """ Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in) and returns a single output of shape (batch, dim1, dim2, dim3, ch_out). :param image_size: tuple, dims of image, (dim1, dim2, dim3) :param out_channels: int, number of out channels, ch_out :param method_name: str, one of ddf, dvf and conditional :param config: dict, backbone configuration :param registry: the registry object having all backbone classes :return: tf.keras.Model """ if not ((isinstance(image_size, tuple) or isinstance(image_size, list)) and len(image_size) == 3): raise ValueError( f"image_size must be tuple of length 3, got {image_size}") if method_name not in ["ddf", "dvf", "conditional", "affine"]: raise ValueError( f"method name has to be one of ddf/dvf/conditional/affine in build_backbone, " f"got {method_name}") if method_name in ["ddf", "dvf"]: out_activation = None # TODO try random init with smaller number out_kernel_initializer = "zeros" # to ensure small ddf and dvf elif method_name in ["conditional"]: out_activation = "sigmoid" # output is probability out_kernel_initializer = "glorot_uniform" elif method_name in ["affine"]: out_activation = None out_kernel_initializer = "zeros" backbone_cls = registry.get_backbone(key=config["name"]) return backbone_cls( image_size=image_size, out_channels=out_channels, out_kernel_initializer=out_kernel_initializer, out_activation=out_activation, **config, )