def __init__(
        self,
        frame_length=3,
        sample_rate=44100,
        num_worker=1,
        MUSDB18_PATH="",
        BIG_DATA=False,
        additional_background_data=[],
        additional_vocal_data=[],
    ):
        np.random.seed(1)
        self.sample_rate = sample_rate
        self.wh = WaveHandler()
        self.BIG_DATA = BIG_DATA
        self.music_folders = []
        for each in additional_background_data:
            self.music_folders += self.readList(each)
        self.vocal_folders = []
        for each in additional_vocal_data:
            self.vocal_folders += self.readList(each)
        self.frame_length = frame_length
        self.bac_file_num = len(self.music_folders)
        self.voc_file_num = len(self.vocal_folders)

        self.num_worker = num_worker
        self.mus = musdb.DB(MUSDB18_PATH, is_wav=True, subsets='train')
        self.pitch_shift_high = [2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
        self.pitch_shift_low = [-2.0, -3.0, -4.0, -5.0, -6.0, -7.0]
def read_wav(estimate_fname, target_fname):
    from util.wave_util import WaveHandler
    wh = WaveHandler()
    estimate = wh.read_wave(estimate_fname, channel=1)
    truth = wh.read_wave(target_fname, channel=2)
    min_length = min(estimate.shape[0], truth.shape[0])
    estimate, truth = estimate[:min_length].reshape(
        (1, min_length, 1)), truth[:min_length].reshape((1, min_length, 1))
    return estimate, truth
Esempio n. 3
0
def get_total_time_in_folder(path):
    if (not path[-1] == '/'):
        raise ValueError("Error: path should end with /")
    wh = WaveHandler()
    total_time = 0
    for cnt, file in enumerate(os.listdir(path)):
        total_time += wh.get_duration(path + file)
    print("total: ")
    print(total_time, "s")
    print(total_time / 60, "min")
    print(total_time / 3600, "h")
Esempio n. 4
0
def delete_unproper_training_data(path):
    if (not path[-1] == '/'):
        raise ValueError("Error: path should end with /")
    wh = WaveHandler()
    files = os.listdir(path)
    for cnt, each in enumerate(files):
        file_pth = path + each
        if (file_pth.split('.')[-1] == 'wav'):
            judge = wh.get_channels_sampwidth_and_sample_rate(file_pth)
            if (not judge[0]):
                print(each, "Unproper! params:", judge[1])
                os.remove(file_pth)
Esempio n. 5
0
def eval_spleeter():
    from evaluate.sdr import sdr_evaluate
    wh = WaveHandler()
    from evaluate.si_sdr_numpy import sdr, si_sdr
    output_test_pth = Config.datahub_root + "musdb18hq/spleeter_out/test/"
    mus_test_pth = Config.datahub_root + "musdb18hq/test/"

    vocal = []
    background = []
    #
    # for each in os.listdir(mus_train_pth):
    #     mus_dir = mus_train_pth + each + "/"
    #     out_dir = output_train_pth + each + "/output/combined/"
    #     # try:
    #     mus_vocal = wh.read_wave(mus_dir + "vocals.wav")
    #     mus_background = wh.read_wave(mus_dir + "background.wav")
    #     output_vocal = wh.read_wave(out_dir + "vocals.wav")
    #     output_background = wh.read_wave(out_dir + "accompaniment.wav")
    #
    #     output_vocal, mus_vocal = unify(output_vocal, mus_vocal)
    #     output_background, mus_background = unify(output_background, mus_background)
    #
    #     v = sdr(output_vocal, mus_vocal)
    #     b = sdr(output_background, mus_background)
    #     vocal.append(v)
    #     background.append(b)
    #     print("FileName: ",each, "\tSDR-VOCAL: ",v,"SDR-BACKGROUND: " ,b)

    for each in sorted(os.listdir(musdb_test_pth)):
        mus_dir = mus_test_pth + each + "/"
        out_dir = output_test_pth + each + "/output/combined/"
        # try:
        mus_vocal = wh.read_wave(mus_dir + "vocals.wav")
        mus_background = wh.read_wave(mus_dir + "background.wav")
        output_vocal = wh.read_wave(out_dir + "vocals.wav")
        output_background = wh.read_wave(out_dir + "accompaniment.wav")

        output_vocal, mus_vocal = unify(output_vocal, mus_vocal)
        output_background, mus_background = unify(output_background,
                                                  mus_background)

        v = sdr(output_vocal, mus_vocal)
        b = sdr(output_background, mus_background)
        vocal.append(v)
        background.append(b)
        print("FileName: ", each, "\tSDR-BACKGROUND: ", b, "\tSDR-VOCAL: ", v)
        # except:
        #     print("Error",each)
    print("AVG-SDR-VOCAL", sum(vocal) / len(vocal))
    print("AVG-SDR-BACKGROUND", sum(background) / len(background))
Esempio n. 6
0
def seg_data():
    wh = WaveHandler()
    dir = Config.datahub_root + "song/441_song_data/"
    seg_dir = Config.datahub_root + "song/seg_song_data/"
    for cnt, fname in enumerate(os.listdir(dir)):
        print("Doing segmentation on ", fname + "...")
        unseg_f = dir + fname
        data = wh.read_wave(unseg_f, channel=2)
        length = data.shape[0]
        for start in np.linspace(0, 0.95, 20):
            seg_data = data[int(start * length):int((start + 0.05) * length)]
            wh.save_wave(seg_data,
                         seg_dir + fname.split('.')[-2] + "_" +
                         str('%.2f' % start) + ".wav",
                         channels=2)
Esempio n. 7
0
def get_total_time_in_txt(txtpath):
    wh = WaveHandler()
    cnt = 0
    files = readList(txtpath)
    total_time = 0
    for file in files:
        try:
            total_time += wh.get_duration(file)
            cnt += 1
        except:
            print("error:", file)

    # print(total_time,"s")
    # print(total_time/60,"min")
    print(
        txtpath.split('/')[-1].split('.')[-2], ",",
        str(total_time / 3600) + "h,", cnt, ", " + txtpath)
    return total_time / 3600, cnt
Esempio n. 8
0
 def __init__(
     self,
     frame_length=Config.frame_length,
     sample_rate=Config.sample_rate,
     num_worker=Config.num_workers,
     sampleNo=20000,
     mu=Config.mu,
     empty_every_n=50,
     sigma=Config.sigma,
     alpha_low=Config.alpha_low,
     alpha_high=Config.
     alpha_high  # If alpha_high get a value greater than 0.5, it would have probability to overflow
 ):
     np.random.seed(1)
     self.sample_rate = sample_rate
     self.frame_length = frame_length
     # self.music_folders = self.readList(Config.musdb_train_background)
     self.music_folders = []
     for each in Config.background_data:
         self.music_folders += self.readList(each)
     self.vocal_folders = []
     for each in Config.vocal_data:
         self.vocal_folders += self.readList(each)
     # prev_data_size = len(self.vocal_folders)
     # if(Config.exclude_list != ""):
     #     for each in self.readList(Config.exclude_list):
     #         self.vocal_folders.remove(each)
     # print(prev_data_size-len(self.vocal_folders)," songs were removed from vocal datasets")
     self.sample_length = int(self.sample_rate * self.frame_length)
     self.cnt = 0
     self.data_counter = 0
     self.empty_every_n = empty_every_n
     self.sampleNo = sampleNo
     self.num_worker = num_worker
     self.wh = WaveHandler()
     # This alpha is to balance the energy between vocal and background
     # Also, this alpha is used to simulate different energy leval between vocal and background
     self.normal_distribution = np.random.normal(mu, sigma, sampleNo)
     self.normal_distribution = self.normal_distribution[
         self.normal_distribution > alpha_low]
     self.normal_distribution = self.normal_distribution[
         self.normal_distribution < alpha_high]
     self.sampleNo = self.normal_distribution.shape[0]
Esempio n. 9
0
 def __init__(self):
     self.vad = webrtcvad.Vad()
     self.wh = WaveHandler()
     self.kernal = np.ones(44100*1)/4410*5
     self.threshold = 20
Esempio n. 10
0
    Config.trail_name + "/" + "data_background.txt")
write_list(
    Config.vocal_data, Config.project_root + "saved_models/" +
    Config.trail_name + "/" + "data_vocal.txt")

# Cache for data
freq_bac_loss_cache = []
freq_voc_loss_cache = []
freq_cons_loss_cache = []

best_sdr_vocal, best_sdr_background = Config.best_sdr_vocal, Config.best_sdr_background

# exclude_dict = load_json("config/json/ExcludeData.json")
# exclude_start_point,vocal_sisdr_min,vocal_sisdr_max,background_sisdr_min,background_sisdr_max = exclude_dict["start_exclude_point"],exclude_dict["vocal_sisdr"][0],exclude_dict["vocal_sisdr"][1],exclude_dict["background_sisdr"][0],exclude_dict["background_sisdr"][1]

wh = WaveHandler()
loss = torch.nn.L1Loss()

if (not Config.start_point == 0):
    model = torch.load(Config.load_model_path + "/model" +
                       str(Config.start_point) + ".pkl",
                       map_location=Config.device)
else:
    if (Config.split_band):
        model = Spleeter(channels=2,
                         unet_inchannels=2 * Config.subband,
                         unet_outchannels=2 * Config.subband).cuda(
                             Config.device)
    else:
        model = Spleeter(channels=2, unet_inchannels=2,
                         unet_outchannels=2).cuda(Config.device)