def stft(x): comp = stft_patch(x.squeeze(1), n_fft=1024, hop_length=256, win_length=1024) real, imag = comp[..., 0], comp[..., 1] mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase
def yet_another_patch(audio, n_fft, hop_length, win_length, window): spec = stft_patch(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window) if spec.dtype in [torch.cfloat, torch.cdouble]: spec = torch.view_as_real(spec) return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2( spec[..., -1], spec[..., 0])
def stft(x, fft_size, hop_size, win_length, window): """Perform STFT and convert to magnitude spectrogram. Args: x (Tensor): Input signal tensor (B, T). fft_size (int): FFT size. hop_size (int): Hop size. win_length (int): Window length. window (str): Window function type. Returns: Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). """ x_stft = stft_patch(x, fft_size, hop_size, win_length, window, return_complex=False) real = x_stft[..., 0] imag = x_stft[..., 1] # NOTE(kan-bayashi): clamp is needed to avoid nan or inf return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
def __init__( self, sample_rate=16000, n_window_size=320, n_window_stride=160, window="hann", normalize="per_feature", n_fft=None, preemph=0.97, nfilt=64, lowfreq=0, highfreq=None, log=True, log_zero_guard_type="add", log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, frame_splicing=1, exact_pad=False, stft_exact_pad=False, # TODO: Remove this in 1.1.0 stft_conv=False, # TODO: Remove this in 1.1.0 pad_value=0, mag_power=2.0, use_grads=False, ): super().__init__() if stft_conv or stft_exact_pad: logging.warning( "Using torch_stft is deprecated and will be removed in 1.1.0. Please set stft_conv and stft_exact_pad " "to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True " "as needed.") if (exact_pad or stft_exact_pad) and n_window_stride % 2 == 1: raise NotImplementedError( f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." ) self.log_zero_guard_value = log_zero_guard_value if (n_window_size is None or n_window_stride is None or not isinstance(n_window_size, int) or not isinstance(n_window_stride, int) or n_window_size <= 0 or n_window_stride <= 0): raise ValueError( f"{self} got an invalid value for either n_window_size or " f"n_window_stride. Both must be positive ints.") logging.info(f"PADDING: {pad_to}") self.win_length = n_window_size self.hop_length = n_window_stride self.n_fft = n_fft or 2**math.ceil(math.log2(self.win_length)) self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None self.stft_exact_pad = stft_exact_pad self.stft_conv = stft_conv if stft_conv: logging.info("STFT using conv") if stft_exact_pad: logging.info("STFT using exact pad") self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window) else: self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window) else: logging.info("STFT using torch") if exact_pad: logging.info("STFT using exact pad") torch_windows = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, 'blackman': torch.blackman_window, 'bartlett': torch.bartlett_window, 'none': None, } window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) self.stft = lambda x: stft_patch( x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=False if exact_pad else True, window=self.window.to(dtype=torch.float), return_complex=False, ) self.normalize = normalize self.log = log self.dither = dither self.frame_splicing = frame_splicing self.nfilt = nfilt self.preemph = preemph self.pad_to = pad_to highfreq = highfreq or sample_rate / 2 filterbanks = torch.tensor(librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float).unsqueeze(0) self.register_buffer("fb", filterbanks) # Calculate maximum sequence length max_length = self.get_seq_len( torch.tensor(max_duration * sample_rate, dtype=torch.float)) max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 self.max_length = max_length + max_pad self.pad_value = pad_value self.mag_power = mag_power # We want to avoid taking the log of zero # There are two options: either adding or clamping to a small value if log_zero_guard_type not in ["add", "clamp"]: raise ValueError( f"{self} received {log_zero_guard_type} for the " f"log_zero_guard_type parameter. It must be either 'add' or " f"'clamp'.") self.use_grads = use_grads if not use_grads: self.forward = torch.no_grad()(self.forward) # log_zero_guard_value is the the small we want to use, we support # an actual number, or "tiny", or "eps" self.log_zero_guard_type = log_zero_guard_type logging.debug(f"sr: {sample_rate}") logging.debug(f"n_fft: {self.n_fft}") logging.debug(f"win_length: {self.win_length}") logging.debug(f"hop_length: {self.hop_length}") logging.debug(f"n_mels: {nfilt}") logging.debug(f"fmin: {lowfreq}") logging.debug(f"fmax: {highfreq}") logging.debug(f"using grads: {use_grads}")
def __init__( self, sample_rate=16000, n_window_size=320, n_window_stride=160, window="hann", normalize="per_feature", n_fft=None, preemph=0.97, nfilt=64, lowfreq=0, highfreq=None, log=True, log_zero_guard_type="add", log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, frame_splicing=1, stft_exact_pad=False, stft_conv=False, pad_value=0, mag_power=2.0, ): super().__init__() self.log_zero_guard_value = log_zero_guard_value if (n_window_size is None or n_window_stride is None or not isinstance(n_window_size, int) or not isinstance(n_window_stride, int) or n_window_size <= 0 or n_window_stride <= 0): raise ValueError( f"{self} got an invalid value for either n_window_size or " f"n_window_stride. Both must be positive ints.") logging.info(f"PADDING: {pad_to}") self.win_length = n_window_size self.hop_length = n_window_stride self.n_fft = n_fft or 2**math.ceil(math.log2(self.win_length)) self.stft_exact_pad = stft_exact_pad self.stft_conv = stft_conv if stft_conv: logging.info("STFT using conv") if stft_exact_pad: logging.info("STFT using exact pad") self.stft = STFTExactPad(self.n_fft, self.hop_length, self.win_length, window) else: self.stft = STFTPatch(self.n_fft, self.hop_length, self.win_length, window) else: logging.info("STFT using torch") torch_windows = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, 'blackman': torch.blackman_window, 'bartlett': torch.bartlett_window, 'none': None, } window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) self.stft = lambda x: stft_patch( x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=False if stft_exact_pad else True, window=self.window.to(dtype=torch.float), return_complex=False, ) self.normalize = normalize self.log = log self.dither = dither self.frame_splicing = frame_splicing self.nfilt = nfilt self.preemph = preemph self.pad_to = pad_to highfreq = highfreq or sample_rate / 2 filterbanks = torch.tensor(librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float).unsqueeze(0) self.register_buffer("fb", filterbanks) # Calculate maximum sequence length max_length = self.get_seq_len( torch.tensor(max_duration * sample_rate, dtype=torch.float)) max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 self.max_length = max_length + max_pad self.pad_value = pad_value self.mag_power = mag_power # We want to avoid taking the log of zero # There are two options: either adding or clamping to a small value if log_zero_guard_type not in ["add", "clamp"]: raise ValueError( f"{self} received {log_zero_guard_type} for the " f"log_zero_guard_type parameter. It must be either 'add' or " f"'clamp'.") # log_zero_guard_value is the the small we want to use, we support # an actual number, or "tiny", or "eps" self.log_zero_guard_type = log_zero_guard_type logging.debug(f"sr: {sample_rate}") logging.debug(f"n_fft: {self.n_fft}") logging.debug(f"win_length: {self.win_length}") logging.debug(f"hop_length: {self.hop_length}") logging.debug(f"n_mels: {nfilt}") logging.debug(f"fmin: {lowfreq}") logging.debug(f"fmax: {highfreq}")
def __init__( self, manifest_filepath: str, sample_rate: int, text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]], tokens: Optional[List[str]] = None, text_normalizer: Optional[Union[Normalizer, Callable[[str], str]]] = None, text_normalizer_call_args: Optional[Dict] = None, text_tokenizer_pad_id: Optional[int] = None, sup_data_types: Optional[List[str]] = None, sup_data_path: Optional[Union[Path, str]] = None, max_duration: Optional[float] = None, min_duration: Optional[float] = None, ignore_file: Optional[str] = None, trim: bool = False, n_fft=1024, win_length=None, hop_length=None, window="hann", n_mels=80, lowfreq=0, highfreq=None, **kwargs, ): """Dataset that loads main data types (audio and text) and specified supplementary data types (e.g. log mel, durations, pitch). Most supplementary data types will be computed on the fly and saved in the supplementary_folder if they did not exist before. Arguments for supplementary data should be also specified in this class and they will be used from kwargs (see keyword args section). Args: manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid json. Each line should contain the following: "audio_filepath": <PATH_TO_WAV> "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional) "duration": <Duration of audio clip in seconds> (Optional) "text": <THE_TRANSCRIPT> (Optional) sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to. text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer. tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer. text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer. text_normalizer_call_args (Optional[Dict]): Additional arguments for text_normalizer function. text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer. sup_data_types (Optional[List[str]]): List of supplementary data types. sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch). max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio files) that will be pruned prior to training. Defaults to None which does not prune. trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False. n_fft (Optional[int]): The number of fft samples. Defaults to 1024 win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft. hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4. window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the equivalent torch window function. n_mels (Optional[int]): The number of mel filters. Defaults to 80. lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0. highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None. Keyword Args: durs_file (Optional[str]): String path to pickled durations location. durs_type (Optional[str]): Type of durations. Currently supported only "aligned-based". pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2'). pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7'). pitch_avg (Optional[float]): The mean that we use to normalize the pitch. pitch_std (Optional[float]): The std that we use to normalize the pitch. pitch_norm (Optional[bool]): Whether to normalize pitch (via pitch_avg and pitch_std) or not. """ super().__init__() self.text_normalizer = text_normalizer self.text_normalizer_call = ( self.text_normalizer.normalize if isinstance( self.text_normalizer, Normalizer) else self.text_normalizer) self.text_normalizer_call_args = text_normalizer_call_args self.text_tokenizer = text_tokenizer if isinstance(self.text_tokenizer, BaseTokenizer): self.text_tokenizer_pad_id = text_tokenizer.pad self.tokens = text_tokenizer.tokens else: if text_tokenizer_pad_id is None: raise ValueError( f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer" ) if tokens is None: raise ValueError( f"tokens must be specified if text_tokenizer is not BaseTokenizer" ) self.text_tokenizer_pad_id = text_tokenizer_pad_id self.tokens = tokens if isinstance(manifest_filepath, str): manifest_filepath = [manifest_filepath] self.manifest_filepath = manifest_filepath if sup_data_path is not None: Path(sup_data_path).mkdir(parents=True, exist_ok=True) self.sup_data_path = sup_data_path self.sup_data_types = ([ DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types ] if sup_data_types is not None else []) self.sup_data_types_set = set(self.sup_data_types) self.data = [] audio_files = [] total_duration = 0 for manifest_file in self.manifest_filepath: with open(Path(manifest_file).expanduser(), 'r') as f: logging.info(f"Loading dataset from {manifest_file}.") for line in tqdm(f): item = json.loads(line) file_info = { "audio_filepath": item["audio_filepath"], "mel_filepath": item["mel_filepath"] if "mel_filepath" in item else None, "duration": item["duration"] if "duration" in item else None, "text_tokens": None, } if "text" in item: text = item["text"] if self.text_normalizer is not None: text = self.text_normalizer_call( text, **self.text_normalizer_call_args) text_tokens = self.text_tokenizer(text) file_info["raw_text"] = item["text"] file_info["text_tokens"] = text_tokens audio_files.append(file_info) if file_info["duration"] is None: logging.info( "Not all audio files have duration information. Duration logging will be disabled." ) total_duration = None if total_duration is not None: total_duration += item["duration"] logging.info(f"Loaded dataset with {len(audio_files)} files.") if total_duration is not None: logging.info( f"Dataset contains {total_duration / 3600:.2f} hours.") if ignore_file: logging.info(f"using {ignore_file} to prune dataset.") with open(Path(ignore_file).expanduser(), "rb") as f: wavs_to_ignore = set(pickle.load(f)) pruned_duration = 0 if total_duration is not None else None pruned_items = 0 for item in audio_files: audio_path = item['audio_filepath'] audio_id = Path(audio_path).stem # Prune data according to min/max_duration & the ignore file if total_duration is not None: if (min_duration and item["duration"] < min_duration) or ( max_duration and item["duration"] > max_duration): pruned_duration += item["duration"] pruned_items += 1 continue if ignore_file and (audio_id in wavs_to_ignore): pruned_items += 1 pruned_duration += item["duration"] wavs_to_ignore.remove(audio_id) continue self.data.append(item) logging.info( f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files" ) if pruned_duration is not None: logging.info( f"Pruned {pruned_duration / 3600:.2f} hours. Final dataset contains " f"{(total_duration - pruned_duration) / 3600:.2f} hours.") self.sample_rate = sample_rate self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) self.trim = trim self.n_fft = n_fft self.n_mels = n_mels self.lowfreq = lowfreq self.highfreq = highfreq self.window = window self.win_length = win_length or self.n_fft self.hop_length = hop_length self.hop_len = self.hop_length or self.n_fft // 4 self.fb = torch.tensor( librosa.filters.mel(self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.lowfreq, fmax=self.highfreq), dtype=torch.float, ).unsqueeze(0) window_fn = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, 'blackman': torch.blackman_window, 'bartlett': torch.bartlett_window, 'none': None, }.get(self.window, None) self.stft = lambda x: stft_patch( input=x, n_fft=self.n_fft, hop_length=self.hop_len, win_length=self.win_length, window=window_fn(self.win_length, periodic=False).to(torch.float) if window_fn else None, ) for data_type in self.sup_data_types: if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES: raise NotImplementedError( f"Current implementation of TTSDataset doesn't support {data_type} type." ) getattr(self, f"add_{data_type.name}")(**kwargs)
def __init__( self, manifest_filepath: str, sample_rate: int, supplementary_folder: Path, max_duration: Optional[float] = None, min_duration: Optional[float] = None, ignore_file: Optional[str] = None, trim: bool = False, n_fft=1024, win_length=None, hop_length=None, window="hann", n_mels=64, lowfreq=0, highfreq=None, pitch_fmin=80, pitch_fmax=640, pitch_avg=0, pitch_std=1, tokenize_text=True, ): """Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies. Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if they did not exist before. Args: manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid json. Each line should contain the following: "audio_filepath": <PATH_TO_WAV> "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional) "duration": <Duration of audio clip in seconds> (Optional) "text": <THE_TRANSCRIPT> (Optional) sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to. supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not specified in the manifest .json file. It will also contain priors, pitches, and energies max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio files) that will be pruned prior to training. Defaults to None which does not prune. trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False. n_fft (Optional[int]): The number of fft samples. Defaults to 1024 win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft. hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4. window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the equivalent torch window function. n_mels (Optional[int]): The number of mel filters. Defaults to 64. lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0. highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None. pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None. pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None. pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0. pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1. tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True. """ super().__init__() self.pitch_fmin = pitch_fmin self.pitch_fmax = pitch_fmax self.pitch_avg = pitch_avg self.pitch_std = pitch_std self.win_length = win_length or n_fft self.sample_rate = sample_rate self.hop_len = hop_length or n_fft // 4 self.parser = make_parser(name="en", do_tokenize=tokenize_text) self.pad_id = self.parser._blank_id Path(supplementary_folder).mkdir(parents=True, exist_ok=True) self.supplementary_folder = supplementary_folder audio_files = [] total_duration = 0 # Load data from manifests # Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models # extract pitch from the audio # Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we # compute the mel on the fly and save it to the supplementary folder # Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will # fail if not set. However vocoders (mel -> audio) will be able to work without text if isinstance(manifest_filepath, str): manifest_filepath = [manifest_filepath] for manifest_file in manifest_filepath: with open(Path(manifest_file).expanduser(), 'r') as f: logging.info(f"Loading dataset from {manifest_file}.") for line in f: item = json.loads(line) # Grab audio, text, mel if they exist file_info = {} file_info["audio_filepath"] = item["audio_filepath"] file_info["mel_filepath"] = item[ "mel_filepath"] if "mel_filepath" in item else None file_info["duration"] = item[ "duration"] if "duration" in item else None # Parse text file_info["text_tokens"] = None if "text" in item: text = item["text"] text_tokens = self.parser(text) file_info["text_tokens"] = text_tokens audio_files.append(file_info) if file_info["duration"] is None: logging.info( "Not all audio files have duration information. Duration logging will be disabled." ) total_duration = None if total_duration is not None: total_duration += item["duration"] logging.info(f"Loaded dataset with {len(audio_files)} files.") if total_duration is not None: logging.info(f"Dataset contains {total_duration/3600:.2f} hours.") self.data = [] if ignore_file: logging.info(f"using {ignore_file} to prune dataset.") with open(Path(ignore_file).expanduser(), "rb") as f: wavs_to_ignore = set(pickle.load(f)) pruned_duration = 0 if total_duration is not None else None pruned_items = 0 for item in audio_files: audio_path = item['audio_filepath'] audio_id = Path(audio_path).stem # Prune data according to min/max_duration & the ignore file if total_duration is not None: if (min_duration and item["duration"] < min_duration) or ( max_duration and item["duration"] > max_duration): pruned_duration += item["duration"] pruned_items += 1 continue if ignore_file and (audio_id in wavs_to_ignore): pruned_items += 1 pruned_duration += item["duration"] wavs_to_ignore.remove(audio_id) continue self.data.append(item) logging.info( f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files" ) if pruned_duration is not None: logging.info( f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains " f"{(total_duration-pruned_duration)/3600:.2f} hours.") self.featurizer = WaveformFeaturizer(sample_rate=sample_rate) self.trim = trim filterbanks = torch.tensor(librosa.filters.mel(sample_rate, n_fft, n_mels=n_mels, fmin=lowfreq, fmax=highfreq), dtype=torch.float).unsqueeze(0) self.fb = filterbanks torch_windows = { 'hann': torch.hann_window, 'hamming': torch.hamming_window, 'blackman': torch.blackman_window, 'bartlett': torch.bartlett_window, 'none': None, } window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.stft = lambda x: stft_patch( input=x, n_fft=n_fft, hop_length=self.hop_len, win_length=self.win_length, window=window_tensor.to(torch.float), )