Example #1
0
    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        if not quickly_collected_only and is_debug():
            stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num,
                                                     self._prunable_layers_num,
                                                     self._max_prunable_flops,
                                                     self._max_prunable_params,
                                                     self.full_flops,
                                                     self.full_params_num)

            nncf_logger.debug(stats.to_str())

        pruned_layers_summary = self._calculate_pruned_layers_summary()
        self._update_benchmark_statistics()
        model_statistics = PrunedModelStatistics(
            self.full_flops, self.current_flops, self.full_params_num,
            self.current_params_num, self.full_filters_num,
            self.current_filters_num, pruned_layers_summary)

        stats = FilterPruningStatistics(model_statistics,
                                        self.scheduler.current_pruning_level,
                                        self.scheduler.target_level,
                                        self.prune_flops)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('filter_pruning', stats)
        return nncf_stats
Example #2
0
    def __enter__(self):
        global _CURRENT_CONTEXT
        self._save_context = _CURRENT_CONTEXT
        _CURRENT_CONTEXT = self
        self._init_thread_local()
        if is_debug():
            self.reset_node_call_counters()

        return self
Example #3
0
    def _get_potential_quantizers_num(self) -> Tuple[int, int]:
        """
        Returns a potential number of quantizers for weights and activations.

        :return: A tuple (wq_potential_num, aq_potential_num) where
            - `wq_potential_num` is a potential number of quantizers for weights.
            - `aq_potential_num` is a potential number of quantizers for activations.
        """
        aq_potential_num = self._info.aq_potential_num if is_debug() else None
        return self._info.wq_potential_num, aq_potential_num
Example #4
0
    def evaluate_strategy(self,
                          collected_strategy: List,
                          skip_constraint=True) -> Tuple:
        assert len(collected_strategy) == len(self.master_df)
        if skip_constraint is not True:
            collected_strategy = self._constrain_model_size(collected_strategy)
        self.master_df[
            'action'] = collected_strategy  # This must be after constraint

        if self.performant_bw:
            self._align_bw_action()
            configs_to_set = self.select_config_for_actions(
                self.master_df['action_aligned'])

            if self._dump_autoq_data or is_debug():
                self._dump_adjacent_quantizer_group_alignment()

            self.master_df['action'] = self.master_df['action_aligned']
        else:
            configs_to_set = self.select_config_for_actions(
                self.master_df['action'])

        self._apply_quantizer_configs_to_model(configs_to_set)

        for idx, qid in zip(self.master_df.index, self.master_df['qid']):
            logger.info("[Q.Env] {:50} | {}".format(
                str(self.qctrl.all_quantizations[find_qid_by_str(
                    self.qctrl, qid)]), idx))

        quantized_score = self._run_quantization_pipeline(
            finetune=self.finetune)

        current_model_size = self.model_size_calculator(
            self._get_quantizer_bitwidth())
        current_model_ratio = self.model_size_calculator.get_model_size_ratio(
            self._get_quantizer_bitwidth())

        current_model_bop_ratio = self.compression_ratio_calculator.run_for_quantizer_setup(
            self.qctrl.get_quantizer_setup_for_current_state())

        reward = self.reward(quantized_score, current_model_ratio)

        info_set = {
            'model_ratio': current_model_ratio,
            'accuracy': quantized_score,
            'model_size': current_model_size,
            'bop_ratio': current_model_bop_ratio
        }

        obs = self.get_normalized_obs(len(collected_strategy) - 1)
        done = True
        self._n_eval += 1

        return obs, reward, done, info_set
Example #5
0
    def forward(self, x):
        if is_debug():
            self.call_count += 1
        # TODO: refactor to get rid of extra if's and calls on each forward
        if not self.is_enabled_quantization():
            return x
        self.set_level_ranges()
        is_exporting = is_tracing_state()
        if is_exporting:
            with no_nncf_trace():
                x = self.run_export_quantization(x)

            # The underlying operator (registered via register_operator) must be executed,
            # otherwise the dynamic graph won't be traced as it was during regular inference.
            # While this does not impact the regular, non-RNN models, for which the graph
            # building and pre-/post-hook calling is only determined by input-agnostic,
            # graph-structure independent trace info (e.g. current op scope and call count),
            # this is important for LSTMs etc. where determining the "first nodes in iteration
            # scopes" depends on whether the input tensors to an operation were traced or not.
            return self.quantize(x, execute_traced_op_as_identity=True)

        return self.quantize(x, execute_traced_op_as_identity=False)
Example #6
0
    def statistics(self,
                   quickly_collected_only: bool = False) -> NNCFStatistics:
        if not quickly_collected_only and is_debug():
            stats = PrunedModelTheoreticalBorderline(self._pruned_layers_num,
                                                     self._prunable_layers_num,
                                                     self._max_prunable_flops,
                                                     self._max_prunable_params,
                                                     self.full_flops,
                                                     self.full_params_num)

            nncf_logger.debug(stats.to_str())

        pruned_layers_summary = {}
        for minfo in self.pruned_module_groups_info.get_all_nodes():
            layer_name = str(minfo.module_scope)
            if layer_name not in pruned_layers_summary:
                pruned_layers_summary[layer_name] = \
                    PrunedLayerSummary(layer_name,
                                       list(minfo.module.weight.size()),
                                       list(self.mask_shape(minfo)),
                                       self.pruning_level_for_mask(minfo))

        self._update_benchmark_statistics()
        model_statistics = PrunedModelStatistics(
            self.full_flops, self.current_flops, self.full_params_num,
            self.current_params_num, self.full_filters_num,
            self.current_filters_num, list(pruned_layers_summary.values()))

        stats = FilterPruningStatistics(model_statistics,
                                        self.scheduler.current_pruning_level,
                                        self.scheduler.target_level,
                                        self.prune_flops)

        nncf_stats = NNCFStatistics()
        nncf_stats.register('filter_pruning', stats)
        return nncf_stats
Example #7
0
    def __init__(self, model: NNCFNetwork,
                 quantization_controller: ExperimentalQuantizationController,
                 hw_precision_constraints: HardwareQuantizationConstraints,
                 eval_loader: torch.utils.data.DataLoader,
                 eval_fn: Callable[[nn.Module, torch.utils.data.DataLoader],
                                   float], hw_config_type: HWConfigType,
                 params: QuantizationEnvParams):

        logger.info("[Q.Env] Instantiating NNCF Quantization Environment")
        self.qctrl = quantization_controller
        self.qmodel = model
        self.eval_loader = eval_loader
        self.eval_fn = eval_fn
        self._hw_precision_constraints = hw_precision_constraints
        self._bn_adaptation = None

        self.model_name = self.qmodel.nncf_module.__class__.__name__

        # Check and only proceed if target device is supported by Q.Env
        self.hw_cfg_type = hw_config_type
        assert self.hw_cfg_type in [None, HWConfigType.VPU]

        # Set target compression ratio
        self.compression_ratio = params.compression_ratio

        self.eval_loader = PartialDataLoader(
            self.eval_loader, iter_ratio=params.eval_subset_ratio)

        # Bool to disable hard resource constraint
        self.skip_constraint = params.skip_constraint

        # Bool to enable bw alignment of adj. Q group to lower precision
        self.performant_bw = params.performant_bw

        # Bool to enable fine-tuning in each episode. Placeholder for now
        self.finetune = False

        # Counter for number of evaluate_strategy calls
        self._n_eval = 0

        # Configure search space for precision according to target device
        if self.hw_cfg_type is None:
            self.model_bitwidth_space = params.bits
        elif self.hw_cfg_type is HWConfigType.VPU:
            self.model_bitwidth_space = self._hw_precision_constraints.get_all_unique_bitwidths(
            )
        self.model_bitwidth_space = sorted(list(self.model_bitwidth_space))

        # Create mapping of QuantizerId to the space of the corresponding quantizer's allowed qconfigs
        #pylint:disable=line-too-long
        self.qconfig_space_map = OrderedDict.fromkeys(
            self.qctrl.all_quantizations.keys(
            ))  # type: Dict[QuantizerId, List[QuantizerConfig]]
        if self.hw_cfg_type is None:
            for qid in self.qconfig_space_map.keys():
                conf = self.qctrl.all_quantizations[qid].get_quantizer_config()
                conf_list_to_set = []
                for bit in self.model_bitwidth_space:
                    bit_adjusted_conf = deepcopy(conf)
                    bit_adjusted_conf.num_bits = bit
                    conf_list_to_set.append(bit_adjusted_conf)
                self.qconfig_space_map[qid] = conf_list_to_set
        else:
            for qid in self.qconfig_space_map:
                conf_list_to_set = []
                bw_vs_qconfigs_dict = self._hw_precision_constraints.get_bitwidth_vs_qconfigs_dict(
                    qid)
                for bitwidth, qconf_list in bw_vs_qconfigs_dict.items():
                    target_qconf = qconf_list[0]
                    if len(qconf_list) > 1:
                        logger.warning(
                            "Received multiple quantizer configurations {qc_lst} for same bitwidth {bw} "
                            "for quantizer {q} - AutoQ can currently only choose among bitwidths, but not "
                            "within quantizer configuration space with the same bitwidths. Selecting {qc} "
                            "as the target configuration for bitwidth {bw}".
                            format(qc_lst=";".join(
                                [str(qconf) for qconf in qconf_list]),
                                   bw=bitwidth,
                                   q=str(qid),
                                   qc=str(target_qconf)))
                    conf_list_to_set.append(target_qconf)

                self.qconfig_space_map[qid] = conf_list_to_set

        # Quantizer Master Table Creation
        self.groups_of_adjacent_quantizers = self.qctrl._groups_of_adjacent_quantizers
        self.quantizer_table = self._create_quantizer_table()

        # Create master dataframe to keep track of quantizable layers and their attributes
        self.master_df, self.state_list = self._get_state_space(
            self.qctrl, self.qmodel, self.quantizer_table)
        if self.master_df.isnull().values.any():
            raise ValueError("Q.Env Master Dataframe has null value(s)")

        assert len(self.quantizer_table) == len(self.qctrl.all_quantizations), \
            "Number of Quantizer is not tally between quantizer table and quantization controller"

        # MinMaxScaler for State Embedding
        self.state_scaler = MinMaxScaler()
        self.state_scaler.fit(self.master_df[self.state_list])

        # Mapping required for quantizer BW alignment flow
        self.adjq_groupwise_intersecting_bw_space = self._create_map_of_adjq_groupid_to_common_bw_space(
        )
        self.adjq_groupwise_df_lut_keys = self._create_map_of_adjq_groupid_to_df_lut_keys(
        )

        # Model Size Calculation
        self.model_size_calculator = ModelSizeCalculator(
            self.qmodel, self.qconfig_space_map)
        self.orig_model_size = self.model_size_calculator.fp_model_size
        self.min_model_size = self.model_size_calculator.min_model_size
        self.max_model_size = self.model_size_calculator.max_model_size
        self.target_model_size = self.orig_model_size * self.compression_ratio

        if self.target_model_size < self.min_model_size and self.target_model_size > self.max_model_size:
            raise ValueError(
                "Model Size Ratio {} is out of bound ({}, {})".format(
                    self.compression_ratio,
                    self.min_model_size / self.orig_model_size,
                    self.max_model_size / self.orig_model_size))

        # Compression Ratio Calculation (BOP relative to 8-bit)
        self.compression_ratio_calculator = CompressionRatioCalculator(
            self.qmodel.get_flops_per_module(),
            self.qctrl.get_quantizer_setup_for_current_state(), self.qctrl.
            groups_of_adjacent_quantizers.weight_qp_id_per_activation_qp_id)

        # Evaluate and store metric score of pretrained model
        self._evaluate_pretrained_model()
        self.qmodel_init_sd = deepcopy(self.qmodel.state_dict())

        self.reset()

        self._dump_autoq_data = params.dump_init_precision_data
        if self._dump_autoq_data or is_debug():
            dump_dir = params.log_dir
            if dump_dir is None:
                dump_dir = DEBUG_LOG_DIR
            self.dump_dir = Path(dump_dir) / Path("autoq_env_dump")
            self.dump_dir.mkdir(parents=True, exist_ok=True)
            # Serialize Q.Env information. Note that these functions should be at the end of Q.Env Initialization.
            self._dump_master_df()
            self._dump_quantized_graph()
            self._dump_groups_of_adjacent_quantizers()
Example #8
0
    def apply_init(self) -> SingleConfigQuantizerSetup:
        if not self._weight_quantizations_by_execution_order:
            return self._algo.get_quantizer_setup_for_current_state()

        original_device = next(self._model.parameters()).device
        self._model.to(self._init_device)

        traces_per_layer = self._calc_traces(self._criterion_fn,
                                             self._criterion,
                                             self._iter_number,
                                             self._tolerance)
        if not traces_per_layer:
            raise RuntimeError('Failed to calculate hessian traces!')

        traces_order = traces_per_layer.traces_order
        weight_qconfig_sequences_in_trace_order, covering_qconfig_sequences = \
            self.get_qconfig_sequences_constrained_by_traces_order(traces_order)

        weight_quantizer_ids_in_execution_order = list(
            self._weight_quantizations_by_execution_order.keys())

        if not weight_qconfig_sequences_in_trace_order:
            warnings.warn(
                'All bitwidths configurations are incompatible with HW Config!',
                RuntimeWarning)
            return None

        weight_qconfig_sequences_in_trace_order = \
            self._filter_qconfig_sequences_by_excessive_bitwidth(weight_qconfig_sequences_in_trace_order)

        if self._bitwidth_assignment_mode == BitwidthAssignmentMode.STRICT:
            weight_qconfig_sequences_in_trace_order = \
                self._filter_qconfig_sequences_by_grouped_weight_quantizers(weight_qconfig_sequences_in_trace_order,
                                                                            weight_quantizer_ids_in_execution_order,
                                                                            self._groups_of_adjacent_quantizers,
                                                                            traces_order)
        if not weight_qconfig_sequences_in_trace_order:
            warnings.warn(
                'No bitwidths configurations are left after removing inconsistent groups of weight quantizers'
                ' with adjacent activation quantizers!', RuntimeWarning)
            return self._algo.get_quantizer_setup_for_current_state()

        compression_ratio_per_qconfig = self.get_compression_ratio_per_qconfig_sequence(
            weight_qconfig_sequences_in_trace_order, traces_order)
        min_ratio = min(compression_ratio_per_qconfig)
        max_ratio = max(compression_ratio_per_qconfig)
        if not min_ratio <= self._compression_ratio <= max_ratio:
            raise AttributeError(
                'Invalid compression ratio={}. Should be within range [{:.3f}, {:.3f}]'
                .format(self._compression_ratio, min_ratio, max_ratio))

        perturbations, weight_observers = self.calc_quantization_noise(
            covering_qconfig_sequences, traces_order)

        metric_per_qconfig_sequence = self.calc_hawq_metric_per_qconfig_sequence(
            weight_qconfig_sequences_in_trace_order, perturbations,
            traces_per_layer, self._init_device)

        qconfig_sequence_index = self.choose_qconfig_sequence(
            metric_per_qconfig_sequence, compression_ratio_per_qconfig,
            self._compression_ratio)
        chosen_qconfig_sequence_in_traces_order = weight_qconfig_sequences_in_trace_order[
            qconfig_sequence_index]
        chosen_qconfig_sequence_in_execution_order = traces_order.get_execution_order_configs(
            chosen_qconfig_sequence_in_traces_order)
        bitwidth_sequence = [
            qconfig.num_bits
            for qconfig in chosen_qconfig_sequence_in_execution_order
        ]
        nncf_logger.info(
            'Chosen HAWQ bitwidth sequence with ratio={:.2f}, bitwidth per weightable layer={}'
            .format(compression_ratio_per_qconfig[qconfig_sequence_index],
                    bitwidth_sequence))
        nncf_logger.debug(
            'Order of the weightable layers in the HAWQ bitwidth sequence (in descending order of average'
            ' Hessian traces) ={}'.format(traces_order))

        final_quantizer_setup = self.get_quantizer_setup_for_qconfig_sequence(
            chosen_qconfig_sequence_in_traces_order, traces_order)
        if is_debug() or self._dump_hawq_data:
            hawq_debugger = HAWQDebugger(
                weight_qconfig_sequences_in_trace_order, perturbations,
                weight_observers, traces_per_layer, self._bitwidths)
            hawq_debugger.dump_metric_MB(metric_per_qconfig_sequence)
            hawq_debugger.dump_metric_flops(metric_per_qconfig_sequence,
                                            compression_ratio_per_qconfig,
                                            qconfig_sequence_index)
            hawq_debugger.dump_avg_traces()
            hawq_debugger.dump_density_of_quantization_noise()
            hawq_debugger.dump_perturbations_ratio()
            new_ctrl, new_model = self._algo.apply_new_quantizer_setup(
                final_quantizer_setup)
            groups_of_adjacent_quantizers = new_ctrl.groups_of_adjacent_quantizers
            hawq_debugger.dump_bitwidth_graph(new_ctrl, new_model,
                                              groups_of_adjacent_quantizers)
        bitwidth_per_scope = self.get_bitwidth_per_scope(final_quantizer_setup)
        str_bw = [
            str(element)
            for element in self.get_bitwidth_per_scope(final_quantizer_setup)
        ]
        nncf_logger.info('\n'.join(
            ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))
        from nncf.common.utils.debug import DEBUG_LOG_DIR
        Path(DEBUG_LOG_DIR).mkdir(parents=True, exist_ok=True)
        with safe_open(Path(DEBUG_LOG_DIR) / 'bitwidth_per_scope.json',
                       "w") as outfile:
            json.dump({'bitwidth_per_scope': bitwidth_per_scope},
                      outfile,
                      indent=4,
                      sort_keys=False)
        self._model.to(original_device)
        return final_quantizer_setup
Example #9
0
    def apply_init(self) -> SingleConfigQuantizerSetup:
        from nncf.torch.automl.environment.quantization_env import QuantizationEnv
        from nncf.torch.automl.agent.ddpg.ddpg import DDPG
        from nncf.common.utils.debug import DEBUG_LOG_DIR

        if self._dump_autoq_data or is_debug():
            dump_dir = self._init_args.config.get('log_dir', None)
            if dump_dir is None:
                dump_dir = DEBUG_LOG_DIR
            self.dump_dir = Path(dump_dir) / Path("autoq") / Path(
                "autoq_agent_dump")
            self.dump_dir.mkdir(parents=True, exist_ok=True)

            self.policy_dict = OrderedDict()  #key: episode
            self.best_policy_dict = OrderedDict()  #key: episode

            self._init_args.config['episodic_nncfcfg'] = str(
                self.dump_dir / "episodic_nncfcfg")
            os.makedirs(self._init_args.config['episodic_nncfcfg'],
                        exist_ok=True)

            try:
                from torch.utils.tensorboard import SummaryWriter
                self.tb_writer = SummaryWriter(self.dump_dir)
                # log compression config to tensorboard
                self.tb_writer.add_text(
                    'AutoQ/run_config',
                    json.dumps(self._init_args.config['compression'],
                               indent=4,
                               sort_keys=False).replace("\n", "\n\n"), 0)
            except ModuleNotFoundError:
                logger.warning(
                    "Tensorboard installation not found! Install tensorboard Python package "
                    "in order for AutoQ tensorboard statistics data to be dumped"
                )

        start_ts = datetime.now()

        from nncf.torch.automl.environment.quantization_env import QuantizationEnvParams
        env_params = QuantizationEnvParams(
            compression_ratio=self._params.compression_ratio,
            eval_subset_ratio=self._params.eval_subset_ratio,
            skip_constraint=self._params.skip_constraint,
            performant_bw=True,
            finetune=self._params.finetune,
            bits=self._params.bits,
            dump_init_precision_data=self._dump_autoq_data,
            log_dir=Path(DEBUG_LOG_DIR) / Path("autoq"))

        # Instantiate Quantization Environment
        env = QuantizationEnv(self._model,
                              self.quantization_controller,
                              self._hw_precision_constraints,
                              self._init_args.data_loader,
                              self._init_args.eval_fn,
                              hw_config_type=self._hw_cfg_type,
                              params=env_params)

        nb_state = len(env.state_list)
        nb_action = 1

        # Control buffer length at run manager level
        if "warmup_iter_number" not in self._ddpg_hparams_override:
            self._ddpg_hparams_override["warmup_iter_number"] = 10

        self._ddpg_hparams_override["rmsize"] = \
            self._ddpg_hparams_override["warmup_iter_number"] * (len(env.master_df)+1)

        # Instantiate Automation Agent
        agent = DDPG(nb_state,
                     nb_action,
                     self._iter_number,
                     hparam_override=self._ddpg_hparams_override)

        if self._dump_autoq_data and self.tb_writer is not None:
            # Need to replace '|' in nodestr (QuantizerId/QuantizerPointId)
            # to '+' as it is a special character in markdown
            temp_df = deepcopy(env.master_df[env.state_list + ['n_op']])
            temp_df["modified_nodestr"] = list(
                map(lambda x: x.replace("|", "+"), temp_df.index.tolist()))
            temp_df = temp_df.set_index("modified_nodestr").reset_index()
            self.tb_writer.add_text('AutoQ/state_embedding',
                                    temp_df.to_markdown())

        best_policy, best_reward = self._search(agent, env)

        end_ts = datetime.now()

        final_qid_vs_qconfig_map = env.select_config_for_actions(best_policy)

        final_quantizer_setup = self.quantization_controller.get_quantizer_setup_for_current_state(
        )
        for qp_id, qconf in final_qid_vs_qconfig_map.items():
            final_quantizer_setup.quantization_points[qp_id].qconfig = qconf

        str_bw = [
            str(element)
            for element in self.get_bitwidth_per_scope(final_quantizer_setup)
        ]
        logger.info('\n'.join(
            ['[AutoQ]\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))
        logger.info('[AutoQ] best_reward: {}'.format(best_reward))
        logger.info('[AutoQ] best_policy: {}'.format(best_policy))
        logger.info("[AutoQ] Search Complete")
        logger.info(
            "[AutoQ] Elapsed time of AutoQ Precision Initialization (): {}".
            format(end_ts - start_ts))
        return final_quantizer_setup
Example #10
0
    def wrapped(*args, **kwargs):
        ctx = get_current_context()
        if not ctx or getattr(ctx, 'in_operator', False) or not ctx.is_tracing:
            op1 = operator(*args, **kwargs)
            return op1

        ctx.in_operator = True

        try:
            if operator_info.skip_trace:
                result = operator(*args, **kwargs)
            elif ctx.is_forwarding:
                from nncf.torch.dynamic_graph.trace_functions import forward_trace_only
                result = forward_trace_only(operator, *args, **kwargs)
            else:
                node = None
                op_name = operator_info.name
                op_address = ctx.get_caller_context(op_name)

                layer_attrs = None
                ignored_algos = []
                # Collect module attributes, if required
                if ctx.trace_dynamic_graph:
                    if op_name in OP_NAMES_REQUIRING_MODULE_ATTRS:
                        curr_module = ctx.get_current_module()
                        if curr_module is None:
                            raise RuntimeError("Operation {} requires module attributes, "
                                               "but it was executed outside any module".format(op_name))
                        layer_attrs = _get_layer_attributes(curr_module, op_name)
                        if isinstance(curr_module, _NNCFModuleMixin):
                            ignored_algos = deepcopy(curr_module.ignored_algorithms)

                ctx.register_operator_call(op_address.operator_name, op_address.scope_in_model)
                op_input = OperatorInput(list(args), kwargs)
                processed_input = ctx.execute_pre_hooks(op_address, op_input)

                if ctx.trace_dynamic_graph:
                    tensor_metas = make_tensor_metas(processed_input)
                    node = ctx.find_operator_node(tensor_metas, op_address)

                args = tuple(processed_input.op_args)
                kwargs = processed_input.op_kwargs
                result = operator(*args, **kwargs)

                if isinstance(result, type(NotImplemented)):
                    nncf_logger.debug("Operation {} returned NotImplemented".format(op_name))
                elif ctx.trace_dynamic_graph and node is None:
                    node = ctx.maybe_add_node(processed_input, tensor_metas, op_address, layer_attrs, ignored_algos)

                if is_debug() and ctx.trace_dynamic_graph and node is not None:
                    ctx.register_node_call(node)
                result = trace_tensors(result, node)
                result = ctx.execute_post_hooks(op_address, result)
        except:
            # Looks like the __repr__ call made during IDE debug to display tensor contents does not exit properly,
            # but instead throws an exception. This try...except block handles such a situation.
            # Otherwise the context is stuck in the "in_operator == True" state.
            ctx.in_operator = False
            raise

        ctx.in_operator = False
        return result
Example #11
0
    def __init__(self,
                 module,
                 input_infos: List[ModelInputInfo],
                 dummy_forward_fn=None,
                 wrap_inputs_fn=None,
                 scopes_without_shape_matching=None,
                 ignored_scopes=None,
                 target_scopes=None,
                 reset: bool = False,
                 wrap_outputs_fn=None,
                 original_model_accuracy=None):
        super().__init__()
        self._set_nncf_wrapped_model(module)
        self._forward_signature = inspect.signature(module.forward)
        self.input_infos = input_infos

        self._original_model_accuracy = original_model_accuracy

        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self._user_dummy_forward_fn = dummy_forward_fn
        self._kd_loss_handler = None

        try:
            device = next(module.parameters()).device
        except StopIteration:
            # Param-less model, assume CPU
            device = 'cpu'

        if wrap_inputs_fn is not None:
            self._wrap_inputs_fn = wrap_inputs_fn
        else:
            self.__input_infos_based_input_wrapper = InputInfoWrapManager(
                self.input_infos,
                self._forward_signature,
                module_ref_for_device=self)
            self._wrap_inputs_fn = self.__input_infos_based_input_wrapper.wrap_inputs

        if wrap_outputs_fn is not None:
            self._wrap_outputs_fn = wrap_outputs_fn
        else:
            self._wrap_outputs_fn = wrap_nncf_model_outputs_with_objwalk

        self._nncf_module_scopes = []  # type: List[Scope]
        self.scopes_without_shape_matching = scopes_without_shape_matching
        self.debug_interface = CombinedDebugInterface() if is_debug() else None
        self._extra_module_types = []  # type: List[ExtraCompressionModuleType]
        # pylint:disable=line-too-long
        self._insertions_into_original_graph = {
        }  # type: Dict[PTTargetPoint, List[Tuple[Callable, TransformationPriority]]]

        _orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=True, with_output_tracing=True)

        nncf_wrapped_model = self.get_nncf_wrapped_model()
        eval_only_op_scopes = self._collect_eval_only_op_scopes(
            nncf_wrapped_model, _orig_graph_build_forward_fn)

        # all modules called in eval mode should be replaced prior to graph building
        self._replace_modules_by_nncf_modules(device, eval_only_op_scopes,
                                              reset)

        _orig_context = TracingContext()

        _orig_context.add_node_comparators([MODEL_INPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        _orig_context.add_node_comparators([MODEL_OUTPUT_OP_NAME],
                                           ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            _orig_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_dynamic_graph = GraphTracer(
            _orig_graph_build_forward_fn).trace_graph(nncf_wrapped_model,
                                                      _orig_context,
                                                      as_eval=True)
        self._original_graph = GraphConverter.convert(
            self._original_dynamic_graph, input_infos=self.input_infos)
        self._compressed_graph = None  # type: PTNNCFGraph

        self._compressed_context = TracingContext()

        self._dummy_forward_fn = self._get_dummy_forward_fn_for_graph_building(
            with_input_tracing=False, with_output_tracing=False)
        self._in_user_dummy_forward = False

        self._compressed_context.add_node_comparators(
            [MODEL_INPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        self._compressed_context.add_node_comparators(
            [MODEL_OUTPUT_OP_NAME], ShapeIgnoringTensorMetaComparator())
        if self.scopes_without_shape_matching:
            self._compressed_context.add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())
        self._load_listener = None