Esempio n. 1
0
    def change_context_window(self, context_window: int):
        """
        Update the context window of the SqueezeExcitation module, in-place if possible.

        Will update the pooling layer to either nn.AdaptiveAvgPool1d() (for global SE) or nn.AvgPool1d()
        (for limited context SE).

        If only the context window is changing but still a limited SE context block - then
        the earlier instance of nn.AvgPool1d() will be updated.

        Args:
            context_window: An integer representing the number of input timeframes that will be used
                to compute the context. Each timeframe corresponds to a single window stride of the
                STFT features.

                Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s
                of context to compute the Squeeze step.
        """
        if hasattr(self, 'context_window'):
            logging.info(
                f"Changing Squeeze-Excitation context window from {self.context_window} to {context_window}"
            )

        self.context_window = int(context_window)

        if self.context_window <= 0:
            if PYTORCH_QUANTIZATION_AVAILABLE and self._quantize:
                if not isinstance(self.pool,
                                  quant_nn.QuantAdaptiveAvgPool1d(1)):
                    self.pool = quant_nn.QuantAdaptiveAvgPool1d(
                        1)  # context window = T

            elif not PYTORCH_QUANTIZATION_AVAILABLE and self._quantize:
                raise ImportError(
                    "pytorch-quantization is not installed. Install from "
                    "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
                )

            else:
                if not isinstance(self.pool, nn.AdaptiveAvgPool1d):
                    self.pool = nn.AdaptiveAvgPool1d(1)  # context window = T
        else:
            if PYTORCH_QUANTIZATION_AVAILABLE and self._quantize:
                if not isinstance(self.pool, quant_nn.QuantAvgPool1d):
                    self.pool = quant_nn.QuantAvgPool1d(self.context_window,
                                                        stride=1)

            elif not PYTORCH_QUANTIZATION_AVAILABLE and self._quantize:
                raise ImportError(
                    "pytorch-quantization is not installed. Install from "
                    "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
                )

            else:
                if not isinstance(self.pool, nn.AvgPool1d):
                    self.pool = nn.AvgPool1d(self.context_window, stride=1)
                else:
                    # update the context window
                    self.pool.kernel_size = _single(self.context_window)
Esempio n. 2
0
    def __init__(
        self,
        channels: int,
        reduction_ratio: int,
        context_window: int = -1,
        interpolation_mode: str = 'nearest',
        activation: Optional[Callable] = None,
        quantize: bool = False,
    ):
        """
        Squeeze-and-Excitation sub-module.

        Args:
            channels: Input number of channels.
            reduction_ratio: Reduction ratio for "squeeze" layer.
            context_window: Integer number of timesteps that the context
                should be computed over, using stride 1 average pooling.
                If value < 1, then global context is computed.
            interpolation_mode: Interpolation mode of timestep dimension.
                Used only if context window is > 1.
                The modes available for resizing are: `nearest`, `linear` (3D-only),
                `bilinear`, `area`
            activation: Intermediate activation function used. Must be a
                callable activation function.
        """
        super(SqueezeExcite, self).__init__()
        self.context_window = int(context_window)
        self.interpolation_mode = interpolation_mode

        if self.context_window <= 0:
            if PYTORCH_QUANTIZATION_AVAILABLE and quantize:
                self.pool = quant_nn.QuantAdaptiveAvgPool1d(
                    1)  # context window = T
            elif not PYTORCH_QUANTIZATION_AVAILABLE and quantize:
                raise ImportError(
                    "pytorch-quantization is not installed. Install from "
                    "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
                )
            else:
                self.pool = nn.AdaptiveAvgPool1d(1)  # context window = T
        else:
            if PYTORCH_QUANTIZATION_AVAILABLE and quantize:
                self.pool = quant_nn.QuantAvgPool1d(self.context_window,
                                                    stride=1)
            elif not PYTORCH_QUANTIZATION_AVAILABLE and quantize:
                raise ImportError(
                    "pytorch-quantization is not installed. Install from "
                    "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
                )
            else:
                self.pool = nn.AvgPool1d(self.context_window, stride=1)

        if activation is None:
            activation = nn.ReLU(inplace=True)

        if PYTORCH_QUANTIZATION_AVAILABLE and quantize:
            self.fc = nn.Sequential(
                quant_nn.QuantLinear(channels,
                                     channels // reduction_ratio,
                                     bias=False),
                activation,
                quant_nn.QuantLinear(channels // reduction_ratio,
                                     channels,
                                     bias=False),
            )
        elif not PYTORCH_QUANTIZATION_AVAILABLE and quantize:
            raise ImportError(
                "pytorch-quantization is not installed. Install from "
                "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(channels, channels // reduction_ratio, bias=False),
                activation,
                nn.Linear(channels // reduction_ratio, channels, bias=False),
            )