class CallbackBase(): """Base callback class.""" logger = get_logger('callback') priority = 0 def __init__(self, handler_conf=None) -> None: self.handlers = None self.bind_handlers(handler_conf) def bind_handlers(self, handler_conf): """Bind event handlers.""" handlers = {} for ev, conf in handler_conf.items(): prio = None if isinstance(conf, (list, tuple)): h = conf[0] if len(conf) > 1: prio = conf[1] else: h = conf event_on(ev, h, self.priority if prio is None else prio) handlers[ev] = h self.handlers = handlers def unbind_handlers(self): """Un-bind event handlers.""" for ev, h in self.handlers.items(): event_off(ev, h)
class Registry(): """Registry class.""" logger = get_logger('registry') def __init__(self, allow_replace=False): self.allow_replace = allow_replace self._reg_class = {} def get_full_path(self, reg_path, reg_id): """Return full registration path.""" return '{}.{}'.format(reg_path, reg_id) def get_reg_name(self, reg_path, reg_id): """Return proper registration name.""" name = self.get_full_path(reg_path, reg_id) return name.lower().replace('-', '').replace('_', '').replace(' ', '') def register(self, regclass, reg_path, reg_id): """Register a component class.""" reg_id = self.get_reg_name(reg_path, reg_id) if reg_id in self._reg_class: self.logger.warning('re-register id: {}'.format(reg_id)) if not self.allow_replace: raise ValueError('Cannot re-register id: {}'.format(reg_id)) self._reg_class[reg_id] = regclass self.logger.debug('registered: {}'.format(reg_id)) def get(self, reg_path, reg_id): """Return registered class by name.""" reg_id = self.get_reg_name(reg_path, reg_id) if reg_id not in self._reg_class: raise ValueError('id \'{}\' not found in registry'.format(reg_id)) return self._reg_class[reg_id]
class TrainerBase(): """Base Trainer class.""" logger = get_logger('trainer') def __init__(self, writer=None): if writer is None: writer = DummyWriter() self.writer = writer def init(self, model, config=None): """Initialize trainer states.""" raise NotImplementedError def model_input(self, data): """Return model input.""" return data[:-1], {} def model_output(self, *args, data=None, model=None, attr=None, **kwargs): """Return model output.""" model_fn = model if attr is None else getattr(model, attr) if data is not None: args, kwargs = self.model_input(data) return model_fn(*args, **kwargs) def loss(self, output=None, data=None, model=None): """Return loss.""" return None def train_epoch(self): """Train for one epoch.""" raise NotImplementedError def valid_epoch(self): """Validate for one epoch.""" raise NotImplementedError def train_step(self): """Train for one step.""" raise NotImplementedError def valid_step(self): """Validate for one step.""" raise NotImplementedError def state_dict(self): """Return current states.""" return {} def load_state_dict(self, sd): """Resume states.""" raise NotImplementedError
class Registry(): """Registry class.""" logger = get_logger('registry') def __init__(self, allow_replace=False): self.allow_replace = allow_replace def get_reg_name(self, name): """Return proper registration name.""" return name.lower().replace('-', '').replace('_', '').replace(' ', '') def register(self, regclass, reg_path, reg_id): """Register a component class.""" reg_id = self.get_reg_name(reg_id) ClassFactory.register_cls(regclass, type_name=get_reg_type(reg_path), alias=reg_id) self.logger.debug('registered: {}'.format(reg_id)) def get(self, reg_path, reg_id): """Return registered class by name.""" reg_id = self.get_reg_name(reg_id) return ClassFactory.get_cls(get_reg_type(reg_path), reg_id)
class MetricsBase(): """Base Metrics class.""" logger = get_logger('metrics') cur_estim = None def __init__(self): self.estim = MetricsBase.get_estim() def __call__(self, *args, **kwargs): """Return metrics output.""" raise NotImplementedError @staticmethod def get_estim(): """Get current Estimator.""" return MetricsBase.cur_estim @staticmethod def set_estim(estim): """Set current Estimator.""" MetricsBase.cur_estim = estim
class DataProviderBase(): """Base DataProvider class.""" logger = get_logger('data_provider') def get_next_train_batch(self): """Return the next train batch.""" return next(self.get_train_iter()) def get_next_valid_batch(self): """Return the next validate batch.""" return next(self.get_valid_iter()) def get_train_iter(self): """Return train iterator.""" raise NotImplementedError def get_valid_iter(self): """Return validate iterator.""" raise NotImplementedError def reset_train_iter(self): """Reset train iterator.""" raise NotImplementedError def reset_valid_iter(self): """Reset validate iterator.""" raise NotImplementedError def get_num_train_batch(self): """Return number of train batches in current epoch.""" raise NotImplementedError def get_num_valid_batch(self): """Return number of validate batches in current epoch.""" raise NotImplementedError
class OptimBase(): """Base Optimizer class.""" logger = get_logger('optim') def __init__(self, space=None): self.space = space or ParamSpace() def state_dict(self): """Return current states.""" return {} def load_state_dict(self, sd): """Resume states.""" raise NotImplementedError def has_next(self): """Return True if Optimizer has the next set of parameters.""" raise NotImplementedError def _next(self): """Return the next set of parameters.""" raise NotImplementedError def next(self, batch_size=1): """Return the next batch of parameter sets.""" batch = [] for _ in range(batch_size): if not self.has_next(): break batch.append(self._next()) return batch def step(self, estim): """Update Optimizer states using Estimator evaluation results.""" pass
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Event managing and triggering.""" import inspect from functools import wraps from . import singleton, make_decorator from modnas.utils.logging import get_logger from modnas.utils import merge_config logger = get_logger(__name__) @singleton class EventManager(): """Event manager class.""" def __init__(self): self.handlers = {} self.event_queue = [] def reset(self): """Reset event states.""" self.handlers.clear() self.event_queue.clear() def get_handlers(self, ev):
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Slot module.""" from functools import partial import copy import torch.nn as nn from modnas.registry.arch_space import register from modnas.utils.logging import get_logger logger = get_logger('arch_space') def _simplify_list(data): return data[0] if isinstance(data, (list, tuple)) and len(data) == 1 else data @register class Slot(nn.Module): """Stub module that is converted to actual modules by Constructors.""" _slots = [] _slot_id = -1 _convert_fn = None _export_fn = None
class EstimBase(): """Base Estimator class.""" logger = get_logger('estim') def __init__(self, config=None, expman=None, trainer=None, constructor=None, exporter=None, model=None, writer=None, name=None): self.name = '' if name is None else name self.config = config self.expman = expman self.constructor = constructor self.exporter = exporter self.model = model self.writer = writer self.cur_epoch = -1 self.metrics = build_metrics_all(config.get('metrics', None), self) self.criterions_all, self.criterions_train, self.criterions_eval, self.criterions_valid = build_criterions_all( config.get('criterion', None), getattr(model, 'device_ids', None)) self.trainer = trainer self.results = [] self.inputs = [] self.arch_descs = [] self.stats = {} self.step_cond = threading.Lock() self.cur_trn_batch = None self.cur_val_batch = None def set_trainer(self, trainer): """Set current trainer.""" self.trainer = trainer def model_output(self, *args, data=None, model=None, **kwargs): """Return model output for given data.""" model = self.model if model is None else model return self.trainer.model_output(*args, data=data, model=model, **kwargs) def loss(self, data, output=None, model=None, mode=None): """Return loss.""" model = self.model if model is None else model output = self.model_output(data=data, model=model) if output is None else output if mode is None: crits = [] elif mode == 'train': crits = self.criterions_train elif mode == 'eval': crits = self.criterions_eval elif mode == 'valid': crits = self.criterions_valid else: raise ValueError('invalid criterion mode: {}'.format(mode)) crits = self.criterions_all + crits loss = self.trainer.loss(model=model, output=output, data=data) for crit in crits: loss = crit(loss, self, output, *data) return loss def loss_output(self, data, model=None, mode=None): """Return loss and model output.""" model = self.model if model is None else model output = self.model_output(data=data, model=model) return self.loss(data, output, model, mode), output def step(self, params): """Return evaluation results of a parameter set.""" raise NotImplementedError def stepped(self, params): """Return evaluation results of a parameter set.""" if not self.step_cond.locked(): self.step_cond.acquire() value = self.step(params) if value is not None: self.step_done(params, value) def wait_done(self): """Wait evaluation steps to finish.""" self.step_cond.acquire() self.step_cond.release() def step_done(self, params, value, arch_desc=None): """Store evaluation results of a parameter set.""" self.inputs.append(params) self.results.append(value) self.arch_descs.append( self.get_arch_desc() if arch_desc is None else arch_desc) if len(self.results) == self.config.arch_update_batch: self.step_cond.release() def print_model_info(self): """Output model information.""" model = self.model if model is not None: self.logger.info(backend.model_summary(model)) def clear_buffer(self): """Clear evaluation results.""" self.inputs, self.results, self.arch_descs = [], [], [] def get_last_results(self): """Return last evaluation results.""" return self.inputs, self.results def buffer(self): """Return generator over evaluated results with parameters and arch_descs.""" for inp, res, desc in zip(self.inputs, self.results, self.arch_descs): yield inp, res, desc def compute_metrics(self, *args, name=None, model=None, to_scalar=True, **kwargs): """Return Metrics results.""" def fmt_key(n, k): return '{}.{}'.format(n, k) def flatten_dict(n, r): if isinstance(r, dict): return { fmt_key(n, k): flatten_dict(fmt_key(n, k), v) for k, v in r.items() } return r def merge_results(dct, n, r): if not isinstance(r, dict): r = {n: r} r = { k: None if v is None else (float(v) if to_scalar else v) for k, v in r.items() } dct.update(r) ret = {} model = self.model if model is None else model names = [name] if name is not None else self.metrics.keys() for mt_name in names: res = self.metrics[mt_name](model, *args, **kwargs) merge_results(ret, mt_name, flatten_dict(mt_name, res)) return ret def run_epoch(self, optim, epoch, tot_epochs): """Run Estimator routine for one epoch.""" raise NotImplementedError def run(self, optim): """Run Estimator routine.""" raise NotImplementedError def get_score(self, res): """Return scalar value from evaluation results.""" if not isinstance(res, dict): return res score = res.get('default', None) if score is None: score = 0 if len(res) == 0 else list(res.values())[0] return score def train_epoch(self, epoch, tot_epochs, model=None): """Train model for one epoch.""" model = self.model if model is None else model ret = self.trainer.train_epoch( estim=self, model=model, tot_steps=self.get_num_train_batch(epoch), epoch=epoch, tot_epochs=tot_epochs) return ret def train_step(self, epoch, tot_epochs, step, tot_steps, model=None): """Train model for one step.""" model = self.model if model is None else model return self.trainer.train_step(estim=self, model=model, epoch=epoch, tot_epochs=tot_epochs, step=step, tot_steps=tot_steps) def valid_epoch(self, epoch=0, tot_epochs=1, model=None): """Validate model for one epoch.""" model = self.model if model is None else model return self.trainer.valid_epoch( estim=self, model=model, tot_steps=self.get_num_valid_batch(epoch), epoch=epoch, tot_epochs=tot_epochs) def valid_step(self, epoch, tot_epochs, step, tot_steps, model=None): """Validate model for one step.""" model = self.model if model is None else model return self.trainer.valid_step(estim=self, model=model, epoch=epoch, tot_epochs=tot_epochs, step=step, tot_steps=tot_steps) def reset_trainer(self, *args, trainer_config=None, model=None, **kwargs): """Reinitialize trainer.""" model = self.model if model is None else model trainer_config = trainer_config or {} trainer_config.update({'epochs': self.config.epochs}) trainer_config.update(kwargs) if self.trainer is not None: self.trainer.init(*args, model=model, config=trainer_config) self.cur_epoch = -1 def get_num_train_batch(self, epoch=None): """Return number of training batches.""" epoch = self.cur_epoch if epoch is None else epoch return 0 if self.trainer is None else self.trainer.get_num_train_batch( epoch=epoch) def get_num_valid_batch(self, epoch=None): """Return number of validating batches.""" epoch = self.cur_epoch if epoch is None else epoch return 0 if self.trainer is None else self.trainer.get_num_valid_batch( epoch=epoch) def get_next_train_batch(self): """Return the next training batch.""" ret = self.trainer.get_next_train_batch() self.cur_trn_batch = ret return ret def get_cur_train_batch(self): """Return the current training batch.""" return self.cur_trn_batch or self.get_next_train_batch() def get_next_valid_batch(self): """Return the next validating batch.""" ret = self.trainer.get_next_valid_batch() self.cur_val_batch = ret return ret def get_cur_valid_batch(self): """Return the current validating batch.""" return self.cur_val_batch def load_state_dict(self, state_dict): """Resume states.""" pass def state_dict(self): """Return current states.""" return {'cur_epoch': self.cur_epoch} def get_arch_desc(self): """Return current archdesc.""" return None if self.exporter is None else self.exporter(self.model) def save_model(self, save_name=None, exporter='DefaultTorchCheckpointExporter'): """Save model checkpoint to file.""" expman = self.expman save_name = 'model_{}_{}.pt'.format(self.name, save_name) chkpt_path = expman.join('chkpt', save_name) build_exporter(exporter, path=chkpt_path)(self.model) def save(self, epoch=None, save_name=None): """Save Estimator states to file.""" expman = self.expman logger = self.logger save_name = 'estim_{}_{}.pkl'.format(self.name, save_name) chkpt_path = expman.join('chkpt', save_name) epoch = epoch or self.cur_epoch try: chkpt = self.state_dict() with open(chkpt_path, 'wb') as f: pickle.dump(chkpt, f) except RuntimeError: logger.error("Failed saving estimator: {}".format( traceback.format_exc())) def save_checkpoint(self, epoch=None, save_name=None): """Save Estimator & model to file.""" epoch = epoch or self.cur_epoch save_name = save_name or 'ep{:03d}'.format(epoch + 1) self.save_model(save_name) self.save(epoch, save_name) def save_arch_desc(self, epoch=None, arch_desc=None, save_name=None, exporter='DefaultToFileExporter'): """Save archdesc to file.""" expman = self.expman logger = self.logger if save_name is not None: fname = 'arch_{}_{}'.format(self.name, save_name) else: epoch = epoch or self.cur_epoch fname = 'arch_{}_ep{:03d}'.format(self.name, epoch + 1) save_path = expman.join('output', fname) try: build_exporter(exporter, path=save_path)(arch_desc) except RuntimeError: logger.error("Failed saving arch_desc: {}".format( traceback.format_exc())) def load(self, chkpt_path): """Load states from file.""" if chkpt_path is None: return self.logger.info("Resuming from checkpoint: {}".format(chkpt_path)) with open(chkpt_path, 'rb') as f: chkpt = pickle.load(f) self.load_state_dict(chkpt)
# but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Default Constructors.""" import importlib import copy from functools import partial from collections import OrderedDict from modnas.registry.arch_space import build as build_module from modnas.registry.construct import register, build from modnas.arch_space.slot import Slot from modnas.utils.logging import get_logger from modnas.utils import import_file logger = get_logger('construct') def get_convert_fn(convert_fn, **kwargs): """Return a new convert function.""" if isinstance(convert_fn, str): return build(convert_fn, **kwargs) elif callable(convert_fn): return convert_fn else: raise ValueError('unsupported convert_fn type: {}'.format( type(convert_fn))) @register class DefaultModelConstructor():
# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Default DataLoader.""" import random from torch.utils.data import DataLoader from torch.utils.data.sampler import SubsetRandomSampler from modnas.registry.data_loader import register from modnas.utils.logging import get_logger logger = get_logger('data_loader') @register def DefaultDataLoader(trn_data, val_data, parallel_multiplier=1, trn_batch_size=64, val_batch_size=64, workers=2, train_size=0, train_ratio=1., train_seed=1, valid_size=0, valid_ratio=0., valid_seed=1):
# -*- coding:utf-8 -*- # Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved. # This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Default Torch Exporters.""" import traceback import torch from modnas.registry.export import register from modnas.utils.logging import get_logger logger = get_logger('export') @register class DefaultTorchCheckpointExporter(): """Exporter that saves model checkpoint to file.""" def __init__(self, path, zip_file=None): self.path = path save_kwargs = {} if zip_file is not None and int('.'.join( torch.__version__.split('.'))) >= 140: save_kwargs['_use_new_zipfile_serialization'] = zip_file self.save_kwargs = save_kwargs def __call__(self, model): """Run Exporter."""
# it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Simulated annealing model optimum finder.""" import heapq import random import numpy as np from .base import ModelOptim from modnas.registry.model_optim import register from modnas.utils.logging import get_logger logger = get_logger('model_optim') @register class SimulatedAnnealingModelOptim(ModelOptim): """Simulated annealing model optimum finder class.""" def __init__(self, space, temp_init=1e4, temp_end=1e-4, cool=0.95, cool_type='exp', batch_size=128, n_iter=1, keep_history=True):
# This program is free software; you can redistribute it and/or modify # it under the terms of the MIT License. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MIT License for more details. """Run hyperparameter tuning on python scripts.""" import os import sys import yaml import optparse from modnas.utils.wrapper import run_hptune from modnas.utils.logging import get_logger logger = get_logger() _default_hptune_config = { 'optim': { 'type': 'RandomSearchOptim' }, 'estimator': { 'tune': { 'type': 'HPTuneEstim', 'epochs': -1, } } } def tune_script():
import copy import threading import traceback from zeus.common import FileOps from zeus.common import ClassFactory, ClassType from zeus.trainer.callbacks import Callback from vega.core.search_space import SearchSpace from vega.core.search_algs import SearchAlgorithm from modnas.data_provider.predefined.default import DefaultDataProvider from modnas.trainer.base import TrainerBase from modnas.utils.wrapper import init_all from modnas.utils.logging import get_logger from modnas.utils import merge_config logger = get_logger('compat') class VegaTrainerWrapper(TrainerBase): """Trainer wrapper for ModularNAS.""" def __init__(self, trainer): super().__init__() self.trainer = trainer self.model = trainer.model self.data_provider = None self.optimizer = None self.lr_scheduler = None self.trainer_loss = None self.proc_batch = None self.cur_batch = None self.step = -1