Example #1
0
File: main.py Project: rmacas/mlg
def inference(config, cla):

    if cla.batch_size is not None:
        batch_size = int(cla.batch_size)
    else:
        batch_size = config['training']['batch_size']

    if cla.target_field_length is not None:
        cla.target_field_length = int(cla.target_field_length)

    if not bool(cla.one_shot):
        model = models.DenoisingWavenet(config, target_field_length=cla.target_field_length,
                                        load_checkpoint=cla.load_checkpoint, print_model_summary=cla.print_model_summary)
        print('Performing inference..')
    else:
        print('Performing one-shot inference..')

    samples_folder_path = os.path.join(config['training']['path'], 'samples')
    output_folder_path = get_valid_output_folder_path(samples_folder_path)

    #If input_path is a single wav file, then set filenames to single element with wav filename
    if cla.noisy_input_path.endswith('.wav'):
        filenames = [cla.noisy_input_path.rsplit('/', 1)[-1]]
        cla.noisy_input_path = cla.noisy_input_path.rsplit('/', 1)[0] + '/'
        if cla.clean_input_path is not None:
            cla.clean_input_path = cla.clean_input_path.rsplit('/', 1)[0] + '/'
    else:
        if not cla.noisy_input_path.endswith('/'):
            cla.noisy_input_path += '/'
        filenames = [filename for filename in os.listdir(cla.noisy_input_path) if filename.endswith('.wav')]

    clean_input = None
    for filename in filenames:
        noisy_input = util.load_wav(cla.noisy_input_path + filename, config['dataset']['sample_rate'])
        if cla.clean_input_path is not None:
            if not cla.clean_input_path.endswith('/'):
                cla.clean_input_path += '/'
            clean_input = util.load_wav(cla.clean_input_path + filename, config['dataset']['sample_rate'])

        input = {'noisy': noisy_input, 'clean': clean_input}

        output_filename_prefix = filename[0:-4] + '_'

        if config['model']['condition_encoding'] == 'one_hot':
            condition_input = util.one_hot_encode(int(cla.condition_value), 29)[0]
        else:
            condition_input = util.binary_encode(int(cla.condition_value), 29)[0]

        if bool(cla.one_shot):
            if len(input['noisy']) % 2 == 0:  # If input length is even, remove one sample
                input['noisy'] = input['noisy'][:-1]
                if input['clean'] is not None:
                    input['clean'] = input['clean'][:-1]
            model = models.DenoisingWavenet(config, load_checkpoint=cla.load_checkpoint, input_length=len(input['noisy']), print_model_summary=cla.print_model_summary)

        print("Denoising: " + filename)
        denoise.denoise_sample(model, input, condition_input, batch_size, output_filename_prefix,
                                            config['dataset']['sample_rate'], output_folder_path)
Example #2
0
def test(config, cla):

    if cla.batch_size is not None:
        batch_size = int(cla.batch_size)
    else:
        batch_size = config['training']['batch_size']

    if cla.target_field_length is not None:
        cla.target_field_length = int(cla.target_field_length)

    model = models.DenoisingWavenet(
        config,
        target_field_length=cla.target_field_length,
        load_checkpoint=cla.load_checkpoint,
        print_model_summary=cla.print_model_summary)

    samples_folder_path = os.path.join(config['training']['path'], 'samples')
    output_folder_path = get_valid_output_folder_path(samples_folder_path)

    if not cla.noisy_input_path.endswith('/'):
        cla.noisy_input_path += '/'
    filenames = [
        filename for filename in os.listdir(cla.noisy_input_path)
        if filename.endswith('.wav')
    ]

    with open('spk_info.json') as f:
        spk_info = json.load(f)

    sdr = []
    n_output = config['training']['n_output'] if 'n_output' in config[
        'training'] else 2
    n_speaker = config['training']['n_speaker'] if 'n_speaker' in config[
        'training'] else 2
    gender_stat = {
        'ch' + str(i + 1): {
            'M': 0,
            'F': 0
        }
        for i in range(n_output)
    }
    # gender_stat = {'ch1':{'M':0,'F':0}, 'ch2':{'M':0,'F':0}}

    for filename in filenames:
        noisy_input = util.load_wav(cla.noisy_input_path + filename,
                                    config['dataset']['sample_rate'])
        if cla.clean_input_path is not None:
            if not cla.clean_input_path.endswith('/'):
                cla.clean_input_path += '/'
            clean_input_1 = util.load_wav(
                cla.clean_input_path + 's1/' + filename,
                config['dataset']['sample_rate'])
            clean_input_2 = util.load_wav(
                cla.clean_input_path + 's2/' + filename,
                config['dataset']['sample_rate'])
        input = {
            'noisy': noisy_input,
            'clean_1': clean_input_1,
            'clean_2': clean_input_2
        }

        output_filename_prefix = filename[0:-4] + '_'
        spk1 = output_filename_prefix.split('_')[0][:3]
        spk2 = output_filename_prefix.split('_')[2][:3]
        spk_name = [spk1, spk2]
        spk_gender = [spk_info[spk1], spk_info[spk2]]

        # print("Denoising: " + filename).
        condition_input = None
        print(filename)
        _sdr, ch_gender, pit_idx = denoise.denoise_sample(
            model,
            input,
            condition_input,
            batch_size,
            output_filename_prefix,
            config['dataset']['sample_rate'],
            n_speaker,
            n_output,
            output_folder_path,
            spk_gender=spk_gender,
            use_pit=cla.use_pit,
            pad=cla.zero_pad)
        # print('sdr = %f, %f' %(_sdr[0],_sdr[1]))
        if spk_gender[0] == 'F' and spk_gender[1] == 'M':
            for i in range(1, -1, -1):
                print('{} {}: sdr={}, idx={}'.format(spk_gender[i],
                                                     spk_name[i], _sdr[i],
                                                     pit_idx[i]))
        else:
            for i in range(2):
                print('{} {}: sdr={}, idx={}'.format(spk_gender[i],
                                                     spk_name[i], _sdr[i],
                                                     pit_idx[i]))
        # print(ch_gender)
        # for ch, stat in ch_gender.items():
        # for gen, num in stat.items():
        # gender_stat[ch][gen] += num
        sdr.append(_sdr)
    sdr = np.array(sdr)
    print('Testing SDR:', np.mean(sdr))
    print(gender_stat)
def inference(config, cla):
    from collections import namedtuple
    MyStruct = namedtuple("MyStruct", "rescaling rescaling_max multiProcFlag")
    hparams = MyStruct(rescaling=True,
                       rescaling_max=0.999,
                       multiProcFlag=False)

    outputfolder = 'bbbbbb'
    os.makedirs(outputfolder, exist_ok=True)

    import pickle
    with open('Data/TestSignalList500.pkl',
              'rb') as f:  # Python 3: open(..., 'rb')
        sequence_i_save, interf_i_save = pickle.load(f)

    # This is for statistical analysis
    SampleN = 100
    random.seed(66666666)

    # # This is for one example
    # SampleN = 1
    # temp = 66
    # sequence_i_save = [sequence_i_save[temp]]
    # interf_i_save = [interf_i_save[temp]]

    # Instantiate Model
    model = models.DenoisingWavenet(
        config,
        load_checkpoint=cla.load_checkpoint,
        print_model_summary=cla.print_model_summary)
    model.model.load_weights(cla.load_checkpoint)

    from DataGenerator import dataGenBig
    dg = dataGenBig(model, seedNum=123456789, verbose=False)
    s2_scale = 0.5

    for sample_i in range(
            SampleN):  # SampleN groups of mixtures and separated signals.
        print("Sample number {}".format(sample_i + 1))
        sequence_i = sequence_i_save[sample_i]
        interf_i = interf_i_save[sample_i]

        target_path = dg.target_test[sequence_i]
        interf_path = dg.interf_test[interf_i]
        print(target_path, '\n', interf_path)

        # generate the mixture

        s1 = np.load(target_path)  # read in the target
        s2_original, _ = librosa.load(
            interf_path
        )  # both the target and the interference are sampled at 22050 Hz
        L = len(s1)
        s2 = s2_original.copy()
        while len(s2) < L:
            s2 = np.concatenate((s2, s2_original), axis=0)

        s2 = s2[:L]

        if (s1 is None) | (s2 is None):
            print("Data loading fail")
            sys.exit()

        # first normalise s2
        s2 = s2 * (s2_scale / max(abs(s2)))

        mixture = s1 + s2
        if hparams.rescaling:
            scale = 1 / max(abs(mixture)) * hparams.rescaling_max
        else:
            scale = 1 / max(
                abs(mixture)
            ) * 0.99  # normalise the mixture thus the maximum magnitude = 0.99

        mixture *= scale

        input = {'noisy': mixture, 'clean': None}

        print("Denoising: " + target_path.split('/')[-1])
        batch_size = 10
        if config['model']['condition_encoding'] == 'one_hot':
            condition_input = util.one_hot_encode(int(cla.condition_value),
                                                  29)[0]
        else:
            condition_input = util.binary_encode(int(cla.condition_value),
                                                 29)[0]

        dst_wav_name = "Ind_{}_est_wavenet_".format(sample_i)
        denoise.denoise_sample(model, input, condition_input, batch_size,
                               dst_wav_name, 22050, outputfolder)