コード例 #1
0
def test_genprof(case):
    from aw_nas.common import get_search_space, genotype_from_str
    from aw_nas.hardware.base import MixinProfilingSearchSpace
    from aw_nas.hardware.utils import Prim, assemble_profiling_nets
    from aw_nas.rollout.general import GeneralSearchSpace

    ss = get_search_space("ofa_mixin", **case["search_space_cfg"])
    assert isinstance(ss, MixinProfilingSearchSpace)
    primitives = ss.generate_profiling_primitives(**case["prof_prim_cfg"])
    cfg = case["search_space_cfg"]
    assert (len(primitives) == len(cfg["width_choice"]) *
            len(cfg["kernel_choice"]) * len(cfg["num_cell_groups"][1:]) * 2 +
            2)
    fields = {f for f in Prim._fields if not f == "kwargs"}
    for prim in primitives:
        assert isinstance(prim, dict)
        assert fields.issubset(prim.keys())
    base_cfg = {"final_model_cfg": {}}
    nets = assemble_profiling_nets(primitives, base_cfg, image_size=224)
    gss = GeneralSearchSpace([])
    counts = 0
    for net in nets:
        genotype = net["final_model_cfg"]["genotypes"]
        assert isinstance(genotype, (list, str))
        if isinstance(genotype, str):
            genotype = genotype_from_str(genotype, gss)
        counts += len(genotype)
        is_channel_consist = [
            c["C_out"] == n["C"] for c, n in zip(genotype[:-1], genotype[1:])
        ]
        assert all(is_channel_consist)

        spatial_size = [g["spatial_size"] for g in genotype]
        stride = [g["stride"] for g in genotype]
        is_size_consist = [
            round(c_size / s) == n_size for s, c_size, n_size in zip(
                stride, spatial_size[:-1], spatial_size[1:])
        ]
        assert all(is_size_consist)

    assert counts >= len(primitives)
コード例 #2
0
def genprof(cfg_file, hwobj_cfg_file, result_dir, compile_hardware,
            num_sample):
    with open(cfg_file, "r") as ss_cfg_f:
        ss_cfg = yaml.load(ss_cfg_f)
    with open(hwobj_cfg_file, "r") as hw_cfg_f:
        hw_cfg = yaml.load(hw_cfg_f)

    ss = get_search_space(hw_cfg["mixin_search_space_type"],
                          **ss_cfg["search_space_cfg"],
                          **hw_cfg["mixin_search_space_cfg"])
    expect(isinstance(ss, MixinProfilingSearchSpace),
           "search space must be a subclass of MixinProfilingsearchspace")

    result_dir = utils.makedir(result_dir)
    # copy cfg files
    shutil.copyfile(cfg_file, os.path.join(result_dir, "config.yaml"))
    shutil.copyfile(hwobj_cfg_file,
                    os.path.join(result_dir, "hwobj_config.yaml"))

    # generate profiling primitive list
    assert 'prof_prims_cfg' in hw_cfg, "key prof_prims_cfg must be specified in hardware configuration file."
    hw_obj_cfg = hw_cfg['prof_prims_cfg']
    prof_prims = list(
        ss.generate_profiling_primitives(**hw_obj_cfg))
    prof_prim_fname = os.path.join(result_dir, "prof_prims.yaml")
    with open(prof_prim_fname, "w") as prof_prim_f:
        yaml.dump(prof_prims, prof_prim_f)
    LOGGER.info("Save the list of profiling primitives to %s", prof_prim_fname)

    if num_sample:
        prof_net_cfgs = sample_networks(
            ss,
            base_cfg_template=hw_cfg["profiling_net_cfg"]
            ["base_cfg_template"],
            num_sample=num_sample,
            **hw_obj_cfg)
    else:
        # assemble profiling nets
        # the primitives can actually be mapped to layers in model during the assembling process
        prof_net_cfgs = assemble_profiling_nets(prof_prims,
                                                **hw_cfg["profiling_net_cfg"])
    prof_net_cfgs = list(prof_net_cfgs)
    prof_net_dir = utils.makedir(os.path.join(result_dir, "prof_nets"),
                                 remove=True)
    prof_fnames = []
    for i_net, prof_net_cfg in enumerate(prof_net_cfgs):
        prof_fname = os.path.join(prof_net_dir, "{}.yaml".format(i_net))
        prof_fnames.append(prof_fname)
        with open(prof_fname, "w") as prof_net_f:
            yaml.dump(prof_net_cfg, prof_net_f)
    LOGGER.info("Save the profiling net configs to directory %s", prof_net_dir)

    # optional (hardware specific): call hardware-specific compiling process
    hw_cfgs = hw_cfg.get("hardware_compilers", [])
    if compile_hardware:
        hw_cfgs.extend([{
            "hardware_compiler_type": hw_name,
            'hardware_compiler_cfg': {}
        } for hw_name in compile_hardware])
    if hw_cfgs:
        hw_compile_dir = utils.makedir(os.path.join(result_dir, "hardwares"),
                                       remove=True)
        LOGGER.info("Call hardware compilers: total %d", len(hw_cfgs))
        for i_hw, hw_cfg in enumerate(hw_cfgs):
            hw_name = hw_cfg["hardware_compiler_type"]
            hw_kwargs = hw_cfg.get("hardware_compiler_cfg", {})
            hw_compiler = BaseHardwareCompiler.get_class_(hw_name)(**hw_kwargs)
            LOGGER.info("{}: Constructed hardware compiler {}{}".format(
                i_hw, hw_name, ":{}".format(hw_kwargs) if hw_kwargs else ""))
            hw_res_dir = utils.makedir(
                os.path.join(hw_compile_dir, "{}-{}".format(i_hw, hw_name)))
            for i_net, prof_cfg in enumerate(prof_net_cfgs):
                res_dir = utils.makedir(os.path.join(hw_res_dir, str(i_net)))
                hw_compiler.compile("{}-{}-{}".format(i_hw, hw_name, i_net),
                                    prof_cfg, res_dir)