예제 #1
0
def registry(table, name, regex, verbose):
    if table:
        if regex:
            registry_names = [
                table_name for table_name in RegistryMeta.avail_tables()
                if any(
                    re.match(s_table, table_name) is not None
                    for s_table in table)
            ]
        else:
            registry_names = table
    else:
        registry_names = RegistryMeta.avail_tables()

    print("Registries\n================")
    for registry_name in registry_names:
        avails = RegistryMeta.registry_dct[registry_name]
        print("**{}**".format(registry_name))
        for class_name, class_ in avails.items():
            if name:
                if (regex and not any(
                        re.match(s_name, class_name) is not None for s_name in name)) or \
                        (not regex and not class_name in name):
                    # not match
                    continue

            print("{}: {}".format(class_name, str(class_)))
            if verbose:
                if hasattr(class_, "all_supported_rollout_types"):
                    print("    *Supported rollout types*: ",
                          class_.all_supported_rollout_types())
                if class_.__doc__ is not None:
                    print(class_.__doc__)
        print("----------------")
예제 #2
0
    def __init__(self,
                 search_space,
                 device,
                 genotypes,
                 backbone_type,
                 backbone_cfg,
                 feature_levels=[4, 5],
                 supernet_state_dict=None,
                 head_type='ssd_header',
                 head_cfg={},
                 num_classes=10,
                 schedule_cfg=None):
        super(SSDFinalModel, self).__init__(schedule_cfg=schedule_cfg)
        self.search_space = search_space
        self.device = device
        self.num_classes = num_classes
        self.feature_levels = feature_levels

        self.backbone = RegistryMeta.get_class('final_model',
                                               backbone_type)(search_space,
                                                              device,
                                                              **backbone_cfg)

        feature_channels = self.backbone.get_feature_channel_num(
            feature_levels)
        self.head = RegistryMeta.get_class("detection_header",
                                           head_type)(device, num_classes,
                                                      feature_channels,
                                                      **head_cfg)

        #self.head = SSDHeadFinalModel(device, num_classes, feature_channels,
        #                              **head_cfg)

        if supernet_state_dict:
            self.load_supernet_state_dict(supernet_state_dict)
        rollout = search_space.rollout_from_genotype(genotypes)
        self.finalize(rollout)

        self.search_space = search_space
        self.device = device
        self.num_classes = num_classes
        self.to(self.device)

        # for flops calculation
        self.total_flops = 0
        self._flops_calculated = False
        self.set_hook()
예제 #3
0
def _init_component(cfg, registry_name, **addi_args):
    type_ = cfg[registry_name + "_type"]
    cfg = cfg.get(registry_name + "_cfg", None)
    if not cfg:
        cfg = {}
    # config items will override addi_args items
    addi_args.update(cfg)
    LOGGER.info("Component [%s] type: %s", registry_name, type_)
    cls = RegistryMeta.get_class(registry_name, type_)
    if LOGGER.level < 20:  # logging is at debug level
        whole_cfg_str = cls.get_current_config_str(cfg)
        LOGGER.debug("%s %s config:\n%s", registry_name, type_,
                     utils.add_text_prefix(whole_cfg_str, "  "))
    return cls(**addi_args)
예제 #4
0
                        lambda cls: data_type in cls.supported_data_types())

            out_f.write(
                utils.component_sample_config_str(comp_name,
                                                  prefix="# ",
                                                  filter_funcs=filter_funcs))
            out_f.write("\n")


@main.command(help="Print registry information.")
@click.option(
    "-t",
    "--table",
    default=[],
    multiple=True,
    type=click.Choice(RegistryMeta.avail_tables()),
    help=
    "If specified, only print classes of the corresponding registries/tables.")
@click.option(
    "-n",
    "--name",
    default=[],
    multiple=True,
    help=
    "If specified, only print the information of the corresponding classes in the registry."
)
@click.option("-r",
              "--regex",
              default=False,
              type=bool,
              is_flag=True,