コード例 #1
0
ファイル: base.py プロジェクト: wangtaogithub/vega
class CallbackBase():
    """Base callback class."""

    logger = get_logger('callback')
    priority = 0

    def __init__(self, handler_conf=None) -> None:
        self.handlers = None
        self.bind_handlers(handler_conf)

    def bind_handlers(self, handler_conf):
        """Bind event handlers."""
        handlers = {}
        for ev, conf in handler_conf.items():
            prio = None
            if isinstance(conf, (list, tuple)):
                h = conf[0]
                if len(conf) > 1:
                    prio = conf[1]
            else:
                h = conf
            event_on(ev, h, self.priority if prio is None else prio)
            handlers[ev] = h
        self.handlers = handlers

    def unbind_handlers(self):
        """Un-bind event handlers."""
        for ev, h in self.handlers.items():
            event_off(ev, h)
コード例 #2
0
ファイル: registry.py プロジェクト: huawei-noah/vega
class Registry():
    """Registry class."""

    logger = get_logger('registry')

    def __init__(self, allow_replace=False):
        self.allow_replace = allow_replace
        self._reg_class = {}

    def get_full_path(self, reg_path, reg_id):
        """Return full registration path."""
        return '{}.{}'.format(reg_path, reg_id)

    def get_reg_name(self, reg_path, reg_id):
        """Return proper registration name."""
        name = self.get_full_path(reg_path, reg_id)
        return name.lower().replace('-', '').replace('_', '').replace(' ', '')

    def register(self, regclass, reg_path, reg_id):
        """Register a component class."""
        reg_id = self.get_reg_name(reg_path, reg_id)
        if reg_id in self._reg_class:
            self.logger.warning('re-register id: {}'.format(reg_id))
            if not self.allow_replace:
                raise ValueError('Cannot re-register id: {}'.format(reg_id))
        self._reg_class[reg_id] = regclass
        self.logger.debug('registered: {}'.format(reg_id))

    def get(self, reg_path, reg_id):
        """Return registered class by name."""
        reg_id = self.get_reg_name(reg_path, reg_id)
        if reg_id not in self._reg_class:
            raise ValueError('id \'{}\' not found in registry'.format(reg_id))
        return self._reg_class[reg_id]
コード例 #3
0
class TrainerBase():
    """Base Trainer class."""

    logger = get_logger('trainer')

    def __init__(self, writer=None):
        if writer is None:
            writer = DummyWriter()
        self.writer = writer

    def init(self, model, config=None):
        """Initialize trainer states."""
        raise NotImplementedError

    def model_input(self, data):
        """Return model input."""
        return data[:-1], {}

    def model_output(self, *args, data=None, model=None, attr=None, **kwargs):
        """Return model output."""
        model_fn = model if attr is None else getattr(model, attr)
        if data is not None:
            args, kwargs = self.model_input(data)
        return model_fn(*args, **kwargs)

    def loss(self, output=None, data=None, model=None):
        """Return loss."""
        return None

    def train_epoch(self):
        """Train for one epoch."""
        raise NotImplementedError

    def valid_epoch(self):
        """Validate for one epoch."""
        raise NotImplementedError

    def train_step(self):
        """Train for one step."""
        raise NotImplementedError

    def valid_step(self):
        """Validate for one step."""
        raise NotImplementedError

    def state_dict(self):
        """Return current states."""
        return {}

    def load_state_dict(self, sd):
        """Resume states."""
        raise NotImplementedError
コード例 #4
0
class Registry():
    """Registry class."""

    logger = get_logger('registry')

    def __init__(self, allow_replace=False):
        self.allow_replace = allow_replace

    def get_reg_name(self, name):
        """Return proper registration name."""
        return name.lower().replace('-', '').replace('_', '').replace(' ', '')

    def register(self, regclass, reg_path, reg_id):
        """Register a component class."""
        reg_id = self.get_reg_name(reg_id)
        ClassFactory.register_cls(regclass, type_name=get_reg_type(reg_path), alias=reg_id)
        self.logger.debug('registered: {}'.format(reg_id))

    def get(self, reg_path, reg_id):
        """Return registered class by name."""
        reg_id = self.get_reg_name(reg_id)
        return ClassFactory.get_cls(get_reg_type(reg_path), reg_id)
コード例 #5
0
ファイル: base.py プロジェクト: vineetrao25/vega
class MetricsBase():
    """Base Metrics class."""

    logger = get_logger('metrics')
    cur_estim = None

    def __init__(self):
        self.estim = MetricsBase.get_estim()

    def __call__(self, *args, **kwargs):
        """Return metrics output."""
        raise NotImplementedError

    @staticmethod
    def get_estim():
        """Get current Estimator."""
        return MetricsBase.cur_estim

    @staticmethod
    def set_estim(estim):
        """Set current Estimator."""
        MetricsBase.cur_estim = estim
コード例 #6
0
class DataProviderBase():
    """Base DataProvider class."""

    logger = get_logger('data_provider')

    def get_next_train_batch(self):
        """Return the next train batch."""
        return next(self.get_train_iter())

    def get_next_valid_batch(self):
        """Return the next validate batch."""
        return next(self.get_valid_iter())

    def get_train_iter(self):
        """Return train iterator."""
        raise NotImplementedError

    def get_valid_iter(self):
        """Return validate iterator."""
        raise NotImplementedError

    def reset_train_iter(self):
        """Reset train iterator."""
        raise NotImplementedError

    def reset_valid_iter(self):
        """Reset validate iterator."""
        raise NotImplementedError

    def get_num_train_batch(self):
        """Return number of train batches in current epoch."""
        raise NotImplementedError

    def get_num_valid_batch(self):
        """Return number of validate batches in current epoch."""
        raise NotImplementedError
コード例 #7
0
ファイル: base.py プロジェクト: huawei-noah/vega
class OptimBase():
    """Base Optimizer class."""

    logger = get_logger('optim')

    def __init__(self, space=None):
        self.space = space or ParamSpace()

    def state_dict(self):
        """Return current states."""
        return {}

    def load_state_dict(self, sd):
        """Resume states."""
        raise NotImplementedError

    def has_next(self):
        """Return True if Optimizer has the next set of parameters."""
        raise NotImplementedError

    def _next(self):
        """Return the next set of parameters."""
        raise NotImplementedError

    def next(self, batch_size=1):
        """Return the next batch of parameter sets."""
        batch = []
        for _ in range(batch_size):
            if not self.has_next():
                break
            batch.append(self._next())
        return batch

    def step(self, estim):
        """Update Optimizer states using Estimator evaluation results."""
        pass
コード例 #8
0
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Event managing and triggering."""
import inspect
from functools import wraps
from . import singleton, make_decorator
from modnas.utils.logging import get_logger
from modnas.utils import merge_config

logger = get_logger(__name__)


@singleton
class EventManager():
    """Event manager class."""
    def __init__(self):
        self.handlers = {}
        self.event_queue = []

    def reset(self):
        """Reset event states."""
        self.handlers.clear()
        self.event_queue.clear()

    def get_handlers(self, ev):
コード例 #9
0
ファイル: slot.py プロジェクト: vineetrao25/vega
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Slot module."""

from functools import partial
import copy
import torch.nn as nn
from modnas.registry.arch_space import register
from modnas.utils.logging import get_logger

logger = get_logger('arch_space')


def _simplify_list(data):
    return data[0] if isinstance(data,
                                 (list, tuple)) and len(data) == 1 else data


@register
class Slot(nn.Module):
    """Stub module that is converted to actual modules by Constructors."""

    _slots = []
    _slot_id = -1
    _convert_fn = None
    _export_fn = None
コード例 #10
0
ファイル: base.py プロジェクト: vineetrao25/vega
class EstimBase():
    """Base Estimator class."""

    logger = get_logger('estim')

    def __init__(self,
                 config=None,
                 expman=None,
                 trainer=None,
                 constructor=None,
                 exporter=None,
                 model=None,
                 writer=None,
                 name=None):
        self.name = '' if name is None else name
        self.config = config
        self.expman = expman
        self.constructor = constructor
        self.exporter = exporter
        self.model = model
        self.writer = writer
        self.cur_epoch = -1
        self.metrics = build_metrics_all(config.get('metrics', None), self)
        self.criterions_all, self.criterions_train, self.criterions_eval, self.criterions_valid = build_criterions_all(
            config.get('criterion', None), getattr(model, 'device_ids', None))
        self.trainer = trainer
        self.results = []
        self.inputs = []
        self.arch_descs = []
        self.stats = {}
        self.step_cond = threading.Lock()
        self.cur_trn_batch = None
        self.cur_val_batch = None

    def set_trainer(self, trainer):
        """Set current trainer."""
        self.trainer = trainer

    def model_output(self, *args, data=None, model=None, **kwargs):
        """Return model output for given data."""
        model = self.model if model is None else model
        return self.trainer.model_output(*args,
                                         data=data,
                                         model=model,
                                         **kwargs)

    def loss(self, data, output=None, model=None, mode=None):
        """Return loss."""
        model = self.model if model is None else model
        output = self.model_output(data=data,
                                   model=model) if output is None else output
        if mode is None:
            crits = []
        elif mode == 'train':
            crits = self.criterions_train
        elif mode == 'eval':
            crits = self.criterions_eval
        elif mode == 'valid':
            crits = self.criterions_valid
        else:
            raise ValueError('invalid criterion mode: {}'.format(mode))
        crits = self.criterions_all + crits
        loss = self.trainer.loss(model=model, output=output, data=data)
        for crit in crits:
            loss = crit(loss, self, output, *data)
        return loss

    def loss_output(self, data, model=None, mode=None):
        """Return loss and model output."""
        model = self.model if model is None else model
        output = self.model_output(data=data, model=model)
        return self.loss(data, output, model, mode), output

    def step(self, params):
        """Return evaluation results of a parameter set."""
        raise NotImplementedError

    def stepped(self, params):
        """Return evaluation results of a parameter set."""
        if not self.step_cond.locked():
            self.step_cond.acquire()
        value = self.step(params)
        if value is not None:
            self.step_done(params, value)

    def wait_done(self):
        """Wait evaluation steps to finish."""
        self.step_cond.acquire()
        self.step_cond.release()

    def step_done(self, params, value, arch_desc=None):
        """Store evaluation results of a parameter set."""
        self.inputs.append(params)
        self.results.append(value)
        self.arch_descs.append(
            self.get_arch_desc() if arch_desc is None else arch_desc)
        if len(self.results) == self.config.arch_update_batch:
            self.step_cond.release()

    def print_model_info(self):
        """Output model information."""
        model = self.model
        if model is not None:
            self.logger.info(backend.model_summary(model))

    def clear_buffer(self):
        """Clear evaluation results."""
        self.inputs, self.results, self.arch_descs = [], [], []

    def get_last_results(self):
        """Return last evaluation results."""
        return self.inputs, self.results

    def buffer(self):
        """Return generator over evaluated results with parameters and arch_descs."""
        for inp, res, desc in zip(self.inputs, self.results, self.arch_descs):
            yield inp, res, desc

    def compute_metrics(self,
                        *args,
                        name=None,
                        model=None,
                        to_scalar=True,
                        **kwargs):
        """Return Metrics results."""
        def fmt_key(n, k):
            return '{}.{}'.format(n, k)

        def flatten_dict(n, r):
            if isinstance(r, dict):
                return {
                    fmt_key(n, k): flatten_dict(fmt_key(n, k), v)
                    for k, v in r.items()
                }
            return r

        def merge_results(dct, n, r):
            if not isinstance(r, dict):
                r = {n: r}
            r = {
                k: None if v is None else (float(v) if to_scalar else v)
                for k, v in r.items()
            }
            dct.update(r)

        ret = {}
        model = self.model if model is None else model
        names = [name] if name is not None else self.metrics.keys()
        for mt_name in names:
            res = self.metrics[mt_name](model, *args, **kwargs)
            merge_results(ret, mt_name, flatten_dict(mt_name, res))
        return ret

    def run_epoch(self, optim, epoch, tot_epochs):
        """Run Estimator routine for one epoch."""
        raise NotImplementedError

    def run(self, optim):
        """Run Estimator routine."""
        raise NotImplementedError

    def get_score(self, res):
        """Return scalar value from evaluation results."""
        if not isinstance(res, dict):
            return res
        score = res.get('default', None)
        if score is None:
            score = 0 if len(res) == 0 else list(res.values())[0]
        return score

    def train_epoch(self, epoch, tot_epochs, model=None):
        """Train model for one epoch."""
        model = self.model if model is None else model
        ret = self.trainer.train_epoch(
            estim=self,
            model=model,
            tot_steps=self.get_num_train_batch(epoch),
            epoch=epoch,
            tot_epochs=tot_epochs)
        return ret

    def train_step(self, epoch, tot_epochs, step, tot_steps, model=None):
        """Train model for one step."""
        model = self.model if model is None else model
        return self.trainer.train_step(estim=self,
                                       model=model,
                                       epoch=epoch,
                                       tot_epochs=tot_epochs,
                                       step=step,
                                       tot_steps=tot_steps)

    def valid_epoch(self, epoch=0, tot_epochs=1, model=None):
        """Validate model for one epoch."""
        model = self.model if model is None else model
        return self.trainer.valid_epoch(
            estim=self,
            model=model,
            tot_steps=self.get_num_valid_batch(epoch),
            epoch=epoch,
            tot_epochs=tot_epochs)

    def valid_step(self, epoch, tot_epochs, step, tot_steps, model=None):
        """Validate model for one step."""
        model = self.model if model is None else model
        return self.trainer.valid_step(estim=self,
                                       model=model,
                                       epoch=epoch,
                                       tot_epochs=tot_epochs,
                                       step=step,
                                       tot_steps=tot_steps)

    def reset_trainer(self, *args, trainer_config=None, model=None, **kwargs):
        """Reinitialize trainer."""
        model = self.model if model is None else model
        trainer_config = trainer_config or {}
        trainer_config.update({'epochs': self.config.epochs})
        trainer_config.update(kwargs)
        if self.trainer is not None:
            self.trainer.init(*args, model=model, config=trainer_config)
        self.cur_epoch = -1

    def get_num_train_batch(self, epoch=None):
        """Return number of training batches."""
        epoch = self.cur_epoch if epoch is None else epoch
        return 0 if self.trainer is None else self.trainer.get_num_train_batch(
            epoch=epoch)

    def get_num_valid_batch(self, epoch=None):
        """Return number of validating batches."""
        epoch = self.cur_epoch if epoch is None else epoch
        return 0 if self.trainer is None else self.trainer.get_num_valid_batch(
            epoch=epoch)

    def get_next_train_batch(self):
        """Return the next training batch."""
        ret = self.trainer.get_next_train_batch()
        self.cur_trn_batch = ret
        return ret

    def get_cur_train_batch(self):
        """Return the current training batch."""
        return self.cur_trn_batch or self.get_next_train_batch()

    def get_next_valid_batch(self):
        """Return the next validating batch."""
        ret = self.trainer.get_next_valid_batch()
        self.cur_val_batch = ret
        return ret

    def get_cur_valid_batch(self):
        """Return the current validating batch."""
        return self.cur_val_batch

    def load_state_dict(self, state_dict):
        """Resume states."""
        pass

    def state_dict(self):
        """Return current states."""
        return {'cur_epoch': self.cur_epoch}

    def get_arch_desc(self):
        """Return current archdesc."""
        return None if self.exporter is None else self.exporter(self.model)

    def save_model(self,
                   save_name=None,
                   exporter='DefaultTorchCheckpointExporter'):
        """Save model checkpoint to file."""
        expman = self.expman
        save_name = 'model_{}_{}.pt'.format(self.name, save_name)
        chkpt_path = expman.join('chkpt', save_name)
        build_exporter(exporter, path=chkpt_path)(self.model)

    def save(self, epoch=None, save_name=None):
        """Save Estimator states to file."""
        expman = self.expman
        logger = self.logger
        save_name = 'estim_{}_{}.pkl'.format(self.name, save_name)
        chkpt_path = expman.join('chkpt', save_name)
        epoch = epoch or self.cur_epoch
        try:
            chkpt = self.state_dict()
            with open(chkpt_path, 'wb') as f:
                pickle.dump(chkpt, f)
        except RuntimeError:
            logger.error("Failed saving estimator: {}".format(
                traceback.format_exc()))

    def save_checkpoint(self, epoch=None, save_name=None):
        """Save Estimator & model to file."""
        epoch = epoch or self.cur_epoch
        save_name = save_name or 'ep{:03d}'.format(epoch + 1)
        self.save_model(save_name)
        self.save(epoch, save_name)

    def save_arch_desc(self,
                       epoch=None,
                       arch_desc=None,
                       save_name=None,
                       exporter='DefaultToFileExporter'):
        """Save archdesc to file."""
        expman = self.expman
        logger = self.logger
        if save_name is not None:
            fname = 'arch_{}_{}'.format(self.name, save_name)
        else:
            epoch = epoch or self.cur_epoch
            fname = 'arch_{}_ep{:03d}'.format(self.name, epoch + 1)
        save_path = expman.join('output', fname)
        try:
            build_exporter(exporter, path=save_path)(arch_desc)
        except RuntimeError:
            logger.error("Failed saving arch_desc: {}".format(
                traceback.format_exc()))

    def load(self, chkpt_path):
        """Load states from file."""
        if chkpt_path is None:
            return
        self.logger.info("Resuming from checkpoint: {}".format(chkpt_path))
        with open(chkpt_path, 'rb') as f:
            chkpt = pickle.load(f)
        self.load_state_dict(chkpt)
コード例 #11
0
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Default Constructors."""

import importlib
import copy
from functools import partial
from collections import OrderedDict
from modnas.registry.arch_space import build as build_module
from modnas.registry.construct import register, build
from modnas.arch_space.slot import Slot
from modnas.utils.logging import get_logger
from modnas.utils import import_file

logger = get_logger('construct')


def get_convert_fn(convert_fn, **kwargs):
    """Return a new convert function."""
    if isinstance(convert_fn, str):
        return build(convert_fn, **kwargs)
    elif callable(convert_fn):
        return convert_fn
    else:
        raise ValueError('unsupported convert_fn type: {}'.format(
            type(convert_fn)))


@register
class DefaultModelConstructor():
コード例 #12
0
ファイル: default.py プロジェクト: vineetrao25/vega
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Default DataLoader."""

import random
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from modnas.registry.data_loader import register
from modnas.utils.logging import get_logger

logger = get_logger('data_loader')


@register
def DefaultDataLoader(trn_data,
                      val_data,
                      parallel_multiplier=1,
                      trn_batch_size=64,
                      val_batch_size=64,
                      workers=2,
                      train_size=0,
                      train_ratio=1.,
                      train_seed=1,
                      valid_size=0,
                      valid_ratio=0.,
                      valid_seed=1):
コード例 #13
0
# -*- coding:utf-8 -*-

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Default Torch Exporters."""
import traceback
import torch
from modnas.registry.export import register
from modnas.utils.logging import get_logger

logger = get_logger('export')


@register
class DefaultTorchCheckpointExporter():
    """Exporter that saves model checkpoint to file."""
    def __init__(self, path, zip_file=None):
        self.path = path
        save_kwargs = {}
        if zip_file is not None and int('.'.join(
                torch.__version__.split('.'))) >= 140:
            save_kwargs['_use_new_zipfile_serialization'] = zip_file
        self.save_kwargs = save_kwargs

    def __call__(self, model):
        """Run Exporter."""
コード例 #14
0
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.

"""Simulated annealing model optimum finder."""
import heapq
import random
import numpy as np
from .base import ModelOptim
from modnas.registry.model_optim import register
from modnas.utils.logging import get_logger


logger = get_logger('model_optim')


@register
class SimulatedAnnealingModelOptim(ModelOptim):
    """Simulated annealing model optimum finder class."""

    def __init__(self,
                 space,
                 temp_init=1e4,
                 temp_end=1e-4,
                 cool=0.95,
                 cool_type='exp',
                 batch_size=128,
                 n_iter=1,
                 keep_history=True):
コード例 #15
0
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
"""Run hyperparameter tuning on python scripts."""

import os
import sys
import yaml
import optparse
from modnas.utils.wrapper import run_hptune
from modnas.utils.logging import get_logger

logger = get_logger()

_default_hptune_config = {
    'optim': {
        'type': 'RandomSearchOptim'
    },
    'estimator': {
        'tune': {
            'type': 'HPTuneEstim',
            'epochs': -1,
        }
    }
}


def tune_script():
コード例 #16
0
import copy
import threading
import traceback
from zeus.common import FileOps
from zeus.common import ClassFactory, ClassType
from zeus.trainer.callbacks import Callback
from vega.core.search_space import SearchSpace
from vega.core.search_algs import SearchAlgorithm
from modnas.data_provider.predefined.default import DefaultDataProvider
from modnas.trainer.base import TrainerBase
from modnas.utils.wrapper import init_all
from modnas.utils.logging import get_logger
from modnas.utils import merge_config

logger = get_logger('compat')


class VegaTrainerWrapper(TrainerBase):
    """Trainer wrapper for ModularNAS."""
    def __init__(self, trainer):
        super().__init__()
        self.trainer = trainer
        self.model = trainer.model
        self.data_provider = None
        self.optimizer = None
        self.lr_scheduler = None
        self.trainer_loss = None
        self.proc_batch = None
        self.cur_batch = None
        self.step = -1