Esempio n. 1
0
    def __init__(
        self,
        block_names=None,
        down_conv_nn=None,
        grid_size=None,
        prev_grid_size=None,
        has_bottleneck=None,
        max_num_neighbors=None,
        deformable=False,
        add_one=False,
        **kwargs,
    ):
        super(KPDualBlock, self).__init__()

        assert len(block_names) == len(down_conv_nn)
        self.blocks = torch.nn.ModuleList()
        for i, class_name in enumerate(block_names):
            # Constructing extra keyword arguments
            block_kwargs = {}
            for key, arg in kwargs.items():
                block_kwargs[key] = arg[i] if is_list(arg) else arg

            # Building the block
            kpcls = getattr(sys.modules[__name__], class_name)
            block = kpcls(
                down_conv_nn=down_conv_nn[i],
                grid_size=grid_size[i],
                prev_grid_size=prev_grid_size[i],
                has_bottleneck=has_bottleneck[i],
                max_num_neighbors=max_num_neighbors[i],
                deformable=deformable[i] if is_list(deformable) else deformable,
                add_one=add_one[i] if is_list(add_one) else add_one,
                **block_kwargs,
            )
            self.blocks.append(block)
    def __init__(
        self,
        radius: Union[float, List[float]],
        max_num_neighbors: Union[int, List[int]] = 64,
    ):

        if DEBUGGING_VARS["FIND_NEIGHBOUR_DIST"]:
            self._dist_meters = [DistributionNeighbour(r) for r in radius]
            max_num_neighbors = [256 for _ in max_num_neighbors]

        if not is_list(max_num_neighbors) and is_list(radius):
            self._radius = radius
            self._max_num_neighbors = [
                max_num_neighbors for i in range(len(self._radius))
            ]
            return

        if not is_list(radius) and is_list(max_num_neighbors):
            self._max_num_neighbors = max_num_neighbors
            self._radius = [
                radius for i in range(len(self._max_num_neighbors))
            ]
            return

        if is_list(max_num_neighbors):
            if len(max_num_neighbors) != len(radius):
                raise ValueError(
                    "Both lists max_num_neighbors and radius should be of the same length"
                )
            self._max_num_neighbors = max_num_neighbors
            self._radius = radius
            return

        self._max_num_neighbors = [max_num_neighbors]
        self._radius = [radius]
Esempio n. 3
0
    def _save_sampling_and_search(self, down_conv):
        sampler = getattr(down_conv, "sampler", None)
        if is_list(sampler):
            self._spatial_ops_dict["sampler"] += sampler
        else:
            self._spatial_ops_dict["sampler"].append(sampler)

        neighbour_finder = getattr(down_conv, "neighbour_finder", None)
        if is_list(neighbour_finder):
            self._spatial_ops_dict["neighbour_finder"] += neighbour_finder
        else:
            self._spatial_ops_dict["neighbour_finder"].append(neighbour_finder)
    def _save_sampling_and_search(self, down_conv):
        sampler = getattr(down_conv, "sampler", None)
        if is_list(sampler):
            self._spatial_ops_dict["sampler"] = sampler + self._spatial_ops_dict["sampler"]
        else:
            self._spatial_ops_dict["sampler"] = [sampler] + self._spatial_ops_dict["sampler"]

        neighbour_finder = getattr(down_conv, "neighbour_finder", None)
        if is_list(neighbour_finder):
            self._spatial_ops_dict["neighbour_finder"] = neighbour_finder + self._spatial_ops_dict["neighbour_finder"]
        else:
            self._spatial_ops_dict["neighbour_finder"] = [neighbour_finder] + self._spatial_ops_dict["neighbour_finder"]
Esempio n. 5
0
    def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib):
        """Construct a Unet unwrapped generator

        The layers will be appended within lists with the following names
        * down_modules : Contains all the down module
        * inner_modules : Contain one or more inner modules
        * up_modules: Contains all the up module

        Parameters:
            opt - options for the network generation
            model_type - type of the model to be generated
            num_class - output of the network
            modules_lib - all modules that can be used in the UNet

        For a recursive implementation. See UnetBaseModel.

        opt is expected to contains the following keys:
        * down_conv
        * up_conv
        * OPTIONAL: innermost

        """
        super(UnwrappedUnetBasedModel, self).__init__(opt)
        # detect which options format has been used to define the model
        self._spatial_ops_dict = {
            "neighbour_finder": [],
            "sampler": [],
            "upsample_op": []
        }

        if is_list(opt.down_conv) or "down_conv_nn" not in opt.down_conv:
            raise NotImplementedError
        else:
            self._init_from_compact_format(opt, model_type, dataset,
                                           modules_lib)
Esempio n. 6
0
    def _save_sampling_and_search(self, submodule):
        sampler = getattr(submodule.down, "sampler", None)
        if is_list(sampler):
            self._spatial_ops_dict["sampler"] = sampler + self._spatial_ops_dict["sampler"]
        else:
            self._spatial_ops_dict["sampler"] = [sampler] + self._spatial_ops_dict["sampler"]

        neighbour_finder = getattr(submodule.down, "neighbour_finder", None)
        if is_list(neighbour_finder):
            self._spatial_ops_dict["neighbour_finder"] = neighbour_finder + self._spatial_ops_dict["neighbour_finder"]
        else:
            self._spatial_ops_dict["neighbour_finder"] = [neighbour_finder] + self._spatial_ops_dict["neighbour_finder"]

        upsample_op = getattr(submodule.up, "upsample_op", None)
        if upsample_op:
            self._spatial_ops_dict["upsample_op"].append(upsample_op)
Esempio n. 7
0
 def _fetch_arguments_from_list(self, opt, index):
     """Fetch the arguments for a single convolution from multiple lists
     of arguments - for models specified in the compact format.
     """
     args = {}
     for o, v in opt.items():
         name = str(o)
         if is_list(v) and len(getattr(opt, o)) > 0:
             if name[-1] == "s" and name not in SPECIAL_NAMES:
                 name = name[:-1]
             v_index = v[index]
             if is_list(v_index):
                 v_index = list(v_index)
             args[name] = v_index
         else:
             if is_list(v):
                 v = list(v)
             args[name] = v
     return args
Esempio n. 8
0
    def _create_inner_modules(self, args_innermost, modules_lib):
        inners = []
        if is_list(args_innermost):
            for inner_opt in args_innermost:
                module_name = self._get_from_kwargs(inner_opt, "module_name")
                inner_module_cls = getattr(modules_lib, module_name)
                inners.append(inner_module_cls(**inner_opt))

        else:
            module_name = self._get_from_kwargs(args_innermost, "module_name")
            inner_module_cls = getattr(modules_lib, module_name)
            inners.append(inner_module_cls(**args_innermost))

        return inners
Esempio n. 9
0
    def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib):
        """Construct a backbone generator (It is a simple down module)
        Parameters:
            opt - options for the network generation
            model_type - type of the model to be generated
            modules_lib - all modules that can be used in the backbone


        opt is expected to contains the following keys:
        * down_conv
        """

        super(BackboneBasedModel, self).__init__(opt)
        self._spatial_ops_dict = {"neighbour_finder": [], "sampler": []}

        # detect which options format has been used to define the model
        if is_list(opt.down_conv) or "down_conv_nn" not in opt.down_conv:
            raise NotImplementedError
        else:
            self._init_from_compact_format(opt, model_type, dataset,
                                           modules_lib)