コード例 #1
0
    def __init__(self,
                 squeezewave,
                 filter_length=1024,
                 n_overlap=4,
                 win_length=1024,
                 mode='zeros'):
        super(Denoiser, self).__init__()
        self.stft = STFT(filter_length=filter_length,
                         hop_length=int(filter_length / n_overlap),
                         win_length=win_length).cuda()
        if mode == 'zeros':
            mel_input = torch.zeros(
                (1, 80, 88),
                dtype=squeezewave.WN[0].start.weight.dtype,
                device=squeezewave.WN[0].start.weight.device)
        elif mode == 'normal':
            mel_input = torch.randn(
                (1, 80, 88),
                dtype=squeezewave.WN[0].start.weight.dtype,
                device=squeezewave.WN[0].start.weight.device)
        else:
            raise Exception("Mode {} if not supported".format(mode))

        with torch.no_grad():
            bias_audio = squeezewave.infer(mel_input, sigma=0.0).float()
            bias_spec, _ = self.stft.transform(bias_audio)

        self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
コード例 #2
0
def test_stft():
    y, sr = librosa.load(librosa.util.example_audio_file())
    audio = torch.autograd.Variable(torch.FloatTensor(y), requires_grad=False).unsqueeze(0)

    if torch.cuda.is_available():
        audio = audio.cuda()

    def mse(ground_truth, estimated):
        return torch.mean((ground_truth - estimated)**2)

    def to_np(tensor):
        return tensor.cpu().data.numpy()

    filter_length=320
    hop_length=160
    stft = STFT(filter_length=filter_length, hop_length=hop_length)
    if torch.cuda.is_available():
        stft = stft.cuda()
    output, m_hat, p_hat = stft(audio)

    D = librosa.stft(y, n_fft=filter_length, hop_length=hop_length)
    m, p = librosa.magphase(D)

    import ipdb; ipdb.set_trace()

    import matplotlib.pyplot as plt
    librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max), y_axis='log', x_axis='time')
    plt.title('Power spectrogram')
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
コード例 #3
0
ファイル: test_stft.py プロジェクト: yluo42/pytorch-stft
def test_stft():
    audio, sr = librosa.load('mixture.mp3', sr=None)
    audio = torch.autograd.Variable(torch.FloatTensor(audio),
                                    requires_grad=False).unsqueeze(0)
    if torch.cuda.is_available():
        audio = audio.cuda()

    def mse(ground_truth, estimated):
        return torch.mean((ground_truth - estimated)**2)

    def to_np(tensor):
        return tensor.cpu().data.numpy()

    for i in range(12):
        filter_length = 2**i
        for j in range(i + 1):
            try:
                hop_length = 2**j
                stft = STFT(filter_length=filter_length, hop_length=hop_length)
                if torch.cuda.is_available():
                    stft = stft.cuda()
                output = stft(audio)
                loss = mse(output, audio)
                print 'MSE: %s @ filter_length = %d, hop_length = %d' % (str(
                    to_np(loss)[0]), filter_length, hop_length)
            except:
                print 'Failed @ filter_length = %d, hop_length = %d' % (
                    filter_length, hop_length)
                print traceback.print_exc()
コード例 #4
0
    def __init__(self,
                 waveglow,
                 filter_length=1024,
                 n_overlap=4,
                 win_length=1024,
                 mode='zeros',
                 device="cuda"):
        super(Denoiser, self).__init__()

        self.device = torch.device(
            "cpu" if not torch.cuda.is_available() else device)

        self.stft = STFT(filter_length=filter_length,
                         hop_length=int(filter_length / n_overlap),
                         win_length=win_length)
        self.stft.to(self.device)

        if mode == 'zeros':
            mel_input = torch.zeros((1, 80, 88),
                                    dtype=waveglow.upsample.weight.dtype,
                                    device=waveglow.upsample.weight.device)
        elif mode == 'normal':
            mel_input = torch.randn((1, 80, 88),
                                    dtype=waveglow.upsample.weight.dtype,
                                    device=waveglow.upsample.weight.device)
        else:
            raise Exception("Mode {} if not supported".format(mode))

        with torch.no_grad():
            bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
            bias_spec, _ = self.stft.transform(bias_audio)

        self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
コード例 #5
0
 def __init__(self, n_fft, hop_length, win_length):
     super(TacotronSTFT, self).__init__()
     self.n_mel_channels = hp.num_mels
     self.sampling_rate = hp.sample_rate
     self.stft_fn = STFT(n_fft, hop_length, win_length)
     self.max_abs_mel_value = hp.max_abs_value
     mel_basis = librosa_mel_fn(hp.sample_rate, n_fft, hp.num_mels,
                                hp.mel_fmin, hp.mel_fmax)
     mel_basis = torch.from_numpy(mel_basis).float()
     self.register_buffer('mel_basis', mel_basis)
コード例 #6
0
 def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
              n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
              mel_fmax=8000.0):
     super(TacotronSTFT, self).__init__()
     self.n_mel_channels = n_mel_channels
     self.sampling_rate = sampling_rate
     self.stft_fn = STFT(filter_length, hop_length, win_length)
     mel_basis = librosa_mel_fn(
         sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
     mel_basis = torch.from_numpy(mel_basis).float()
     self.register_buffer('mel_basis', mel_basis)
コード例 #7
0
ファイル: nnet.py プロジェクト: yaoydong/voice-filter
    def __init__(self,
                 frame_len,
                 frame_hop,
                 round_pow_of_two=True,
                 embedding_dim=512,
                 log_mag=False,
                 mvn_mag=False,
                 lstm_dim=400,
                 linear_dim=600,
                 l2_norm=True,
                 bidirectional=False,
                 non_linear="relu"):
        super(VoiceFilter, self).__init__()
        supported_nonlinear = {
            "relu": F.relu,
            "sigmoid": th.sigmoid,
            "tanh": th.tanh
        }
        if non_linear not in supported_nonlinear:
            raise RuntimeError(
                "Unsupported non-linear function: {}".format(non_linear))
        N = 2**math.ceil(
            math.log2(frame_len)) if round_pow_of_two else frame_len
        num_bins = N // 2 + 1

        self.stft = STFT(frame_len,
                         frame_hop,
                         round_pow_of_two=round_pow_of_two)
        self.istft = iSTFT(frame_len,
                           frame_hop,
                           round_pow_of_two=round_pow_of_two)
        self.cnn_f = Conv2dBlock(1, 64, kernel_size=(7, 1))
        self.cnn_t = Conv2dBlock(64, 64, kernel_size=(1, 7))
        blocks = []
        for d in range(5):
            blocks.append(
                Conv2dBlock(64, 64, kernel_size=(5, 5), dilation=(1, 2**d)))
        self.cnn_tf = nn.Sequential(*blocks)
        self.proj = Conv2dBlock(64, 8, kernel_size=(1, 1))
        self.lstm = nn.LSTM(8 * num_bins + embedding_dim,
                            lstm_dim,
                            batch_first=True,
                            bidirectional=bidirectional)
        self.mask = nn.Sequential(
            nn.Linear(lstm_dim * 2 if bidirectional else lstm_dim, linear_dim),
            nn.ReLU(), nn.Linear(linear_dim, num_bins))
        self.non_linear = supported_nonlinear[non_linear]
        self.embedding_dim = embedding_dim
        self.l2_norm = l2_norm
        self.log_mag = log_mag
        self.bn = nn.BatchNorm1d(num_bins) if mvn_mag else None
コード例 #8
0
ファイル: layers.py プロジェクト: fasongsong/Speech_synthesis
 def __init__(self, hparams):
     super(TacotronSTFT, self).__init__()
     self.n_mel_channels = hparams.n_mel_channels
     self.sampling_rate = hparams.sampling_rate
     # print(hparams.filter_length, hparams.hop_length, hparams.win_length)
     self.stft_fn = STFT(hparams.filter_length, hparams.hop_length,
                         hparams.win_length)
     self.max_abs_mel_value = hparams.max_abs_mel_value
     mel_basis = librosa_mel_fn(hparams.sampling_rate,
                                hparams.filter_length,
                                hparams.n_mel_channels, hparams.mel_fmin,
                                hparams.mel_fmax)
     mel_basis = torch.from_numpy(mel_basis).float()
     self.register_buffer('mel_basis', mel_basis)
コード例 #9
0
ファイル: melspec.py プロジェクト: taalua/speech2singing
    def __init__(self, hp):
        super(MelSpectrogram, self).__init__()
        self.n_mel_channels = hp.n_mel_channels
        self.sampling_rate = hp.sampling_rate

        self.stft_fn = STFT(hp.filter_length, hp.hop_length,
                            hp.win_length).cuda()
        mel_basis = librosa_mel_fn(hp.sampling_rate, hp.filter_length,
                                   hp.n_mel_channels, hp.mel_fmin, None)

        inv_mel_basis = np.linalg.pinv(mel_basis)

        mel_basis = torch.from_numpy(mel_basis).float()
        inv_mel_basis = torch.from_numpy(inv_mel_basis).float().cuda()

        self.register_buffer('mel_basis', mel_basis)
        self.register_buffer('inv_mel_basis', inv_mel_basis)
コード例 #10
0
    def __init__(self):

        # Initialize node
        rospy.init_node('detect_alpha', anonymous=True)

        # Get ros parameters
        sleep(10)
        fs = rospy.get_param("sampling_rate")
        #fs = 125
        print fs
        channel_count = rospy.get_param("eeg_channel_count")
        print channel_count

        # Initialize STFT
        self.stft = STFT(fs, 1.0, 0.25, channel_count)
        self.stft.remove_dc()
        self.stft.bandpass(5.0, 15.0)
        self.stft.window('hann')
        self.freq_bins = self.stft.freq_bins
        self.FFT = np.zeros((len(self.freq_bins), channel_count))

        # Choose channels
        self.channel_mask = np.full(channel_count, False, dtype=bool)
        self.channel_mask[7 - 1] = True
        self.channel_mask[8 - 1] = True

        # Define bands
        self.G1_mask = np.logical_and(5 < self.freq_bins, self.freq_bins < 7.5)
        self.Al_mask = np.logical_and(8.5 < self.freq_bins,
                                      self.freq_bins < 11.5)
        self.G2_mask = np.logical_and(12.5 < self.freq_bins,
                                      self.freq_bins < 15)

        # Initialize filters
        self.movavg = MovAvg(4)
        self.ignore = Ignore(0)

        # Setup publishers
        self.pub_guard1 = rospy.Publisher('guard1', Float32, queue_size=1)
        self.pub_alpha = rospy.Publisher('alpha', Float32, queue_size=1)
        self.pub_guard2 = rospy.Publisher('guard2', Float32, queue_size=1)
        self.pub_eyes = rospy.Publisher('eyes_closed', Bool, queue_size=1)

        # Subscribe
        rospy.Subscriber("eeg_channels", BCIuVolts, self.newSample)
コード例 #11
0
def wav_to_image(
    filename,
    wlen,
    mindata,
    maxdata,
    save=False,
    name_save=None,
):
    h = wlen / 4
    K = np.sum(hamming(wlen, False)) / wlen

    nfft = int(2**(np.ceil(np.log2(wlen))))
    Fs, data_seq = wavfile.read(filename)
    raw_data = data_seq.astype('float32')
    max_dt = np.amax(np.absolute(raw_data))
    raw_data = raw_data / max_dt
    stft_data, _, _ = STFT(raw_data, wlen, h, nfft, Fs)
    s = abs(stft_data) / wlen / K
    if np.fmod(nfft, 2):
        s[1:, :] *= 2
    else:
        s[1:-2] *= 2
    data_temp = 20 * np.log10(s + 10**-6)
    outdata = data_temp.transpose()
    """Scaling"""
    mindata = np.amin(outdata, axis=0, keepdims=True)
    maxdata = np.amax(outdata, axis=0, keepdims=True)
    outdata -= mindata
    outdata /= (maxdata - mindata)
    outdata *= 0.8
    outdata += 0.1
    figmin = np.zeros((5, outdata.shape[1]))
    figmax = np.ones((5, outdata.shape[1]))
    outdata = np.concatenate((outdata, figmin, figmax), axis=0)

    dpi = 96
    a = float(outdata.shape[0]) / dpi
    b = float(outdata.shape[1]) / dpi

    f = plt.figure(figsize=(b, a), dpi=dpi)
    f.figimage(outdata)
    if save:
        f.savefig(name_save, dpi=f.dpi)
    return f
コード例 #12
0
 def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
              n_mel_channels=40, sampling_rate=16000, mel_fmin=0.0,
              mel_fmax=8000.0):
     """ mel 特征抽取
     :param filter_length: fft采样点数
     :param hop_length:  移动 stride
     :param win_length: 窗长
     :param n_mel_channels: mel channel 个数
     :param sampling_rate: 采样率
     :param mel_fmin:   最小截止频率
     :param mel_fmax:  最大截止频率
     """
     super(MelSpec, self).__init__()
     self.n_mel_channels = n_mel_channels
     self.sampling_rate = sampling_rate
     self.stft_fn = STFT(filter_length=filter_length, hop_length=hop_length, win_length=win_length)
     mel_bias = librosa_mel_fn(sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
     mel_bias = torch.from_numpy(mel_bias).float()
     self.register_buffer('mel_bias', mel_bias)
コード例 #13
0
def inference(wav, model, sample_length):

    vocal = []
    bgm = []
    print(len(wav))
    print(sample_length)
    batch_size = 2**13
    for i in tqdm.tqdm(range(len(wav) // (sample_length * batch_size))):
        start = i * sample_length * batch_size
        end = min((i + 1) * sample_length, len(wav))
        small_wavs = np.stack([
            wav[start + j * sample_length:start + (j + 1) * sample_length]
            for j in range(batch_size)
        ])
        #print(small_wavs.shape)
        in_wav = torch.autograd.Variable(torch.FloatTensor(small_wavs),
                                         requires_grad=False).cuda()
        #print(in_wav.shape)
        stft = STFT(input_data=in_wav).cuda()

        magnitude, phase = stft()
        magnitude = torch.squeeze(magnitude)
        phase = torch.squeeze(phase)
        size = [in_wav.size(1) for _ in range(in_wav.size(0))]

        #print(magnitude.shape)
        vocal_recon, noise_recon = model(magnitude.transpose(1, 2))

        #print(vocal_recon.shape)
        #print(noise_recon.shape)i

        vocal.append(
            reConstructWav(size,
                           vocal_recon.transpose(1, 2).cpu().detach(),
                           phase.cpu().detach()).view(-1))
        bgm.append(
            reConstructWav(size,
                           noise_recon.transpose(1, 2).cpu().detach(),
                           phase.cpu().detach()).view(-1))

    print(torch.cat(vocal).shape)
    return torch.cat(vocal).numpy(), torch.cat(bgm).numpy()
コード例 #14
0
def prepareDataFiles(store_data, song_name, mix_path, vocal_path, bgm_path):
    try:
        os.mkdir(os.path.join(store_data, song_name))
        os.mkdir(os.path.join(os.path.join(store_data, song_name), "mixture"))
        os.mkdir(os.path.join(os.path.join(store_data, song_name), "vocal"))
        os.mkdir(os.path.join(os.path.join(store_data, song_name), "noise"))
    except:
        pass

    mixture, mix_rate = librosa.core.load(mix_path, sr=16000)
    vocal, vocal_rate = librosa.core.load(vocal_path, sr=16000)
    bgm, bgm_rate = librosa.core.load(bgm_path, sr=16000)

    # Loop through wave form and zero out any values that are close to zero so that
    # there are no points that will explode into large values.
    # Need to check effect on Spectrum, since loss is done with the spectrums rather
    # than the waveforms themselves

    for stype, data, rate in zip(["mixture", "vocal", "noise"],
                                 [mixture, vocal, bgm],
                                 [mix_rate, vocal_rate, bgm_rate]):
        path = os.path.join(os.path.join(store_data, song_name), stype)
        filename = song_name

        in_wav = torch.autograd.Variable(torch.FloatTensor(data),
                                         requires_grad=False).unsqueeze(0)
        stft = STFT(input_data=in_wav)
        magnitude, phase = stft()
        magnitude = torch.squeeze(magnitude)
        phase = torch.squeeze(phase)
        size = in_wav.size(1)
        # f, t, Sxx = signal.stft(data,rate,nperseg=1000)
        # magnitude = np.abs(Sxx)
        # phase = np.unwrap(np.angle(Sxx),axis=-2)

        np.save(os.path.join(path, "rate_" + filename), rate)
        # np.save(os.path.join(path,"freq_"+ filename),f)
        # np.save(os.path.join(path,"time_"+ filename),t)
        np.save(os.path.join(path, "magnitude_" + filename), magnitude)
        np.save(os.path.join(path, "phase_" + filename), phase)
        np.save(os.path.join(path, "size_" + filename), size)
コード例 #15
0
ファイル: denoiser.py プロジェクト: SortAnon/hifi-gan
    def __init__(
        self, hifigan, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"
    ):
        super(Denoiser, self).__init__()
        self.stft = STFT(
            filter_length=filter_length,
            hop_length=int(filter_length / n_overlap),
            win_length=win_length,
        ).cuda()
        if mode == "zeros":
            mel_input = torch.zeros((1, 80, 88)).cuda()
        elif mode == "normal":
            mel_input = torch.randn((1, 80, 88)).cuda()
        else:
            raise Exception("Mode {} if not supported".format(mode))

        with torch.no_grad():
            bias_audio = hifigan(mel_input).view(1, -1).float()
            bias_spec, _ = self.stft.transform(bias_audio)

        self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
コード例 #16
0
ファイル: layers.py プロジェクト: yhgon/NanoFlow
    def __init__(self,
                 filter_length=1024,
                 hop_length=256,
                 win_length=1024,
                 n_mel_channels=80,
                 sampling_rate=22050,
                 mel_fmin=0.0,
                 mel_fmax=None,
                 ref_level_db=10.,
                 min_level_db=-100.):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(sampling_rate, filter_length,
                                   n_mel_channels, mel_fmin, mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

        # to be used for mel_spectrogram_dbver
        self.ref_level_db = ref_level_db
        self.min_level_db = min_level_db
コード例 #17
0
    def __init__(self,
                 model,
                 sound,
                 target,
                 decoder,
                 sample_rate=16000,
                 device="cpu",
                 save=None):
        """
        model: deepspeech model
        sound: raw sound data [-1 to +1] (read from torchaudio.load)
        label: string
        """
        self.sound = sound
        self.sample_rate = sample_rate
        self.target_string = target
        self.target = target
        self.__init_target()

        self.model = model
        self.model.to(device)
        self.model.train()
        self.decoder = decoder
        self.criterion = nn.CTCLoss()
        self.device = device
        n_fft = int(self.sample_rate * 0.02)
        hop_length = int(self.sample_rate * 0.01)
        win_length = int(self.sample_rate * 0.02)
        self.torch_stft = STFT(n_fft=n_fft,
                               hop_length=hop_length,
                               win_length=win_length,
                               window='hamming',
                               center=True,
                               pad_mode='reflect',
                               freeze_parameters=True,
                               device=self.device)
        self.save = save
コード例 #18
0
def reConstructWav(size, magnitude, phase):
    """the differentiable reconstruction for mixture spectrogram with vocal and noise"""
    stft = STFT(size=size, magnitude=magnitude, phase=phase)
    stft = stft.cuda()
    xrec = stft(inv=True)
    return xrec
コード例 #19
0
        for i in range(len(batch)):
            wav = batch[i][0]
            assert wav.shape[-1] >= n_samples
            wav_truncated[i, :] = wav[0, :n_samples]
        return wav_truncated


dataset = torchaudio.datasets.LJSPEECH('./data')
dataset = torch.utils.data.Subset(dataset, range(100))

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=8,
                                         shuffle=True,
                                         collate_fn=Collate())

stft_deterministic = STFT(filter_length=256, hop_length=128, win_length=256)
stft_model = STFT(filter_length=256,
                  hop_length=128,
                  win_length=256,
                  trainable=True)

criterion = nn.MSELoss()
optimizer = Adam(stft_model.parameters(), lr=1e-1)

n_epoch = 100
torch.save(stft_model.state_dict(), f'./experiments/trainable_fft_{-1}')
for epoch in range(n_epoch):

    for i, batch in enumerate(dataloader):
        stft_model.zero_grad()
        targ = stft_deterministic(batch)
コード例 #20
0
ファイル: generate.py プロジェクト: ktho22/vctts
if args.gpu is None:
    args.use_gpu = False
    args.gpu = []
else:
    args.use_gpu = True
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu[0])

model = Tacotron(args)
if args.init_from:
    model.load_state_dict(checkpoint['state_dict'])
    model.reset_decoder_states()
    print('loaded checkpoint %s' % (args.init_from))

stft = STFT(filter_length=args.n_fft)
model = model.eval()
if args.use_gpu:
    model = model.cuda()
    stft = stft.cuda()


def main():
    db = TTSDataset()
    collate = collate_class(use_txt=args.use_txt)
    loader = torch.utils.data.DataLoader(db,
                                         batch_size=1,
                                         shuffle=False,
                                         collate_fn=collate.fn,
                                         drop_last=True)
    model_name = args.init_from.split('/')[-1][:-3]
コード例 #21
0
# from torchaudio_stft import ISTFT, STFT

x = torch.rand((64, 1, 8000), requires_grad=True, dtype=torch.float32).cuda()
n_fft = 320
hop_length = 160

# ##### Test torchaudio.transforms.Spectrogram on multi-gpus
# window_fn = torch.hann_window
# power = None
# spectrogram = torchaudio.transforms.Spectrogram(
#     n_fft=n_fft,
#     hop_length=hop_length,
#     window_fn=window_fn,
#     power=power
# )
# spectrogram = nn.DataParallel(spectrogram)
# spectrogram.cuda()
# out = spectrogram(input_data)

##### Test torch.stft and torch.istft
x = F.pad(x, pad=(0, n_fft // 2), mode='constant', value=0)
stft_extractor = STFT(n_fft=n_fft, hop_length=hop_length, window='hann')
stft_extractor = nn.DataParallel(stft_extractor)
stft_extractor.cuda()
x_stft_real, x_stft_imag = stft_extractor(x)

istft_extractor = ISTFT(n_fft=n_fft, hop_length=hop_length, window='hann')
istft_extractor = nn.DataParallel(istft_extractor)
istft_extractor.cuda()
x_reconst = istft_extractor(x_stft_real, x_stft_imag, length=8000)
print(torch.max(torch.abs(x[..., :8000] - x_reconst)))
コード例 #22
0
ファイル: gan_accom.py プロジェクト: taalua/speech2singing
    m, opt, iteration = load_checkpoint(
        f'checkpoint/{args.checkpoint_path}/gen', m, opt)
    dis_high, opt_dis, iteration = load_checkpoint(
        f'checkpoint/{args.checkpoint_path}/dis', dis_high, opt_dis)
    dis_accom, opt_accom, iteration = load_checkpoint(
        f'checkpoint/{args.checkpoint_path}/dis_accom', dis_accom, opt_accom)
'''
###########################################################
    In general, we preprocess data to npy, and put them in 
    specific folder. Dataloader load npy file. 

    But in this example, I show that how to transfrom audio
    into stft, melspectrogram by torch.nn.module (MelSpectrogram).
###########################################################
'''
stft_fn = STFT(hp.filter_length, hp.hop_length, hp.win_length).cuda()

while True:

    voc = next(inf_iterator_voc_speech).cuda()
    #voc = stft_fn.transform_mag(voc)

    linear = next(inf_iterator_lin_speech).cuda()
    #linear = stft_fn.transform_mag(linear)
    accom = next(inf_iterator_accom_speech).cuda()

    voc = voc[..., :voc.size(2) // 16 * 16]
    linear = linear[..., :linear.size(2) // 16 * 16]
    accom = accom[..., :accom.size(2) // 16 * 16]

    fake_accom = voc