Example #1
0
    def init_from_dirs(cls, dirs, search_space=None, cfg_template_file=None):
        """
        Init population from directories.

        Args:
          dirs: [directory paths]
          search_space: SearchSpace
          cfg_template_file: if not specified, default: "template.yaml" under `dirs[0]`
        Returns: Population

        There should be multiple meta-info (yaml) files named as "`<number>.yaml` under each
        directory, each of them specificy the meta information for a model in the population,
        with `<number>` represent its index.
        Note there should not be duplicate index, if there are duplicate index,
        rename or soft-link the files.

        In each meta-info file, the possible meta informations are:
        * genotype
        * train_config
        * checkpoint_path
        * (optional) confidence
        * perfs: a dict of performance name to performance value

        "template.yaml" under the first dir will be used as the template training config for
        new candidate model
        """
        assert dirs, "No dirs specified!"
        if cfg_template_file is None:
            cfg_template_file = os.path.join(dirs[0], "template.yaml")
        with open(cfg_template_file, "r") as cfg_f:
            cfg_template = ConfigTemplate(yaml.safe_load(cfg_f))
        _logger.getChild("population").info("Read the template config from %s",
                                            cfg_template_file)
        model_records = collections.OrderedDict()
        if search_space is None:
            # assume can parse search space from config template
            from aw_nas.common import get_search_space
            search_space = get_search_space(cfg_template["search_space_type"],
                                            **cfg_template["search_space_cfg"])
        for _, dir_ in enumerate(dirs):
            meta_files = glob.glob(os.path.join(dir_, "*.yaml"))
            for fname in meta_files:
                if "template.yaml" in fname:
                    # do not parse template.yaml
                    continue
                index = int(os.path.basename(fname).rsplit(".", 1)[0])
                expect(
                    index not in model_records,
                    "There are duplicate index: {}. rename or soft-link the files"
                    .format(index))
                model_records[index] = ModelRecord.init_from_file(
                    fname, search_space)
        _logger.getChild("population").info(
            "Parsed %d directories, total %d model records loaded.", len(dirs),
            len(model_records))
        return Population(search_space, model_records, cfg_template)
Example #2
0
#pylint: disable=unused-import
from aw_nas.utils import logger as _logger

_LOGGER = _logger.getChild("final")

from .cnn_trainer import CNNFinalTrainer
from .cnn_model import CNNGenotypeModel

from .rnn_trainer import RNNFinalTrainer
from .rnn_model import RNNGenotypeModel

from .dense import DenseGenotypeModel

from .ofa_model import OFAGenotypeModel

try:
    from .ssd_model import SSDFinalModel, SSDHeadFinalModel
    from .det_trainer import DetectionFinalTrainer
except ImportError as e:
    _LOGGER.warn(("Cannot import module detection: {}\n"
                  "Should install EXTRAS_REQUIRE `det`").format(e))

from .general_model import GeneralGenotypeModel
Example #3
0
from torch.backends import cudnn

from torchviz import make_dot

import aw_nas
from aw_nas import utils
from aw_nas.dataset import AVAIL_DATA_TYPES
from aw_nas import utils, BaseRollout
from aw_nas.common import rollout_from_genotype_str
from aw_nas.utils.common_utils import _OrderedCommandGroup
from aw_nas.utils.vis_utils import WrapWriter
from aw_nas.utils import RegistryMeta
from aw_nas.utils import logger as _logger
from aw_nas.utils.exception import expect

LOGGER = _logger.getChild("main")


def _init_components_from_cfg(
    cfg,
    device,
    evaluator_only=False,
    controller_only=False,
    from_controller=False,
    search_space=None,
    controller=None,
):
    """
    Initialize components using configuration.
    Order:
    `search_space`, `controller`, `dataset`, `weights_manager`, `objective`, `evaluator`, `trainer`
Example #4
0
# pylint: disable=invalid-name
import os
import re
import sys
import imp
import inspect
from collections import defaultdict

from aw_nas.utils.exception import PluginException
from aw_nas.utils.common_utils import get_awnas_dir
from aw_nas.utils import logger as _logger

LOGGER = _logger.getChild("plugin")

plugins = []
plugin_modules = defaultdict(list)
import_errors = {}
norm_pattern = re.compile(r"[/|.]")

class AwnasPlugin(object):
    NAME = None

    dataset_list = []
    controller_list = []
    evaluator_list = []
    weights_manager_list = []
    objective_list = []
    trainer_list = []

    @classmethod
    def validate(cls):
Example #5
0
from aw_nas.utils import logger as _logger

_LOGGER = _logger.getChild("hardware.compiler")

try:
    from aw_nas.hardware.compiler import dpu
except ImportError as e:
    _LOGGER.warn("Cannot import hardware compiler for dpu: {}\n".format(e))
Example #6
0
    def _save_all(self):
        if self.train_dir is not None:
            self.controller.save(self._save_path("controller"))
            # FIXME: do evaluator really need to be saved, since evaluator is not updated
            self.evaluator.save(self._save_path("evaluator"))
            self.logger.info("Step %3d: Save all checkpoints to directory %s",
                             self.epoch, self._save_path())

    def save(self, path):
        # No state of trainer need to be saved?
        pass

    def load(self, path):
        pass

    @classmethod
    def get_default_config_str(cls):
        all_str = super(AsyncTrainer, cls).get_default_config_str()
        # Possible dispatcher configs
        all_str += utils.component_sample_config_str("dispatcher", prefix="#   ") + "\n"
        return all_str


_LOGGER = _logger.getChild("ray_dispatcher")
try:
    from aw_nas.trainer.ray_dispatcher import RayDispatcher
except ImportError as e:
    _LOGGER.warn(
        ("Cannot import module aw_nas.evaluator.ray_dispatcher: {}\n"
         "Should install ray package first.").format(e))
Example #7
0
# -*- coding: utf-8 -*-
from itertools import product
from functools import reduce

import numpy as np

from aw_nas.hardware.base import BaseHardwareObjectiveModel, MixinProfilingSearchSpace
from aw_nas.hardware.utils import Prim

from aw_nas.utils import logger as _logger
from aw_nas.utils import make_divisible
from aw_nas.rollout.ofa import MNasNetOFASearchSpace

logger = _logger.getChild("ofa_obj")


class OFAMixinProfilingSearchSpace(MNasNetOFASearchSpace,
                                   MixinProfilingSearchSpace):
    NAME = "ofa_mixin"

    def __init__(
        self,
        width_choice,
        depth_choice,
        kernel_choice,
        num_cell_groups,
        expansions,
        fixed_primitives=None,
        schedule_cfg=None,
    ):
        super(OFAMixinProfilingSearchSpace, self).__init__(
Example #8
0
 def logger(self):
     # logger should be a mixin class. but i'll leave it as it is...
     if self._logger is None:
         self._logger = _logger.getChild(self.__class__.__name__)
     return self._logger
Example #9
0
# pylint: disable=invalid-name
import os
import pickle

import numpy as np

from aw_nas.objective.detection_utils.base import Metrics
from aw_nas.utils import logger as _logger

_LOGGER = _logger.getChild("det.metrics")

try:
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
except ImportError as e:
    _LOGGER.warn(("Cannot import pycocotools: {}\n"
                  "Should install EXTRAS_REQUIRE `det`").format(e))

    class COCODetectionMetrics(Metrics):
        NAME = "coco"

        def __new__(cls, *args, **kwargs):
            _LOGGER.error(
                "COCODetectionMetrics cannot be used. Install required dependencies!"
            )
            raise Exception()

        def __call__(self, boxes):
            pass

else:
Example #10
0
import yaml
import click

import aw_nas
from aw_nas import utils
from aw_nas.common import get_search_space
from aw_nas.utils import logger as _logger
from aw_nas.utils.exception import expect
from aw_nas.utils.common_utils import _OrderedCommandGroup
from aw_nas.hardware.utils import assemble_profiling_nets, iterate, sample_networks
from aw_nas.hardware.base import BaseHardwareCompiler, MixinProfilingSearchSpace

# patch click.option to show the default values
click.option = functools.partial(click.option, show_default=True)

LOGGER = _logger.getChild("main_hw")


@click.group(
    cls=_OrderedCommandGroup,
    help="The awnas-hw command line interface. "
    "Use `AWNAS_LOG_LEVEL` environment variable to modify the log level.")
@click.version_option(version=aw_nas.__version__)
def main():
    pass


# ---- generate profiling networks ----
@main.command(
    help="Generate profiling networks for search space. "
    "hwobj_cfg_fil sections: profiling_primitive_cfg, profiling_net_cfg, "
Example #11
0
# -*- coding: utf-8 -*-
"""A simple registry meta class.
"""

import abc
import collections

from aw_nas.utils import logger as _logger

__all__ = ["RegistryMeta", "RegistryError"]

LOGGER = _logger.getChild("registry")


class RegistryError(Exception):
    pass


def _default_dct_of_list():
    return collections.defaultdict(list)


class RegistryMeta(abc.ABCMeta):
    registry_dct = collections.defaultdict(dict)
    supported_rollout_dct = collections.defaultdict(_default_dct_of_list)

    def __init__(cls, name, bases, namespace):
        super(RegistryMeta, cls).__init__(name, bases, namespace)
        ## DEPRECATED: the interface of every class is defined explicitly in the arguments
        ##             instead of a cover-all config dict,
        ##             as failing loudly can avoid subtle bugs (e.g. mistyping)
Example #12
0
"""
Built-in tight coupled NAS flows.
"""

from aw_nas.utils import logger as _logger
_LOGGER = _logger.getChild("btc")

try:
    from aw_nas.btcs import nasbench_101
except ImportError as e:
    _LOGGER.warn(("Cannot import module nasbench: {}\n"
                  "Should install the NASBench 101 package following "
                  "https://github.com/google-research/nasbench").format(e))

try:
    from aw_nas.btcs import nasbench_201
except ImportError as e:
    _LOGGER.warn(("Cannot import module nasbench_201: {}\n"
                  "Should install the NASBench 201 package following "
                  "https://github.com/D-X-Y/NAS-Bench-201").format(e))
Example #13
0
import copy
import inspect
from inspect import signature
import os
import pickle
from collections import namedtuple

import numpy as np
import yaml

try:
    from sklearn import linear_model
except ImportError as e:
    from aw_nas.utils import logger as _logger
    _logger.getChild("hardware").warn(
        ("Cannot import module hardware.utils: {}\n"
         "Should install scikit-learn to make some hardware-related"
         " functionalities work").format(e))

from aw_nas.hardware.base import (BaseHardwareObjectiveModel,
                                  MixinProfilingSearchSpace, Preprocessor)
from aw_nas.ops import get_op

Prim_ = namedtuple(
    "Prim",
    ["prim_type", "spatial_size", "C", "C_out", "stride", "affine", "kwargs"],
)


class Prim(Prim_):
    def __new__(cls, prim_type, spatial_size, C, C_out, stride, affine,
                **kwargs):
Example #14
0
 def logger(self):
     if self._logger is None:
         self._logger = _logger.getChild(self.__class__.__name__)
     return self._logger
Example #15
0
 def __new__(cls, *args, **kwargs):
     from aw_nas.utils import logger as _logger
     _logger.getChild("arch_network").error(
         ("RandomForest arch network cannot be used: Cannot import module sklearn: {}".format(imp_exception)))
     raise Exception()