def __init__( self, block_names=None, down_conv_nn=None, grid_size=None, prev_grid_size=None, has_bottleneck=None, max_num_neighbors=None, deformable=False, add_one=False, **kwargs, ): super(KPDualBlock, self).__init__() assert len(block_names) == len(down_conv_nn) self.blocks = torch.nn.ModuleList() for i, class_name in enumerate(block_names): # Constructing extra keyword arguments block_kwargs = {} for key, arg in kwargs.items(): block_kwargs[key] = arg[i] if is_list(arg) else arg # Building the block kpcls = getattr(sys.modules[__name__], class_name) block = kpcls( down_conv_nn=down_conv_nn[i], grid_size=grid_size[i], prev_grid_size=prev_grid_size[i], has_bottleneck=has_bottleneck[i], max_num_neighbors=max_num_neighbors[i], deformable=deformable[i] if is_list(deformable) else deformable, add_one=add_one[i] if is_list(add_one) else add_one, **block_kwargs, ) self.blocks.append(block)
def __init__( self, radius: Union[float, List[float]], max_num_neighbors: Union[int, List[int]] = 64, ): if DEBUGGING_VARS["FIND_NEIGHBOUR_DIST"]: self._dist_meters = [DistributionNeighbour(r) for r in radius] max_num_neighbors = [256 for _ in max_num_neighbors] if not is_list(max_num_neighbors) and is_list(radius): self._radius = radius self._max_num_neighbors = [ max_num_neighbors for i in range(len(self._radius)) ] return if not is_list(radius) and is_list(max_num_neighbors): self._max_num_neighbors = max_num_neighbors self._radius = [ radius for i in range(len(self._max_num_neighbors)) ] return if is_list(max_num_neighbors): if len(max_num_neighbors) != len(radius): raise ValueError( "Both lists max_num_neighbors and radius should be of the same length" ) self._max_num_neighbors = max_num_neighbors self._radius = radius return self._max_num_neighbors = [max_num_neighbors] self._radius = [radius]
def _save_sampling_and_search(self, down_conv): sampler = getattr(down_conv, "sampler", None) if is_list(sampler): self._spatial_ops_dict["sampler"] += sampler else: self._spatial_ops_dict["sampler"].append(sampler) neighbour_finder = getattr(down_conv, "neighbour_finder", None) if is_list(neighbour_finder): self._spatial_ops_dict["neighbour_finder"] += neighbour_finder else: self._spatial_ops_dict["neighbour_finder"].append(neighbour_finder)
def _save_sampling_and_search(self, down_conv): sampler = getattr(down_conv, "sampler", None) if is_list(sampler): self._spatial_ops_dict["sampler"] = sampler + self._spatial_ops_dict["sampler"] else: self._spatial_ops_dict["sampler"] = [sampler] + self._spatial_ops_dict["sampler"] neighbour_finder = getattr(down_conv, "neighbour_finder", None) if is_list(neighbour_finder): self._spatial_ops_dict["neighbour_finder"] = neighbour_finder + self._spatial_ops_dict["neighbour_finder"] else: self._spatial_ops_dict["neighbour_finder"] = [neighbour_finder] + self._spatial_ops_dict["neighbour_finder"]
def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib): """Construct a Unet unwrapped generator The layers will be appended within lists with the following names * down_modules : Contains all the down module * inner_modules : Contain one or more inner modules * up_modules: Contains all the up module Parameters: opt - options for the network generation model_type - type of the model to be generated num_class - output of the network modules_lib - all modules that can be used in the UNet For a recursive implementation. See UnetBaseModel. opt is expected to contains the following keys: * down_conv * up_conv * OPTIONAL: innermost """ super(UnwrappedUnetBasedModel, self).__init__(opt) # detect which options format has been used to define the model self._spatial_ops_dict = { "neighbour_finder": [], "sampler": [], "upsample_op": [] } if is_list(opt.down_conv) or "down_conv_nn" not in opt.down_conv: raise NotImplementedError else: self._init_from_compact_format(opt, model_type, dataset, modules_lib)
def _save_sampling_and_search(self, submodule): sampler = getattr(submodule.down, "sampler", None) if is_list(sampler): self._spatial_ops_dict["sampler"] = sampler + self._spatial_ops_dict["sampler"] else: self._spatial_ops_dict["sampler"] = [sampler] + self._spatial_ops_dict["sampler"] neighbour_finder = getattr(submodule.down, "neighbour_finder", None) if is_list(neighbour_finder): self._spatial_ops_dict["neighbour_finder"] = neighbour_finder + self._spatial_ops_dict["neighbour_finder"] else: self._spatial_ops_dict["neighbour_finder"] = [neighbour_finder] + self._spatial_ops_dict["neighbour_finder"] upsample_op = getattr(submodule.up, "upsample_op", None) if upsample_op: self._spatial_ops_dict["upsample_op"].append(upsample_op)
def _fetch_arguments_from_list(self, opt, index): """Fetch the arguments for a single convolution from multiple lists of arguments - for models specified in the compact format. """ args = {} for o, v in opt.items(): name = str(o) if is_list(v) and len(getattr(opt, o)) > 0: if name[-1] == "s" and name not in SPECIAL_NAMES: name = name[:-1] v_index = v[index] if is_list(v_index): v_index = list(v_index) args[name] = v_index else: if is_list(v): v = list(v) args[name] = v return args
def _create_inner_modules(self, args_innermost, modules_lib): inners = [] if is_list(args_innermost): for inner_opt in args_innermost: module_name = self._get_from_kwargs(inner_opt, "module_name") inner_module_cls = getattr(modules_lib, module_name) inners.append(inner_module_cls(**inner_opt)) else: module_name = self._get_from_kwargs(args_innermost, "module_name") inner_module_cls = getattr(modules_lib, module_name) inners.append(inner_module_cls(**args_innermost)) return inners
def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib): """Construct a backbone generator (It is a simple down module) Parameters: opt - options for the network generation model_type - type of the model to be generated modules_lib - all modules that can be used in the backbone opt is expected to contains the following keys: * down_conv """ super(BackboneBasedModel, self).__init__(opt) self._spatial_ops_dict = {"neighbour_finder": [], "sampler": []} # detect which options format has been used to define the model if is_list(opt.down_conv) or "down_conv_nn" not in opt.down_conv: raise NotImplementedError else: self._init_from_compact_format(opt, model_type, dataset, modules_lib)