def read_QI(info,ist,file,line=None): """ QI[it][il][ib*ia*im,ie] """ QI = dict() for it in info.Nt[ist]: QI[it] = ND_list(info.Nl[it]) for il in range(info.Nl[it]): QI[it][il] = torch_complex.ComplexTensor( np.empty((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]),dtype=np.float32), np.empty((info.Nb[ist],info.Na[ist][it],info.Nm(il),info.Ne[it]),dtype=np.float32) ) for ib in range(info.Nb[ist]): for it in info.Nt[ist]: for ia in range(info.Na[ist][it]): for il in range(info.Nl[it]): for im in range(info.Nm(il)): for ie in range(info.Ne[it]): if not line: line = file.readline().split() QI[it][il].real[ib,ia,im,ie] = float(line.pop(0)) if not line: line = file.readline().split() QI[it][il].imag[ib,ia,im,ie] = float(line.pop(0)) for it in info.Nt[ist]: for il in range(info.Nl[it]): QI[it][il] = torch_complex.ComplexTensor( torch.from_numpy(QI[it][il].real).view(-1,info.Ne[it]), torch.from_numpy(QI[it][il].imag).view(-1,info.Ne[it])) return QI,line
def read_SI(info,ist,file,line=None): """ SI[it1,it2][il1][il2][ie1,ia1,im1,ia2,im2,ie2] """ SI = dict() for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ): SI[it1,it2] = ND_list(info.Nl[it1],info.Nl[it2]) for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ): SI[it1,it2][il1][il2] = torch_complex.ComplexTensor( np.empty((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]),dtype=np.float32), np.empty((info.Na[ist][it1],info.Nm(il1),info.Ne[it1],info.Na[ist][it2],info.Nm(il2),info.Ne[it2]),dtype=np.float32) ) for it1 in info.Nt[ist]: for ia1 in range(info.Na[ist][it1]): for il1 in range(info.Nl[it1]): for im1 in range(info.Nm(il1)): for it2 in info.Nt[ist]: for ia2 in range(info.Na[ist][it2]): for il2 in range(info.Nl[it2]): for im2 in range(info.Nm(il2)): for ie1 in range(info.Ne[it1]): for ie2 in range(info.Ne[it2]): if not line: line = file.readline().split() SI[it1,it2][il1][il2].real[ia1,im1,ie1,ia2,im2,ie2] = float(line.pop(0)) if not line: line = file.readline().split() SI[it1,it2][il1][il2].imag[ia1,im1,ie1,ia2,im2,ie2] = float(line.pop(0)) for it1,it2 in itertools.product( info.Nt[ist], info.Nt[ist] ): for il1,il2 in itertools.product( range(info.Nl[it1]), range(info.Nl[it2]) ): SI[it1,it2][il1][il2] = torch_complex.ComplexTensor( torch.from_numpy(SI[it1,it2][il1][il2].real), torch.from_numpy(SI[it1,it2][il1][il2].imag)) return SI,line
def stft( time_signal, size: int = 1024, shift: int = 256, *, # axis=-1, # I never use this and it complicated the code window: [str, typing.Callable] = 'blackman', window_length: int = None, fading: typing.Optional[typing.Union[bool, str]] = 'full', pad: bool = True, symmetric_window: bool = False, ): """ >>> import numpy as np >>> import random >>> from paderbox.transform.module_stft import stft as np_stft, istft as np_istft >>> kwargs = dict( ... size=np.random.randint(100, 200), ... shift=np.random.randint(40, 100), ... window=random.choice(['blackman', 'hann', 'hamming']), ... fading=random.choice(['full', 'half', False]), ... ) >>> num_samples = np.random.randint(200, 500) >>> a = np.random.rand(num_samples) >>> A_np = np_stft(a, **kwargs) >>> A_pt = stft(torch.tensor(a), **kwargs) >>> np.testing.assert_allclose( ... A_np, A_pt.numpy(), err_msg=str(kwargs), atol=1e-10) """ assert isinstance(time_signal, torch.Tensor) if window_length is None: window_length = size else: if window_length != size: raise NotImplementedError( 'Torch does not support window_length != size\n' 'window_length = {window_length} != {size} = size') # Pad with zeros to have enough samples for the window function to fade. assert fading in [None, True, False, 'full', 'half'], (fading, type(fading)) if fading not in [False, None]: if fading == 'half': pad_width = [ (window_length - shift) // 2, math.ceil((window_length - shift) / 2), ] else: pad_width = [ window_length - shift, window_length - shift, ] time_signal = torch.nn.functional.pad(time_signal, pad_width, mode='constant') window = _get_window( window=window, symmetric_window=symmetric_window, window_length=window_length, ) time_signal_seg = segment_axis(time_signal, window_length, shift=shift, axis=-1, end='pad' if pad else 'cut') out = torch.rfft( time_signal_seg * window, 1, # size, ) assert out.shape[-1] == 2, out.shape return torch_complex.ComplexTensor(out[..., 0], out[..., 1])
def complex_numpy(self, example, device): import torch_complex return torch_complex.ComplexTensor(example, device=device)