def test_fold_batch_norms_inference_no_fold(model, input_shape):
    orig_model = deepcopy(model)
    folded_model = mt.fold_batch_norms(model,
                                       dummy_input=torch.randn(input_shape),
                                       inference=True)
    for (n_orig, m_orig), (n_folded,
                           m_folded) in zip(orig_model.named_modules(),
                                            folded_model.named_modules()):
        assert n_folded == n_orig
        assert type(m_folded) == type(m_orig)

    for (n_orig, p_orig), (n_folded,
                           p_folded) in zip(orig_model.named_parameters(),
                                            folded_model.named_parameters()):
        assert n_folded == n_orig
        assert (p_folded == p_orig).all().item() == 1
def test_fold_batch_norms_inference(model, input_shape):
    # Make sure we have non-trivial values to work with
    nn.init.uniform_(model.seq[1].weight)
    nn.init.uniform_(model.seq[1].bias)
    nn.init.uniform_(model.seq[1].running_mean)
    nn.init.uniform_(model.seq[1].running_var)

    model.eval()
    orig_model = deepcopy(model)
    dummy_input = torch.randn(input_shape)
    folded_model = mt.fold_batch_norms(model,
                                       dummy_input=dummy_input,
                                       inference=True)
    assert type(folded_model.seq[0]) == type(orig_model.seq[0])
    assert type(folded_model.seq[1]) == nn.Identity

    y_orig = orig_model(dummy_input)
    y_folded = folded_model(dummy_input)
    torch.testing.assert_allclose(y_folded, y_orig)
def ptq_greedy_search(model,
                      dummy_input,
                      eval_fn,
                      calib_eval_fn=None,
                      recurrent=False,
                      act_stats=None,
                      args=None,
                      module_override_gen_fn=None,
                      input_override_gen_fn=None,
                      fold_sequences=True):
    """
    Perform greedy search on Post Train Quantization configuration for the model.
    Args:
        model (nn.Module): the model to quantize
        dummy_input (torch.Tensor): a dummy input to be passed to the model
        eval_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that
          accepts the model. All other arguments should be set in advance (can be done using functools.partial), or
          they will be left with their default values.
        calib_eval_fn (function): An 'evaluation' function to use for forward passing
          through the model to collection quantization calibration statistics.
          if None provided - will use `eval_fn` as a default.
        recurrent (bool): a flag to indicate whether the model has recurrent connections.
        act_stats (OrderedDict): quant calibration activation stats.
          if None provided - will be calculated on runtime.
        args (dict or argparse.Namespace): command line arguments. alternatively - a dict.
        module_override_gen_fn: A function to generate module overrides.
          assumes signature
          `def module_override_gen_fn(module: nn.Module,
                                      module_name: str,
                                      sg: distiller.SummaryGraph,
                                      overrides_dict: OrderedDict,
                                      **kwargs)-> Generator[OrderedDict, None, None]`
        input_override_gen_fn: Same as module_override_gen_fn, only quantized inputs to the top level layers.
        fold_sequences (bool): fold batch norms before quantizing
    Returns:
        (quantized_model, best_overrides_dict)
    Note:
        It is assumed that `eval_fn` returns a satisfying metric of performance (e.g. accuracy)
        and the greedy search aims to maximize this metric.
    """
    if args is None:
        args = get_default_args()
    elif isinstance(args, dict):
        updated_args = get_default_args()
        updated_args.__dict__.update(args)
        args = updated_args

    if fold_sequences:
        model = fold_batch_norms(model, dummy_input)
    best_overrides_dict = OrderedDict()
    if args.resume_search_from:
        with open(args.resume_search_from, 'r') as f:
            best_overrides_dict = distiller.yaml_ordered_load(f)
        msglogger.info('Loaded search checkpoint from %s' %
                       args.resume_search_from)
    overrides_dict = OrderedDict()
    sg = SummaryGraph(model, dummy_input)
    modules_to_quantize = sg.layers_topological_order(recurrent)
    adjacency_map = sg.adjacency_map()
    modules_dict = OrderedDict(
        model.named_modules())  # type: OrderedDict[str, nn.Module]
    modules_to_quantize = [
        m for m in modules_to_quantize if m not in args.qe_no_quant_layers
    ]

    module_override_gen_fn = module_override_gen_fn or module_override_generator
    input_override_gen_fn = input_override_gen_fn or input_override_generator

    calib_eval_fn = calib_eval_fn or eval_fn
    if not act_stats:
        msglogger.info('Collecting stats for model...')
        model_temp = distiller.utils.make_non_parallel_copy(model)
        act_stats = collect_quant_stats(model_temp, calib_eval_fn)
        del model_temp
        if args:
            act_stats_path = '%s_act_stats.yaml' % args.arch
            msglogger.info('Done. Saving act stats into %s' % act_stats_path)
            distiller.yaml_ordered_save(act_stats_path, act_stats)
    msglogger.info('Evaluating baseline score for model...')
    base_score = args.base_score or eval_fn(model)
    msglogger.info("Base score: %.3f" % base_score)

    def recalibrate_stats(module_name, act_stats):
        """
        Re-collects quant-calibration stats for successor modules of the current module.
        """
        msglogger.info('Recalibrating stats...')
        modules_to_recalibrate = {
            op.name
            for op in adjacency_map[module_name].successors
        } & set(act_stats)
        if not modules_to_recalibrate:
            # either there aren't any successors or
            # the successors aren't in the stats file - skip
            return act_stats
        q = PostTrainLinearQuantizer(
            distiller.utils.make_non_parallel_copy(model),
            bits_activations=None,
            bits_parameters=None,
            bits_accum=32,
            mode=LinearQuantMode.ASYMMETRIC_SIGNED,
            clip_acts=ClipMode.NONE,
            overrides=deepcopy(best_overrides_dict),
            model_activation_stats=deepcopy(act_stats),
            inputs_quant_auto_fallback=False,
            per_channel_wts=args.qe_per_channel)
        q.prepare_model(dummy_input)
        # recalibrate on the current best quantized version of the model.
        recalib_act_stats = collect_quant_stats(
            q.model, calib_eval_fn, modules_to_collect=modules_to_recalibrate)
        msglogger.info('Done.')
        act_stats.update(recalib_act_stats)
        return act_stats

    loaded_from_checkpoint = []
    # Quantize inputs:
    input_modules = get_inputs_to_quantize(sg, args,
                                           recurrent)  # top level modules
    for module_name, input_idxs in input_modules.items():
        denormalized_module_name = distiller.denormalize_module_name(
            model, module_name)
        module = modules_dict[denormalized_module_name]
        if isinstance(module, SKIP_MODULES):
            msglogger.info('Skipping module \'%s\' of type %s.' %
                           (module_name, type(module)))
            continue
        msglogger.info('Quantizing top level inputs for %s' % module_name)

        normalized_module_name = module_name
        if isinstance(model, nn.DataParallel):
            normalized_module_name = re.sub(r'module\.', '',
                                            normalized_module_name)
        if normalized_module_name in best_overrides_dict and \
                best_overrides_dict[normalized_module_name].get('input_overrides', None):
            # This means the loaded dict already has the module
            msglogger.info(
                "  Quantizing '%s' based on loaded checkpoint: %s" %
                (module_name, best_overrides_dict[normalized_module_name]))
            if best_overrides_dict[normalized_module_name].get(
                    'bits_activations'):
                loaded_from_checkpoint.append(normalized_module_name)
            continue
        if not best_overrides_dict.get(normalized_module_name, None):
            best_overrides_dict[normalized_module_name] = OrderedDict()
        for input_idx in input_idxs:
            best_module_override = search_best_local_settings(
                module,
                module_name,
                sg,
                act_stats,
                eval_fn,
                best_overrides_dict,
                input_override_gen_fn,
                input_idx=input_idx,
                bits_activations=args.qe_bits_acts,
                bits_weights=args.qe_bits_wts,
                per_channel=args.qe_per_channel)
            best_overrides_dict[normalized_module_name].update(
                best_module_override)
        # Leave only the input_overrides settings:
        current_input_overrides = best_overrides_dict[normalized_module_name][
            'input_overrides']
        best_overrides_dict[normalized_module_name] = override_odict(
            input_overrides=current_input_overrides)

    # Quantize layers as a whole:
    for module_name in modules_to_quantize:
        module = modules_dict[module_name]
        if isinstance(module, SKIP_MODULES):
            msglogger.info('Skipping module \'%s\' of type %s.' %
                           (module_name, module.__class__.__name__))
            continue

        normalized_module_name = module_name
        if isinstance(model, nn.DataParallel):
            normalized_module_name = re.sub(r'module\.', '',
                                            normalized_module_name)

        if normalized_module_name in best_overrides_dict and \
                best_overrides_dict[normalized_module_name].get('bits_activations', None)\
                and normalized_module_name not in loaded_from_checkpoint:
            # This means the loaded dict already has the module
            msglogger.info(
                "  Quantizing '%s'(%s) based on loaded checkpoint: %s" %
                (module_name, module.__class__.__name__,
                 best_overrides_dict[normalized_module_name]))
            loaded_from_checkpoint.append(normalized_module_name)
            continue
        if not best_overrides_dict.get(normalized_module_name, None):
            best_overrides_dict[normalized_module_name] = OrderedDict()
        # Hard coded workaround for avgpool->reshape->fc
        if normalized_module_name == 'fc':
            input_override = override_odict(bits_activations=8,
                                            clip_acts='NONE')
            best_overrides_dict['fc'].update(
                OrderedDict([('input_overrides',
                              OrderedDict([(0, input_override)]))]))
        best_module_override = search_best_local_settings(
            module,
            module_name,
            sg,
            act_stats,
            eval_fn,
            best_overrides_dict,
            module_override_gen_fn,
            bits_activations=args.qe_bits_acts,
            bits_weights=args.qe_bits_wts,
            per_channel=args.qe_per_channel)
        best_overrides_dict[normalized_module_name].update(
            best_module_override)
        distiller.yaml_ordered_save('%s.ptq_greedy_search.yaml' % args.arch,
                                    best_overrides_dict)
        # # end of search - we update the calibration of the next layers:
        # recalibrate_stats(module_name, act_stats)

    quantizer = PostTrainLinearQuantizer(
        model,
        bits_activations=None,
        bits_parameters=None,
        bits_accum=32,
        mode=LinearQuantMode.ASYMMETRIC_SIGNED,
        clip_acts=ClipMode.NONE,
        overrides=deepcopy(best_overrides_dict),
        model_activation_stats=act_stats,
        inputs_quant_auto_fallback=False,
        per_channel_wts=args.qe_per_channel)
    quantizer.prepare_model(dummy_input)
    msglogger.info('best_overrides_dict: %s' % best_overrides_dict)
    msglogger.info('Best score: %f' % eval_fn(quantizer.model))
    return model, best_overrides_dict
def ptq_coordinate_search(quantizer,
                          dummy_input,
                          eval_fn,
                          test_fn=None,
                          method='Powell',
                          maxiter=None,
                          maxfev=None,
                          basinhopping=False,
                          basinhopping_niter=100,
                          init_mode=ClipMode.NONE,
                          init_method=None,
                          search_clipping=False,
                          minimizer_kwargs=None):
    """
    Searches for the optimal post-train quantization configuration (scale/zero_points)
    for a model using numerical methods, as described by scipy.optimize.minimize.
    Args:
        quantizer (distiller.quantization.PostTrainLinearQuantizer): A configured PostTrainLinearQuantizer object
          containing the model being quantized
        dummy_input: an sample expected input to the model
        eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form
          `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
        test_fn (callable): a function to test the current performance of the model. Assumed it has a signature of
          the form `test_fn(model)->dict`, where the returned dict contains relevant results to be logged.
          For example: {'top-1': VAL, 'top-5': VAL, 'loss': VAL}
        method (str or callable): Minimization method as accepted by scipy.optimize.minimize.
        maxiter (int): Maximum number of iterations to perform during minimization
        maxfev (int): Maximum number of total function evaluations to perform during minimization
        basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method,
          will pass the `method` argument to `scipy.optimize.basinhopping`.
        basinhopping_niter (int): Number of iterations to perform if basinhopping is set
        init_mode (ClipMode or callable or str or dict): See 'init_linear_quant_params'
        init_method (str or callable): See 'init_layer_linear_quant_params'
        search_clipping (bool): Search on clipping values instead of directly on scale/zero-point (scale and zero-
          point are inferred from the clipping values)
        minimizer_kwargs (dict): Optional additional arguments for scipy.optimize.minimize
    """
    if not isinstance(quantizer, PostTrainLinearQuantizer):
        raise ValueError(
            'Only PostTrainLinearQuantizer supported, but got a {}'.format(
                quantizer.__class__.__name__))
    if quantizer.prepared:
        raise ValueError(
            'Expecting a quantizer for which prepare_model has not been called'
        )

    run_device = distiller.model_device(quantizer.model)

    original_model = deepcopy(quantizer.model).cpu()
    original_model = fold_batch_norms(original_model, dummy_input)

    if not quantizer.model_activation_stats:
        msglogger.info('Collecting stats for model...')
        model_temp = _make_non_parallel_copy(original_model).to(
            device=run_device)
        act_stats = collect_quant_stats(model_temp,
                                        eval_fn,
                                        inplace_runtime_check=True,
                                        disable_inplace_attrs=True,
                                        save_dir=getattr(
                                            msglogger, 'logdir', '.'))
        if model_temp != original_model:
            del model_temp
        quantizer.model_activation_stats = act_stats
        quantizer.model.quantizer_metadata['params'][
            'model_activation_stats'] = act_stats

    # Preparing model and init conditions:
    msglogger.info("Initializing quantizer...")

    # Make sure weights are re-quantizable and clip-able
    quantizer.save_fp_weights = True
    quantizer.also_clip_weights = True

    # Disable any user set activations clipping - we'll be using init_args
    quantizer.clip_acts = ClipMode.NONE
    for overrides_dict in quantizer.module_overrides_map.values():
        overrides_dict.pop('clip_acts', None)

    quantizer.prepare_model(dummy_input)
    quantizer.model.eval()
    quantizer.model = quantizer.model.cpu()

    validate_quantization_settings(quantizer.model, search_clipping)

    msglogger.info("Initializing quantization parameters...")
    init_linear_quant_params(quantizer,
                             original_model,
                             eval_fn,
                             dummy_input,
                             init_mode,
                             init_method,
                             search_clipping=search_clipping,
                             run_device=run_device)

    msglogger.info("Evaluating initial quantization score...")
    best_data = {
        'score': eval_fn(quantizer.model),
        'qp_dict': deepcopy(quantizer.linear_quant_params)
    }
    msglogger.info("Evaluation set loss after initialization %.3f" %
                   best_data['score'])
    if test_fn:
        msglogger.info('Evaluating on full test set...')
        results = test_fn(quantizer.model)
        s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()])
        msglogger.info('Test: ' + s)

    init_qp_dict = OrderedDict(
        quantizer.named_linear_quant_params(search_clipping, filter=True))
    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping)

    iter_counter = count(1)
    eval_counter = count(1)

    def feed_forward_fn(qp_vec):
        # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping):
        #     return 1e6
        qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping)
        quantizer.update_linear_quant_params(qp_dict)
        loss = eval_fn(quantizer.model)

        i = next(eval_counter)
        if i % 20 == 0:
            msglogger.info('%d evaluations: loss=%.3f' % (i, loss))

        return loss

    def callback(qp_vec):
        score = feed_forward_fn(qp_vec)
        i = next(iter_counter)
        msglogger.info("Iteration %d: \t Score=%.3f" % (i, score))
        if score < best_data['score']:
            best_data['score'] = score
            best_data['qp_dict'] = quant_params_vec2dict(
                keys, qp_vec, search_clipping)
            msglogger.info("Saving current best quantization parameters.")
        if test_fn:
            msglogger.info('Evaluating on full test set...')
            results = test_fn(quantizer.model)
            s = ', '.join(
                ['{} = {:.3f}'.format(k, v) for k, v in results.items()])
            msglogger.info('Test: ' + s)

    options = OrderedDict()
    options['maxiter'] = maxiter
    options['maxfev'] = maxfev

    minimizer_kwargs = minimizer_kwargs or OrderedDict()
    minimizer_kwargs.update({'method': method, 'options': options})
    if basinhopping:
        msglogger.info(
            'Using basinhopping global minimum search with "%s" local minimization method'
            % method)
        res = opt.basinhopping(feed_forward_fn,
                               init_qp_vec,
                               basinhopping_niter,
                               callback=callback,
                               minimizer_kwargs=minimizer_kwargs)
    else:
        msglogger.info('Using "%s" minimization algorithm.' % method)
        res = opt.minimize(feed_forward_fn,
                           init_qp_vec,
                           callback=callback,
                           **minimizer_kwargs)

    msglogger.info('Optimization done')
    msglogger.info('Best score: {}'.format(best_data['score']))
    msglogger.info('Best Configuration: {}'.format(best_data['qp_dict']))
    return quantizer.model, best_data['qp_dict']
def ptq_coordinate_search(model,
                          dummy_input,
                          eval_fn,
                          method='Powell',
                          options=None,
                          act_stats=None,
                          args=None,
                          fold_sequences=True,
                          basinhopping=False,
                          init_args=None,
                          minimizer_kwargs=None,
                          test_fn=None):
    """
    Searches for the optimal post-train quantization configuration (scale/zero_points)
    for a model using numerical methods, as described by scipy.optimize.minimize.
    Args:
        model (nn.Module): model to quantize
        dummy_input: an sample expected input to the model
        eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form
          `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
        method (str or callable): minimization method as accepted by scipy.optimize.minimize.
        options (dict or None): options for the scipy optimizer
        act_stats (OrderedDict): dictionary of statistics per layer, including inputs and outputs.
          for more context refer to collect_quant_stats.
        args: arguments from command-line.
        fold_sequences (bool): flag, indicates to fold sequences before performing the search.
        basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method,
          will pass the `method` argument to `scipy.optimize.basinhopping`.
        init_args (tuple): arguments for initializing the linear quantization parameters.
          Refer to `init_linear_quant_params` for more details.
        minimizer_kwargs (dict): the kwargs for scipy.optimize.minimize procedure.
        test_fn (callable): a function to test the current performance of the model.
    """
    if fold_sequences:
        model = fold_batch_norms(model, dummy_input)
    if args is None:
        args = get_default_args()
    elif isinstance(args, dict):
        updated_args = get_default_args()
        updated_args.__dict__.update(args)
        args = updated_args
    original_model = deepcopy(model)

    if not act_stats and not args.qe_config_file:
        msglogger.info('Collecting stats for model...')
        model_temp = distiller.utils.make_non_parallel_copy(model)
        act_stats = collect_quant_stats(model_temp, eval_fn)
        del model_temp
        if args:
            act_stats_path = '%s_act_stats.yaml' % args.arch
            msglogger.info('Done. Saving act stats into %s' % act_stats_path)
            distiller.yaml_ordered_save(act_stats_path, act_stats)
            args.qe_stats_file = act_stats_path

    # Preparing model and init conditions:
    msglogger.info("Initializing quantizer...")
    quantizer = PostTrainLinearQuantizer.from_args(model, args)

    # Make sure weights are re-quantizable and clip-able
    quantizer.save_fp_weights = True
    quantizer.also_clip_weights = True

    # Disable any user set activations clipping - we'll be using init_args
    quantizer.clip_acts = ClipMode.NONE
    for overrides_dict in quantizer.module_overrides_map.values():
        overrides_dict.pop('clip_acts', None)

    quantizer.prepare_model(dummy_input)
    quantizer.model.eval()

    validate_quantization_settings(args, quantizer.model)

    msglogger.info("Initializing quantization parameters...")
    init_args = init_args or (args.init_mode, args.init_mode_method)
    init_linear_quant_params(quantizer,
                             original_model,
                             eval_fn,
                             dummy_input,
                             *init_args,
                             search_clipping=args.search_clipping)

    msglogger.info("Evaluating initial quantization score...")
    best_data = {
        'score': eval_fn(model),
        'qp_dict': deepcopy(quantizer.linear_quant_params)
    }
    msglogger.info("Evaluation set loss after initialization %.3f" %
                   best_data['score'])
    if test_fn:
        msglogger.info('Evaluating on full test set...')
        l_top1, l_top5, l_loss = test_fn(quantizer.model)
        msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' %
                       (l_loss, l_top1, l_top5))

    init_qp_dict = OrderedDict(
        quantizer.named_linear_quant_params(args.search_clipping, filter=True))
    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict,
                                              args.search_clipping)

    iter_counter = count(1)
    eval_counter = count(1)

    def feed_forward_fn(qp_vec):
        # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping):
        #     return 1e6
        qp_dict = quant_params_vec2dict(keys, qp_vec, args.search_clipping)
        quantizer.update_linear_quant_params(qp_dict)
        loss = eval_fn(quantizer.model)

        i = next(eval_counter)
        if i % 20 == 0:
            msglogger.info('%d evaluations: loss=%.3f' % (i, loss))

        return loss

    def callback(qp_vec):
        score = feed_forward_fn(qp_vec)
        i = next(iter_counter)
        msglogger.info("Iteration %d: \t Score=%.3f" % (i, score))
        if score < best_data['score']:
            best_data['score'] = score
            best_data['qp_dict'] = quant_params_vec2dict(
                keys, qp_vec, args.search_clipping)
            msglogger.info("Saving current best quantization parameters.")
        if test_fn:
            msglogger.info('Evaluating on full test set...')
            l_top1, l_top5, l_loss = test_fn(quantizer.model)
            msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' %
                           (l_loss, l_top1, l_top5))

    options = options or OrderedDict()
    if args.maxiter is not None:
        options['maxiter'] = args.maxiter
    if args.maxfev is not None:
        options['maxfev'] = args.maxfev
    minimizer_kwargs = minimizer_kwargs or OrderedDict()
    minimizer_kwargs.update({'method': method, 'options': options})
    basinhopping = basinhopping or args.basinhopping
    if basinhopping:
        msglogger.info(
            'Using basinhopping global minimum search with "%s" local minimization method'
            % method)
        res = opt.basinhopping(feed_forward_fn,
                               init_qp_vec,
                               args.niter,
                               callback=callback,
                               minimizer_kwargs=minimizer_kwargs)
    else:
        msglogger.info('Using "%s" minimization algorithm.' % method)
        res = opt.minimize(feed_forward_fn,
                           init_qp_vec,
                           callback=callback,
                           **minimizer_kwargs)

    msglogger.info("Optimization done. Best configuration: %s" %
                   best_data['qp_dict'])
    return model, best_data['qp_dict']