示例#1
0
文件: main.py 项目: rmacas/mlg
def inference(config, cla):

    if cla.batch_size is not None:
        batch_size = int(cla.batch_size)
    else:
        batch_size = config['training']['batch_size']

    if cla.target_field_length is not None:
        cla.target_field_length = int(cla.target_field_length)

    if not bool(cla.one_shot):
        model = models.DenoisingWavenet(config, target_field_length=cla.target_field_length,
                                        load_checkpoint=cla.load_checkpoint, print_model_summary=cla.print_model_summary)
        print('Performing inference..')
    else:
        print('Performing one-shot inference..')

    samples_folder_path = os.path.join(config['training']['path'], 'samples')
    output_folder_path = get_valid_output_folder_path(samples_folder_path)

    #If input_path is a single wav file, then set filenames to single element with wav filename
    if cla.noisy_input_path.endswith('.wav'):
        filenames = [cla.noisy_input_path.rsplit('/', 1)[-1]]
        cla.noisy_input_path = cla.noisy_input_path.rsplit('/', 1)[0] + '/'
        if cla.clean_input_path is not None:
            cla.clean_input_path = cla.clean_input_path.rsplit('/', 1)[0] + '/'
    else:
        if not cla.noisy_input_path.endswith('/'):
            cla.noisy_input_path += '/'
        filenames = [filename for filename in os.listdir(cla.noisy_input_path) if filename.endswith('.wav')]

    clean_input = None
    for filename in filenames:
        noisy_input = util.load_wav(cla.noisy_input_path + filename, config['dataset']['sample_rate'])
        if cla.clean_input_path is not None:
            if not cla.clean_input_path.endswith('/'):
                cla.clean_input_path += '/'
            clean_input = util.load_wav(cla.clean_input_path + filename, config['dataset']['sample_rate'])

        input = {'noisy': noisy_input, 'clean': clean_input}

        output_filename_prefix = filename[0:-4] + '_'

        if config['model']['condition_encoding'] == 'one_hot':
            condition_input = util.one_hot_encode(int(cla.condition_value), 29)[0]
        else:
            condition_input = util.binary_encode(int(cla.condition_value), 29)[0]

        if bool(cla.one_shot):
            if len(input['noisy']) % 2 == 0:  # If input length is even, remove one sample
                input['noisy'] = input['noisy'][:-1]
                if input['clean'] is not None:
                    input['clean'] = input['clean'][:-1]
            model = models.DenoisingWavenet(config, load_checkpoint=cla.load_checkpoint, input_length=len(input['noisy']), print_model_summary=cla.print_model_summary)

        print("Denoising: " + filename)
        denoise.denoise_sample(model, input, condition_input, batch_size, output_filename_prefix,
                                            config['dataset']['sample_rate'], output_folder_path)
def inference(config, cla):
    from collections import namedtuple
    MyStruct = namedtuple("MyStruct", "rescaling rescaling_max multiProcFlag")
    hparams = MyStruct(rescaling=True,
                       rescaling_max=0.999,
                       multiProcFlag=False)

    outputfolder = 'bbbbbb'
    os.makedirs(outputfolder, exist_ok=True)

    import pickle
    with open('Data/TestSignalList500.pkl',
              'rb') as f:  # Python 3: open(..., 'rb')
        sequence_i_save, interf_i_save = pickle.load(f)

    # This is for statistical analysis
    SampleN = 100
    random.seed(66666666)

    # # This is for one example
    # SampleN = 1
    # temp = 66
    # sequence_i_save = [sequence_i_save[temp]]
    # interf_i_save = [interf_i_save[temp]]

    # Instantiate Model
    model = models.DenoisingWavenet(
        config,
        load_checkpoint=cla.load_checkpoint,
        print_model_summary=cla.print_model_summary)
    model.model.load_weights(cla.load_checkpoint)

    from DataGenerator import dataGenBig
    dg = dataGenBig(model, seedNum=123456789, verbose=False)
    s2_scale = 0.5

    for sample_i in range(
            SampleN):  # SampleN groups of mixtures and separated signals.
        print("Sample number {}".format(sample_i + 1))
        sequence_i = sequence_i_save[sample_i]
        interf_i = interf_i_save[sample_i]

        target_path = dg.target_test[sequence_i]
        interf_path = dg.interf_test[interf_i]
        print(target_path, '\n', interf_path)

        # generate the mixture

        s1 = np.load(target_path)  # read in the target
        s2_original, _ = librosa.load(
            interf_path
        )  # both the target and the interference are sampled at 22050 Hz
        L = len(s1)
        s2 = s2_original.copy()
        while len(s2) < L:
            s2 = np.concatenate((s2, s2_original), axis=0)

        s2 = s2[:L]

        if (s1 is None) | (s2 is None):
            print("Data loading fail")
            sys.exit()

        # first normalise s2
        s2 = s2 * (s2_scale / max(abs(s2)))

        mixture = s1 + s2
        if hparams.rescaling:
            scale = 1 / max(abs(mixture)) * hparams.rescaling_max
        else:
            scale = 1 / max(
                abs(mixture)
            ) * 0.99  # normalise the mixture thus the maximum magnitude = 0.99

        mixture *= scale

        input = {'noisy': mixture, 'clean': None}

        print("Denoising: " + target_path.split('/')[-1])
        batch_size = 10
        if config['model']['condition_encoding'] == 'one_hot':
            condition_input = util.one_hot_encode(int(cla.condition_value),
                                                  29)[0]
        else:
            condition_input = util.binary_encode(int(cla.condition_value),
                                                 29)[0]

        dst_wav_name = "Ind_{}_est_wavenet_".format(sample_i)
        denoise.denoise_sample(model, input, condition_input, batch_size,
                               dst_wav_name, 22050, outputfolder)