def init_from_dirs(cls, dirs, search_space=None, cfg_template_file=None): """ Init population from directories. Args: dirs: [directory paths] search_space: SearchSpace cfg_template_file: if not specified, default: "template.yaml" under `dirs[0]` Returns: Population There should be multiple meta-info (yaml) files named as "`<number>.yaml` under each directory, each of them specificy the meta information for a model in the population, with `<number>` represent its index. Note there should not be duplicate index, if there are duplicate index, rename or soft-link the files. In each meta-info file, the possible meta informations are: * genotype * train_config * checkpoint_path * (optional) confidence * perfs: a dict of performance name to performance value "template.yaml" under the first dir will be used as the template training config for new candidate model """ assert dirs, "No dirs specified!" if cfg_template_file is None: cfg_template_file = os.path.join(dirs[0], "template.yaml") with open(cfg_template_file, "r") as cfg_f: cfg_template = ConfigTemplate(yaml.safe_load(cfg_f)) _logger.getChild("population").info("Read the template config from %s", cfg_template_file) model_records = collections.OrderedDict() if search_space is None: # assume can parse search space from config template from aw_nas.common import get_search_space search_space = get_search_space(cfg_template["search_space_type"], **cfg_template["search_space_cfg"]) for _, dir_ in enumerate(dirs): meta_files = glob.glob(os.path.join(dir_, "*.yaml")) for fname in meta_files: if "template.yaml" in fname: # do not parse template.yaml continue index = int(os.path.basename(fname).rsplit(".", 1)[0]) expect( index not in model_records, "There are duplicate index: {}. rename or soft-link the files" .format(index)) model_records[index] = ModelRecord.init_from_file( fname, search_space) _logger.getChild("population").info( "Parsed %d directories, total %d model records loaded.", len(dirs), len(model_records)) return Population(search_space, model_records, cfg_template)
#pylint: disable=unused-import from aw_nas.utils import logger as _logger _LOGGER = _logger.getChild("final") from .cnn_trainer import CNNFinalTrainer from .cnn_model import CNNGenotypeModel from .rnn_trainer import RNNFinalTrainer from .rnn_model import RNNGenotypeModel from .dense import DenseGenotypeModel from .ofa_model import OFAGenotypeModel try: from .ssd_model import SSDFinalModel, SSDHeadFinalModel from .det_trainer import DetectionFinalTrainer except ImportError as e: _LOGGER.warn(("Cannot import module detection: {}\n" "Should install EXTRAS_REQUIRE `det`").format(e)) from .general_model import GeneralGenotypeModel
from torch.backends import cudnn from torchviz import make_dot import aw_nas from aw_nas import utils from aw_nas.dataset import AVAIL_DATA_TYPES from aw_nas import utils, BaseRollout from aw_nas.common import rollout_from_genotype_str from aw_nas.utils.common_utils import _OrderedCommandGroup from aw_nas.utils.vis_utils import WrapWriter from aw_nas.utils import RegistryMeta from aw_nas.utils import logger as _logger from aw_nas.utils.exception import expect LOGGER = _logger.getChild("main") def _init_components_from_cfg( cfg, device, evaluator_only=False, controller_only=False, from_controller=False, search_space=None, controller=None, ): """ Initialize components using configuration. Order: `search_space`, `controller`, `dataset`, `weights_manager`, `objective`, `evaluator`, `trainer`
# pylint: disable=invalid-name import os import re import sys import imp import inspect from collections import defaultdict from aw_nas.utils.exception import PluginException from aw_nas.utils.common_utils import get_awnas_dir from aw_nas.utils import logger as _logger LOGGER = _logger.getChild("plugin") plugins = [] plugin_modules = defaultdict(list) import_errors = {} norm_pattern = re.compile(r"[/|.]") class AwnasPlugin(object): NAME = None dataset_list = [] controller_list = [] evaluator_list = [] weights_manager_list = [] objective_list = [] trainer_list = [] @classmethod def validate(cls):
from aw_nas.utils import logger as _logger _LOGGER = _logger.getChild("hardware.compiler") try: from aw_nas.hardware.compiler import dpu except ImportError as e: _LOGGER.warn("Cannot import hardware compiler for dpu: {}\n".format(e))
def _save_all(self): if self.train_dir is not None: self.controller.save(self._save_path("controller")) # FIXME: do evaluator really need to be saved, since evaluator is not updated self.evaluator.save(self._save_path("evaluator")) self.logger.info("Step %3d: Save all checkpoints to directory %s", self.epoch, self._save_path()) def save(self, path): # No state of trainer need to be saved? pass def load(self, path): pass @classmethod def get_default_config_str(cls): all_str = super(AsyncTrainer, cls).get_default_config_str() # Possible dispatcher configs all_str += utils.component_sample_config_str("dispatcher", prefix="# ") + "\n" return all_str _LOGGER = _logger.getChild("ray_dispatcher") try: from aw_nas.trainer.ray_dispatcher import RayDispatcher except ImportError as e: _LOGGER.warn( ("Cannot import module aw_nas.evaluator.ray_dispatcher: {}\n" "Should install ray package first.").format(e))
# -*- coding: utf-8 -*- from itertools import product from functools import reduce import numpy as np from aw_nas.hardware.base import BaseHardwareObjectiveModel, MixinProfilingSearchSpace from aw_nas.hardware.utils import Prim from aw_nas.utils import logger as _logger from aw_nas.utils import make_divisible from aw_nas.rollout.ofa import MNasNetOFASearchSpace logger = _logger.getChild("ofa_obj") class OFAMixinProfilingSearchSpace(MNasNetOFASearchSpace, MixinProfilingSearchSpace): NAME = "ofa_mixin" def __init__( self, width_choice, depth_choice, kernel_choice, num_cell_groups, expansions, fixed_primitives=None, schedule_cfg=None, ): super(OFAMixinProfilingSearchSpace, self).__init__(
def logger(self): # logger should be a mixin class. but i'll leave it as it is... if self._logger is None: self._logger = _logger.getChild(self.__class__.__name__) return self._logger
# pylint: disable=invalid-name import os import pickle import numpy as np from aw_nas.objective.detection_utils.base import Metrics from aw_nas.utils import logger as _logger _LOGGER = _logger.getChild("det.metrics") try: from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval except ImportError as e: _LOGGER.warn(("Cannot import pycocotools: {}\n" "Should install EXTRAS_REQUIRE `det`").format(e)) class COCODetectionMetrics(Metrics): NAME = "coco" def __new__(cls, *args, **kwargs): _LOGGER.error( "COCODetectionMetrics cannot be used. Install required dependencies!" ) raise Exception() def __call__(self, boxes): pass else:
import yaml import click import aw_nas from aw_nas import utils from aw_nas.common import get_search_space from aw_nas.utils import logger as _logger from aw_nas.utils.exception import expect from aw_nas.utils.common_utils import _OrderedCommandGroup from aw_nas.hardware.utils import assemble_profiling_nets, iterate, sample_networks from aw_nas.hardware.base import BaseHardwareCompiler, MixinProfilingSearchSpace # patch click.option to show the default values click.option = functools.partial(click.option, show_default=True) LOGGER = _logger.getChild("main_hw") @click.group( cls=_OrderedCommandGroup, help="The awnas-hw command line interface. " "Use `AWNAS_LOG_LEVEL` environment variable to modify the log level.") @click.version_option(version=aw_nas.__version__) def main(): pass # ---- generate profiling networks ---- @main.command( help="Generate profiling networks for search space. " "hwobj_cfg_fil sections: profiling_primitive_cfg, profiling_net_cfg, "
# -*- coding: utf-8 -*- """A simple registry meta class. """ import abc import collections from aw_nas.utils import logger as _logger __all__ = ["RegistryMeta", "RegistryError"] LOGGER = _logger.getChild("registry") class RegistryError(Exception): pass def _default_dct_of_list(): return collections.defaultdict(list) class RegistryMeta(abc.ABCMeta): registry_dct = collections.defaultdict(dict) supported_rollout_dct = collections.defaultdict(_default_dct_of_list) def __init__(cls, name, bases, namespace): super(RegistryMeta, cls).__init__(name, bases, namespace) ## DEPRECATED: the interface of every class is defined explicitly in the arguments ## instead of a cover-all config dict, ## as failing loudly can avoid subtle bugs (e.g. mistyping)
""" Built-in tight coupled NAS flows. """ from aw_nas.utils import logger as _logger _LOGGER = _logger.getChild("btc") try: from aw_nas.btcs import nasbench_101 except ImportError as e: _LOGGER.warn(("Cannot import module nasbench: {}\n" "Should install the NASBench 101 package following " "https://github.com/google-research/nasbench").format(e)) try: from aw_nas.btcs import nasbench_201 except ImportError as e: _LOGGER.warn(("Cannot import module nasbench_201: {}\n" "Should install the NASBench 201 package following " "https://github.com/D-X-Y/NAS-Bench-201").format(e))
import copy import inspect from inspect import signature import os import pickle from collections import namedtuple import numpy as np import yaml try: from sklearn import linear_model except ImportError as e: from aw_nas.utils import logger as _logger _logger.getChild("hardware").warn( ("Cannot import module hardware.utils: {}\n" "Should install scikit-learn to make some hardware-related" " functionalities work").format(e)) from aw_nas.hardware.base import (BaseHardwareObjectiveModel, MixinProfilingSearchSpace, Preprocessor) from aw_nas.ops import get_op Prim_ = namedtuple( "Prim", ["prim_type", "spatial_size", "C", "C_out", "stride", "affine", "kwargs"], ) class Prim(Prim_): def __new__(cls, prim_type, spatial_size, C, C_out, stride, affine, **kwargs):
def logger(self): if self._logger is None: self._logger = _logger.getChild(self.__class__.__name__) return self._logger
def __new__(cls, *args, **kwargs): from aw_nas.utils import logger as _logger _logger.getChild("arch_network").error( ("RandomForest arch network cannot be used: Cannot import module sklearn: {}".format(imp_exception))) raise Exception()