Example #1
0
    def load_state_dict(self, state, normalize_dataparallel_keys=False):
        try:
            loaded_masks = state['masks_dict']
        except KeyError as exception:
            msglogger.error('could not load the CompressionScheduler state.'
                            ' masks_dict is missing from state')
            with contextlib.suppress(TypeError):
                msglogger.debug('Scheduler state keys are: {}'.format(', '.join(state)))
            raise

        if normalize_dataparallel_keys:
            loaded_masks = {normalize_module_name(k): v for k, v in loaded_masks.items()}
        device = model_device(self.model)
        for name, mask in self.zeros_mask_dict.items():
            masker = self.zeros_mask_dict[name]
            masker.mask = loaded_masks[name]
            if masker.mask is not None:
                masker.mask = masker.mask.to(device)
Example #2
0
def _create_graph(dataset, model):
    dummy_input = utils.get_dummy_input(dataset, utils.model_device(model))
    return SummaryGraph(model, dummy_input)
Example #3
0
def execute_thinning_recipe(model,
                            zeros_mask_dict,
                            recipe,
                            optimizer,
                            loaded_from_file=False):
    """Apply a thinning recipe to a model.
    This will remove filters and channels, as well as handle batch-normalization parameter
    adjustment, and thinning of weight tensors.
    """
    device = utils.model_device(model)
    layers = {mod_name: m for mod_name, m in model.named_modules()}
    for layer_name, directives in recipe.modules.items():
        for attr, val in directives.items():
            if attr in ['running_mean', 'running_var']:
                running = getattr(layers[layer_name], attr)
                dim_to_trim = val[0]
                indices_to_select = val[1]
                # Check that we're not trying to trim a parameter that is already "thin"
                if running.size(dim_to_trim) != indices_to_select.nelement():
                    msglogger.debug("[thinning] {}: setting {} to {}".format(
                        layer_name, attr, indices_to_select.nelement()))
                    setattr(
                        layers[layer_name], attr,
                        torch.index_select(running,
                                           dim=dim_to_trim,
                                           index=indices_to_select.to(
                                               running.device)))
            else:
                msglogger.debug("[thinning] {}: setting {} to {}".format(
                    layer_name, attr, val))
                setattr(layers[layer_name], attr, val)

    assert len(recipe.parameters) > 0

    with torch.no_grad():
        for param_name, param_directives in recipe.parameters.items():
            if param_name == "module.fc.weight":
                debug = True
            msglogger.debug("{} : {}".format(param_name, param_directives))
            param = utils.model_find_param(model, param_name)
            assert param is not None
            for directive in param_directives:
                dim = directive[0]
                indices = directive[1].to(device)
                len_indices = indices.nelement()
                if len(directive) == 4:  # TODO: this code is hard to follow
                    msglogger.debug("{}-{}-{}: SHAPE = {}".format(
                        param_name, param.shape, id(param),
                        list(directive[2])))
                    selection_view = param.view(*directive[2])
                    # Check that we're not trying to trim a parameter that is already "thin"
                    if param.data.size(dim) != len_indices:
                        param.data = torch.index_select(
                            selection_view, dim, indices)
                        if param.grad is not None:
                            # We also need to change the dimensions of the gradient tensor.
                            grad_selection_view = param.grad.resize(
                                *directive[2])
                            if grad_selection_view.size(dim) != len_indices:
                                param.grad = torch.index_select(
                                    grad_selection_view, dim, indices)
                                # update optimizer
                                if _optimizer_thinning(optimizer, param, dim,
                                                       indices, directive[3]):
                                    msglogger.debug(
                                        "Updated [4D] velocity buffer for {} (dim={},size={},shape={})"
                                        .format(param_name, dim, len_indices,
                                                directive[3]))

                    param.data = param.view(*directive[3])
                    if param.grad is not None:
                        param.grad = param.grad.resize_(*directive[3])
                else:
                    if param.data.size(dim) != len_indices:
                        msglogger.debug(
                            "[thinning] changing param {} ({})  dim:{}  new len: {}"
                            .format(param_name, param.shape, dim, len_indices))
                        assert param.size(dim) > len_indices
                        param.data = torch.index_select(
                            param.data, dim, indices.to(param.device))
                        msglogger.debug(
                            "[thinning] changed param {}".format(param_name))
                    # We also need to change the dimensions of the gradient tensor.
                    # If have not done a backward-pass thus far, then the gradient will
                    # not exist, and therefore won't need to be re-dimensioned.
                    if param.grad is not None and param.grad.size(
                            dim) != len_indices:
                        param.grad = torch.index_select(
                            param.grad, dim, indices.to(param.device))
                        # update optimizer
                        if _optimizer_thinning(optimizer, param, dim, indices):
                            msglogger.debug("Updated velocity buffer %s" %
                                            param_name)

                if not loaded_from_file and zeros_mask_dict:
                    # If the masks are loaded from a checkpoint file, then we don't need to change
                    # their shape, because they are already correctly shaped
                    mask = zeros_mask_dict[param_name].mask
                    if mask is not None and (mask.size(dim) != len_indices):
                        zeros_mask_dict[param_name].mask = torch.index_select(
                            mask, dim, indices)
def ptq_coordinate_search(quantizer,
                          dummy_input,
                          eval_fn,
                          test_fn=None,
                          method='Powell',
                          maxiter=None,
                          maxfev=None,
                          basinhopping=False,
                          basinhopping_niter=100,
                          init_mode=ClipMode.NONE,
                          init_method=None,
                          search_clipping=False,
                          minimizer_kwargs=None):
    """
    Searches for the optimal post-train quantization configuration (scale/zero_points)
    for a model using numerical methods, as described by scipy.optimize.minimize.
    Args:
        quantizer (quantization.PostTrainLinearQuantizer): A configured PostTrainLinearQuantizer object
          containing the model being quantized
        dummy_input: an sample expected input to the model
        eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form
          `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
        test_fn (callable): a function to test the current performance of the model. Assumed it has a signature of
          the form `test_fn(model)->dict`, where the returned dict contains relevant results to be logged.
          For example: {'top-1': VAL, 'top-5': VAL, 'loss': VAL}
        method (str or callable): Minimization method as accepted by scipy.optimize.minimize.
        maxiter (int): Maximum number of iterations to perform during minimization
        maxfev (int): Maximum number of total function evaluations to perform during minimization
        basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method,
          will pass the `method` argument to `scipy.optimize.basinhopping`.
        basinhopping_niter (int): Number of iterations to perform if basinhopping is set
        init_mode (ClipMode or callable or str or dict): See 'init_linear_quant_params'
        init_method (str or callable): See 'init_layer_linear_quant_params'
        search_clipping (bool): Search on clipping values instead of directly on scale/zero-point (scale and zero-
          point are inferred from the clipping values)
        minimizer_kwargs (dict): Optional additional arguments for scipy.optimize.minimize
    """
    if not isinstance(quantizer, PostTrainLinearQuantizer):
        raise ValueError(
            'Only PostTrainLinearQuantizer supported, but got a {}'.format(
                quantizer.__class__.__name__))
    if quantizer.prepared:
        raise ValueError(
            'Expecting a quantizer for which prepare_model has not been called'
        )

    run_device = utils.model_device(quantizer.model)

    original_model = deepcopy(quantizer.model).cpu()
    original_model = fold_batch_norms(original_model, dummy_input)

    if not quantizer.model_activation_stats:
        msglogger.info('Collecting stats for model...')
        model_temp = _make_non_parallel_copy(original_model).to(
            device=run_device)
        act_stats = collect_quant_stats(model_temp,
                                        eval_fn,
                                        inplace_runtime_check=True,
                                        disable_inplace_attrs=True,
                                        save_dir=getattr(
                                            msglogger, 'logdir', '.'))
        if model_temp != original_model:
            del model_temp
        quantizer.model_activation_stats = act_stats
        quantizer.model.quantizer_metadata['params'][
            'model_activation_stats'] = act_stats

    # Preparing model and init conditions:
    msglogger.info("Initializing quantizer...")

    # Make sure weights are re-quantizable and clip-able
    quantizer.save_fp_weights = True
    quantizer.also_clip_weights = True

    # Disable any user set activations clipping - we'll be using init_args
    quantizer.clip_acts = ClipMode.NONE
    for overrides_dict in quantizer.module_overrides_map.values():
        overrides_dict.pop('clip_acts', None)

    quantizer.prepare_model(dummy_input)
    quantizer.model.eval()
    quantizer.model = quantizer.model.cpu()

    validate_quantization_settings(quantizer.model, search_clipping)

    msglogger.info("Initializing quantization parameters...")
    init_linear_quant_params(quantizer,
                             original_model,
                             eval_fn,
                             dummy_input,
                             init_mode,
                             init_method,
                             search_clipping=search_clipping,
                             run_device=run_device)

    msglogger.info("Evaluating initial quantization score...")
    best_data = {
        'score': eval_fn(quantizer.model),
        'qp_dict': deepcopy(quantizer.linear_quant_params)
    }
    msglogger.info("Evaluation set loss after initialization %.3f" %
                   best_data['score'])
    if test_fn:
        msglogger.info('Evaluating on full test set...')
        results = test_fn(quantizer.model)
        s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()])
        msglogger.info('Test: ' + s)

    init_qp_dict = OrderedDict(
        quantizer.named_linear_quant_params(search_clipping, filter=True))
    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping)

    iter_counter = count(1)
    eval_counter = count(1)

    def feed_forward_fn(qp_vec):
        # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping):
        #     return 1e6
        qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping)
        quantizer.update_linear_quant_params(qp_dict)
        loss = eval_fn(quantizer.model)

        i = next(eval_counter)
        if i % 20 == 0:
            msglogger.info('%d evaluations: loss=%.3f' % (i, loss))

        return loss

    def callback(qp_vec):
        score = feed_forward_fn(qp_vec)
        i = next(iter_counter)
        msglogger.info("Iteration %d: \t Score=%.3f" % (i, score))
        if score < best_data['score']:
            best_data['score'] = score
            best_data['qp_dict'] = quant_params_vec2dict(
                keys, qp_vec, search_clipping)
            msglogger.info("Saving current best quantization parameters.")
        if test_fn:
            msglogger.info('Evaluating on full test set...')
            results = test_fn(quantizer.model)
            s = ', '.join(
                ['{} = {:.3f}'.format(k, v) for k, v in results.items()])
            msglogger.info('Test: ' + s)

    options = OrderedDict()
    options['maxiter'] = maxiter
    options['maxfev'] = maxfev

    minimizer_kwargs = minimizer_kwargs or OrderedDict()
    minimizer_kwargs.update({'method': method, 'options': options})
    if basinhopping:
        msglogger.info(
            'Using basinhopping global minimum search with "%s" local minimization method'
            % method)
        res = opt.basinhopping(feed_forward_fn,
                               init_qp_vec,
                               basinhopping_niter,
                               callback=callback,
                               minimizer_kwargs=minimizer_kwargs)
    else:
        msglogger.info('Using "%s" minimization algorithm.' % method)
        res = opt.minimize(feed_forward_fn,
                           init_qp_vec,
                           callback=callback,
                           **minimizer_kwargs)

    msglogger.info('Optimization done')
    msglogger.info('Best score: {}'.format(best_data['score']))
    msglogger.info('Best Configuration: {}'.format(best_data['qp_dict']))
    return quantizer.model, best_data['qp_dict']
Example #5
0
    def prepare_model(self, dummy_input=None):
        """
        Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width
        and overrides configuration provided to __init__(), and according to the replacement_factory as
        defined by the Quantizer sub-class being used.

        Note:
            If multiple sub-modules within the model actually reference the same module, then that module
            is replaced only once, according to the configuration (bit-width and/or overrides) of the
            first encountered reference.
            Toy Example - say a module is constructed using this bit of code:

                shared_relu = nn.ReLU
                self.relu1 = shared_relu
                self.relu2 = shared_relu

            When traversing the model, a replacement will be generated when 'self.relu1' is encountered.
            Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced
            with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2'
            will be ignored. A warning message will be shown.
        """
        if self.prepared:
            raise RuntimeError('prepare_model can be called only once')

        msglogger.info('Preparing model for quantization using {0}'.format(
            self.__class__.__name__))

        self.model.quantizer_metadata["dummy_input"] = dummy_input
        if dummy_input is not None:
            summary_graph = utils.SummaryGraph(self.model, dummy_input)
            self.adjacency_map = summary_graph.adjacency_map(
                dedicated_modules_only=False)
            del summary_graph

        model_device = utils.model_device(self.model)

        self._pre_prepare_model(dummy_input)

        self._pre_process_container(self.model)
        for module_name, module in self.model.named_modules():
            qbits = self.module_qbits_map[module_name]
            curr_parameters = dict(module.named_parameters())
            for param_name, param in curr_parameters.items():
                n_bits = qbits.bias if param_name.endswith(
                    'bias') else qbits.wts
                if n_bits is None:
                    continue
                fp_attr_name = param_name
                if self.train_with_fp_copy:
                    hack_float_backup_parameter(module, param_name, n_bits)
                    fp_attr_name = FP_BKP_PREFIX + param_name
                self.params_to_quantize.append(
                    _ParamToQuant(module, module_name, fp_attr_name,
                                  param_name, n_bits))

                param_full_name = '.'.join([module_name, param_name])
                msglogger.debug(
                    "Parameter '{0}' will be quantized to {1} bits".format(
                        param_full_name, n_bits))

        # If an optimizer was passed, assume we need to update it
        if self.optimizer:
            for pg in self._get_new_optimizer_params_groups():
                self.optimizer.add_param_group(pg)

        self._post_prepare_model()

        # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers
        self.model.to(model_device)

        utils.assign_layer_fq_names(self.model)

        self.prepared = True

        msglogger.debug('Quantized model:\n\n{0}\n'.format(self.model))
Example #6
0
    def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):
        self._src_model = model
        self._named_modules = OrderedDict(model.named_modules())
        self._adj_map = None
        self._layers_topological_order = None
        self._top_level_ops = set()
        model_clone = utils.make_non_parallel_copy(model)

        # Switch all instances of torch.nn.ModuleList in the model to our CACPModuleList
        # See documentation of _ModuleList class for details on why this is done
        model_clone, converted_module_names_map = _to_modulelist(model_clone)

        with torch.onnx.set_training(model_clone, False):

            device = utils.model_device(model_clone)
            dummy_input = utils.convert_tensors_recursively_to(dummy_input,
                                                               device=device)
            self.dummy_input = dummy_input
            trace, _ = jit.get_trace_graph(model_clone,
                                           dummy_input,
                                           _force_outplace=True)

            # As of PyTorch 1.3.0, ONNX trace optimization has an issue that results in incorrect scope names
            # of nodes in the trace graph.
            # These can make it impossible, in some cases, to derive the connectivity of the model using the original
            # module names. So we try to detect these cases and apply workarounds

            # The issue is:
            #   Dropout ops are removed by ONNX trace optimization. However, the op BEFORE the original dropout op
            #   gets the scope name of the dropout op
            pre_dropout_nodes_scope_names = OrderedDict()

            prev_non_dropout_op = None
            for node in trace.graph().nodes():
                kind = node.kind()
                if 'aten' not in kind:
                    continue
                if kind == 'aten::dropout':
                    if prev_non_dropout_op:
                        pre_dropout_nodes_scope_names[node.scopeName(
                        )] = prev_non_dropout_op.scopeName()
                else:
                    prev_non_dropout_op = node

            # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
            # composing a GEMM operation; etc.
            torch.onnx._optimize_trace(trace,
                                       torch.onnx.OperatorExportTypes.ONNX)

            graph = trace.graph()
            self.ops = OrderedDict()
            self.module_ops_map = defaultdict(list)
            self.params = OrderedDict()
            self.edges = []
            self.temp = OrderedDict()

            in_out = list(graph.inputs()) + list(graph.outputs())
            for param in in_out:
                self.__add_param(param)

            for node in graph.nodes():
                new_op = self.__create_op(node)

                if apply_scope_name_workarounds:
                    # Here we apply the workaround to the issue of dropout op scope name overriding previous op's
                    # scope name
                    if new_op['name'] in pre_dropout_nodes_scope_names:
                        new_op['orig-name'] = pre_dropout_nodes_scope_names[
                            new_op['name']]
                        new_op['name'] = new_op['orig-name']

                # Convert the graph node's scope name to a PyTorch module name
                module_name = onnx_name_2_pytorch_name(new_op['orig-name'])

                # Get name from before conversion to CACPModuleList
                module_name = converted_module_names_map[module_name]

                if len(module_name) == 0:
                    # Special case where the module name is an empty string - this happens
                    # when the op is called from the "top-level" of the model
                    new_op['name'] = 'top_level_op'
                else:
                    new_op['name'] = module_name

                # Save the calling module name in the op dict. Denormalize it so it can
                # be directly matched with the actual model
                module_name = utils.denormalize_module_name(
                    self._src_model, module_name)
                new_op['module-name'] = module_name

                # The node's scope name in the graph corresponds to the module from which the op was called.
                # This means that when ops are invoked from the same module via functional calls or direct
                # operations on tensors, these ops will have the SAME MODEL NAME associated with them.
                # For example:
                #   t = t1 + t2
                #   t = F.relu(t)
                # In this case the add operation and the ReLU operation will have the same name, which is
                # derived from the module they're contained in.
                #
                # Another case where different ops will have the same module name is when a module is reused:
                #   out = self.conv1(x)
                #   out = self.relu(out)    <=== First use of self.relu
                #   out = self.conv2(out)
                #   out = self.relu(out)    <=== Second use of self.relu
                # In this case the graph will have 2 distinct ReLU nodes, with the same scope name.
                #
                # Operators with the same name create very confusing graphs (in ResNet, for example),
                # so we "unroll" them.
                same_module_cnt = len(self.module_ops_map[module_name])
                if same_module_cnt:
                    # TODO: Was this meant to be applied only to 'top_level_ops'? Also, it's not
                    #       applied to the first module that had the same name
                    new_op['name'] += "_%s_%d" % (new_op['type'],
                                                  same_module_cnt)
                self.module_ops_map[module_name].append(new_op['name'])

                # Finally we register the new op in the ops collection
                self.ops[new_op['name']] = new_op

                for input_ in node.inputs():
                    self.__add_input(new_op, input_)
                    self.edges.append(
                        SummaryGraph.Edge(input_.debugName(), new_op['name']))

                for output in node.outputs():
                    self.__add_output(new_op, output)
                    self.edges.append(
                        SummaryGraph.Edge(new_op['name'], output.debugName()))

                new_op['attrs'] = OrderedDict([
                    (attr_name, node[attr_name])
                    for attr_name in node.attributeNames()
                ])

        self.__merge_pad_avgpool()
        self.add_macs_attr()
        self.add_footprint_attr()
        self.add_arithmetic_intensity_attr()
        del trace
        del graph
        del model_clone