Example #1
0
    def export_samples(
        self,
        sample_batches: List[Any],
        sample_labels: List[Any] = None,
        exp_counter: int = 0,
    ):
        """
        Export a set list of sample batches as inputs and outputs through the model.

        :param sample_batches: a list of the sample batches to feed through the module
                               for saving inputs and outputs
        :param sample_labels: an optional list of sample labels that correspond to the
            the batches for saving
        :param exp_counter: the counter to start exporting the tensor files at
        """
        sample_batches = [
            tensors_to_device(batch, "cpu") for batch in sample_batches
        ]
        inputs_dir = os.path.join(self._output_dir, "_sample-inputs")
        outputs_dir = os.path.join(self._output_dir, "_sample-outputs")
        labels_dir = os.path.join(self._output_dir, "_sample-labels")

        with torch.no_grad():
            for batch, lab in zip(
                    sample_batches,
                    sample_labels
                    if sample_labels else [None for _ in sample_batches],
            ):
                out = tensors_module_forward(batch, self._module)

                exported_input = tensors_export(
                    batch,
                    inputs_dir,
                    name_prefix="inp",
                    counter=exp_counter,
                    break_batch=True,
                )
                if isinstance(out, dict):
                    new_out = []
                    for key in out:
                        new_out.append(out[key])
                    out = new_out
                exported_output = tensors_export(
                    out,
                    outputs_dir,
                    name_prefix="out",
                    counter=exp_counter,
                    break_batch=True,
                )

                if lab is not None:
                    tensors_export(lab,
                                   labels_dir,
                                   "lab",
                                   counter=exp_counter,
                                   break_batch=True)

                assert len(exported_input) == len(exported_output)
                exp_counter += len(exported_input)
Example #2
0
    def export_onnx(
        self,
        sample_batch: Any,
        name: str = "model.onnx",
        opset: int = DEFAULT_ONNX_OPSET,
        disable_bn_fusing: bool = True,
        convert_qat: bool = False,
        **export_kwargs,
    ):
        """
        Export an onnx file for the current module and for a sample batch.
        Sample batch used to feed through the model to freeze the graph for a
        particular execution.

        :param sample_batch: the batch to export an onnx for, handles creating the
            static graph for onnx as well as setting dimensions
        :param name: name of the onnx file to save
        :param opset: onnx opset to use for exported model. Default is 11, if torch
            version is 1.2 or below, default is 9
        :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
            fusing during torch export. Default and suggested setting is True. Batch
            norm fusing will change the exported parameter names as well as affect
            sensitivity analyses of the exported graph.  Additionally, the DeepSparse
            inference engine, and other engines, perform batch norm fusing at model
            compilation.
        :param convert_qat: if True and quantization aware training is detected in
            the module being exported, the resulting QAT ONNX model will be converted
            to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
            is False.
        :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
            call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
            See more on the torch.onnx.export api spec in the PyTorch docs:
            https://pytorch.org/docs/stable/onnx.html
        """
        if not export_kwargs:
            export_kwargs = {}
        if "output_names" not in export_kwargs:
            sample_batch = tensors_to_device(sample_batch, "cpu")
            module = deepcopy(self._module).cpu()
            module.eval()
            with torch.no_grad():
                out = tensors_module_forward(sample_batch,
                                             module,
                                             check_feat_lab_inp=False)
                export_kwargs["output_names"] = self.get_output_names(out)
        export_onnx(
            module=self._module,
            sample_batch=sample_batch,
            file_path=os.path.join(self._output_dir, name),
            opset=opset,
            disable_bn_fusing=disable_bn_fusing,
            convert_qat=convert_qat,
            **export_kwargs,
        )
Example #3
0
    def forward(self, data: Any, pred: Any) -> Dict[str, Tensor]:
        """
        override to calculate the knowledge distillation loss if kd_settings
        is supplied and not None

        :param data: the input data to the model, expected to contain the labels
        :param pred: the predicted output from the model
        :return: a dictionary containing all calculated losses and metrics with
            the loss from the loss_fn at DEFAULT_LOSS_KEY
        """
        losses = super().forward(data, pred)

        if self._kd_settings is not None:
            with torch.no_grad():
                teacher = self._kd_settings.teacher  # type: Module
                preds_teacher = tensors_module_forward(
                    self.get_inputs(data, pred, TEACHER_LOSS_KEY),
                    teacher.eval())

            preds_teacher = self.get_preds(data, preds_teacher,
                                           TEACHER_LOSS_KEY)

            soft_log_probs = TF.log_softmax(
                self.get_preds(data, pred, DEFAULT_LOSS_KEY) /
                self._kd_settings.temp_student,
                dim=1,
            )
            soft_targets = TF.softmax(preds_teacher /
                                      self._kd_settings.temp_teacher,
                                      dim=1)
            distill_loss = (
                TF.kl_div(soft_log_probs, soft_targets, size_average=False) /
                soft_targets.shape[0])

            if not self._kd_settings.contradict_hinton:
                # in hinton's original paper they included T^2 as a scaling factor
                # some implementations dropped this factor
                # so contradicting hinton does not scale by T^2
                distill_loss = (
                    (self._kd_settings.temp_student +
                     self._kd_settings.temp_teacher) / 2)**2 * distill_loss

            losses[DEFAULT_LOSS_KEY] = (
                self._kd_settings.weight * distill_loss +
                (1 - self._kd_settings.weight) * losses[DEFAULT_LOSS_KEY])

        return losses
Example #4
0
def export_onnx(
    module: Module,
    sample_batch: Any,
    file_path: str,
    opset: int = DEFAULT_ONNX_OPSET,
    disable_bn_fusing: bool = True,
    convert_qat: bool = False,
    dynamic_axes: Union[str, Dict[str, List[int]]] = None,
    skip_input_quantize: bool = False,
    **export_kwargs,
):
    """
    Export an onnx file for the current module and for a sample batch.
    Sample batch used to feed through the model to freeze the graph for a
    particular execution.

    :param module: torch Module object to export
    :param sample_batch: the batch to export an onnx for, handles creating the
        static graph for onnx as well as setting dimensions
    :param file_path: path to the onnx file to save
    :param opset: onnx opset to use for exported model. Default is 11, if torch
        version is 1.2 or below, default is 9
    :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
        fusing during torch export. Default and suggested setting is True. Batch
        norm fusing will change the exported parameter names as well as affect
        sensitivity analyses of the exported graph.  Additionally, the DeepSparse
        inference engine, and other engines, perform batch norm fusing at model
        compilation.
    :param convert_qat: if True and quantization aware training is detected in
        the module being exported, the resulting QAT ONNX model will be converted
        to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
        is False.
    :param dynamic_axes: dictionary of input or output names to list of dimensions
        of those tensors that should be exported as dynamic. May input 'batch'
        to set the first dimension of all inputs and outputs to dynamic. Default
        is an empty dict
    :param skip_input_quantize: if True, the export flow will attempt to delete
        the first Quantize Linear Nodes(s) immediately after model input and set
        the model input type to UINT8. Default is False
    :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
        call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
        See more on the torch.onnx.export api spec in the PyTorch docs:
        https://pytorch.org/docs/stable/onnx.html
    """
    if not export_kwargs:
        export_kwargs = {}

    if isinstance(sample_batch, Dict) and not isinstance(
            sample_batch, collections.OrderedDict):
        warnings.warn(
            "Sample inputs passed into the ONNX exporter should be in "
            "the same order defined in the model forward function. "
            "Consider using OrderedDict for this purpose.",
            UserWarning,
        )

    sample_batch = tensors_to_device(sample_batch, "cpu")
    create_parent_dirs(file_path)

    module = deepcopy(module).cpu()
    module.eval()

    with torch.no_grad():
        out = tensors_module_forward(sample_batch,
                                     module,
                                     check_feat_lab_inp=False)

    if "input_names" not in export_kwargs:
        if isinstance(sample_batch, Tensor):
            export_kwargs["input_names"] = ["input"]
        elif isinstance(sample_batch, Dict):
            export_kwargs["input_names"] = list(sample_batch.keys())
            sample_batch = tuple(
                [sample_batch[f] for f in export_kwargs["input_names"]])
        elif isinstance(sample_batch, Iterable):
            export_kwargs["input_names"] = [
                "input_{}".format(index)
                for index, _ in enumerate(iter(sample_batch))
            ]
            if isinstance(sample_batch, List):
                sample_batch = tuple(
                    sample_batch)  # torch.onnx.export requires tuple

    if "output_names" not in export_kwargs:
        export_kwargs["output_names"] = _get_output_names(out)

    if dynamic_axes == "batch":
        dynamic_axes = {
            tensor_name: {
                0: "batch"
            }
            for tensor_name in (export_kwargs["input_names"] +
                                export_kwargs["output_names"])
        }

    # disable active quantization observers because they cannot be exported
    disabled_observers = []
    for submodule in module.modules():
        if (hasattr(submodule, "observer_enabled")
                and submodule.observer_enabled[0] == 1):
            submodule.observer_enabled[0] = 0
            disabled_observers.append(submodule)

    is_quant_module = any(
        hasattr(submodule, "qconfig") and submodule.qconfig
        for submodule in module.modules())
    batch_norms_wrapped = False
    if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing:
        # prevent batch norm fusing by adding a trivial operation before every
        # batch norm layer
        batch_norms_wrapped = _wrap_batch_norms(module)

    torch.onnx.export(
        module,
        sample_batch,
        file_path,
        strip_doc_string=True,
        verbose=False,
        opset_version=opset,
        dynamic_axes=dynamic_axes,
        **export_kwargs,
    )

    # re-enable disabled quantization observers
    for submodule in disabled_observers:
        submodule.observer_enabled[0] = 1

    # onnx file fixes
    onnx_model = onnx.load(file_path)
    # fix changed batch norm names
    _fix_batch_norm_names(onnx_model)
    if batch_norms_wrapped:
        # clean up graph from any injected / wrapped operations
        _delete_trivial_onnx_adds(onnx_model)
    onnx.save(onnx_model, file_path)

    if convert_qat and is_quant_module:
        # overwrite exported model with fully quantized version
        quantize_torch_qat_export(model=file_path, output_file_path=file_path)

    if skip_input_quantize:
        try:
            skip_onnx_input_quantize(file_path, file_path)
        except Exception as e:
            _LOGGER.warning(
                f"Unable to skip input QuantizeLinear op with exception {e}")
Example #5
0
    def export_onnx(
        self,
        sample_batch: Any,
        name: str = "model.onnx",
        opset: int = DEFAULT_ONNX_OPSET,
        disable_bn_fusing: bool = True,
    ):
        """
        Export an onnx file for the current module and for a sample batch.
        Sample batch used to feed through the model to freeze the graph for a
        particular execution.

        :param sample_batch: the batch to export an onnx for, handles creating the
            static graph for onnx as well as setting dimensions
        :param name: name of the onnx file to save
        :param opset: onnx opset to use for exported model. Default is 11, if torch
            version is 1.2 or below, default is 9
        :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
            fusing during torch export. Default and suggested setting is True. Batch
            norm fusing will change the exported parameter names as well as affect
            sensitivity analyses of the exported graph.  Additionally, the DeepSparse
            inference engine, and other engines, perform batch norm fusing at model
            compilation.
        """
        sample_batch = tensors_to_device(sample_batch, "cpu")
        onnx_path = os.path.join(self._output_dir, name)
        create_parent_dirs(onnx_path)

        with torch.no_grad():
            out = tensors_module_forward(sample_batch, self._module)

        input_names = None
        if isinstance(sample_batch, Tensor):
            input_names = ["input"]
        elif isinstance(sample_batch, Iterable):
            input_names = [
                "input_{}".format(index)
                for index, _ in enumerate(iter(sample_batch))
            ]

        output_names = None
        if isinstance(out, Tensor):
            output_names = ["output"]
        elif isinstance(out, Iterable):
            output_names = [
                "output_{}".format(index) for index, _ in enumerate(iter(out))
            ]

        # disable active quantization observers because they cannot be exported
        disabled_observers = []
        for submodule in self._module.modules():
            if (hasattr(submodule, "observer_enabled")
                    and submodule.observer_enabled[0] == 1):
                submodule.observer_enabled[0] = 0
                disabled_observers.append(submodule)

        is_quant_module = any(
            hasattr(submodule, "qconfig") and submodule.qconfig
            for submodule in self._module.modules())
        batch_norms_wrapped = False
        if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing:
            # prevent batch norm fusing by adding a trivial operation before every
            # batch norm layer
            export_module = deepcopy(self._module)
            batch_norms_wrapped = _wrap_batch_norms(export_module)
        else:
            export_module = self._module

        torch.onnx.export(
            export_module,
            sample_batch,
            onnx_path,
            input_names=input_names,
            output_names=output_names,
            strip_doc_string=True,
            verbose=False,
            opset_version=opset,
        )

        # re-enable disabled quantization observers
        for submodule in disabled_observers:
            submodule.observer_enabled[0] = 1

        # clean up graph from any injected / wrapped operations
        if batch_norms_wrapped:
            onnx_model = onnx.load(onnx_path)
            _delete_trivial_onnx_adds(onnx_model)
            onnx.save(onnx_model, onnx_path)