def test_genprof(case): from aw_nas.common import get_search_space, genotype_from_str from aw_nas.hardware.base import MixinProfilingSearchSpace from aw_nas.hardware.utils import Prim, assemble_profiling_nets from aw_nas.rollout.general import GeneralSearchSpace ss = get_search_space("ofa_mixin", **case["search_space_cfg"]) assert isinstance(ss, MixinProfilingSearchSpace) primitives = ss.generate_profiling_primitives(**case["prof_prim_cfg"]) cfg = case["search_space_cfg"] assert (len(primitives) == len(cfg["width_choice"]) * len(cfg["kernel_choice"]) * len(cfg["num_cell_groups"][1:]) * 2 + 2) fields = {f for f in Prim._fields if not f == "kwargs"} for prim in primitives: assert isinstance(prim, dict) assert fields.issubset(prim.keys()) base_cfg = {"final_model_cfg": {}} nets = assemble_profiling_nets(primitives, base_cfg, image_size=224) gss = GeneralSearchSpace([]) counts = 0 for net in nets: genotype = net["final_model_cfg"]["genotypes"] assert isinstance(genotype, (list, str)) if isinstance(genotype, str): genotype = genotype_from_str(genotype, gss) counts += len(genotype) is_channel_consist = [ c["C_out"] == n["C"] for c, n in zip(genotype[:-1], genotype[1:]) ] assert all(is_channel_consist) spatial_size = [g["spatial_size"] for g in genotype] stride = [g["stride"] for g in genotype] is_size_consist = [ round(c_size / s) == n_size for s, c_size, n_size in zip( stride, spatial_size[:-1], spatial_size[1:]) ] assert all(is_size_consist) assert counts >= len(primitives)
def genprof(cfg_file, hwobj_cfg_file, result_dir, compile_hardware, num_sample): with open(cfg_file, "r") as ss_cfg_f: ss_cfg = yaml.load(ss_cfg_f) with open(hwobj_cfg_file, "r") as hw_cfg_f: hw_cfg = yaml.load(hw_cfg_f) ss = get_search_space(hw_cfg["mixin_search_space_type"], **ss_cfg["search_space_cfg"], **hw_cfg["mixin_search_space_cfg"]) expect(isinstance(ss, MixinProfilingSearchSpace), "search space must be a subclass of MixinProfilingsearchspace") result_dir = utils.makedir(result_dir) # copy cfg files shutil.copyfile(cfg_file, os.path.join(result_dir, "config.yaml")) shutil.copyfile(hwobj_cfg_file, os.path.join(result_dir, "hwobj_config.yaml")) # generate profiling primitive list assert 'prof_prims_cfg' in hw_cfg, "key prof_prims_cfg must be specified in hardware configuration file." hw_obj_cfg = hw_cfg['prof_prims_cfg'] prof_prims = list( ss.generate_profiling_primitives(**hw_obj_cfg)) prof_prim_fname = os.path.join(result_dir, "prof_prims.yaml") with open(prof_prim_fname, "w") as prof_prim_f: yaml.dump(prof_prims, prof_prim_f) LOGGER.info("Save the list of profiling primitives to %s", prof_prim_fname) if num_sample: prof_net_cfgs = sample_networks( ss, base_cfg_template=hw_cfg["profiling_net_cfg"] ["base_cfg_template"], num_sample=num_sample, **hw_obj_cfg) else: # assemble profiling nets # the primitives can actually be mapped to layers in model during the assembling process prof_net_cfgs = assemble_profiling_nets(prof_prims, **hw_cfg["profiling_net_cfg"]) prof_net_cfgs = list(prof_net_cfgs) prof_net_dir = utils.makedir(os.path.join(result_dir, "prof_nets"), remove=True) prof_fnames = [] for i_net, prof_net_cfg in enumerate(prof_net_cfgs): prof_fname = os.path.join(prof_net_dir, "{}.yaml".format(i_net)) prof_fnames.append(prof_fname) with open(prof_fname, "w") as prof_net_f: yaml.dump(prof_net_cfg, prof_net_f) LOGGER.info("Save the profiling net configs to directory %s", prof_net_dir) # optional (hardware specific): call hardware-specific compiling process hw_cfgs = hw_cfg.get("hardware_compilers", []) if compile_hardware: hw_cfgs.extend([{ "hardware_compiler_type": hw_name, 'hardware_compiler_cfg': {} } for hw_name in compile_hardware]) if hw_cfgs: hw_compile_dir = utils.makedir(os.path.join(result_dir, "hardwares"), remove=True) LOGGER.info("Call hardware compilers: total %d", len(hw_cfgs)) for i_hw, hw_cfg in enumerate(hw_cfgs): hw_name = hw_cfg["hardware_compiler_type"] hw_kwargs = hw_cfg.get("hardware_compiler_cfg", {}) hw_compiler = BaseHardwareCompiler.get_class_(hw_name)(**hw_kwargs) LOGGER.info("{}: Constructed hardware compiler {}{}".format( i_hw, hw_name, ":{}".format(hw_kwargs) if hw_kwargs else "")) hw_res_dir = utils.makedir( os.path.join(hw_compile_dir, "{}-{}".format(i_hw, hw_name))) for i_net, prof_cfg in enumerate(prof_net_cfgs): res_dir = utils.makedir(os.path.join(hw_res_dir, str(i_net))) hw_compiler.compile("{}-{}-{}".format(i_hw, hw_name, i_net), prof_cfg, res_dir)