def change_context_window(self, context_window: int): """ Update the context window of the SqueezeExcitation module, in-place if possible. Will update the pooling layer to either nn.AdaptiveAvgPool1d() (for global SE) or nn.AvgPool1d() (for limited context SE). If only the context window is changing but still a limited SE context block - then the earlier instance of nn.AvgPool1d() will be updated. Args: context_window: An integer representing the number of input timeframes that will be used to compute the context. Each timeframe corresponds to a single window stride of the STFT features. Say the window_stride = 0.01s, then a context window of 128 represents 128 * 0.01 s of context to compute the Squeeze step. """ if hasattr(self, 'context_window'): logging.info( f"Changing Squeeze-Excitation context window from {self.context_window} to {context_window}" ) self.context_window = int(context_window) if self.context_window <= 0: if PYTORCH_QUANTIZATION_AVAILABLE and self._quantize: if not isinstance(self.pool, quant_nn.QuantAdaptiveAvgPool1d(1)): self.pool = quant_nn.QuantAdaptiveAvgPool1d( 1) # context window = T elif not PYTORCH_QUANTIZATION_AVAILABLE and self._quantize: raise ImportError( "pytorch-quantization is not installed. Install from " "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) else: if not isinstance(self.pool, nn.AdaptiveAvgPool1d): self.pool = nn.AdaptiveAvgPool1d(1) # context window = T else: if PYTORCH_QUANTIZATION_AVAILABLE and self._quantize: if not isinstance(self.pool, quant_nn.QuantAvgPool1d): self.pool = quant_nn.QuantAvgPool1d(self.context_window, stride=1) elif not PYTORCH_QUANTIZATION_AVAILABLE and self._quantize: raise ImportError( "pytorch-quantization is not installed. Install from " "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) else: if not isinstance(self.pool, nn.AvgPool1d): self.pool = nn.AvgPool1d(self.context_window, stride=1) else: # update the context window self.pool.kernel_size = _single(self.context_window)
def __init__( self, channels: int, reduction_ratio: int, context_window: int = -1, interpolation_mode: str = 'nearest', activation: Optional[Callable] = None, quantize: bool = False, ): """ Squeeze-and-Excitation sub-module. Args: channels: Input number of channels. reduction_ratio: Reduction ratio for "squeeze" layer. context_window: Integer number of timesteps that the context should be computed over, using stride 1 average pooling. If value < 1, then global context is computed. interpolation_mode: Interpolation mode of timestep dimension. Used only if context window is > 1. The modes available for resizing are: `nearest`, `linear` (3D-only), `bilinear`, `area` activation: Intermediate activation function used. Must be a callable activation function. """ super(SqueezeExcite, self).__init__() self.context_window = int(context_window) self.interpolation_mode = interpolation_mode if self.context_window <= 0: if PYTORCH_QUANTIZATION_AVAILABLE and quantize: self.pool = quant_nn.QuantAdaptiveAvgPool1d( 1) # context window = T elif not PYTORCH_QUANTIZATION_AVAILABLE and quantize: raise ImportError( "pytorch-quantization is not installed. Install from " "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) else: self.pool = nn.AdaptiveAvgPool1d(1) # context window = T else: if PYTORCH_QUANTIZATION_AVAILABLE and quantize: self.pool = quant_nn.QuantAvgPool1d(self.context_window, stride=1) elif not PYTORCH_QUANTIZATION_AVAILABLE and quantize: raise ImportError( "pytorch-quantization is not installed. Install from " "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) else: self.pool = nn.AvgPool1d(self.context_window, stride=1) if activation is None: activation = nn.ReLU(inplace=True) if PYTORCH_QUANTIZATION_AVAILABLE and quantize: self.fc = nn.Sequential( quant_nn.QuantLinear(channels, channels // reduction_ratio, bias=False), activation, quant_nn.QuantLinear(channels // reduction_ratio, channels, bias=False), ) elif not PYTORCH_QUANTIZATION_AVAILABLE and quantize: raise ImportError( "pytorch-quantization is not installed. Install from " "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." ) else: self.fc = nn.Sequential( nn.Linear(channels, channels // reduction_ratio, bias=False), activation, nn.Linear(channels // reduction_ratio, channels, bias=False), )