def get_quantization_handlers(): """This sets up the handler map. The priority for schemes is: - If a handler is present for a class then for the declared scheme or any scheme if SCHEME is None it is matched - If a handler is present for a superclass of the class then for the declared scheme or any scheme if SCHEME is None it is matched - The default handler for the scheme is matched """ handlers = {} schemes = set() # Collect all handlers for cls in get_all_subclasses(QuantizionHandler): if not cls.PARAMS_TYPE: continue if cls.SCHEME: schemes.add(cls.SCHEME) for params_cls in cls.PARAMS_TYPE: phandlers = handlers.setdefault(params_cls, {}) pscheme_handlers = phandlers.setdefault(cls.SCHEME, []) pscheme_handlers.append(cls) # Iterate through all parameters and set a scheme for pclass in get_all_subclasses(Parameters): # parameters does not define an op name if not pclass.CLS_OP_NAME: continue # see if we have any handlers for this class phandlers = handlers.setdefault(pclass, {}) for scheme in schemes: # handlers for class and scheme are already present matched_handlers = phandlers.setdefault(scheme, []) for hpclass, class_handlers in handlers.items(): if hpclass == '__default__': continue if issubclass(pclass, hpclass): # is subclass and scheme is present if scheme in class_handlers: add_in_handlers(matched_handlers, class_handlers[scheme]) # is subclass and all schemes match is present if None in class_handlers: add_in_handlers(matched_handlers, class_handlers[scheme]) if not matched_handlers: # match the default scheme add_in_handlers(matched_handlers, handlers['__default__'][scheme]) del handlers['__default__'] return handlers
def get_all_options_by_params(): options = {} for handler in get_all_subclasses(QuantizionHandler): if handler.PARAMS_TYPE is None or not handler.OPTIONS: continue for params in handler.PARAMS_TYPE: if params == '__default__': params = Parameters options.setdefault(params, {}).update(handler.OPTIONS) for params in get_all_subclasses(Parameters): poptions = {} for k, v in options.items(): if issubclass(params, k): poptions.update(v) options[params] = poptions return options
def __init__(self) -> None: self._generators = {} for gen_class in get_all_subclasses(GeneratorBase): if gen_class.PARAMS: for params in gen_class.PARAMS: handler_list = self._generators.setdefault(params, []) handlers = {} handler_list.append(handlers) if gen_class.KTYPES: for ktype in gen_class.KTYPES: handlers[ktype] = gen_class else: handlers[None] = gen_class for params, handlers in list(self._generators.items()): for params_subclass in get_all_subclasses(params): if params_subclass not in self._generators: self._generators[params_subclass] = handlers
def get_all_handlers(cls, opts): handlers = {} for handler_cls in get_all_subclasses(cls): if not handler_cls.HANDLES: continue for params_cls in handler_cls.HANDLES[0]: handlers_for = handlers.setdefault(params_cls, []) handlers_for.append((handler_cls(opts), handler_cls.HANDLES[1])) return handlers
def deco(cls): # copy the closest base class options so that we create # a new class variable on this class setattr(cls, "OPTIONS", deepcopy(getattr(cls, "OPTIONS"))) # Now add / overide options cls_opts = getattr(cls, "OPTIONS") cls_opts.update({opt['name']: opt for opt in args}) cls_opts_hash = object.__hash__(cls_opts) # since object classes can be intialized in an arbitrary order # copy to all subclasses that have already set options for subcls in get_all_subclasses(cls): sub_cls_opts = getattr(subcls, "OPTIONS") if object.__hash__(sub_cls_opts) != cls_opts_hash: sub_cls_opts.update({opt['name']: opt for opt in args}) return cls
def get_all_options(): options = {} for handler in get_all_subclasses(QuantizionHandler): if not handler.OPTIONS: continue for opt_name, opt in handler.OPTIONS.items(): optrec = options.setdefault(opt_name, {'handlers': set()}) for k in opt.keys(): if k in optrec: if k != 'help' and opt[k] != optrec[k]: raise ValueError( f'Quantization option {k} has different definitions' ) else: optrec[k] = opt[k] optrec['handlers'].add(handler) return options
LOG = logging.getLogger("nntool." + __name__) def general_validation(match: Matcher): if match.DESCRIPTION is None: LOG.warning('matcher %s has no description', match.NAME) if match.NAME is None: raise ValueError(f'match {match.NAME} has no name') if '*' in match.RUN_BEFORE and '*' in match.RUN_AFTER: raise ValueError( f'match {match.NAME} has wildcard in run_before and run_after') return match ALL_MATCHERS = [general_validation(match_class) for match_class in get_all_subclasses(Matcher) if match_class.NAME is not None] def select_matchers(group=None): return [match_class for match_class in ALL_MATCHERS if (group is None or '*' in match_class.GROUPS or group in match_class.GROUPS)] def order_matchers(matchers): first_matchers = [match for match in matchers if '*' in match.RUN_BEFORE] last_matchers = [match for match in matchers if '*' in match.RUN_AFTER] rest = [match for match in matchers if match not in first_matchers + last_matchers] rest_sorted = [] while rest: