def __init__(self, algo: 'QuantizationController', config: Config,
              default_activation_bitwidth: int, default_weight_bitwidth: int, criterion: _Loss,
              data_loader: DataLoader, is_distributed: bool = False):
     super().__init__(algo, config, default_activation_bitwidth, default_weight_bitwidth,
                      criterion, data_loader, is_distributed)
     self._traces_per_layer_path = config.get('traces_per_layer_path', None)
     self._num_data_points = config.get('num_data_points', 200)
     self._iter_number = config.get('iter_number', 200)
     self._tolerance = config.get('tolerance', 1e-5)
     self._bits = config.get('bits', [4, 8])
    def __init__(self, algo: 'QuantizationController', config: Config,
                 default_activation_bitwidth: int, default_weight_bitwidth: int,
                 criterion: _Loss, data_loader: DataLoader, is_distributed: bool = False):
        self._algo = algo
        self._model = self._algo._model  # type: NNCFNetwork
        self._bitwidth_per_scope = config.get('bitwidth_per_scope', {})  # type: List[List]
        self._default_activation_bitwidth = default_activation_bitwidth
        self._default_weight_bitwidth = default_weight_bitwidth
        self._criterion = criterion
        self._data_loader = data_loader
        self._is_distributed = is_distributed

        self._all_quantizations = {}
        self._ordered_weight_quantizations = []
        for class_type in QUANTIZATION_MODULES.registry_dict.values():
            quantization_type = class_type.__name__
            act_module_dict = self._model.get_compression_modules_by_type(
                CompressionModuleType.ACTIVATION_QUANTIZER)
            func_module_dict = self._model.get_compression_modules_by_type(CompressionModuleType.FUNCTION_QUANTIZER)
            weight_module_dict = self._model.get_nncf_wrapped_model()
            self._all_quantizations.update(get_all_modules_by_type(act_module_dict, quantization_type))
            self._all_quantizations.update(get_all_modules_by_type(func_module_dict, quantization_type))
            ops_quantizations = get_all_modules_by_type(weight_module_dict, quantization_type)
            self._ordered_weight_quantizations.extend([q for q in ops_quantizations.values() if q.is_weights])
            self._all_quantizations.update(ops_quantizations)
class PruningScheduler(CompressionScheduler):
    def __init__(self, pruning_algo, params: Config = None):
        super().__init__()
        if params is None:
            self._params = Config()
        else:
            self._params = params

        self.algo = pruning_algo

        # Number of initial steps of training before pruning
        self.num_init_steps = self._params.get('num_init_steps', 0)
        self.pruning_steps = self._params.get('pruning_steps', 100)

        # Pruning rates
        self.initial_pruning = self._params.get('pruning_init', 0)
        self.pruning_target = self._params.get('pruning_target', 0.5)

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self._set_pruning_level()

    def epoch_step(self, epoch=None):
        super().epoch_step(epoch)
        self._set_pruning_level()

    def _set_pruning_level(self):
        self.algo.set_pruning_rate(self.current_pruning_level)

        if self.last_epoch >= (self.pruning_steps + self.num_init_steps):
            self.algo.freeze()

    def _calc_density_level(self):
        raise NotImplementedError

    @property
    def current_pruning_level(self):
        if self.last_epoch >= self.num_init_steps:
            return 1 - self._calc_density_level()
        return 0
def create_compressed_model(
        model: Module,
        config: Config,
        dummy_forward_fn: Callable[[Module], Any] = None,
        dump_graphs=True
) -> Tuple[CompressionAlgorithmController, NNCFNetwork]:
    """
    The main function used to produce a model ready for compression fine-tuning from an original PyTorch
    model and a configuration object.
    dummy_forward_fn
    :param model: The original model. Should have its parameters already loaded from a checkpoint or another
    source.
    :param config: A configuration object used to determine the exact compression modifications to be applied
    to the model
    :param dummy_forward_fn: will be used instead of a *forward* function call to build
    the internal graph representation via tracing. Specifying this is useful when the original training pipeline
    has special formats of data loader output or has additional *forward* arguments other than input tensors.
    Otherwise, the *forward* call of the model during graph tracing will be made with mock tensors according
    to the shape specified in the config object.
    :param dump_graphs: Whether or not should also dump the internal graph representation of the
    original and compressed models in the .dot format into the log directory.
    :return: A controller for the compression algorithm (or algorithms, in which case the controller
    is an instance of CompositeCompressionController) and the model ready for compression wrapped
    as an object of NNCFNetwork."""

    if dump_graphs:
        if dummy_forward_fn is None:
            input_info_list = create_input_infos(config)
            graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=True))
        else:
            graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)

        if is_main_process():
            graph = graph_builder.build_graph(model)
            graph.dump_graph(osp.join(config.log_dir, "original_graph.dot"),
                             extended=True)

    if is_debug():
        set_debug_log_dir(config.log_dir)

    input_info_list = create_input_infos(config)
    scopes_without_shape_matching = config.get('scopes_without_shape_matching',
                                               [])
    ignored_scopes = config.get('ignored_scopes')
    target_scopes = config.get('target_scopes')

    compressed_model = NNCFNetwork(
        model,
        input_infos=input_info_list,
        dummy_forward_fn=dummy_forward_fn,
        ignored_scopes=ignored_scopes,
        target_scopes=target_scopes,
        scopes_without_shape_matching=scopes_without_shape_matching)

    compression_algo_builder_list = create_compression_algorithm_builders(
        config)

    for builder in compression_algo_builder_list:
        compressed_model = builder.apply_to(compressed_model)
    compression_ctrl = compressed_model.commit_compression_changes()

    if dump_graphs and is_main_process() and compression_algo_builder_list:
        if dummy_forward_fn is None:
            compressed_graph_builder = GraphBuilder(
                custom_forward_fn=create_dummy_forward_fn(
                    input_info_list, with_input_tracing=False))
        else:
            compressed_graph_builder = GraphBuilder(
                custom_forward_fn=dummy_forward_fn)

        graph = compressed_graph_builder.build_graph(
            compressed_model, compressed_model.get_tracing_context())
        graph.dump_graph(osp.join(config.log_dir, "compressed_graph.dot"),
                         extended=True)

    return compression_ctrl, compressed_model