Beispiel #1
0
    def __init__(self, model_args=None):
        super().__init__()

        self.model_args = model_args
        self.hop_size = model_args['stride']
        self.channels = model_args['n_filters']
        self.win_len = model_args['kernel_size']
        self.bias = model_args['bias']
        self.enc_act = model_args['enc_act']
        self.regularizer = model_args['mask']

        if 'use_stft' not in model_args.keys():
            self.use_stft = None
        else:
            self.use_stft = model_args['stft']
        if self.use_stft:
            self.stft = STFT(filter_length=(self.channels - 1) * 2,
                             hop_length=self.hop_size,
                             win_length=self.win_len,
                             window='hann')
            self.encoder = Encoder1D(stft=self.stft)
            self.decoder = Decoder1D(stft=self.stft)
        else:
            self.encoder = Encoder1D(self.channels, self.win_len, self.hop_size, enc_act=self.enc_act, bias=self.bias)
            self.decoder = Decoder1D(self.channels, self.win_len, self.hop_size, bias=self.bias)

        self.eps = 10e-9
Beispiel #2
0
def _prepare_network(device,
                     filter_length=1024,
                     hop_length=512,
                     win_length=None,
                     window='hann'):
    stft = STFT(filter_length=filter_length,
                hop_length=hop_length,
                win_length=win_length,
                window=window).to(device)
    return stft
Beispiel #3
0
def inverse_stft(
    magnitude_list,
    phase_list,
    hop_length=512,
    win_length=2048,
    window="hann",
    device="cpu",
    input_tensor=False,
    output_tensor=False,
    stft=None,
):
    # hop_length is overlap length of window. Generally it use win_length//4
    filter_length = win_length

    if input_tensor is False:
        magnitude_list = [
            torch.FloatTensor(magnitude) for magnitude in magnitude_list
        ]
        phase_list = [torch.FloatTensor(phase) for phase in phase_list]

    magnitude_cat = torch.cat(magnitude_list, dim=0)
    phase_cat = torch.cat(phase_list, dim=0)

    if stft is None:
        stft = STFT(
            filter_length=filter_length,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
        ).to(device)

    reconstructed_audio = stft.inverse(magnitude_cat, phase_cat)
    reconstructed_audio = torch.split(reconstructed_audio,
                                      split_size_or_sections=1,
                                      dim=0)

    if output_tensor is True:
        return reconstructed_audio

    return [r_audio.cpu().numpy() for r_audio in reconstructed_audio]
Beispiel #4
0
def transform_stft(
    audio_list,
    hop_length=512,
    win_length=2048,
    window="hann",
    device="cpu",
    input_tensor=False,
    output_tensor=False,
    stft=None,
):
    # hop_length is overlap length of window. Generally it use win_length//4
    filter_length = win_length

    if input_tensor is False:
        audio_list = [torch.FloatTensor(audio) for audio in audio_list]

    audio_list = [audio.unsqueeze(0) for audio in audio_list]
    audio_cat = torch.cat(audio_list, dim=0)
    audio_cat = audio_cat.to(device)

    if stft is None:
        stft = STFT(
            filter_length=filter_length,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
        ).to(device)

    magnitudes, phases = stft.transform(audio_cat)
    magnitudes = torch.split(magnitudes, split_size_or_sections=1, dim=0)
    phases = torch.split(phases, split_size_or_sections=1, dim=0)

    if output_tensor is True:
        return magnitudes, phases

    return (
        [magnitude.cpu().numpy() for magnitude in magnitudes],
        [phase.cpu().numpy() for phase in phases],
    )
Beispiel #5
0
def _stft(y, pad_mode="constant"):
    global _torch_stft_instance
    if wavenet_hparams.use_torch_stft:
        y_torch = torch.FloatTensor(y)
        y_torch = y_torch.unsqueeze(0)
        if _torch_stft_instance is None:
            _torch_stft_instance = STFT(
                filter_length=wavenet_hparams.fft_size,
                hop_length=get_hop_size(),
                win_length=get_win_length(),
                window=wavenet_hparams.window,
            ).to("cpu")
        return _torch_stft_instance.transform(y_torch)
    # use constant padding (defaults to zeros) instead of reflection padding
    return librosa.stft(
        y=y,
        n_fft=wavenet_hparams.fft_size,
        hop_length=get_hop_size(),
        win_length=get_win_length(),
        window=wavenet_hparams.window,
        pad_mode=pad_mode,
    )
Beispiel #6
0
def main():
    # Define constant
    device = 'cuda'
    filter_length = 1024
    hop_length = 256
    win_length = 1024
    window = 'hann'
    n_iter = 50
    duration = 23

    # Load audio and form tensor
    audio, sr = librosa.load("your_music.mp3", duration=duration, offset=0)
    audio = torch.FloatTensor(audio)
    audio = audio.unsqueeze(0)
    audio = audio.to(device)

    # STFT (input size is [BATCH, N], N is the number of sample points)
    stft = STFT(filter_length=filter_length, hop_length=hop_length, win_length=win_length, window=window).to(device)
    magnitude, phase = stft.transform(audio)

    # Grifflin-lim to reconstruct the wave without phase
    wave = griffinlim(magnitude, sr*duration, n_iter=n_iter, angles=None, hop_length=hop_length, win_length=win_length, device=device)
    wave = wave.cpu().numpy()
    librosa.output.write_wav('out.wav', wave[0], sr, norm=True)
    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 window_size=.02,
                 window_stride=.01,
                 window_type='hamming',
                 normalize=True,
                 max_len=101):
        classes, class_to_idx = find_classes(root)
        spects = make_dataset(root, class_to_idx)
        if len(spects) == 0:
            raise (RuntimeError("Found 0 sound files in subfolders of: " +
                                root +
                                "Supported audio file extensions are: " +
                                ",".join(AUDIO_EXTENSIONS)))

        self.root = root
        self.spects = spects
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = spect_loader
        self.window_size = window_size
        self.window_stride = window_stride
        self.window_type = window_type
        self.normalize = normalize
        self.max_len = max_len

        sr = 16000
        n_fft = int(sr * window_size)
        win_length = n_fft
        hop_length = int(sr * window_stride)
        device = 'cpu'
        self.stft = STFT(filter_length=n_fft,
                         hop_length=hop_length,
                         win_length=win_length,
                         window=window_type).to(device)
import torch.nn as nn
import torch.nn.functional as F
import librosa
from torch.utils.data import DataLoader, Dataset
from hparam import hparam as hp
# from dvector_create import get_dvector_vf
from utils import for_stft, for_stft_2
# from torchsummary import summary
from model_gan_multi_SN_ori_ff import *
from torch_stft import STFT
# from tensorboardX import SummaryWriter

window_length = int(hp.data.window_gan * hp.data.sr)
hop_length = int(hp.data.hop * hp.data.sr)
stft = STFT(filter_length=hp.data.nfft,
            hop_length=hop_length,
            win_length=window_length)


class dataset_preprocess_FTGAN(Dataset):
    def __init__(self, shuffle=False):
        self.clean_tf = hp.data.gan_clean_tf
        self.noisy_tf = hp.data.gan_noisy_tf
        # self.clean_wav = hp.data.gan_clean_wav
        # self.noisy_wav = hp.data.gan_noisy_wav

        # self.file_list = os.listdir(self.clean_wav)
        self.file_list = os.listdir(self.clean_tf)
        self.shuffle = shuffle
        # self.utter_start = utter_start
    def __len__(self):
def frame_n_audio_to_datasets(
    frame_path_dir,
    audio_path_dir,
    file_filter="*",
    save_dir="./dataset",
    device="cpu",
    start_index=0,
):
    frame_path_list = glob(os.path.join(frame_path_dir, file_filter))
    audio_path_list = glob(os.path.join(audio_path_dir, file_filter))

    dirs = {
        data_type: os.path.join(save_dir, data_type)
        for data_type in FILENAME_TEMPLATE.keys()
    }

    for _, dir in dirs.items():
        os.makedirs(dir, exist_ok=True)

    #
    i = start_index
    raw_data = RawDatasetV2(frame_path_list=frame_path_list,
                            audio_path_list=audio_path_list,
                            transforms={})

    data_loader = DataLoader(
        dataset=raw_data,
        batch_size=DEFAULT_CONFIG["batch_size"],
        shuffle=DEFAULT_CONFIG["shuffle"],
        num_workers=DEFAULT_CONFIG["batch_size"] // 2,
        collate_fn=custom_collate_fn,
    )

    # build stft
    stft = STFT(
        filter_length=DEFAULT_CONFIG["win_length"],
        hop_length=DEFAULT_CONFIG["hop_length"],
        win_length=DEFAULT_CONFIG["win_length"],
        window=DEFAULT_CONFIG["window"],
    ).to(device)

    # Audio Part
    for frame_list, audio_list in data_loader:
        if frame_list is None or audio_list is None:
            continue

        magnitude_list, phase_list = transform_stft(
            audio_list=audio_list,
            hop_length=DEFAULT_CONFIG["hop_length"],
            win_length=DEFAULT_CONFIG["win_length"],
            window=DEFAULT_CONFIG["window"],
            device=device,
            input_tensor=True,
            output_tensor=False,
            stft=stft,
        )

        stft_to_mel_params = [{
            "magnitude": magnitude,
            "phase": phase
        } for magnitude, phase in zip(magnitude_list, phase_list)]
        melspecgrams = parallelize(
            func=stft_to_mel,
            params=stft_to_mel_params,
            n_jobs=DEFAULT_CONFIG["n_jobs"],
        )

        log_mel_spec_list = [item[0] for item in melspecgrams]
        mel_if_list = [item[1] for item in melspecgrams]

        # Save
        frame_list = [frame.cpu().numpy() for frame in frame_list]
        audio_list = [audio.cpu().numpy() for audio in audio_list]

        params = [{
            "data_dict": {
                "frame": frame,
                "audio": audio,
                "log_mel_spec": log_mel_spec,
                "mel_if": mel_if,
            },
            "dirs": dirs,
            "file_index": i + idx,
        } for idx, (frame, audio, log_mel_spec, mel_if) in enumerate(
            zip(frame_list, audio_list, log_mel_spec_list, mel_if_list))]

        # save files with parallel
        parallelize(func=save_files,
                    params=params,
                    n_jobs=DEFAULT_CONFIG["n_jobs"])

        i += len(frame_list)
Beispiel #10
0
filter_length = 1024
hop_length = 256
win_length = 1024  # doesn't need to be specified. if not specified, it's the same as filter_length
window = 'hann'
librosa_stft = librosa.stft(audio,
                            n_fft=filter_length,
                            hop_length=hop_length,
                            window=window)
_magnitude = np.abs(librosa_stft)

audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
audio = audio.to(device)

stft = STFT(filter_length=filter_length,
            hop_length=hop_length,
            win_length=win_length,
            window=window).to(device)

magnitude, phase = stft.transform(audio)
plt.figure(figsize=(6, 6))
plt.subplot(211)
plt.title('PyTorch STFT magnitude')
plt.xlabel('Frames')
plt.ylabel('FFT bin')
plt.imshow(20 * np.log10(1 + magnitude[0].cpu().data.numpy()),
           aspect='auto',
           origin='lower')

plt.subplot(212)
plt.title('Librosa STFT magnitude')
plt.xlabel('Frames')
Beispiel #11
0
    def __init__(self):
        super(VoiceFilter_SN, self).__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(1, 7),
                      padding=(0, 3),
                      dilation=(1, 1)), nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(7, 1),
                      padding=(3, 0),
                      dilation=(1, 1)), nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(5, 5),
                      padding=(2, 2),
                      dilation=(1, 1)), nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(5, 5),
                      padding=(6, 2),
                      dilation=(3, 1)), nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(5, 5),
                      padding=(10, 2),
                      dilation=(5, 1)), nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(5, 5),
                      padding=(26, 2),
                      dilation=(13, 1)), nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=64,
                      out_channels=8,
                      kernel_size=(1, 1),
                      dilation=(1, 1)), nn.BatchNorm2d(8),
            nn.ReLU(inplace=False))

        self.LSTM1 = nn.LSTM(257 * 8, 400, num_layers=1, batch_first=True)
        for name, param in self.LSTM1.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0.0)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)

        self.FC3 = nn.Sequential(nn.Dropout(0.5, False), nn.Linear(400, 600),
                                 nn.ReLU())
        for name, param in self.FC3.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)

        self.FC4 = nn.Sequential(nn.Dropout(0.5, False), nn.Linear(600, 257))
        for name, param in self.FC4.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                nn.init.xavier_normal_(param)
        # self.loss1 = nn.L1Loss()
        self.stft = STFT(filter_length=hp.data.nfft,
                         hop_length=hop_length,
                         win_length=window_length)
Beispiel #12
0
def train():
    device_ids = [8, 9, 10, 11, 12, 13, 14, 15]
    lr_factor = 6
    iteration = 0
    # writer = SummaryWriter('Gan_Loss_Log')
    train_dataset = dataset_preprocess_FTGAN()
    # num_loader < 2*12
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.batch_gan * len(device_ids),
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)
    G = VoiceFilter_SN()
    D = discriminator_SN()
    # G.eval()
    # D.eval()
    # G.load_state_dict(torch.load("/data2/lps/model_gan/G/epoch_4.pth"))
    # D.load_state_dict(torch.load("/data2/lps/model_gan/D/epoch_4.pth"))
    # save_model_G = torch.load("/data2/lps/model_gan/G/multi_epoch_6.pth")
    # save_model_D = torch.load("/data2/lps/model_gan/D/multi_epoch_6.pth")
    # model_dict_G = G.state_dict()
    # model_dict_D = D.state_dict()
    # state_dict_G = {k:v for k,v in save_model_G.items() if k in model_dict_G.keys()}
    # state_dict_D = {k:v for k,v in save_model_D.items() if k in model_dict_D.keys()}
    # model_dict_G.update(state_dict_G)
    # model_dict_D.update(state_dict_D)
    # G.load_state_dict(model_dict_G)
    # D.load_state_dict(model_dict_D)
    stft = STFT(filter_length=hp.data.nfft,
                hop_length=hop_length,
                win_length=window_length).cuda(device_ids[0])
    G = G.cuda(device_ids[0])
    D = D.cuda(device_ids[0])
    G = torch.nn.DataParallel(G, device_ids=device_ids)
    D = torch.nn.DataParallel(D, device_ids=device_ids)

    BCE_loss_fn = nn.BCELoss()
    L1loss = nn.L1Loss()
    # optimizer = torch.optim.SGD(voicefilter_net.parameters(), lr=hp.train.lr * lr_factor, momentum=hp.train.momentum)
    optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0001 * lr_factor)
    optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0004 * lr_factor)
    G.train()
    D.train()

    real_vector = torch.ones(hp.train.batch_gan * len(device_ids),
                             1).cuda(device_ids[0])
    fake_vector = torch.zeros(hp.train.batch_gan * len(device_ids),
                              1).cuda(device_ids[0])
    # a = torch.Tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
    # b = a.transpose(1, 2)
    # c = a.permute(0, 2, 1)
    # d = a.view(2, 6)
    # print('a:', a)
    # print('b:', b)
    # print('c:', c)
    # print('d:' d) # 执行的是按照数字的顺序将数组变成所想要的形状的数组
    for e in range(1000):
        # print('e:', e)
        total_loss_L1 = 0
        total_loss_G = 0
        total_loss_D = 0
        total_loss_G_all = 0
        # (10, 257, 101), (10, 257, 101), (10, 16000), (10, 16000), (10, 257, 101)
        for num, (clean_tf, noisy_tf, clean_wav, noisy_wav,
                  noisy_phase) in enumerate(train_loader):
            # print('num:', num)
            # print('noisy_phase.size:', noisy_phase.shape)
            clean_wav_gpu = clean_wav.unsqueeze(1).cuda(
                device_ids[0])  # (10, 1, 16000)
            noisy_wav_gpu = noisy_wav.unsqueeze(1).cuda(
                device_ids[0])  # (10, 1, 16000)
            clean_tf_gpu = clean_tf.transpose(1, 2).cuda(
                device_ids[0])  # (10, 101, 257)
            noisy_tf_gpu = noisy_tf.transpose(1, 2).unsqueeze(1).cuda(
                device_ids[0])  # (10, 1, 101, 257)
            noisy_phase_gpu = noisy_phase.cuda(device_ids[0])  # (10, 257, 101)
            # print('nosiy_phase_gpu.size:', noisy_phase_gpu.shape)
            # print('clean_wav_gpu.size:', clean_wav_gpu.shape)
            # print('noisy_wav_gpu.size:', noisy_wav_gpu.shape)
            # -----------------------------------------------------
            optimizer_G.zero_grad()
            G_mask = G(noisy_tf_gpu)
            noisy_tf_gpu = torch.squeeze(noisy_tf_gpu)
            G_tf = G_mask * noisy_tf_gpu
            loss_G_1 = L1loss(G_tf, clean_tf_gpu)
            G_tf_1 = G_tf.transpose(1, 2)
            # print('G_tf_1.shape:', G_tf_1.shape, 'noisy_phase_gpu.shape:', noisy_phase_gpu.shape)
            G_wav_gpu = stft.inverse(G_tf_1, noisy_phase_gpu)
            G_wav_gpu = G_wav_gpu.unsqueeze(1)
            G_wav_D = torch.cat((G_wav_gpu, noisy_wav_gpu), dim=1)
            D_clean = torch.cat((clean_wav_gpu, noisy_wav_gpu), dim=1)
            in_1 = D(G_wav_D)
            in_2 = D(D_clean)
            in_3 = F.sigmoid(in_1 - in_2)
            loss_G_2 = BCE_loss_fn(in_3, real_vector)
            loss_G = 100 * loss_G_1 + loss_G_2
            loss_G.backward()
            optimizer_G.step()
            #    Train Discriminator

            optimizer_D.zero_grad()
            G_wav_gpu_1 = G_wav_gpu.detach()
            G_wav_D_1 = torch.cat((G_wav_gpu_1, noisy_wav_gpu), dim=1)
            in_4 = D(G_wav_D_1)
            in_5 = D(D_clean)
            in_6 = F.sigmoid(in_5 - in_4)
            in_7 = F.sigmoid(in_4 - in_5)
            # loss_D = D(D_clean, G_wav_D_1, real_vector)
            loss_D_real = BCE_loss_fn(in_6, real_vector)
            loss_D_fake = BCE_loss_fn(in_7, fake_vector)
            loss_D = (loss_D_real + loss_D_fake) / 2
            loss_D.backward()
            optimizer_D.step()
            # print('loss_G:', loss_G, 'loss_D:', loss_D)
            # print('finish the first')
            # -------------------------------------------------
            iteration = iteration + 1
            total_loss_G = total_loss_G + loss_G_2
            total_loss_L1 = total_loss_L1 + loss_G_1
            total_loss_D = total_loss_D + loss_D
            total_loss_G_all = total_loss_G_all + loss_G
            # if (iteration % 10 == 0):
            #     niter = hp.train.batch_gan*len(device_ids)*len(train_loader) + iteration
            #     writer.add_scalar('Train/G_1', loss_G_1, niter)
            #     writer.add_scalar('Train/G_2', loss_G_2, niter)
            #     writer.add_scalar('Train/G', loss_G, niter)
            #     writer.add_scalar('Train/D', loss_D, niter)
            if (iteration % 50 == 0):
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss_L1:{5:.5f}\tLoss_G:{6:.5f}\tLoss_D:{7:.5f}\tLoss_G_all:{8:.5f}\tL1_loss_ave:{9:.6f}\tG_loss_ave:{10:.6f}\tD_loss_ave:{11:.6f}\tG_all_loss_ave:{12:.6f}\t\n".format(
                    time.ctime(), e + 1, num,
                    len(train_dataset) //
                    (hp.train.batch_gan * len(device_ids)), iteration,
                    loss_G_1, loss_G_2, loss_D, loss_G, total_loss_L1 / num,
                    total_loss_G / num, total_loss_D / num,
                    total_loss_G_all / num)
                print(mesg)

        mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss_L1:{5:.5f}\tLoss_G:{6:.5f}\tLoss_D:{7:.5f}\tLoss_G_all:{8:.5f}\tL1_loss_ave:{9:.6f}\tG_loss_ave:{10:.6f}\tD_loss_ave:{11:.6f}\tG_all_loss_ave:{12:.6f}\t\n".format(
            time.ctime(), e + 1, num,
            len(train_dataset) // (hp.train.batch_gan * len(device_ids)),
            iteration, loss_G_1, loss_G_2, loss_D, loss_G, total_loss_L1 / num,
            total_loss_G / num, total_loss_D / num, total_loss_G_all / num)
        print(mesg)
        # print('aaaaaa')
        if (e + 1) % 1 == 0:
            model_mid_name = 'multi_epoch_' + str(e + 1) + '.pth'
            G.eval()
            G.cpu()
            model_mid_path_G = os.path.join('/workspace/model/rgan_SN_2/G',
                                            model_mid_name)
            torch.save(G.state_dict(), model_mid_path_G)
            D.eval()
            D.cpu()
            model_mid_path_D = os.path.join('/workspace/model/rgan_SN_2/D',
                                            model_mid_name)
            torch.save(D.state_dict(), model_mid_path_D)
            G.cuda(device_ids[0]).train()
            D.cuda(device_ids[0]).train()