def registry(table, name, regex, verbose): if table: if regex: registry_names = [ table_name for table_name in RegistryMeta.avail_tables() if any( re.match(s_table, table_name) is not None for s_table in table) ] else: registry_names = table else: registry_names = RegistryMeta.avail_tables() print("Registries\n================") for registry_name in registry_names: avails = RegistryMeta.registry_dct[registry_name] print("**{}**".format(registry_name)) for class_name, class_ in avails.items(): if name: if (regex and not any( re.match(s_name, class_name) is not None for s_name in name)) or \ (not regex and not class_name in name): # not match continue print("{}: {}".format(class_name, str(class_))) if verbose: if hasattr(class_, "all_supported_rollout_types"): print(" *Supported rollout types*: ", class_.all_supported_rollout_types()) if class_.__doc__ is not None: print(class_.__doc__) print("----------------")
def __init__(self, search_space, device, genotypes, backbone_type, backbone_cfg, feature_levels=[4, 5], supernet_state_dict=None, head_type='ssd_header', head_cfg={}, num_classes=10, schedule_cfg=None): super(SSDFinalModel, self).__init__(schedule_cfg=schedule_cfg) self.search_space = search_space self.device = device self.num_classes = num_classes self.feature_levels = feature_levels self.backbone = RegistryMeta.get_class('final_model', backbone_type)(search_space, device, **backbone_cfg) feature_channels = self.backbone.get_feature_channel_num( feature_levels) self.head = RegistryMeta.get_class("detection_header", head_type)(device, num_classes, feature_channels, **head_cfg) #self.head = SSDHeadFinalModel(device, num_classes, feature_channels, # **head_cfg) if supernet_state_dict: self.load_supernet_state_dict(supernet_state_dict) rollout = search_space.rollout_from_genotype(genotypes) self.finalize(rollout) self.search_space = search_space self.device = device self.num_classes = num_classes self.to(self.device) # for flops calculation self.total_flops = 0 self._flops_calculated = False self.set_hook()
def _init_component(cfg, registry_name, **addi_args): type_ = cfg[registry_name + "_type"] cfg = cfg.get(registry_name + "_cfg", None) if not cfg: cfg = {} # config items will override addi_args items addi_args.update(cfg) LOGGER.info("Component [%s] type: %s", registry_name, type_) cls = RegistryMeta.get_class(registry_name, type_) if LOGGER.level < 20: # logging is at debug level whole_cfg_str = cls.get_current_config_str(cfg) LOGGER.debug("%s %s config:\n%s", registry_name, type_, utils.add_text_prefix(whole_cfg_str, " ")) return cls(**addi_args)
lambda cls: data_type in cls.supported_data_types()) out_f.write( utils.component_sample_config_str(comp_name, prefix="# ", filter_funcs=filter_funcs)) out_f.write("\n") @main.command(help="Print registry information.") @click.option( "-t", "--table", default=[], multiple=True, type=click.Choice(RegistryMeta.avail_tables()), help= "If specified, only print classes of the corresponding registries/tables.") @click.option( "-n", "--name", default=[], multiple=True, help= "If specified, only print the information of the corresponding classes in the registry." ) @click.option("-r", "--regex", default=False, type=bool, is_flag=True,