def test_fold_batch_norms_inference_no_fold(model, input_shape): orig_model = deepcopy(model) folded_model = mt.fold_batch_norms(model, dummy_input=torch.randn(input_shape), inference=True) for (n_orig, m_orig), (n_folded, m_folded) in zip(orig_model.named_modules(), folded_model.named_modules()): assert n_folded == n_orig assert type(m_folded) == type(m_orig) for (n_orig, p_orig), (n_folded, p_folded) in zip(orig_model.named_parameters(), folded_model.named_parameters()): assert n_folded == n_orig assert (p_folded == p_orig).all().item() == 1
def test_fold_batch_norms_inference(model, input_shape): # Make sure we have non-trivial values to work with nn.init.uniform_(model.seq[1].weight) nn.init.uniform_(model.seq[1].bias) nn.init.uniform_(model.seq[1].running_mean) nn.init.uniform_(model.seq[1].running_var) model.eval() orig_model = deepcopy(model) dummy_input = torch.randn(input_shape) folded_model = mt.fold_batch_norms(model, dummy_input=dummy_input, inference=True) assert type(folded_model.seq[0]) == type(orig_model.seq[0]) assert type(folded_model.seq[1]) == nn.Identity y_orig = orig_model(dummy_input) y_folded = folded_model(dummy_input) torch.testing.assert_allclose(y_folded, y_orig)
def ptq_greedy_search(model, dummy_input, eval_fn, calib_eval_fn=None, recurrent=False, act_stats=None, args=None, module_override_gen_fn=None, input_override_gen_fn=None, fold_sequences=True): """ Perform greedy search on Post Train Quantization configuration for the model. Args: model (nn.Module): the model to quantize dummy_input (torch.Tensor): a dummy input to be passed to the model eval_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that accepts the model. All other arguments should be set in advance (can be done using functools.partial), or they will be left with their default values. calib_eval_fn (function): An 'evaluation' function to use for forward passing through the model to collection quantization calibration statistics. if None provided - will use `eval_fn` as a default. recurrent (bool): a flag to indicate whether the model has recurrent connections. act_stats (OrderedDict): quant calibration activation stats. if None provided - will be calculated on runtime. args (dict or argparse.Namespace): command line arguments. alternatively - a dict. module_override_gen_fn: A function to generate module overrides. assumes signature `def module_override_gen_fn(module: nn.Module, module_name: str, sg: distiller.SummaryGraph, overrides_dict: OrderedDict, **kwargs)-> Generator[OrderedDict, None, None]` input_override_gen_fn: Same as module_override_gen_fn, only quantized inputs to the top level layers. fold_sequences (bool): fold batch norms before quantizing Returns: (quantized_model, best_overrides_dict) Note: It is assumed that `eval_fn` returns a satisfying metric of performance (e.g. accuracy) and the greedy search aims to maximize this metric. """ if args is None: args = get_default_args() elif isinstance(args, dict): updated_args = get_default_args() updated_args.__dict__.update(args) args = updated_args if fold_sequences: model = fold_batch_norms(model, dummy_input) best_overrides_dict = OrderedDict() if args.resume_search_from: with open(args.resume_search_from, 'r') as f: best_overrides_dict = distiller.yaml_ordered_load(f) msglogger.info('Loaded search checkpoint from %s' % args.resume_search_from) overrides_dict = OrderedDict() sg = SummaryGraph(model, dummy_input) modules_to_quantize = sg.layers_topological_order(recurrent) adjacency_map = sg.adjacency_map() modules_dict = OrderedDict( model.named_modules()) # type: OrderedDict[str, nn.Module] modules_to_quantize = [ m for m in modules_to_quantize if m not in args.qe_no_quant_layers ] module_override_gen_fn = module_override_gen_fn or module_override_generator input_override_gen_fn = input_override_gen_fn or input_override_generator calib_eval_fn = calib_eval_fn or eval_fn if not act_stats: msglogger.info('Collecting stats for model...') model_temp = distiller.utils.make_non_parallel_copy(model) act_stats = collect_quant_stats(model_temp, calib_eval_fn) del model_temp if args: act_stats_path = '%s_act_stats.yaml' % args.arch msglogger.info('Done. Saving act stats into %s' % act_stats_path) distiller.yaml_ordered_save(act_stats_path, act_stats) msglogger.info('Evaluating baseline score for model...') base_score = args.base_score or eval_fn(model) msglogger.info("Base score: %.3f" % base_score) def recalibrate_stats(module_name, act_stats): """ Re-collects quant-calibration stats for successor modules of the current module. """ msglogger.info('Recalibrating stats...') modules_to_recalibrate = { op.name for op in adjacency_map[module_name].successors } & set(act_stats) if not modules_to_recalibrate: # either there aren't any successors or # the successors aren't in the stats file - skip return act_stats q = PostTrainLinearQuantizer( distiller.utils.make_non_parallel_copy(model), bits_activations=None, bits_parameters=None, bits_accum=32, mode=LinearQuantMode.ASYMMETRIC_SIGNED, clip_acts=ClipMode.NONE, overrides=deepcopy(best_overrides_dict), model_activation_stats=deepcopy(act_stats), inputs_quant_auto_fallback=False, per_channel_wts=args.qe_per_channel) q.prepare_model(dummy_input) # recalibrate on the current best quantized version of the model. recalib_act_stats = collect_quant_stats( q.model, calib_eval_fn, modules_to_collect=modules_to_recalibrate) msglogger.info('Done.') act_stats.update(recalib_act_stats) return act_stats loaded_from_checkpoint = [] # Quantize inputs: input_modules = get_inputs_to_quantize(sg, args, recurrent) # top level modules for module_name, input_idxs in input_modules.items(): denormalized_module_name = distiller.denormalize_module_name( model, module_name) module = modules_dict[denormalized_module_name] if isinstance(module, SKIP_MODULES): msglogger.info('Skipping module \'%s\' of type %s.' % (module_name, type(module))) continue msglogger.info('Quantizing top level inputs for %s' % module_name) normalized_module_name = module_name if isinstance(model, nn.DataParallel): normalized_module_name = re.sub(r'module\.', '', normalized_module_name) if normalized_module_name in best_overrides_dict and \ best_overrides_dict[normalized_module_name].get('input_overrides', None): # This means the loaded dict already has the module msglogger.info( " Quantizing '%s' based on loaded checkpoint: %s" % (module_name, best_overrides_dict[normalized_module_name])) if best_overrides_dict[normalized_module_name].get( 'bits_activations'): loaded_from_checkpoint.append(normalized_module_name) continue if not best_overrides_dict.get(normalized_module_name, None): best_overrides_dict[normalized_module_name] = OrderedDict() for input_idx in input_idxs: best_module_override = search_best_local_settings( module, module_name, sg, act_stats, eval_fn, best_overrides_dict, input_override_gen_fn, input_idx=input_idx, bits_activations=args.qe_bits_acts, bits_weights=args.qe_bits_wts, per_channel=args.qe_per_channel) best_overrides_dict[normalized_module_name].update( best_module_override) # Leave only the input_overrides settings: current_input_overrides = best_overrides_dict[normalized_module_name][ 'input_overrides'] best_overrides_dict[normalized_module_name] = override_odict( input_overrides=current_input_overrides) # Quantize layers as a whole: for module_name in modules_to_quantize: module = modules_dict[module_name] if isinstance(module, SKIP_MODULES): msglogger.info('Skipping module \'%s\' of type %s.' % (module_name, module.__class__.__name__)) continue normalized_module_name = module_name if isinstance(model, nn.DataParallel): normalized_module_name = re.sub(r'module\.', '', normalized_module_name) if normalized_module_name in best_overrides_dict and \ best_overrides_dict[normalized_module_name].get('bits_activations', None)\ and normalized_module_name not in loaded_from_checkpoint: # This means the loaded dict already has the module msglogger.info( " Quantizing '%s'(%s) based on loaded checkpoint: %s" % (module_name, module.__class__.__name__, best_overrides_dict[normalized_module_name])) loaded_from_checkpoint.append(normalized_module_name) continue if not best_overrides_dict.get(normalized_module_name, None): best_overrides_dict[normalized_module_name] = OrderedDict() # Hard coded workaround for avgpool->reshape->fc if normalized_module_name == 'fc': input_override = override_odict(bits_activations=8, clip_acts='NONE') best_overrides_dict['fc'].update( OrderedDict([('input_overrides', OrderedDict([(0, input_override)]))])) best_module_override = search_best_local_settings( module, module_name, sg, act_stats, eval_fn, best_overrides_dict, module_override_gen_fn, bits_activations=args.qe_bits_acts, bits_weights=args.qe_bits_wts, per_channel=args.qe_per_channel) best_overrides_dict[normalized_module_name].update( best_module_override) distiller.yaml_ordered_save('%s.ptq_greedy_search.yaml' % args.arch, best_overrides_dict) # # end of search - we update the calibration of the next layers: # recalibrate_stats(module_name, act_stats) quantizer = PostTrainLinearQuantizer( model, bits_activations=None, bits_parameters=None, bits_accum=32, mode=LinearQuantMode.ASYMMETRIC_SIGNED, clip_acts=ClipMode.NONE, overrides=deepcopy(best_overrides_dict), model_activation_stats=act_stats, inputs_quant_auto_fallback=False, per_channel_wts=args.qe_per_channel) quantizer.prepare_model(dummy_input) msglogger.info('best_overrides_dict: %s' % best_overrides_dict) msglogger.info('Best score: %f' % eval_fn(quantizer.model)) return model, best_overrides_dict
def ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=None, method='Powell', maxiter=None, maxfev=None, basinhopping=False, basinhopping_niter=100, init_mode=ClipMode.NONE, init_method=None, search_clipping=False, minimizer_kwargs=None): """ Searches for the optimal post-train quantization configuration (scale/zero_points) for a model using numerical methods, as described by scipy.optimize.minimize. Args: quantizer (distiller.quantization.PostTrainLinearQuantizer): A configured PostTrainLinearQuantizer object containing the model being quantized dummy_input: an sample expected input to the model eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. test_fn (callable): a function to test the current performance of the model. Assumed it has a signature of the form `test_fn(model)->dict`, where the returned dict contains relevant results to be logged. For example: {'top-1': VAL, 'top-5': VAL, 'loss': VAL} method (str or callable): Minimization method as accepted by scipy.optimize.minimize. maxiter (int): Maximum number of iterations to perform during minimization maxfev (int): Maximum number of total function evaluations to perform during minimization basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method, will pass the `method` argument to `scipy.optimize.basinhopping`. basinhopping_niter (int): Number of iterations to perform if basinhopping is set init_mode (ClipMode or callable or str or dict): See 'init_linear_quant_params' init_method (str or callable): See 'init_layer_linear_quant_params' search_clipping (bool): Search on clipping values instead of directly on scale/zero-point (scale and zero- point are inferred from the clipping values) minimizer_kwargs (dict): Optional additional arguments for scipy.optimize.minimize """ if not isinstance(quantizer, PostTrainLinearQuantizer): raise ValueError( 'Only PostTrainLinearQuantizer supported, but got a {}'.format( quantizer.__class__.__name__)) if quantizer.prepared: raise ValueError( 'Expecting a quantizer for which prepare_model has not been called' ) run_device = distiller.model_device(quantizer.model) original_model = deepcopy(quantizer.model).cpu() original_model = fold_batch_norms(original_model, dummy_input) if not quantizer.model_activation_stats: msglogger.info('Collecting stats for model...') model_temp = _make_non_parallel_copy(original_model).to( device=run_device) act_stats = collect_quant_stats(model_temp, eval_fn, inplace_runtime_check=True, disable_inplace_attrs=True, save_dir=getattr( msglogger, 'logdir', '.')) if model_temp != original_model: del model_temp quantizer.model_activation_stats = act_stats quantizer.model.quantizer_metadata['params'][ 'model_activation_stats'] = act_stats # Preparing model and init conditions: msglogger.info("Initializing quantizer...") # Make sure weights are re-quantizable and clip-able quantizer.save_fp_weights = True quantizer.also_clip_weights = True # Disable any user set activations clipping - we'll be using init_args quantizer.clip_acts = ClipMode.NONE for overrides_dict in quantizer.module_overrides_map.values(): overrides_dict.pop('clip_acts', None) quantizer.prepare_model(dummy_input) quantizer.model.eval() quantizer.model = quantizer.model.cpu() validate_quantization_settings(quantizer.model, search_clipping) msglogger.info("Initializing quantization parameters...") init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode, init_method, search_clipping=search_clipping, run_device=run_device) msglogger.info("Evaluating initial quantization score...") best_data = { 'score': eval_fn(quantizer.model), 'qp_dict': deepcopy(quantizer.linear_quant_params) } msglogger.info("Evaluation set loss after initialization %.3f" % best_data['score']) if test_fn: msglogger.info('Evaluating on full test set...') results = test_fn(quantizer.model) s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()]) msglogger.info('Test: ' + s) init_qp_dict = OrderedDict( quantizer.named_linear_quant_params(search_clipping, filter=True)) keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping) iter_counter = count(1) eval_counter = count(1) def feed_forward_fn(qp_vec): # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping): # return 1e6 qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping) quantizer.update_linear_quant_params(qp_dict) loss = eval_fn(quantizer.model) i = next(eval_counter) if i % 20 == 0: msglogger.info('%d evaluations: loss=%.3f' % (i, loss)) return loss def callback(qp_vec): score = feed_forward_fn(qp_vec) i = next(iter_counter) msglogger.info("Iteration %d: \t Score=%.3f" % (i, score)) if score < best_data['score']: best_data['score'] = score best_data['qp_dict'] = quant_params_vec2dict( keys, qp_vec, search_clipping) msglogger.info("Saving current best quantization parameters.") if test_fn: msglogger.info('Evaluating on full test set...') results = test_fn(quantizer.model) s = ', '.join( ['{} = {:.3f}'.format(k, v) for k, v in results.items()]) msglogger.info('Test: ' + s) options = OrderedDict() options['maxiter'] = maxiter options['maxfev'] = maxfev minimizer_kwargs = minimizer_kwargs or OrderedDict() minimizer_kwargs.update({'method': method, 'options': options}) if basinhopping: msglogger.info( 'Using basinhopping global minimum search with "%s" local minimization method' % method) res = opt.basinhopping(feed_forward_fn, init_qp_vec, basinhopping_niter, callback=callback, minimizer_kwargs=minimizer_kwargs) else: msglogger.info('Using "%s" minimization algorithm.' % method) res = opt.minimize(feed_forward_fn, init_qp_vec, callback=callback, **minimizer_kwargs) msglogger.info('Optimization done') msglogger.info('Best score: {}'.format(best_data['score'])) msglogger.info('Best Configuration: {}'.format(best_data['qp_dict'])) return quantizer.model, best_data['qp_dict']
def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=None, act_stats=None, args=None, fold_sequences=True, basinhopping=False, init_args=None, minimizer_kwargs=None, test_fn=None): """ Searches for the optimal post-train quantization configuration (scale/zero_points) for a model using numerical methods, as described by scipy.optimize.minimize. Args: model (nn.Module): model to quantize dummy_input: an sample expected input to the model eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. method (str or callable): minimization method as accepted by scipy.optimize.minimize. options (dict or None): options for the scipy optimizer act_stats (OrderedDict): dictionary of statistics per layer, including inputs and outputs. for more context refer to collect_quant_stats. args: arguments from command-line. fold_sequences (bool): flag, indicates to fold sequences before performing the search. basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method, will pass the `method` argument to `scipy.optimize.basinhopping`. init_args (tuple): arguments for initializing the linear quantization parameters. Refer to `init_linear_quant_params` for more details. minimizer_kwargs (dict): the kwargs for scipy.optimize.minimize procedure. test_fn (callable): a function to test the current performance of the model. """ if fold_sequences: model = fold_batch_norms(model, dummy_input) if args is None: args = get_default_args() elif isinstance(args, dict): updated_args = get_default_args() updated_args.__dict__.update(args) args = updated_args original_model = deepcopy(model) if not act_stats and not args.qe_config_file: msglogger.info('Collecting stats for model...') model_temp = distiller.utils.make_non_parallel_copy(model) act_stats = collect_quant_stats(model_temp, eval_fn) del model_temp if args: act_stats_path = '%s_act_stats.yaml' % args.arch msglogger.info('Done. Saving act stats into %s' % act_stats_path) distiller.yaml_ordered_save(act_stats_path, act_stats) args.qe_stats_file = act_stats_path # Preparing model and init conditions: msglogger.info("Initializing quantizer...") quantizer = PostTrainLinearQuantizer.from_args(model, args) # Make sure weights are re-quantizable and clip-able quantizer.save_fp_weights = True quantizer.also_clip_weights = True # Disable any user set activations clipping - we'll be using init_args quantizer.clip_acts = ClipMode.NONE for overrides_dict in quantizer.module_overrides_map.values(): overrides_dict.pop('clip_acts', None) quantizer.prepare_model(dummy_input) quantizer.model.eval() validate_quantization_settings(args, quantizer.model) msglogger.info("Initializing quantization parameters...") init_args = init_args or (args.init_mode, args.init_mode_method) init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, *init_args, search_clipping=args.search_clipping) msglogger.info("Evaluating initial quantization score...") best_data = { 'score': eval_fn(model), 'qp_dict': deepcopy(quantizer.linear_quant_params) } msglogger.info("Evaluation set loss after initialization %.3f" % best_data['score']) if test_fn: msglogger.info('Evaluating on full test set...') l_top1, l_top5, l_loss = test_fn(quantizer.model) msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5)) init_qp_dict = OrderedDict( quantizer.named_linear_quant_params(args.search_clipping, filter=True)) keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, args.search_clipping) iter_counter = count(1) eval_counter = count(1) def feed_forward_fn(qp_vec): # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping): # return 1e6 qp_dict = quant_params_vec2dict(keys, qp_vec, args.search_clipping) quantizer.update_linear_quant_params(qp_dict) loss = eval_fn(quantizer.model) i = next(eval_counter) if i % 20 == 0: msglogger.info('%d evaluations: loss=%.3f' % (i, loss)) return loss def callback(qp_vec): score = feed_forward_fn(qp_vec) i = next(iter_counter) msglogger.info("Iteration %d: \t Score=%.3f" % (i, score)) if score < best_data['score']: best_data['score'] = score best_data['qp_dict'] = quant_params_vec2dict( keys, qp_vec, args.search_clipping) msglogger.info("Saving current best quantization parameters.") if test_fn: msglogger.info('Evaluating on full test set...') l_top1, l_top5, l_loss = test_fn(quantizer.model) msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5)) options = options or OrderedDict() if args.maxiter is not None: options['maxiter'] = args.maxiter if args.maxfev is not None: options['maxfev'] = args.maxfev minimizer_kwargs = minimizer_kwargs or OrderedDict() minimizer_kwargs.update({'method': method, 'options': options}) basinhopping = basinhopping or args.basinhopping if basinhopping: msglogger.info( 'Using basinhopping global minimum search with "%s" local minimization method' % method) res = opt.basinhopping(feed_forward_fn, init_qp_vec, args.niter, callback=callback, minimizer_kwargs=minimizer_kwargs) else: msglogger.info('Using "%s" minimization algorithm.' % method) res = opt.minimize(feed_forward_fn, init_qp_vec, callback=callback, **minimizer_kwargs) msglogger.info("Optimization done. Best configuration: %s" % best_data['qp_dict']) return model, best_data['qp_dict']