Esempio n. 1
0
def generate(file_list, data_dir, output_dir, context_len=32, stats=None,
             base_model_path='./pls.model', gan_model_path='./noise_gen.model'):
    
    pulse_model = time_glot_model(timesteps=context_len)
    gan_model = generator()
    
    pulse_model.compile(loss='mse', optimizer="adam")
    gan_model.compile(loss='mse', optimizer="adam")

    pulse_model.load_weights(base_model_path)
    gan_model.load_weights(gan_model_path)

    for data in nc_data_provider(file_list, data_dir, input_only=True, 
                                 context_len=context_len):
        for fname, ac_data in data.iteritems():
            print fname
                                              
            pls_pred, _ = pulse_model.predict([ac_data])
            noise = np.random.randn(pls_pred.shape[0], pls_pred.shape[1])
            pls_gan, _ = gan_model.predict([pls_pred, noise])
            
            out_file = os.path.join(args.output_dir, fname + '.pls')
            pls_gan.astype(np.float32).tofile(out_file)

            out_file = os.path.join(args.output_dir, fname + '.pls_nonoise')
            pls_pred.astype(np.float32).tofile(out_file)
Esempio n. 2
0
def train_pls_model(BATCH_SIZE, data_dir, file_list, context_len=32, max_files=30):
    
    no_epochs = 20
    max_epochs_no_improvement = 5

    timesteps = context_len

    optim = adam(lr=0.0001)
    pls_model = time_glot_model(timesteps=timesteps)
    pls_model.compile(loss=['mse', 'mse'], loss_weights=[1.0, 0.0], optimizer=optim) # disregard fft loss

    fft_mod = fft_model()

    patience = max_epochs_no_improvement
    best_val_loss = 1e20
    for epoch in range(no_epochs):
        print("Pre-train epoch is", epoch)
        epoch_error = [0.0, 0.0]
        total_batches = 0
        val_data = []
        for data in nc_data_provider(file_list, data_dir,
                                     max_files=max_files, context_len=timesteps):

            if len(val_data) == 0:
                val_data = data
                print("using data subset for validation")
                continue

            X_train = data[0]
            Y_train = data[1]

            no_batches = int(X_train.shape[0] / BATCH_SIZE)                
            print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))

            # shuffle data
            ind = np.random.permutation(X_train.shape[0])
            X_train = X_train[ind]
            Y_train = Y_train[ind]
            for index in range(int(X_train.shape[0] / BATCH_SIZE)):
                x_feats_batch = X_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
                y_feats_batch = Y_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]

                x_feats_batch_fft = fft_mod.predict(x_feats_batch)
                
                d = pls_model.train_on_batch([y_feats_batch],
                                             [x_feats_batch, x_feats_batch_fft])
                    
                epoch_error += d

                if  (index + total_batches) % 500 == 0:

                    print("pre-training batch %d, wave loss: %f, spec loss %f" %
                          (index+total_batches, d[0], d[1]))

                    wave, spec = pls_model.predict([y_feats_batch])

                    wav_gen = wave[0,:]
                    wav_ref = x_feats_batch[0,:]
                    wavs = np.array([wav_ref, wav_gen])
                    plot_feats(wavs, epoch, index+total_batches, ext='.wave-pls')

                    spec_gen = spec[0,:]
                    spec_ref = x_feats_batch_fft[0,:]
                    specs = np.array([spec_ref, spec_gen])
                    plot_feats(specs, epoch, index+total_batches, ext='.spec-pls')
                    
            total_batches += no_batches                

        epoch_error[0] /= total_batches
        epoch_error[1] /= total_batches

        val_spec = fft_mod.predict(val_data[0])
        val_loss = pls_model.evaluate([val_data[1]],
                                      [val_data[0], val_spec],
                                      batch_size=BATCH_SIZE)
        
        print("epoch %d validation wave loss: %f ,spec loss %f \n" %
              (epoch, val_loss[0], val_loss[1]))

        print("epoch %d training wave loss: %f, spec loss %f \n" %
              (epoch, epoch_error[0], epoch_error[1]))
        
        # only on wave loss
        if val_loss[0] < best_val_loss:
            best_val_loss = val_loss[0]
            patience = max_epochs_no_improvement
            pls_model.save_weights('./pls.model')
        else:
            patience -= 1

        if patience == 0:
            break

    print "Finished training" 
Esempio n. 3
0
def train_noise_model(BATCH_SIZE, data_dir, file_list, save_weights=False,
                      context_len=32, max_files=30, stats=None):

    no_epochs = 15
    
    timesteps = context_len

    optim_container = adam(lr=1e-4)
    optim_discriminator = SGD(lr=1e-5)

    fft_mod = fft_model()
    pls_model = time_glot_model(timesteps=timesteps)

    pls_model.compile(loss=['mse','mse'], loss_weights=[1.0, 1.0], optimizer='adam')
    pls_model.load_weights("./pls.model")

    disc_model = discriminator()
    gen_model = generator()
    disc_on_gen = gan_container(gen_model, disc_model)
 
    gen_model.compile(loss='mse', optimizer="adam")

    # use peek adversarial and peek mse loss for training generator
    disc_model.trainable = False
    disc_on_gen.compile(loss=['mse','mse'], loss_weights=[1.0, 1.0], optimizer=optim_container) 

    # don't use peek loss for discriminator
    disc_model.trainable = True
    disc_model.compile(loss=['mse','mse'], loss_weights=[1.0, 0.0], optimizer=optim_discriminator) 

    print "Discriminator model:"
    print disc_model.summary()
    print "Generator model:"
    print gen_model.summary()
    print "Joint model:"
    print disc_on_gen.summary()

    label_fake = np.zeros((BATCH_SIZE, 1), dtype=np.float32)
    label_real = np.ones((BATCH_SIZE, 1), dtype=np.float32)

    # train residual GAN with FFT     
    for epoch in range(no_epochs):
        print("Epoch is", epoch)

        epoch_error = 0
        total_batches = 0
      
        for data in nc_data_provider(file_list, data_dir,
                                     max_files=max_files, context_len=timesteps):   

            X_train = data[0]
            Y_train = data[1]

            pls_len = X_train.shape[1]

            no_batches = int(X_train.shape[0] / BATCH_SIZE)    

            # shuffle data
            ind = np.random.permutation(X_train.shape[0])
            X_train = X_train[ind]
            Y_train = Y_train[ind]
            for index in range(int(X_train.shape[0] / BATCH_SIZE)):
                x_feats_batch = X_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
                y_feats_batch = Y_train[
                    index * BATCH_SIZE:(index + 1) * BATCH_SIZE]

                x_pred_batch, x_pred_batch_fft = pls_model.predict([y_feats_batch])
                                
                pls_pred = x_pred_batch
                pls_real = x_feats_batch

                # smoothing windows to prevent edge effects
                pls_pred *= smoothwin
                pls_real *= smoothwin

                # evaluate target fft
                fft_real = fft_mod.predict(pls_real)

                noise = np.random.randn(BATCH_SIZE, pls_len)

                # train generator through discriminator
                _, peek_real = disc_model.predict([pls_real, fft_real])
                disc_model.trainable = False
                loss_g = disc_on_gen.train_on_batch([pls_pred, noise], [label_real, peek_real])
 
                noise = np.random.randn(BATCH_SIZE, pls_len)

                # train discriminator with real data
                disc_model.trainable = True
                loss_dr = disc_model.train_on_batch([pls_real, fft_real], [label_real, peek_real])

                # train discriminator with fake data
                pls_fake, fft_fake = gen_model.predict([pls_pred, noise])
                loss_df = disc_model.train_on_batch([pls_fake, fft_fake], [label_fake, peek_real])
        
                if (index + total_batches) % 500 == 0:

                    print("training batch %d, G loss: %f, D loss (real): %f, D loss (fake): %f" %
                          (index + total_batches, loss_g[0], loss_dr[0], loss_df[0]))

                if (index + total_batches) % 500 == 0:

                    wav_ref = pls_real[0,:]
                    wav_gen = pls_pred[0,:]
                    wav_noised = pls_fake[0,:]
                    wavs = np.array([wav_ref, wav_gen, wav_noised])
                    plot_feats(wavs, epoch, index+total_batches, ext='.wave')
                 
            total_batches += no_batches

        gen_model.save_weights('./models/noise_gen_epoch' + str(epoch) + '.model')

    print "Finished noise model training"