コード例 #1
0
    def init_from_dirs(cls, dirs, search_space=None, cfg_template_file=None):
        """
        Init population from directories.

        Args:
          dirs: [directory paths]
          search_space: SearchSpace
          cfg_template_file: if not specified, default: "template.yaml" under `dirs[0]`
        Returns: Population

        There should be multiple meta-info (yaml) files named as "`<number>.yaml` under each
        directory, each of them specificy the meta information for a model in the population,
        with `<number>` represent its index.
        Note there should not be duplicate index, if there are duplicate index,
        rename or soft-link the files.

        In each meta-info file, the possible meta informations are:
        * genotype
        * train_config
        * checkpoint_path
        * (optional) confidence
        * perfs: a dict of performance name to performance value

        "template.yaml" under the first dir will be used as the template training config for
        new candidate model
        """
        assert dirs, "No dirs specified!"
        if cfg_template_file is None:
            cfg_template_file = os.path.join(dirs[0], "template.yaml")
        with open(cfg_template_file, "r") as cfg_f:
            cfg_template = ConfigTemplate(yaml.safe_load(cfg_f))
        getLogger("population").info("Read the template config from %s",
                                     cfg_template_file)
        model_records = collections.OrderedDict()
        if search_space is None:
            # assume can parse search space from config template
            from aw_nas.common import get_search_space
            search_space = get_search_space(cfg_template["search_space_type"],
                                            **cfg_template["search_space_cfg"])
        for _, dir_ in enumerate(dirs):
            meta_files = glob.glob(os.path.join(dir_, "*.yaml"))
            for fname in meta_files:
                if "template.yaml" in fname:
                    # do not parse template.yaml
                    continue
                index = int(os.path.basename(fname).rsplit(".", 1)[0])
                expect(
                    index not in model_records,
                    "There are duplicate index: {}. rename or soft-link the files"
                    .format(index))
                model_records[index] = ModelRecord.init_from_file(
                    fname, search_space)
        getLogger("population").info(
            "Parsed %d directories, total %d model records loaded.", len(dirs),
            len(model_records))
        return Population(search_space, model_records, cfg_template)
コード例 #2
0
from aw_nas import AwnasPlugin, utils
from aw_nas.objective.base import BaseObjective
from aw_nas.weights_manager.super_net import SuperNet, SubCandidateNet
from aw_nas.weights_manager.diff_super_net import DiffSuperNet, DiffSubCandidateNet
from aw_nas.utils.torch_utils import accuracy, _to_device
from aw_nas.utils.exception import expect, ConfigException
from aw_nas.common import get_search_space
from aw_nas.final.base import FinalModel
from aw_nas.utils import DistributedDataParallel
from aw_nas.objective.image import CrossEntropyLabelSmooth

try:
    import foolbox as fb
except ImportError:
    utils.getLogger("robustness plugin").warn(
        "Cannot import foolbox. You should install FOOLBOX toolbox (version 2.4.0) for running distance attacks!"
    )


# ---- different types of Adversaries ----
class PgdAdvGenerator(object):
    def __init__(self,
                 epsilon,
                 n_step,
                 step_size,
                 rand_init,
                 mean,
                 std,
                 use_eval_mode=False):
        self.epsilon = epsilon
        self.n_step = n_step
コード例 #3
0
import yaml
import click

import aw_nas
from aw_nas import utils
from aw_nas.common import get_search_space
from aw_nas.utils.exception import expect
from aw_nas.utils.common_utils import _OrderedCommandGroup
from aw_nas.hardware.utils import assemble_profiling_nets, iterate, sample_networks
from aw_nas.hardware.base import BaseHardwareCompiler, MixinProfilingSearchSpace

# patch click.option to show the default values
click.option = functools.partial(click.option, show_default=True)

LOGGER = utils.getLogger("main_hw")


@click.group(
    cls=_OrderedCommandGroup,
    help="The awnas-hw command line interface. "
    "Use `AWNAS_LOG_LEVEL` environment variable to modify the log level.")
@click.version_option(version=aw_nas.__version__)
def main():
    pass


# ---- generate profiling networks ----
@main.command(
    help="Generate profiling networks for search space. "
    "hwobj_cfg_fil sections: profiling_primitive_cfg, profiling_net_cfg, "
コード例 #4
0
    def eval_queue(self,
                   queue,
                   criterions,
                   steps=1,
                   mode="eval",
                   aggregate_fns=None,
                   **kwargs):
        # BN running statistics calibration
        if self.calib_bn_num > 0:
            # check `calib_bn_num` first
            calib_num = 0
            calib_data = []
            calib_batch = 0
            while calib_num < self.calib_bn_num:
                if calib_batch == steps:
                    utils.getLogger("robustness plugin.{}".format(self.__class__.__name__)).warn(
                        "steps (%d) reached, true calib bn num (%d)", calib_num, steps)
                    break
                calib_data.append(next(queue))
                calib_num += len(calib_data[-1][1])
                calib_batch += 1
            self.calib_bn(calib_data)
        elif self.calib_bn_batch > 0:
            if self.calib_bn_batch > steps:
                utils.getLogger("robustness plugin.{}".format(self.__class__.__name__)).warn(
                    "eval steps (%d) < `calib_bn_batch` (%d). Only use %d batches.",
                    steps, self.calib_bn_steps, steps)
                calib_bn_batch = steps
            else:
                calib_bn_batch = self.calib_bn_batch
            # check `calib_bn_batch` then
            calib_data = [next(queue) for _ in range(calib_bn_batch)]
            self.calib_bn(calib_data)
        else:
            calib_data = []

        self._set_mode("eval") # Use eval mode after BN calibration

        aggr_ans = []
        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for i in range(steps):
                if i < len(calib_data):# self.calib_bn_batch:
                    data = calib_data[i]
                else:
                    data = next(queue)
                data = _to_device(data, self.get_device())
                outputs = self.forward_data(data[0], **kwargs)
                ans = utils.flatten_list(
                    [c(data[0], outputs, data[1]) for c in criterions])
                aggr_ans.append(ans)
                del outputs
                print("\reva step {}/{} ".format(i, steps), end="", flush=True)

        aggr_ans = np.asarray(aggr_ans).transpose()

        if aggregate_fns is None:
            # by default, aggregate batch rewards with MEAN
            aggregate_fns = [lambda perfs: np.mean(perfs) if len(perfs) > 0 else 0.]\
                            * len(aggr_ans)
        return [aggr_fn(ans) for aggr_fn, ans in zip(aggregate_fns, aggr_ans)]
コード例 #5
0
"""
Built-in tight coupled NAS flows.
"""

from aw_nas.utils import getLogger
_LOGGER = getLogger("btc")

try:
    from aw_nas.btcs import nasbench_101
except ImportError as e:
    _LOGGER.warn(("Cannot import module nasbench: {}\n"
                  "Should install the NASBench 101 package following "
                  "https://github.com/google-research/nasbench").format(e))

try:
    from aw_nas.btcs import nasbench_201
except ImportError as e:
    _LOGGER.warn(("Cannot import module nasbench_201: {}\n"
                  "Should install the NASBench 201 package following "
                  "https://github.com/D-X-Y/NAS-Bench-201").format(e))

try:
    from aw_nas.btcs import nasbench_301
except ImportError as e:
    _LOGGER.warn((
        "Cannot import module nasbench_301: {}\n"
        "Should install the NASBench 301 package following "
        "https://github.com/automl/nasbench301.\n"
        "There still exist some bugs in commit 48a5f0ca152b83ae2fa31365116c0fb480466fb1, "
        "by the time of 2020/12/29, if these bugs are not fixed, can temporarily install this: "
        "pip install git+https://github.com/walkerning/nasbench301.git"
コード例 #6
0
from torch import nn

from aw_nas import utils
from aw_nas.common import assert_rollout_type
from aw_nas.utils import data_parallel
from aw_nas.utils.torch_utils import _to_device
from aw_nas.utils.common_utils import make_divisible, nullcontext
from aw_nas.utils import DistributedDataParallel
from aw_nas.weights_manager.base import BaseWeightsManager, CandidateNet
from aw_nas.weights_manager.detection_header import DetectionHeader

try:
    from torch.nn import SyncBatchNorm
    convert_sync_bn = SyncBatchNorm.convert_sync_batchnorm
except ImportError:
    utils.getLogger("weights_manager.detection").warn(
        "Import convert_sync_bn failed! SyncBatchNorm might not work!")
    convert_sync_bn = lambda m: m

__all__ = ["DetectionBackboneSupernet"]


class DetectionBackboneSupernet(BaseWeightsManager, nn.Module):
    NAME = "det_supernet"

    def __init__(
            self,
            search_space,
            device,
            rollout_type,
            feature_levels=[3, 4, 5],
            search_backbone_type="ofa_supernet",
コード例 #7
0
ファイル: metrics.py プロジェクト: zzzDavid/aw_nas
# pylint: disable=invalid-name
import os
import pickle

import numpy as np

from aw_nas.objective.detection_utils.base import Metrics
from aw_nas.utils import getLogger

_LOGGER = getLogger("det.metrics")

try:
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
except ImportError as e:
    _LOGGER.warn(("Cannot import pycocotools: {}\n"
                  "Should install EXTRAS_REQUIRE `det`").format(e))

    class COCODetectionMetrics(Metrics):
        NAME = "coco"

        def __new__(cls, *args, **kwargs):
            _LOGGER.error(
                "COCODetectionMetrics cannot be used. Install required dependencies!"
            )
            raise Exception()

        def __call__(self, boxes):
            pass

else:
コード例 #8
0
from aw_nas.utils import getLogger

_LOGGER = getLogger("hardware.compiler")

try:
    from aw_nas.hardware.compiler import dpu
except ImportError as e:
    _LOGGER.warn("Cannot import hardware compiler for dpu: {}\n".format(e))

try:
    from aw_nas.hardware.compiler import xavier
except ImportError as e:
    _LOGGER.warn("Cannot import hardware compiler for xavier: {}\n".format(e))


コード例 #9
0
 def __new__(cls, *args, **kwargs):
     from aw_nas.utils import getLogger
     getLogger("arch_network").error((
         "RandomForest arch network cannot be used: Cannot import module sklearn: {}"
         .format(imp_exception)))
     raise Exception()
コード例 #10
0
# -*- coding: utf-8 -*-
from itertools import product
from functools import reduce

import numpy as np

from aw_nas.hardware.base import BaseHardwarePerformanceModel, MixinProfilingSearchSpace
from aw_nas.hardware.utils import Prim
from aw_nas.utils import getLogger
from aw_nas.utils import make_divisible
from aw_nas.rollout.ofa import MNasNetOFASearchSpace, SSDOFASearchSpace

logger = getLogger("ofa_obj")


class OFAMixinProfilingSearchSpace(MNasNetOFASearchSpace,
                                   MixinProfilingSearchSpace):
    NAME = "ofa_mixin"

    def __init__(
        self,
        width_choice,
        depth_choice,
        kernel_choice,
        image_size_choice,
        num_cell_groups,
        expansions,
        fixed_primitives=None,
        schedule_cfg=None,
    ):
        super(OFAMixinProfilingSearchSpace, self).__init__(
コード例 #11
0
 def logger(self):
     if self._logger is None:
         self._logger = getLogger(self.__class__.__name__)
     return self._logger
コード例 #12
0
import random
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch import optim

import numpy as np
import yaml

try:
    from sklearn import linear_model
    from sklearn.neural_network import MLPRegressor
except ImportError as e:
    from aw_nas.utils import getLogger
    getLogger("hardware").warn(
        ("Cannot import module hardware.utils: {}\n"
         "Should install scikit-learn to make some hardware-related"
         " functionalities work").format(e))

from aw_nas.hardware.base import (BaseHardwarePerformanceModel,
                                  MixinProfilingSearchSpace, Preprocessor)
from aw_nas.ops import get_op

Prim_ = namedtuple(
    "Prim",
    ["prim_type", "spatial_size", "C", "C_out", "stride", "affine", "kwargs"],
)


class Prim(Prim_):
    def __new__(cls, prim_type, spatial_size, C, C_out, stride, affine,
                **kwargs):
コード例 #13
0
from torch import nn
from torch.utils.data.distributed import DistributedSampler

from aw_nas import utils
from aw_nas.final.base import FinalTrainer
from aw_nas.utils.common_utils import nullcontext
from aw_nas.utils.exception import expect
from aw_nas.utils import DataParallel
from aw_nas.utils import DistributedDataParallel
from aw_nas.utils.torch_utils import calib_bn

try:
    from torch.nn import SyncBatchNorm
    convert_sync_bn = SyncBatchNorm.convert_sync_batchnorm
except ImportError:
    utils.getLogger("cnn_trainer").warn(
        "Import convert_sync_bn failed! SyncBatchNorm might not work!")
    convert_sync_bn = lambda m: m


def _warmup_update_lr(optimizer, epoch, init_lr, warmup_epochs):
    """
    update learning rate of optimizers
    """
    lr = init_lr * epoch / warmup_epochs
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


class CNNFinalTrainer(FinalTrainer):  #pylint: disable=too-many-instance-attributes
    NAME = "cnn_trainer"
コード例 #14
0
#pylint: disable=unused-import

from aw_nas.utils import getLogger
_LOGGER = getLogger("dataset")

from aw_nas.dataset.base import BaseDataset
from aw_nas.dataset import cifar10
from aw_nas.dataset import ptb
from aw_nas.dataset import imagenet
from aw_nas.dataset import tiny_imagenet
from aw_nas.dataset import cifar100
from aw_nas.dataset import svhn

try:
    from aw_nas.dataset import voc
    from aw_nas.dataset import coco
except ImportError as e:
    _LOGGER.warn(
        ("Cannot import module detection: {}\n"
         "Should install EXTRAS_REQUIRE `det`").format(e))


AVAIL_DATA_TYPES = ["image", "sequence"]
コード例 #15
0
#pylint: disable=unused-import

from aw_nas.utils import getLogger
_LOGGER = getLogger("final")

from .cnn_trainer import CNNFinalTrainer
from .cnn_model import CNNGenotypeModel

from .bnn_model import BNNGenotypeModel

from .rnn_trainer import RNNFinalTrainer
from .rnn_model import RNNGenotypeModel

from .dense import DenseGenotypeModel

from .ofa_model import OFAGenotypeModel

try:
    from .ssd_model import SSDFinalModel, SSDHeadFinalModel
    from .det_trainer import DetectionFinalTrainer
except ImportError as e:
    _LOGGER.warn(("Cannot import module detection: {}\n"
                  "Should install EXTRAS_REQUIRE `det`").format(e))

from .general_model import GeneralGenotypeModel
コード例 #16
0
# -*- coding: utf-8 -*-
"""A simple registry meta class.
"""

import abc
import collections

from aw_nas.utils import getLogger

__all__ = ["RegistryMeta", "RegistryError"]

LOGGER = getLogger("registry")


class RegistryError(Exception):
    pass


def _default_dct_of_list():
    return collections.defaultdict(list)


class RegistryMeta(abc.ABCMeta):
    registry_dct = collections.defaultdict(dict)
    supported_rollout_dct = collections.defaultdict(_default_dct_of_list)

    def __init__(cls, name, bases, namespace):
        super(RegistryMeta, cls).__init__(name, bases, namespace)
        ## DEPRECATED: the interface of every class is defined explicitly in the arguments
        ##             instead of a cover-all config dict,
        ##             as failing loudly can avoid subtle bugs (e.g. mistyping)