def log_model_buffers(self, model, buffer_names, tag_prefix, epoch, completed, total, freq): """Logs values of model buffers. Notes: 1. Buffers are logged separately per-layer (i.e. module) within model 2. All values in a single buffer are logged such that they will be displayed on the same graph in TensorBoard 3. Similarly, if multiple buffers are provided in buffer_names, all are presented on the same graph. If this is un-desirable, call the function separately for each buffer 4. USE WITH CAUTION: While sometimes desirable, displaying multiple distinct values in a single graph isn't well supported in TensorBoard. It is achieved using a work-around, which slows down TensorBoard loading time considerably as the number of distinct values increases. Therefore, while not limited, this function is only meant for use with a very limited number of buffers and/or values, e.g. 2-5. """ for module_name, module in model.named_modules(): if utils.has_children(module): continue sd = module.state_dict() values = [] for buf_name in buffer_names: try: values += sd[buf_name].view(-1).tolist() except KeyError: continue if values: tag = '/'.join([tag_prefix, module_name]) self.tblogger.list_summary(tag, values, total * epoch + completed, len(values) > 1) self.tblogger.sync_to_file()
def _dedicated_module_check(self, n, dedicated_modules_only=False): if not dedicated_modules_only: return True module_name = self.ops[n]['module-name'] module = self._named_modules[module_name] return len(self.module_ops_map[module_name] ) == 1 and not utils.has_children(module)
def _should_collect(self, module): if module.cacp_name in self._dont_collect_list: return False # In general, we only collect stats for "leaf" modules. # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these # models, the quantized modules are actually wrappers of the original FP32 modules, so they are # NOT leaf modules - but we still want to track them. if utils.has_children(module) and not ( is_post_train_quant_wrapper(module) or isinstance(module, QFunctionalWrapper)): return False if isinstance(module, torch.nn.Identity): return False register_all_class_types = not self.classes if register_all_class_types or isinstance(module, tuple(self.classes)): return True return False
def _ptq_convert_pass_replace_range_linear_wrappers(module): # Hacky deferred import for now to workaround circular dependency # TODO: Proper fix from . import RangeLinearQuantWrapper reassign = OrderedDict() for n, m in module.named_children(): new_m = m if isinstance(m, quantization.RangeLinearQuantWrapper): new_m = m.to_pytorch_quant( need_reduce_range(m.output_quant_settings.quant_mode, torch.quint8)) requires_quantized_inputs = not ( isinstance(new_m, nn.Sequential) and isinstance(new_m[0], ConditionalDeQuantizeWrapper)) if requires_quantized_inputs: d = OrderedDict() for idx, qmd in m.inputs_quant_metadata_fallback.items(): qset = m.inputs_quant_settings_overrides.get( idx, m.output_quant_settings) scale, zp = qparams_to_pytorch( qmd.scale, qmd.zero_point, qset.num_bits, qset.quant_mode, torch.quint8, need_reduce_range(qset.quant_mode, torch.quint8)) d[idx] = (scale, zp, torch.quint8) new_m = ConditionalQuantizeWrapper(new_m, d) elif isinstance(m, quantization.RangeLinearEmbeddingWrapper): new_m = m.to_pytorch_quant( need_reduce_range(m.wts_quant_settings.quant_mode, torch.quint8)) elif utils.has_children(m): new_m = _ptq_convert_pass_replace_range_linear_wrappers(m) elif not isinstance(m, nn.Identity): # Module not quantized in CACP, possibly need to de-quant input new_m = ConditionalDeQuantizeWrapper(m) reassign[n] = new_m for n, new_m in reassign.items(): module._modules[n] = new_m return module
def cleanup(module): reassign = OrderedDict() for n, m in module.named_children(): new_m = m if isinstance(m, ConditionalQuantizeWrapper): for idx in m.quant.already_quantized: if str(idx) in m.quant.quantizers: m.quant.quantizers.pop(str(idx)) if len(m.quant.quantizers) == 0: new_m = m.wrapped elif isinstance(m, ConditionalDeQuantizeWrapper): if not m.dequant.any_quantized: new_m = m.wrapped elif utils.has_children(m): cleanup(m) reassign[n] = new_m for n, new_m in reassign.items(): module._modules[n] = new_m return module
def convert_container(container): # To maintain a similar order of registered modules compared to the original container, we unregister # all modules and then register them again # We take care to include duplicated modules, which are not returned by the original named_moduels/children # implementation in torch.nn.Module named_children = OrderedDict( _named_children_with_duplicates(container)) for n, _ in named_children.items(): delattr(container, n) for name, child in named_children.items(): if isinstance(child, nn.ModuleList): child = _ModuleList(name, container, child) to_check = child.modules() else: to_check = [child] setattr(container, name, child) for m in to_check: if isinstance(m, _ModuleList): continue if utils.has_children(m): convert_container(m) return container
def _pre_process_container(self, container, prefix=''): def replace_msg(module_name, modules=None): msglogger.debug('Module ' + module_name) if modules: msglogger.debug('\tReplacing: {}.{}'.format( modules[0].__module__, modules[0].__class__.__name__)) msglogger.debug('\tWith: {}.{}'.format( modules[1].__module__, modules[1].__class__.__name__)) else: msglogger.debug('\tSkipping') # Iterate through model, insert quantization functions as appropriate for name, module in container.named_children(): full_name = prefix + name if isinstance(module, tuple(self.replacement_blacklist)): replace_msg(full_name) continue if module in self.modules_processed: previous_name, previous_wrapper = self.modules_processed[ module] warnings.warn( "Module '{0}' references to same module as '{1}'." ' Replacing with reference the same wrapper.'.format( full_name, previous_name), UserWarning) if previous_wrapper: replace_msg(full_name, (module, previous_wrapper)) setattr(container, name, previous_wrapper) else: replace_msg(full_name) continue current_qbits = self.module_qbits_map[full_name] # TODO - Review necessity of the block below if current_qbits.acts is None and current_qbits.wts is None and not self.module_overrides_map[ full_name]: # We indicate this module wasn't replaced by a wrapper replace_msg(full_name) self.modules_processed[module] = full_name, None else: # We use a type hint comment to let IDEs know replace_fn is a function replace_fn = self.replacement_factory.get( type(module), self.default_repalcement_fn) # type: Optional[Callable] # If the replacement function wasn't specified - continue without replacing this module. if replace_fn is not None: valid_kwargs, invalid_kwargs = utils.filter_kwargs( self.module_overrides_map[full_name], replace_fn) if invalid_kwargs: raise TypeError( """Quantizer of type %s doesn't accept \"%s\" as override arguments for %s. Allowed kwargs: %s""" % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs))) new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs) if new_module != module: replace_msg(full_name, (module, new_module)) # Add to history of prepared submodules self.modules_processed[module] = full_name, new_module # To allow recreating this wrapper later on valid_args = full_name, deepcopy(self.module_qbits_map) self.modules_processed_args[ full_name] = valid_args, valid_kwargs setattr(container, name, new_module) # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping if not utils.has_children( module) and utils.has_children(new_module): for sub_module_name, sub_module in new_module.named_modules( ): self._add_qbits_entry( full_name + '.' + sub_module_name, type(sub_module), current_qbits) self.module_qbits_map[full_name] = QBits( acts=current_qbits.acts, wts=None, bias=None) else: replace_msg(full_name) self.modules_processed[module] = full_name, None if utils.has_children(module): # For container we call recursively self._pre_process_container(module, full_name + '.')