Пример #1
0
    def log_model_buffers(self, model, buffer_names, tag_prefix, epoch,
                          completed, total, freq):
        """Logs values of model buffers.

        Notes:
            1. Buffers are logged separately per-layer (i.e. module) within model
            2. All values in a single buffer are logged such that they will be displayed on the same graph in
               TensorBoard
            3. Similarly, if multiple buffers are provided in buffer_names, all are presented on the same graph.
               If this is un-desirable, call the function separately for each buffer
            4. USE WITH CAUTION: While sometimes desirable, displaying multiple distinct values in a single
               graph isn't well supported in TensorBoard. It is achieved using a work-around, which slows
               down TensorBoard loading time considerably as the number of distinct values increases.
               Therefore, while not limited, this function is only meant for use with a very limited number of
               buffers and/or values, e.g. 2-5.

        """
        for module_name, module in model.named_modules():
            if utils.has_children(module):
                continue

            sd = module.state_dict()
            values = []
            for buf_name in buffer_names:
                try:
                    values += sd[buf_name].view(-1).tolist()
                except KeyError:
                    continue

            if values:
                tag = '/'.join([tag_prefix, module_name])
                self.tblogger.list_summary(tag, values,
                                           total * epoch + completed,
                                           len(values) > 1)
        self.tblogger.sync_to_file()
Пример #2
0
 def _dedicated_module_check(self, n, dedicated_modules_only=False):
     if not dedicated_modules_only:
         return True
     module_name = self.ops[n]['module-name']
     module = self._named_modules[module_name]
     return len(self.module_ops_map[module_name]
                ) == 1 and not utils.has_children(module)
Пример #3
0
    def _should_collect(self, module):
        if module.cacp_name in self._dont_collect_list:
            return False
        # In general, we only collect stats for "leaf" modules.
        # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these
        # models, the quantized modules are actually wrappers of the original FP32 modules, so they are
        # NOT leaf modules - but we still want to track them.
        if utils.has_children(module) and not (
                is_post_train_quant_wrapper(module)
                or isinstance(module, QFunctionalWrapper)):
            return False
        if isinstance(module, torch.nn.Identity):
            return False

        register_all_class_types = not self.classes
        if register_all_class_types or isinstance(module, tuple(self.classes)):
            return True

        return False
Пример #4
0
def _ptq_convert_pass_replace_range_linear_wrappers(module):
    # Hacky deferred import for now to workaround circular dependency
    # TODO: Proper fix
    from . import RangeLinearQuantWrapper

    reassign = OrderedDict()
    for n, m in module.named_children():
        new_m = m
        if isinstance(m, quantization.RangeLinearQuantWrapper):
            new_m = m.to_pytorch_quant(
                need_reduce_range(m.output_quant_settings.quant_mode,
                                  torch.quint8))

            requires_quantized_inputs = not (
                isinstance(new_m, nn.Sequential)
                and isinstance(new_m[0], ConditionalDeQuantizeWrapper))

            if requires_quantized_inputs:
                d = OrderedDict()
                for idx, qmd in m.inputs_quant_metadata_fallback.items():
                    qset = m.inputs_quant_settings_overrides.get(
                        idx, m.output_quant_settings)
                    scale, zp = qparams_to_pytorch(
                        qmd.scale, qmd.zero_point, qset.num_bits,
                        qset.quant_mode, torch.quint8,
                        need_reduce_range(qset.quant_mode, torch.quint8))
                    d[idx] = (scale, zp, torch.quint8)
                new_m = ConditionalQuantizeWrapper(new_m, d)
        elif isinstance(m, quantization.RangeLinearEmbeddingWrapper):
            new_m = m.to_pytorch_quant(
                need_reduce_range(m.wts_quant_settings.quant_mode,
                                  torch.quint8))
        elif utils.has_children(m):
            new_m = _ptq_convert_pass_replace_range_linear_wrappers(m)
        elif not isinstance(m, nn.Identity):
            # Module not quantized in CACP, possibly need to de-quant input
            new_m = ConditionalDeQuantizeWrapper(m)
        reassign[n] = new_m

    for n, new_m in reassign.items():
        module._modules[n] = new_m

    return module
Пример #5
0
    def cleanup(module):
        reassign = OrderedDict()
        for n, m in module.named_children():
            new_m = m
            if isinstance(m, ConditionalQuantizeWrapper):
                for idx in m.quant.already_quantized:
                    if str(idx) in m.quant.quantizers:
                        m.quant.quantizers.pop(str(idx))
                if len(m.quant.quantizers) == 0:
                    new_m = m.wrapped
            elif isinstance(m, ConditionalDeQuantizeWrapper):
                if not m.dequant.any_quantized:
                    new_m = m.wrapped
            elif utils.has_children(m):
                cleanup(m)
            reassign[n] = new_m
        for n, new_m in reassign.items():
            module._modules[n] = new_m

        return module
Пример #6
0
 def convert_container(container):
     # To maintain a similar order of registered modules compared to the original container, we unregister
     # all modules and then register them again
     # We take care to include duplicated modules, which are not returned by the original named_moduels/children
     # implementation in torch.nn.Module
     named_children = OrderedDict(
         _named_children_with_duplicates(container))
     for n, _ in named_children.items():
         delattr(container, n)
     for name, child in named_children.items():
         if isinstance(child, nn.ModuleList):
             child = _ModuleList(name, container, child)
             to_check = child.modules()
         else:
             to_check = [child]
         setattr(container, name, child)
         for m in to_check:
             if isinstance(m, _ModuleList):
                 continue
             if utils.has_children(m):
                 convert_container(m)
     return container
Пример #7
0
    def _pre_process_container(self, container, prefix=''):
        def replace_msg(module_name, modules=None):
            msglogger.debug('Module ' + module_name)
            if modules:
                msglogger.debug('\tReplacing: {}.{}'.format(
                    modules[0].__module__, modules[0].__class__.__name__))
                msglogger.debug('\tWith:      {}.{}'.format(
                    modules[1].__module__, modules[1].__class__.__name__))
            else:
                msglogger.debug('\tSkipping')

        # Iterate through model, insert quantization functions as appropriate
        for name, module in container.named_children():
            full_name = prefix + name
            if isinstance(module, tuple(self.replacement_blacklist)):
                replace_msg(full_name)
                continue
            if module in self.modules_processed:
                previous_name, previous_wrapper = self.modules_processed[
                    module]
                warnings.warn(
                    "Module '{0}' references to same module as '{1}'."
                    ' Replacing with reference the same wrapper.'.format(
                        full_name, previous_name), UserWarning)
                if previous_wrapper:
                    replace_msg(full_name, (module, previous_wrapper))
                    setattr(container, name, previous_wrapper)
                else:
                    replace_msg(full_name)
                continue
            current_qbits = self.module_qbits_map[full_name]
            # TODO - Review necessity of the block below
            if current_qbits.acts is None and current_qbits.wts is None and not self.module_overrides_map[
                    full_name]:
                # We indicate this module wasn't replaced by a wrapper
                replace_msg(full_name)
                self.modules_processed[module] = full_name, None
            else:
                # We use a type hint comment to let IDEs know replace_fn is a function
                replace_fn = self.replacement_factory.get(
                    type(module),
                    self.default_repalcement_fn)  # type: Optional[Callable]
                # If the replacement function wasn't specified - continue without replacing this module.
                if replace_fn is not None:
                    valid_kwargs, invalid_kwargs = utils.filter_kwargs(
                        self.module_overrides_map[full_name], replace_fn)
                    if invalid_kwargs:
                        raise TypeError(
                            """Quantizer of type %s doesn't accept \"%s\" 
                                            as override arguments for %s. Allowed kwargs: %s"""
                            % (type(self), list(invalid_kwargs), type(module),
                               list(valid_kwargs)))
                    new_module = replace_fn(module, full_name,
                                            self.module_qbits_map,
                                            **valid_kwargs)
                    if new_module != module:
                        replace_msg(full_name, (module, new_module))
                        # Add to history of prepared submodules
                        self.modules_processed[module] = full_name, new_module
                        # To allow recreating this wrapper later on
                        valid_args = full_name, deepcopy(self.module_qbits_map)
                        self.modules_processed_args[
                            full_name] = valid_args, valid_kwargs
                        setattr(container, name, new_module)

                        # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
                        if not utils.has_children(
                                module) and utils.has_children(new_module):
                            for sub_module_name, sub_module in new_module.named_modules(
                            ):
                                self._add_qbits_entry(
                                    full_name + '.' + sub_module_name,
                                    type(sub_module), current_qbits)
                            self.module_qbits_map[full_name] = QBits(
                                acts=current_qbits.acts, wts=None, bias=None)
                    else:
                        replace_msg(full_name)
                        self.modules_processed[module] = full_name, None

            if utils.has_children(module):
                # For container we call recursively
                self._pre_process_container(module, full_name + '.')