示例#1
0
    def _pre_process_container(self, container, prefix=''):
        # Iterate through model, insert quantization functions as appropriate
        for name, module in container.named_children():
            full_name = prefix + name
            current_qbits = self.module_qbits_map[full_name]
            if current_qbits.acts is None and current_qbits.wts is None:
                continue
            try:
                new_module = self.replacement_factory[type(module)](
                    module, full_name, self.module_qbits_map)
                msglogger.debug(
                    'Module {0}: Replacing \n{1} with \n{2}'.format(
                        full_name, module, new_module))
                setattr(container, name, new_module)

                # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
                if not distiller.has_children(
                        module) and distiller.has_children(new_module):
                    for sub_module_name, sub_module in new_module.named_modules(
                    ):
                        self._add_qbits_entry(
                            full_name + '.' + sub_module_name,
                            type(sub_module), current_qbits)
                    self.module_qbits_map[full_name] = QBits(
                        acts=current_qbits.acts, wts=None)
            except KeyError:
                pass

            if distiller.has_children(module):
                # For container we call recursively
                self._pre_process_container(module, full_name + '.')
示例#2
0
    def _pre_process_container(self, container, prefix=''):
        # Iterate through model, insert quantization functions as appropriate
        for name, module in container.named_children():
            full_name = prefix + name
            if module in self.modules_replaced:
                previous_name, previous_wrapper = self.modules_replaced[module]
                warnings.warn(
                    "Module '{0}' references to same module as '{1}'."
                    ' Replacing with reference the same wrapper.'.format(
                        full_name, previous_name), UserWarning)
                msglogger.debug(
                    'Module {0}: Replacing \n{1} with \n{2}'.format(
                        full_name, module, previous_wrapper))
                setattr(container, name, previous_wrapper)
                continue
            current_qbits = self.module_qbits_map[full_name]
            if current_qbits.acts is None and current_qbits.wts is None:
                if self.module_overrides_map[full_name]:
                    raise ValueError(
                        "Adding overrides while not quantizing is not allowed."
                    )
                continue

            # We use a type hint comment to let IDEs know replace_fn is a function
            replace_fn = self.replacement_factory[type(
                module)]  # type: Optional[Callable]
            # If the replacement function wasn't specified - continue without replacing this module.
            if replace_fn is not None:
                valid_kwargs, invalid_kwargs = distiller.filter_kwargs(
                    self.module_overrides_map[full_name], replace_fn)
                if invalid_kwargs:
                    raise TypeError(
                        """Quantizer of type %s doesn't accept \"%s\" 
                                        as override arguments for %s. Allowed kwargs: %s"""
                        % (type(self), list(invalid_kwargs), type(module),
                           list(valid_kwargs)))
                new_module = replace_fn(module, full_name,
                                        self.module_qbits_map, **valid_kwargs)
                msglogger.debug(
                    'Module {0}: Replacing \n{1} with \n{2}'.format(
                        full_name, module, new_module))
                # Add to history of prepared submodules
                self.modules_replaced[module] = full_name, new_module
                setattr(container, name, new_module)

                # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
                if not distiller.has_children(
                        module) and distiller.has_children(new_module):
                    for sub_module_name, sub_module in new_module.named_modules(
                    ):
                        self._add_qbits_entry(
                            full_name + '.' + sub_module_name,
                            type(sub_module), current_qbits)
                    self.module_qbits_map[full_name] = QBits(
                        acts=current_qbits.acts, wts=None, bias=None)

            if distiller.has_children(module):
                # For container we call recursively
                self._pre_process_container(module, full_name + '.')
示例#3
0
 def dedicated_module_check(n):
     if not dedicated_modules_only:
         return True
     module_name = self.ops[n]['module-name']
     module = named_modules[module_name]
     return len(self.module_ops_map[module_name]
                ) == 1 and not distiller.has_children(module)
示例#4
0
def _ptq_convert_pass_replace_range_linear_wrappers(module):
    # Hacky deferred import for now to workaround circular dependency
    # TODO: Proper fix
    from distiller.quantization import RangeLinearQuantWrapper

    reassign = OrderedDict()
    for n, m in module.named_children():
        new_m = m
        if isinstance(m, distiller.quantization.RangeLinearQuantWrapper):
            new_m = m.to_pytorch_quant(need_reduce_range(m.output_quant_settings.quant_mode, torch.quint8))

            requires_quantized_inputs = not (isinstance(new_m, nn.Sequential) and
                                             isinstance(new_m[0], ConditionalDeQuantizeWrapper))

            if requires_quantized_inputs:
                d = OrderedDict()
                for idx, qmd in m.inputs_quant_metadata_fallback.items():
                    qset = m.inputs_quant_settings_overrides.get(idx, m.output_quant_settings)
                    scale, zp = distiller_qparams_to_pytorch(qmd.scale, qmd.zero_point, qset.num_bits,
                                                             qset.quant_mode, torch.quint8,
                                                             need_reduce_range(qset.quant_mode, torch.quint8))
                    d[idx] = (scale, zp, torch.quint8)
                new_m = ConditionalQuantizeWrapper(new_m, d)
        elif distiller.has_children(m):
            new_m = _ptq_convert_pass_replace_range_linear_wrappers(m)
        elif not isinstance(m, nn.Identity):
            # Module not quantized in Distiller, possibly need to de-quant input
            new_m = ConditionalDeQuantizeWrapper(m)
        reassign[n] = new_m

    for n, new_m in reassign.items():
        module._modules[n] = new_m

    return module
示例#5
0
    def log_model_buffers(self, model, buffer_names, tag_prefix, epoch,
                          completed, total, freq):
        """Logs values of model buffers.

        Notes:
            1. Buffers are logged separately per-layer (i.e. module) within model
            2. All values in a single buffer are logged such that they will be displayed on the same graph in
               TensorBoard
            3. Similarly, if multiple buffers are provided in buffer_names, all are presented on the same graph.
               If this is un-desirable, call the function separately for each buffer
            4. USE WITH CAUTION: While sometimes desirable, displaying multiple distinct values in a single
               graph isn't well supported in TensorBoard. It is achieved using a work-around, which slows
               down TensorBoard loading time considerably as the number of distinct values increases.
               Therefore, while not limited, this function is only meant for use with a very limited number of
               buffers and/or values, e.g. 2-5.

        """
        for module_name, module in model.named_modules():
            if distiller.has_children(module):
                continue

            sd = module.state_dict()
            values = []
            for buf_name in buffer_names:
                try:
                    values += sd[buf_name].view(-1).tolist()
                except KeyError:
                    continue

            if values:
                tag = '/'.join([tag_prefix, module_name])
                self.tblogger.list_summary(tag, values,
                                           total * epoch + completed,
                                           len(values) > 1)
        self.tblogger.sync_to_file()
示例#6
0
def test_param_quantization(model, optimizer, qbits, overrides,
                            explicit_expected_overrides, train_with_fp_copy):
    # Build expected QBits
    expected_qbits, post_prepare_changes = get_expected_qbits(
        model, qbits, explicit_expected_overrides)

    q = DummyQuantizer(model,
                       optimizer=optimizer,
                       bits_activations=qbits.acts,
                       bits_weights=qbits.wts,
                       bits_bias=qbits.bias,
                       overrides=deepcopy(overrides),
                       train_with_fp_copy=train_with_fp_copy)
    q.prepare_model()
    expected_qbits.update(post_prepare_changes)

    q_model_pre_quant = deepcopy(model)
    q.quantize_params()
    for (name, pre_quant_module), post_quant_module in zip(
            q_model_pre_quant.named_modules(), model.modules()):
        # Skip containers
        # if len(list(pre_quant_module.modules())) > 1:
        if has_children(pre_quant_module):
            continue

        num_qbits = expected_qbits[name].wts

        for param_name, pre_quant_param in pre_quant_module.named_parameters():
            quantizable = num_qbits is not None
            if param_name.endswith('bias'):
                num_bits = expected_qbits[name].bias
            else:
                num_bits = num_qbits

            if quantizable and train_with_fp_copy:
                # "param_name" and "pre_quant_param" refer to the float copy

                # Check the float copy didn't change
                post_quant_fp_copy = getattr(post_quant_module, param_name)
                assert torch.equal(pre_quant_param, post_quant_fp_copy)

                quant_param = getattr(post_quant_module,
                                      param_name.replace(FP_BKP_PREFIX, ''))

                # Check weights quantization properly recorded for autograd
                gfn = quant_param.grad_fn
                assert gfn is not None
                assert str(type(gfn).__name__) == 'AddBackward0'
                gfn = gfn.next_functions[0][0]
                assert str(type(gfn).__name__) == 'AccumulateGrad'
                assert id(gfn.variable) == id(post_quant_fp_copy)
            else:
                quant_param = getattr(post_quant_module, param_name)

            expected = dummy_quantize_params(
                pre_quant_param,
                _ParamToQuant(None, None, None, None,
                              num_bits)) if quantizable else pre_quant_param
            assert torch.equal(quant_param, expected)
示例#7
0
 def start_laplace(self):
     self._check_required_stats()
     self.collecting_laplace = True
     # reset batch_idx for all leaf modules
     for module in self.model.modules():
         if distiller.has_children(module) or isinstance(
                 module, torch.nn.Identity):
             continue
         module.batch_idx = 0
示例#8
0
    def _pre_process_container(self, container, prefix=''):
        # Iterate through model, insert quantization functions as appropriate
        for name, module in container.named_children():
            full_name = prefix + name
            current_qbits = self.module_qbits_map[full_name]
            if current_qbits.acts is None and current_qbits.wts is None:
                if self.module_overrides_map[full_name]:
                    raise ValueError(
                        "Adding overrides while not quantizing is not allowed."
                    )
                continue
            try:
                replace_fn = self.replacement_factory[type(module)]
                valid_kwargs, invalid_kwargs = distiller.filter_kwargs(
                    self.module_overrides_map[full_name], replace_fn)
                if invalid_kwargs:
                    raise TypeError(
                        """Quantizer of type %s doesn't accept \"%s\" 
                                        as override arguments for %s. Allowed kwargs: %s"""
                        % (type(self), list(invalid_kwargs), type(module),
                           list(valid_kwargs)))
                new_module = self.replacement_factory[type(module)](
                    module, full_name, self.module_qbits_map, **valid_kwargs)
                msglogger.debug(
                    'Module {0}: Replacing \n{1} with \n{2}'.format(
                        full_name, module, new_module))
                setattr(container, name, new_module)

                # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
                if not distiller.has_children(
                        module) and distiller.has_children(new_module):
                    for sub_module_name, sub_module in new_module.named_modules(
                    ):
                        self._add_qbits_entry(
                            full_name + '.' + sub_module_name,
                            type(sub_module), current_qbits)
                    self.module_qbits_map[full_name] = QBits(
                        acts=current_qbits.acts, wts=None, bias=None)
            except KeyError:
                pass

            if distiller.has_children(module):
                # For container we call recursively
                self._pre_process_container(module, full_name + '.')
def replace_quantize_module(model):
    for name, module in model.named_children():
        if isinstance(module, ClippedLinearQuantization):
            setattr(
                model, name,
                ClippedOnly(num_bits=module.num_bits,
                            clip_val=module.clip_val,
                            inplace=module.inplace))

        if distiller.has_children(module):
            replace_quantize_module(module)
示例#10
0
 def start_second_pass(self):
     self._check_required_stats()
     self.collecting_second_pass = True
     # reset batch_idx for all leaf modules
     for module in self.model.modules():
         if distiller.has_children(module) or isinstance(
                 module, torch.nn.Identity):
             continue
         module.batch_idx = 0
         for record in module.quant_stats.inputs:
             record['total_numel'] = 0
         module.quant_stats.output['total_numel'] = 0
示例#11
0
    def start_module(self, module):
        """Iteratively register to the forward-pass callback of all eligible modules.

        Eligible modules are currently filtered by their class type.
        """
        if distiller.has_children(module) or isinstance(
                module, torch.nn.Identity):
            return
        register_all_class_types = not self.classes
        if register_all_class_types or isinstance(module, tuple(self.classes)):
            self.fwd_hook_handles.append(
                module.register_forward_hook(self._activation_stats_cb))
            self._start_counter(module)
示例#12
0
    def _should_collect(self, module):
        if module.distiller_name in self._dont_collect_list:
            return False
        # In general, we only collect stats for "leaf" modules.
        # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these
        # models, the quantized modules are actually wrappers of the original FP32 modules, so they are
        # NOT leaf modules - but we still want to track them.
        if distiller.has_children(module) and not is_post_train_quant_wrapper(module):
            return False
        if isinstance(module, torch.nn.Identity):
            return False

        register_all_class_types = not self.classes
        if register_all_class_types or isinstance(module, tuple(self.classes)):
            return True

        return False
示例#13
0
 def convert_container(container):
     named_children = OrderedDict(container.named_children())
     # To maintain a similar order of registered modules compared to the original container, we unregister
     # all modules and then register them again
     for n, _ in named_children.items():
         delattr(container, n)
     for name, child in named_children.items():
         if isinstance(child, nn.ModuleList):
             child = _DistillerModuleList(name, container, child)
             to_check = child.modules()
         else:
             to_check = [child]
         setattr(container, name, child)
         for m in to_check:
             if isinstance(m, _DistillerModuleList):
                 continue
             if distiller.has_children(m):
                 convert_container(m)
     return container
    def cleanup(module):
        reassign = OrderedDict()
        for n, m in module.named_children():
            new_m = m
            if isinstance(m, ConditionalQuantizeWrapper):
                for idx in m.quant.already_quantized:
                    if str(idx) in m.quant.quantizers:
                        m.quant.quantizers.pop(str(idx))
                if len(m.quant.quantizers) == 0:
                    new_m = m.wrapped
            elif isinstance(m, ConditionalDeQuantizeWrapper):
                if not m.dequant.any_quantized:
                    new_m = m.wrapped
            elif distiller.has_children(m):
                cleanup(m)
            reassign[n] = new_m
        for n, new_m in reassign.items():
            module._modules[n] = new_m

        return module
示例#15
0
 def convert_container(container):
     # To maintain a similar order of registered modules compared to the original container, we unregister
     # all modules and then register them again
     # We take care to include duplicated modules, which are not returned by the original named_moduels/children
     # implementation in torch.nn.Module
     named_children = OrderedDict(_named_children_with_duplicates(container))
     for n, _ in named_children.items():
         delattr(container, n)
     for name, child in named_children.items():
         if isinstance(child, nn.ModuleList):
             child = _DistillerModuleList(name, container, child)
             to_check = child.modules()
         else:
             to_check = [child]
         setattr(container, name, child)
         for m in to_check:
             if isinstance(m, _DistillerModuleList):
                 continue
             if distiller.has_children(m):
                 convert_container(m)
     return container
示例#16
0
 def _check_required_stats(self):
     """
     Check whether the required statistics were collected to allow collecting laplace distribution stats.
     """
     for name, module in self.model.named_modules():
         if distiller.has_children(module) or isinstance(
                 module, torch.nn.Identity):
             continue
         if not hasattr(module, 'quant_stats'):
             raise RuntimeError(
                 'Collection of Laplace distribution statistics is '
                 'only allowed after collection of stats has started.')
         for i, input_stats_record in enumerate(module.quant_stats.inputs):
             if 'mean' not in input_stats_record:
                 raise RuntimeError(
                     'The required stats for input[%d] in module "%s" were not collected. '
                     'Please collect the required statistics using `collector.start()` and evaluating'
                     ' the model for enough batches.' % (i, name))
         if 'mean' not in module.quant_stats.output:
             raise RuntimeError(
                 'The required stats for the output in module "%s" were not collected. '
                 'Please collect the required statistics using `collector.start()` and evaluating'
                 ' the model for enough batches.' % name)
示例#17
0
    def _pre_process_container(self, container, prefix=''):
        def replace_msg(module_name, modules=None):
            msglogger.debug('Module ' + module_name)
            if modules:
                msglogger.debug('\tReplacing: {}.{}'.format(
                    modules[0].__module__, modules[0].__class__.__name__))
                msglogger.debug('\tWith:      {}.{}'.format(
                    modules[1].__module__, modules[1].__class__.__name__))
            else:
                msglogger.debug('\tSkipping')

        # Iterate through model, insert quantization functions as appropriate
        for name, module in container.named_children():
            full_name = prefix + name
            if isinstance(module, tuple(self.replacement_blacklist)):
                replace_msg(full_name)
                continue
            if module in self.modules_processed:
                previous_name, previous_wrapper = self.modules_processed[
                    module]
                warnings.warn(
                    "Module '{0}' references to same module as '{1}'."
                    ' Replacing with reference the same wrapper.'.format(
                        full_name, previous_name), UserWarning)
                if previous_wrapper:
                    replace_msg(full_name, (module, previous_wrapper))
                    setattr(container, name, previous_wrapper)
                else:
                    replace_msg(full_name)
                continue
            current_qbits = self.module_qbits_map[full_name]
            # TODO - Review necessity of the block below
            if current_qbits.acts is None and current_qbits.wts is None and not self.module_overrides_map[
                    full_name]:
                # We indicate this module wasn't replaced by a wrapper
                replace_msg(full_name)
                self.modules_processed[module] = full_name, None
            else:
                # We use a type hint comment to let IDEs know replace_fn is a function
                replace_fn = self.replacement_factory.get(
                    type(module),
                    self.default_repalcement_fn)  # type: Optional[Callable]
                # If the replacement function wasn't specified - continue without replacing this module.
                if replace_fn is not None:
                    valid_kwargs, invalid_kwargs = distiller.filter_kwargs(
                        self.module_overrides_map[full_name], replace_fn)
                    if invalid_kwargs:
                        raise TypeError(
                            """Quantizer of type %s doesn't accept \"%s\" 
                                            as override arguments for %s. Allowed kwargs: %s"""
                            % (type(self), list(invalid_kwargs), type(module),
                               list(valid_kwargs)))
                    new_module = replace_fn(module, full_name,
                                            self.module_qbits_map,
                                            **valid_kwargs)
                    if new_module != module:
                        replace_msg(full_name, (module, new_module))
                        # Add to history of prepared submodules
                        self.modules_processed[module] = full_name, new_module
                        # To allow recreating this wrapper later on
                        valid_args = full_name, deepcopy(self.module_qbits_map)
                        self.modules_processed_args[
                            full_name] = valid_args, valid_kwargs
                        setattr(container, name, new_module)

                        # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
                        if not distiller.has_children(
                                module) and distiller.has_children(new_module):
                            for sub_module_name, sub_module in new_module.named_modules(
                            ):
                                self._add_qbits_entry(
                                    full_name + '.' + sub_module_name,
                                    type(sub_module), current_qbits)
                            self.module_qbits_map[full_name] = QBits(
                                acts=current_qbits.acts, wts=None, bias=None)
                    else:
                        replace_msg(full_name)
                        self.modules_processed[module] = full_name, None

            if distiller.has_children(module):
                # For container we call recursively
                self._pre_process_container(module, full_name + '.')