Ejemplo n.º 1
0
def train_rpca(dataloader):
    start_time = time.time()
    for batch_idx, (mixed, s1, s2, lengths) in enumerate(dataloader):
        for i in range(len(mixed)):
            mixed_spec = get_spec(mixed[i])
            mixed_mag, mixed_phase = separate_magnitude_phase(mixed_spec)
            rpca = R_pca(mixed_mag)
            X_music, X_sing = rpca.fit()
            # X_sing, X_music = time_freq_masking(mixed_spec, X_music, X_sing)

            # reconstruct wav
            pred_music_wav = librosa.istft(
                combine_magnitdue_phase(X_music, mixed_phase))
            pred_sing_wav = librosa.istft(
                combine_magnitdue_phase(X_sing, mixed_phase))

            nsdr, sir, sar, lens = bss_eval(mixed[i], s1[i], s2[i],
                                            pred_music_wav, pred_sing_wav)
            scorekeepr.update(nsdr, sir, sar, lens)
        scorekeepr.print_score()

        print("time elasped", time.time() - start_time)
        print("{} / {}".format(batch_idx, len(dataloader)))
Ejemplo n.º 2
0
def main(args):

    #先看数据集数据是否存在
    if not os.path.exists(args.dataset_train_dir) or not os.path.exists(
            args.dataset_validate_dir):
        raise NameError(
            '数据集路径"./dataset/MIR-1K/Wavfile"或"./dataset/MIR-1K/UndividedWavfile"不存在!'
        )

    # 1. 导入需要训练的数据集文件路径,存到列表中即可
    train_file_list = load_file(args.dataset_train_dir)
    valid_file_list = load_file(args.dataset_validate_dir)

    # 数据集的采样率
    mir1k_sr = args.dataset_sr
    # 用于短时傅里叶变换,窗口大小
    n_fft = 1024
    # 步幅;帧移对应卷积中的stride;
    hop_length = n_fft // 4

    # Model parameters
    # 学习率
    learning_rate = args.learning_rate

    # 用于创建rnn节点数
    num_hidden_units = [1024, 1024, 1024, 1024, 1024]
    # batch 长度
    batch_size = args.batch_size
    # 获取多少帧数据
    sample_frames = args.sample_frames
    # 训练迭代次数
    iterations = args.iterations
    # dropout
    dropout_rate = args.dropout_rate

    # 模型保存路径
    model_dir = args.model_dir
    model_filename = args.model_filename

    #导入训练数据集的wav数据,
    #wavs_mono_train存的是单声道,wavs_music_train 存的是背景音乐,wavs_voice_train 存的是纯人声
    wavs_mono_train, wavs_music_train, wavs_voice_train = load_wavs(
        filenames=train_file_list, sr=mir1k_sr)
    # 通过短时傅里叶变换将声音转到频域
    stfts_mono_train, stfts_music_train, stfts_voice_train = wavs_to_specs(
        wavs_mono=wavs_mono_train,
        wavs_music=wavs_music_train,
        wavs_voice=wavs_voice_train,
        n_fft=n_fft,
        hop_length=hop_length)

    # 跟上面一样,只不过这里是测试集的数据
    wavs_mono_valid, wavs_music_valid, wavs_voice_valid = load_wavs(
        filenames=valid_file_list, sr=mir1k_sr)
    stfts_mono_valid, stfts_music_valid, stfts_voice_valid = wavs_to_specs(
        wavs_mono=wavs_mono_valid,
        wavs_music=wavs_music_valid,
        wavs_voice=wavs_voice_valid,
        n_fft=n_fft,
        hop_length=hop_length)

    #初始化模型
    model = SVMRNN(num_features=n_fft // 2 + 1,
                   num_hidden_units=num_hidden_units)

    # 加载模型,如果没有模型,则初始化所有变量
    startepo = model.load(file_dir=model_dir)

    print('startepo:' + str(startepo))

    #开始训练
    for i in (range(iterations)):
        #从模型中断处开始训练
        if i < startepo:
            continue

        # 获取下一batch数据
        data_mono_batch, data_music_batch, data_voice_batch = get_next_batch(
            stfts_mono=stfts_mono_train,
            stfts_music=stfts_music_train,
            stfts_voice=stfts_voice_train,
            batch_size=batch_size,
            sample_frames=sample_frames)

        #获取频率值
        x_mixed_src, _ = separate_magnitude_phase(data=data_mono_batch)
        y_music_src, _ = separate_magnitude_phase(data=data_music_batch)
        y_voice_src, _ = separate_magnitude_phase(data=data_voice_batch)

        #送入神经网络,开始训练
        train_loss = model.train(x_mixed_src=x_mixed_src,
                                 y_music_src=y_music_src,
                                 y_voice_src=y_voice_src,
                                 learning_rate=learning_rate,
                                 dropout_rate=dropout_rate)

        if i % 10 == 0:
            print('Step: %d Train Loss: %f' % (i, train_loss))

        if i % 200 == 0:
            #这里是测试模型准确率的
            print('==============================================')
            data_mono_batch, data_music_batch, data_voice_batch = get_next_batch(
                stfts_mono=stfts_mono_valid,
                stfts_music=stfts_music_valid,
                stfts_voice=stfts_voice_valid,
                batch_size=batch_size,
                sample_frames=sample_frames)

            x_mixed_src, _ = separate_magnitude_phase(data=data_mono_batch)
            y_music_src, _ = separate_magnitude_phase(data=data_music_batch)
            y_voice_src, _ = separate_magnitude_phase(data=data_voice_batch)

            y_music_src_pred, y_voice_src_pred, validate_loss = model.validate(
                x_mixed_src=x_mixed_src,
                y_music_src=y_music_src,
                y_voice_src=y_voice_src,
                dropout_rate=dropout_rate)
            print('Step: %d Validation Loss: %f' % (i, validate_loss))
            print('==============================================')

        if i % 200 == 0:
            model.save(directory=model_dir,
                       filename=model_filename,
                       global_step=i)
Ejemplo n.º 3
0
def eval(args):
    mir1k_sr = 16000
    n_fft = 1024
    hop_length = n_fft // 4
    num_rnn_layer = 3
    num_hidden_units = args['hidden_size']
    checkpoint = torch.load("model_10000.pth")

    mir1k_dir = 'data/MIR1K/MIR-1K'
    test_path = os.path.join(mir1k_dir, 'test_temp.json')
    # test_path = os.path.join(mir1k_dir, 'MIR-1K_test.json')

    with open(test_path, 'r') as text_file:
        content = json.load(text_file)
        # content = text_file.readlines()
    # wav_filenames = [file.strip() for file in content] 
    wav_filenames = ["{}/{}".format("data/MIR1K/MIR-1K/Wavfile", f) for f in content]
    print(len(wav_filenames))
    split_size = int(len(wav_filenames)/5.)
    model = EnsembleModel(n_fft // 2 , 512).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    wavs_src1_pred = list()
    wavs_src2_pred = list()
    model.eval()
    step = 1
    for i in range(5):
        start = i*split_size
        wavs_mono, wavs_src1, wavs_src2 = load_wavs(filenames = wav_filenames[start:start+split_size], sr = mir1k_sr)

        stfts_mono, stfts_src1, stfts_src2 = wavs_to_specs(
            wavs_mono = wavs_mono, wavs_src1 = wavs_src1, wavs_src2 = wavs_src2, n_fft = n_fft, hop_length = hop_length)
        stfts_mono_full, stfts_src1_full, stfts_src2_full = prepare_data_full(stfts_mono = stfts_mono, stfts_src1 = stfts_src1, stfts_src2 = stfts_src2)
        # print(len(stfts_mono_full))
        with torch.no_grad():
            for wav_filename, wav_mono, stft_mono_full in zip(wav_filenames, wavs_mono, stfts_mono_full):
                # print(stft_mono_full.shape)
                stft_mono_magnitude, stft_mono_phase = separate_magnitude_phase(data = stft_mono_full)
                max_length_even = stft_mono_magnitude.shape[0]-1 if (stft_mono_magnitude.shape[0]%2 != 0) else stft_mono_magnitude.shape[0]
                stft_mono_magnitude = np.array([stft_mono_magnitude[:max_length_even,:512]])
                # print(stft_mono_magnitude.shape)
                stft_mono_magnitude = torch.Tensor(stft_mono_magnitude).to(device)

                orig_length = max_length_even
                # reminder = np.floor(orig_length / 64)
                # print(64*reminder)
                startIdx = 0
                y1_pred_list = np.zeros((orig_length, 512), dtype=np.float32) # (batch, 512, 64)
                y2_pred_list = np.zeros((orig_length, 512), dtype=np.float32)
                while startIdx+64 < orig_length:
                    y1_pred, y2_pred = model(stft_mono_magnitude[:, startIdx: startIdx+64, :])

                # ISTFT with the phase from mono
                    y1_pred = y1_pred.cpu().numpy()
                    y2_pred = y2_pred.cpu().numpy()
                    y1_pred_list[startIdx: startIdx+64, :] = y1_pred[0]
                    y2_pred_list[startIdx: startIdx+64, :] = y2_pred[0]

                    startIdx += 64
                # calcualte things outside of 64 size blocks
                # y1_pred, y2_pred = model(stft_mono_magnitude[:, startIdx: orig_length, :])

                # y1_pred = y1_pred.cpu().numpy()
                # y2_pred = y2_pred.cpu().numpy()
                # y1_pred_list[startIdx: orig_length, :] = y1_pred[0]
                # y2_pred_list[startIdx: orig_length, :] = y2_pred[0]


                y1_stft_hat = combine_magnitdue_phase(magnitudes = y1_pred_list[:(startIdx),:], phases = stft_mono_phase[:(startIdx), :512])
                y2_stft_hat = combine_magnitdue_phase(magnitudes = y2_pred_list[:(startIdx),:], phases = stft_mono_phase[:(startIdx), :512])

                y1_stft_hat = y1_stft_hat.transpose()
                y2_stft_hat = y2_stft_hat.transpose()

                y1_hat = librosa.istft(y1_stft_hat, hop_length = hop_length)
                y2_hat = librosa.istft(y2_stft_hat, hop_length = hop_length)


                wavs_src1_pred.append(y1_hat)
                wavs_src2_pred.append(y2_hat)
                print("{}/{}\n".format(step, len(wav_filenames)))
                step += 1
    wavs_mono, wavs_src1, wavs_src2 = load_wavs(filenames = wav_filenames, sr = mir1k_sr)
    gnsdr, gsir, gsar = bss_eval_global(wavs_mono = wavs_mono, wavs_src1 = wavs_src1, wavs_src2 = wavs_src2, wavs_src1_pred = wavs_src1_pred, wavs_src2_pred = wavs_src2_pred)

    print('GNSDR:', gnsdr)
    print('GSIR:', gsir)
    print('GSAR:', gsar)
Ejemplo n.º 4
0
def main(args):
    input_dir = args.input_dir
    output_dir = args.output_dir
    dataset_sr = args.dataset_sr
    model_dir = args.model_dir
    dropout_rate = args.dropout_rate

    #如果输入目录不存在,返回错误
    if not os.path.exists(input_dir):
        raise NameError('音频输入文件夹"./songs/input"不存在!')

    #输出文件夹不存在,则创建
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    #找到要分离背景音乐和人声的音频文件
    song_filenames = list()
    for file in os.listdir(input_dir):
        if file.endswith('.mp3'):
            song_filenames.append(os.path.join(input_dir, file))

    #加载输入音频文件
    wavs_mono = list()
    for filename in song_filenames:
        wav_mono, _ = librosa.load(filename, sr=dataset_sr, mono=True)
        wavs_mono.append(wav_mono)

    # 用于短时傅里叶变换,窗口大小
    n_fft = 1024
    # 步幅;帧移对应卷积中的stride;
    hop_length = n_fft // 4
    # 用于创建rnn节点数
    num_hidden_units = [1024, 1024, 1024, 1024, 1024]

    #将其转到频域
    stfts_mono = list()
    for wav_mono in wavs_mono:
        stft_mono = librosa.stft(wav_mono, n_fft = n_fft, hop_length = hop_length)
        stfts_mono.append(stft_mono.transpose())

    #初始化神经网络
    model = SVMRNN(num_features = n_fft // 2 + 1, num_hidden_units = num_hidden_units)
    #导入模型
    model.load(file_dir = model_dir)

    for wav_filename, wav_mono, stft_mono in zip(song_filenames, wavs_mono, stfts_mono):
        wav_filename_base = os.path.basename(wav_filename)
        #单声道音频文件
        wav_mono_filename = wav_filename_base.split('.')[0] + '_mono.wav'
        #分离后的背景音乐音频文件
        wav_music_filename = wav_filename_base.split('.')[0] + '_music.wav'
        #分离后的人声音频文件
        wav_voice_filename = wav_filename_base.split('.')[0] + '_voice.wav'

        #要保存的文件的相对路径
        wav_mono_filepath = os.path.join(output_dir, wav_mono_filename)
        wav_music_hat_filepath = os.path.join(output_dir, wav_music_filename)
        wav_voice_hat_filepath = os.path.join(output_dir, wav_voice_filename)

        print('Processing %s ...' % wav_filename_base)

        stft_mono_magnitude, stft_mono_phase = separate_magnitude_phase(data = stft_mono)
        stft_mono_magnitude = np.array([stft_mono_magnitude])

        y_music_pred, y_voice_pred = model.test(x_mixed_src = stft_mono_magnitude, dropout_rate = dropout_rate)

        # 根据振幅和相位,转为复数,用于下面的逆短时傅里叶变换
        y_music_stft_hat = combine_magnitude_phase(magnitudes = y_music_pred[0], phases = stft_mono_phase)
        y_voice_stft_hat = combine_magnitude_phase(magnitudes = y_voice_pred[0], phases = stft_mono_phase)

        y_music_stft_hat = y_music_stft_hat.transpose()
        y_voice_stft_hat = y_voice_stft_hat.transpose()

        #逆短时傅里叶变换,将数据从频域转到时域
        y_music_hat = librosa.istft(y_music_stft_hat, hop_length = hop_length)
        y_voice_hat = librosa.istft(y_voice_stft_hat, hop_length = hop_length)

        #保存数据
        librosa.output.write_wav(wav_mono_filepath, wav_mono, dataset_sr)
        librosa.output.write_wav(wav_music_hat_filepath, y_music_hat, dataset_sr)
        librosa.output.write_wav(wav_voice_hat_filepath, y_voice_hat, dataset_sr)
Ejemplo n.º 5
0
def main(args):
    input_dir = args.input_dir
    output_dir = args.output_dir
    dataset_sr = args.dataset_sr
    model_dir = args.model_dir
    dropout_rate = args.dropout_rate

    if not os.path.exists(input_dir):
        raise NameError('音频输入文件夹"./songs/input"不存在!')

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    song_filenames = list()
    file_route1 = 'voice'
    file_route2 = 'music'
    for file in os.listdir(input_dir):
        if file.endswith('.mp3'):
            song_filenames.append(os.path.join(input_dir, file))

    wavs_mono = list()
    for filename in song_filenames:
        print('wf', filename)
        wav_mono, _ = librosa.load(filename, sr=dataset_sr, mono=True)
        wavs_mono.append(wav_mono)

    n_fft = 1024
    hop_length = n_fft // 4
    num_hidden_units = [1024, 1024, 1024, 1024, 1024]

    stfts_mono = list()
    for wav_mono in wavs_mono:
        stft_mono = librosa.stft(wav_mono, n_fft=n_fft, hop_length=hop_length)
        stfts_mono.append(stft_mono.transpose())

    model = SVMRNN(num_features=n_fft // 2 + 1,
                   num_hidden_units=num_hidden_units)
    model.load(file_dir=model_dir)

    for wav_filename, wav_mono, stft_mono in zip(song_filenames, wavs_mono,
                                                 stfts_mono):
        wav_filename_base = os.path.basename(wav_filename)
        wav_mono_filename = 'mono.wav'
        #分离后的背景音乐音频文件
        wav_music_filename = 'music.wav'
        #分离后的人声音频文件
        wav_voice_filename = 'voice.wav'

        #要保存的文件的相对路径
        wav_mono_filepath = os.path.join(output_dir, wav_mono_filename)
        wav_music_hat_filepath = os.path.join(output_dir, wav_music_filename)
        wav_voice_hat_filepath = os.path.join(output_dir, wav_voice_filename)

        print('Processing %s ...' % wav_filename_base)

        stft_mono_magnitude, stft_mono_phase = separate_magnitude_phase(
            data=stft_mono)
        stft_mono_magnitude = np.array([stft_mono_magnitude])

        y_music_pred, y_voice_pred = model.test(
            x_mixed_src=stft_mono_magnitude, dropout_rate=dropout_rate)

        y_music_stft_hat = combine_magnitude_phase(magnitudes=y_music_pred[0],
                                                   phases=stft_mono_phase)
        y_voice_stft_hat = combine_magnitude_phase(magnitudes=y_voice_pred[0],
                                                   phases=stft_mono_phase)

        y_music_stft_hat = y_music_stft_hat.transpose()
        y_voice_stft_hat = y_voice_stft_hat.transpose()

        y_music_hat = librosa.istft(y_music_stft_hat, hop_length=hop_length)
        y_voice_hat = librosa.istft(y_voice_stft_hat, hop_length=hop_length)

        librosa.output.write_wav(wav_mono_filepath, wav_mono, dataset_sr)
        librosa.output.write_wav(wav_music_hat_filepath, y_music_hat,
                                 dataset_sr)
        librosa.output.write_wav(wav_voice_hat_filepath, y_voice_hat,
                                 dataset_sr)
    y, sr = librosa.load(music_flie)
    S = np.abs(librosa.stft(y))
    print(librosa.power_to_db(S**2))
    if music_flie == 'songs/input/bbb.wav':
        print(file_route1)
    else:
        print(file_route2)
Ejemplo n.º 6
0
def train_rnn(args):

    mir1k_dir = 'data/MIR1K/MIR-1K'
    # train_path = os.path.join(mir1k_dir, 'MIR-1K_train.json')
    # valid_path = os.path.join(mir1k_dir, 'MIR-1K_val.json')

    train_path = os.path.join(mir1k_dir, 'train_temp.json')
    valid_path = os.path.join(mir1k_dir, 'val_temp.json')

    wav_filenames_train = []

    with open(train_path, 'r') as f:
        content = json.load(f)
    wav_filenames_train = np.array(
        ["{}/{}".format("data/MIR1K/MIR-1K/Wavfile", f) for f in content])

    with open(valid_path, 'r') as text_file:
        content = json.load(text_file)
    wav_filenames_valid = np.array(
        ["{}/{}".format("data/MIR1K/MIR-1K/Wavfile", f) for f in content])

    # Preprocess parameters
    mir1k_sr = 16000
    n_fft = 1024
    hop_length = n_fft // 4
    # Model parameters
    learning_rate = args['learning_rate']
    num_rnn_layer = args['num_layers']
    num_hidden_units = args['hidden_size']
    dropout = args['dropout']
    sample_frames = args['sample_frames']
    save_dirs = "checkpoint"
    batch_size = 64
    iterations = 100000

    # trial_id = nni.get_sequence_id()
    # if not os.path.isdir('checkpoint'):
    #     os.mkdir('checkpoint')
    # save_dir = 'checkpoint/trial'+str(trial_id) + "/"
    # os.makedirs(save_dir)
    # # store param in each trail (for testing)
    # with open(save_dir+"params.json", "w") as f:
    #     json.dump(args, f)
    save_dir = "./"
    # train_log_filename = save_dir + 'train_log_temp.csv'
    train_log_filename = 'train_log_temp.csv'

    # Load train wavs
    # Turn waves to spectrums
    random_wavs = np.random.choice(len(wav_filenames_train),
                                   len(wav_filenames_train),
                                   replace=False)
    wavs_mono_train, wavs_src1_train, wavs_src2_train = load_wavs(
        filenames=wav_filenames_train[random_wavs], sr=mir1k_sr)

    stfts_mono_train, stfts_src1_train, stfts_src2_train = wavs_to_specs(
        wavs_mono=wavs_mono_train,
        wavs_src1=wavs_src1_train,
        wavs_src2=wavs_src2_train,
        n_fft=n_fft,
        hop_length=hop_length)
    # Initialize model

    model = BaselineModelTemp(n_fft // 2, 512, dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 1645314
    loss_fn = nn.MSELoss().to(device)
    step = 1.

    start_time = time.time()
    total_loss = 0.
    train_step = 1.
    total_loss = 0.
    best_val_loss = np.inf
    stop = 0
    for i in (range(iterations)):
        model.train()
        data_mono_batch, data_src1_batch, data_src2_batch = sample_data_batch(
            stfts_mono=stfts_mono_train,
            stfts_src1=stfts_src1_train,
            stfts_src2=stfts_src2_train,
            batch_size=batch_size,
            sample_frames=sample_frames)

        x_mixed, _ = separate_magnitude_phase(data=data_mono_batch)
        y1, _ = separate_magnitude_phase(data=data_src1_batch)
        y2, _ = separate_magnitude_phase(data=data_src2_batch)

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i, learning_rate)

        max_length_even = x_mixed.shape[2] - 1 if (
            x_mixed.shape[2] % 2 != 0) else x_mixed.shape[2]

        x_mixed = torch.Tensor(x_mixed[:, :, :max_length_even]).to(device)
        y1 = torch.Tensor(y1[:, :, :max_length_even]).to(device)
        y2 = torch.Tensor(y2[:, :, :max_length_even]).to(device)
        pred_s1, pred_s2 = model(x_mixed)

        loss = loss_fn(torch.cat((pred_s1, pred_s2), 1), torch.cat(
            (y1, y2),
            1))  #((y1-pred_s1)**2 + (y2-pred_s2)**2).sum()/y1.data.nelement()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        print("iteration: ", i, ", loss: ", total_loss / train_step)
        print("time elasped", time.time() - start_time)

        train_step += 1
        # validate and save progress
        if i % 2000 == 0:
            # reset
            wavs_mono_train, wavs_src1_train, wavs_src2_train = None, None, None
            stfts_mono_train, stfts_src1_train, stfts_src2_train = None, None, None
            # start valdiation
            wavs_mono_valid, wavs_src1_valid, wavs_src2_valid = load_wavs(
                filenames=wav_filenames_valid, sr=mir1k_sr)

            mixed_stft, s1_stft, s2_stft = get_specs_transpose(
                wavs_mono_valid, wavs_src1_valid, wavs_src2_valid)
            val_len = len(mixed_stft)
            model.eval()
            val_losses = 0.
            with torch.no_grad():
                for j, (mix_spec, s1_spec,
                        s2_spec) in enumerate(zip(mixed_stft, s1_stft,
                                                  s2_stft)):
                    x_mixed, _ = separate_magnitude_phase(data=mix_spec)
                    y1, _ = separate_magnitude_phase(data=s1_spec)
                    y2, _ = separate_magnitude_phase(data=s2_spec)
                    length = x_mixed.shape[0] - x_mixed.shape[0] % 2
                    # print(length)
                    x_mixed = torch.Tensor(
                        x_mixed[:length, :512]).unsqueeze(0).to(device)
                    y1 = torch.Tensor(
                        y1[:length, :512]).unsqueeze(0).to(device)
                    y2 = torch.Tensor(
                        y2[:length, :512]).unsqueeze(0).to(device)

                    pred_s1, pred_s2 = model(x_mixed)
                    loss = ((y1 - pred_s1)**2 +
                            (y2 - pred_s2)**2).sum() / y1.data.nelement()
                    val_losses += loss.cpu().numpy()
            # nni.report_intermediate_result(val_losses/val_len)
            print("{}, {}, {}\n".format(i, total_loss / train_step,
                                        val_losses / len(mixed_stft)))

            with open(train_log_filename, "a") as f:
                f.write("{}, {}, {}\n".format(i, total_loss / train_step,
                                              val_losses / len(mixed_stft)))
            if best_val_loss > val_losses / (len(mixed_stft)):
                best_val_loss = val_losses / (len(mixed_stft))
                stop = 0
            if best_val_loss < val_losses / (len(mixed_stft)):
                stop += 1
            if stop >= 2 and i >= 10000:
                break
            # if i % 10000==0:
            torch.save(
                {
                    'epoch': i,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, save_dir + "model_" + str(i) + ".pth")
            wavs_mono_valid, wavs_src1_valid, wavs_src2_valid = None, None, None
            mixed_stft, s1_stft, s2_stft = None, None, None
            random_wavs = np.random.choice(len(wav_filenames_train),
                                           len(wav_filenames_train),
                                           replace=False)
            wavs_mono_train, wavs_src1_train, wavs_src2_train = load_wavs(
                filenames=wav_filenames_train[random_wavs], sr=mir1k_sr)

            stfts_mono_train, stfts_src1_train, stfts_src2_train = wavs_to_specs(
                wavs_mono=wavs_mono_train,
                wavs_src1=wavs_src1_train,
                wavs_src2=wavs_src2_train,
                n_fft=n_fft,
                hop_length=hop_length)
            train_step = 1.
            total_loss = 0.
    # nni.report_final_result(val_losses/val_len)

    torch.save(
        {
            'epoch': i,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, save_dir + "final_model.pth")