def train_model(feature_size, hidden_size, init_window_size, generator_model, generator_gan_optimizer, generator_tf_optimizer, discriminator_feature_model, discriminator_output_model, discriminator_gan_optimizer, num_epochs, model_name): # generator updater print 'COMPILING GAN UPDATE FUNCTION ' gan_updater = set_gan_update_function(generator_model=generator_model, discriminator_feature_model=discriminator_feature_model, discriminator_output_model=discriminator_output_model, generator_optimizer=generator_gan_optimizer, discriminator_optimizer=discriminator_gan_optimizer, generator_grad_clipping=.0, discriminator_grad_clipping=.0) print 'COMPILING TF UPDATE FUNCTION ' tf_updater = set_tf_update_function(generator_model=generator_model, generator_optimizer=generator_tf_optimizer, generator_grad_clipping=.0) # evaluator print 'COMPILING EVALUATION FUNCTION ' evaluator = set_evaluation_function(generator_model=generator_model) # sample generator print 'COMPILING SAMPLING FUNCTION ' sample_generator = set_sample_function(generator_model=generator_model) print 'READ RAW WAV DATA' _, train_raw_data = wavfile.read('/data/lisatmp4/taesup/data/YouTubeAudio/XqaJ2Ol5cC4.wav') valid_raw_data = train_raw_data[160000000:] train_raw_data = train_raw_data[:160000000] train_raw_data = train_raw_data[2000:] train_raw_data = (train_raw_data/(1.15*2.**13)).astype(floatX) valid_raw_data = (valid_raw_data/(1.15*2.**13)).astype(floatX) num_train_total_steps = train_raw_data.shape[0] num_valid_total_steps = valid_raw_data.shape[0] batch_size = 64 num_valid_sequences = num_valid_total_steps/(feature_size*init_window_size)-1 valid_source_data = valid_raw_data[:num_valid_sequences*(feature_size*init_window_size)] valid_source_data = valid_source_data.reshape((num_valid_sequences, init_window_size, feature_size)) valid_target_data = valid_raw_data[feature_size:feature_size+num_valid_sequences*(feature_size*init_window_size)] valid_target_data = valid_target_data.reshape((num_valid_sequences, init_window_size, feature_size)) valid_raw_data = None num_seeds = 10 valid_shuffle_idx = np_rng.permutation(num_valid_sequences) valid_source_data = valid_source_data[valid_shuffle_idx] valid_target_data = valid_target_data[valid_shuffle_idx] valid_seed_data = valid_source_data[:num_seeds][0][:] valid_source_data = numpy.swapaxes(valid_source_data, axis1=0, axis2=1) valid_target_data = numpy.swapaxes(valid_target_data, axis1=0, axis2=1) num_valid_batches = num_valid_sequences/batch_size print 'NUM OF VALID BATCHES : ', num_valid_sequences/batch_size best_valid = 10000. print 'START TRAINING' # for each epoch tf_mse_list = [] tf_generator_grad_list = [] gan_generator_grad_list = [] gan_generator_cost_list = [] gan_discriminator_grad_list = [] gan_discriminator_cost_list = [] gan_true_score_list = [] gan_false_score_list = [] gan_mse_list = [] valid_mse_list = [] train_batch_count = 0 for e in xrange(num_epochs): window_size = init_window_size + 5*e sequence_size = feature_size*window_size last_seq_idx = num_train_total_steps-(sequence_size+feature_size) train_seq_orders = np_rng.permutation(last_seq_idx) train_seq_orders = train_seq_orders[:last_seq_idx-last_seq_idx%batch_size] train_seq_orders = train_seq_orders.reshape((-1, batch_size)) print 'NUM OF TRAIN BATCHES : ', train_seq_orders.shape[0] # for each batch for batch_idx, batch_info in enumerate(train_seq_orders): # source data train_source_idx = batch_info.reshape((batch_size, 1)) + numpy.repeat(numpy.arange(sequence_size).reshape((1, sequence_size)), batch_size, axis=0) train_source_data = train_raw_data[train_source_idx] train_source_data = train_source_data.reshape((batch_size, window_size, feature_size)) train_source_data = numpy.swapaxes(train_source_data, axis1=0, axis2=1) # target data train_target_idx = train_source_idx + feature_size train_target_data = train_raw_data[train_target_idx] train_target_data = train_target_data.reshape((batch_size, window_size, feature_size)) train_target_data = numpy.swapaxes(train_target_data, axis1=0, axis2=1) # tf update tf_update_output = tf_updater(train_source_data, train_target_data) tf_square_error = tf_update_output[0].mean() tf_generator_grad_norm = tf_update_output[1] # gan update gan_update_output = gan_updater(train_source_data, train_target_data) generator_gan_cost = gan_update_output[0].mean() discriminator_gan_cost = gan_update_output[1].mean() discriminator_true_score = gan_update_output[2].mean() discriminator_false_score = gan_update_output[3].mean() gan_square_error = gan_update_output[4].mean() gan_generator_grad_norm = gan_update_output[5] gan_discriminator_grad_norm = gan_update_output[6] train_batch_count += 1 tf_generator_grad_list.append(tf_generator_grad_norm) tf_mse_list.append(tf_square_error) gan_generator_grad_list.append(gan_generator_grad_norm) gan_generator_cost_list.append(generator_gan_cost) gan_discriminator_grad_list.append(gan_discriminator_grad_norm) gan_discriminator_cost_list.append(discriminator_gan_cost) gan_true_score_list.append(discriminator_true_score) gan_false_score_list.append(discriminator_false_score) gan_mse_list.append(gan_square_error) if train_batch_count%10==0: print '============{}_LENGTH{}============'.format(model_name, window_size) print 'epoch {}, batch_cnt {} => TF generator mse cost {}'.format(e, train_batch_count, tf_mse_list[-1]) print 'epoch {}, batch_cnt {} => GAN generator mse cost {}'.format(e, train_batch_count, gan_mse_list[-1]) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => GAN generator cost {}'.format(e, train_batch_count, gan_generator_cost_list[-1]) print 'epoch {}, batch_cnt {} => GAN discriminator cost {}'.format(e, train_batch_count, gan_discriminator_cost_list[-1]) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => GAN input score {}'.format(e, train_batch_count, gan_true_score_list[-1]) print 'epoch {}, batch_cnt {} => GAN sample score {}'.format(e, train_batch_count, gan_false_score_list[-1]) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => GAN discrim. grad norm {}'.format(e, train_batch_count, gan_discriminator_grad_list[-1]) print 'epoch {}, batch_cnt {} => GAN generator grad norm {}'.format(e, train_batch_count, gan_generator_grad_list[-1]) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => TF generator grad norm {}'.format(e, train_batch_count, tf_generator_grad_list[-1]) if train_batch_count%100==0: tf_valid_mse = 0.0 valid_batch_count = 0 for valid_idx in xrange(num_valid_batches): start_idx = batch_size*valid_idx end_idx = batch_size*(valid_idx+1) evaluation_outputs = evaluator(valid_source_data[:][start_idx:end_idx][:], valid_target_data[:][start_idx:end_idx][:]) tf_valid_mse += evaluation_outputs[0].mean() valid_batch_count += 1 if valid_idx==0: recon_data = evaluation_outputs[1] recon_data = numpy.swapaxes(recon_data, axis1=0, axis2=1) recon_data = recon_data[:10] recon_data = recon_data.reshape((10, -1)) recon_data = recon_data*(1.15*2.**13) recon_data = recon_data.astype(numpy.int16) save_wavfile(recon_data, model_name+'_recon') orig_data = valid_target_data[:][start_idx:end_idx][:] orig_data = numpy.swapaxes(orig_data, axis1=0, axis2=1) orig_data = orig_data[:10] orig_data = orig_data.reshape((10, -1)) orig_data = orig_data*(1.15*2.**13) orig_data = orig_data.astype(numpy.int16) save_wavfile(orig_data, model_name+'_orig') valid_mse_list.append(tf_valid_mse/valid_batch_count) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => TF valid mse cost {}'.format(e, train_batch_count, valid_mse_list[-1]) if best_valid>valid_mse_list[-1]: best_valid = valid_mse_list[-1] if train_batch_count%500==0: numpy.save(file=model_name+'tf_mse', arr=numpy.asarray(tf_mse_list)) numpy.save(file=model_name+'tf_gen_grad', arr=numpy.asarray(tf_generator_grad_list)) numpy.save(file=model_name+'gan_mse', arr=numpy.asarray(gan_mse_list)) numpy.save(file=model_name+'gan_gen_cost', arr=numpy.asarray(gan_generator_cost_list)) numpy.save(file=model_name+'gan_disc_cost', arr=numpy.asarray(gan_true_score_list)) numpy.save(file=model_name+'gan_input_score', arr=numpy.asarray(gan_true_score_list)) numpy.save(file=model_name+'gan_sample_score', arr=numpy.asarray(gan_false_score_list)) numpy.save(file=model_name+'gan_gen_grad', arr=numpy.asarray(gan_generator_grad_list)) numpy.save(file=model_name+'gan_disc_grad', arr=numpy.asarray(gan_discriminator_grad_list)) numpy.save(file=model_name+'valid_mse', arr=numpy.asarray(valid_mse_list)) num_sec = 100 sampling_length = num_sec*sampling_rate/feature_size seed_input_data = valid_seed_data [generated_sequence, ] = sample_generator(seed_input_data, sampling_length) sample_data = numpy.swapaxes(generated_sequence, axis1=0, axis2=1) sample_data = sample_data.reshape((num_seeds, -1)) sample_data = sample_data*(1.15*2.**13) sample_data = sample_data.astype(numpy.int16) save_wavfile(sample_data, model_name+'_sample') if best_valid==valid_mse_list[-1]: save_model_params(generator_model, model_name+'_gen_model.pkl') save_model_params(discriminator_feature_model, model_name+'_disc_feat_model.pkl') save_model_params(discriminator_output_model, model_name+'_disc_output_model.pkl')
def train_model(feature_size, hidden_size, init_window_size, generator_model, generator_optimizer, num_epochs, model_name): # model updater print 'COMPILING UPDATER FUNCTION ' t = time() updater_function = set_updater_function(generator_model=generator_model, generator_optimizer=generator_optimizer, generator_grad_clipping=.0) print '%.2f SEC '%(time()-t) # evaluator print 'COMPILING EVALUATION FUNCTION ' t = time() evaluation_function = set_evaluation_function(generator_model=generator_model) print '%.2f SEC '%(time()-t) # sample generator print 'COMPILING SAMPLING FUNCTION ' t = time() sampling_function = set_sampling_function(generator_model=generator_model) print '%.2f SEC '%(time()-t) print 'READ RAW WAV DATA' _, train_raw_data = wavfile.read('/data/lisatmp4/taesup/data/YouTubeAudio/XqaJ2Ol5cC4.wav') valid_raw_data = train_raw_data[160000000:] train_raw_data = train_raw_data[:160000000] train_raw_data = train_raw_data[2000:] train_raw_data = (train_raw_data/(1.15*2.**13)).astype(floatX) valid_raw_data = (valid_raw_data/(1.15*2.**13)).astype(floatX) num_train_total_steps = train_raw_data.shape[0] num_valid_total_steps = valid_raw_data.shape[0] batch_size = 64 num_valid_sequences = num_valid_total_steps/(feature_size*init_window_size)-1 valid_source_data = valid_raw_data[:num_valid_sequences*(feature_size*init_window_size)] valid_source_data = valid_source_data.reshape((num_valid_sequences, init_window_size, feature_size)) valid_target_data = valid_raw_data[feature_size:feature_size+num_valid_sequences*(feature_size*init_window_size)] valid_target_data = valid_target_data.reshape((num_valid_sequences, init_window_size, feature_size)) valid_raw_data = None num_seeds = 10 valid_shuffle_idx = np_rng.permutation(num_valid_sequences) valid_source_data = valid_source_data[valid_shuffle_idx] valid_target_data = valid_target_data[valid_shuffle_idx] valid_seed_data = valid_source_data[:num_seeds][0][:] valid_source_data = numpy.swapaxes(valid_source_data, axis1=0, axis2=1) valid_target_data = numpy.swapaxes(valid_target_data, axis1=0, axis2=1) num_valid_batches = num_valid_sequences/batch_size print 'NUM OF VALID BATCHES : ', num_valid_sequences/batch_size best_valid = 10000. print 'START TRAINING' # for each epoch train_sample_cost_list = [] train_regularizer_cost_list = [] train_gradient_norm_list = [] train_lambda_regularizer_list = [] valid_sample_cost_list = [] train_batch_count = 0 for e in xrange(num_epochs): window_size = init_window_size + 5*e sequence_size = feature_size*window_size last_seq_idx = num_train_total_steps-(sequence_size+feature_size) train_seq_orders = np_rng.permutation(last_seq_idx) train_seq_orders = train_seq_orders[:last_seq_idx-last_seq_idx%batch_size] train_seq_orders = train_seq_orders.reshape((-1, batch_size)) print 'NUM OF TRAIN BATCHES : ', train_seq_orders.shape[0] # for each batch for batch_idx, batch_info in enumerate(train_seq_orders): # source data train_source_idx = batch_info.reshape((batch_size, 1)) + numpy.repeat(numpy.arange(sequence_size).reshape((1, sequence_size)), batch_size, axis=0) train_source_data = train_raw_data[train_source_idx] train_source_data = train_source_data.reshape((batch_size, window_size, feature_size)) train_source_data = numpy.swapaxes(train_source_data, axis1=0, axis2=1) # target data train_target_idx = train_source_idx + feature_size train_target_data = train_raw_data[train_target_idx] train_target_data = train_target_data.reshape((batch_size, window_size, feature_size)) train_target_data = numpy.swapaxes(train_target_data, axis1=0, axis2=1) # update model lambda_regularizer = 0.1 updater_outputs = updater_function(train_source_data, train_target_data, lambda_regularizer) train_sample_cost = updater_outputs[0].mean() train_regularizer_cost = updater_outputs[1].mean() train_gradient_norm = updater_outputs[2] train_batch_count += 1 train_sample_cost_list.append(train_sample_cost) train_regularizer_cost_list.append(train_regularizer_cost) train_gradient_norm_list.append(train_gradient_norm) train_lambda_regularizer_list.append(lambda_regularizer) if train_batch_count%10==0: print '============{}_LENGTH{}============'.format(model_name, window_size) print 'epoch {}, batch_cnt {} => train sample cost {}'.format(e, train_batch_count, train_sample_cost_list[-1]) print 'epoch {}, batch_cnt {} => train regularizer cost {}'.format(e, train_batch_count, train_regularizer_cost_list[-1]) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => train gradient norm {}'.format(e, train_batch_count, train_gradient_norm_list[-1]) print 'epoch {}, batch_cnt {} => train regularizer lambda {}'.format(e, train_batch_count, train_lambda_regularizer_list[-1]) if train_batch_count%100==0: tf_valid_mse = 0.0 valid_batch_count = 0 for valid_idx in xrange(num_valid_batches): start_idx = batch_size*valid_idx end_idx = batch_size*(valid_idx+1) evaluation_outputs = evaluation_function(valid_source_data[:][start_idx:end_idx][:], valid_target_data[:][start_idx:end_idx][:]) tf_valid_mse += evaluation_outputs[0].mean() valid_batch_count += 1 if valid_idx==0: recon_data = evaluation_outputs[1] recon_data = numpy.swapaxes(recon_data, axis1=0, axis2=1) recon_data = recon_data[:10] recon_data = recon_data.reshape((10, -1)) recon_data = recon_data*(1.15*2.**13) recon_data = recon_data.astype(numpy.int16) save_wavfile(recon_data, model_name+'_recon') orig_data = valid_target_data[:][start_idx:end_idx][:] orig_data = numpy.swapaxes(orig_data, axis1=0, axis2=1) orig_data = orig_data[:10] orig_data = orig_data.reshape((10, -1)) orig_data = orig_data*(1.15*2.**13) orig_data = orig_data.astype(numpy.int16) save_wavfile(orig_data, model_name+'_orig') valid_sample_cost_list.append(tf_valid_mse/valid_batch_count) print '----------------------------------------------------------' print 'epoch {}, batch_cnt {} => valid sample cost {}'.format(e, train_batch_count, valid_sample_cost_list[-1]) if best_valid>valid_sample_cost_list[-1]: best_valid = valid_sample_cost_list[-1] if train_batch_count%500==0: numpy.save(file=model_name+'_train_sample_cost', arr=numpy.asarray(train_sample_cost_list)) numpy.save(file=model_name+'_train_regularizer_cost', arr=numpy.asarray(train_regularizer_cost_list)) numpy.save(file=model_name+'_train_gradient_norm', arr=numpy.asarray(train_gradient_norm_list)) numpy.save(file=model_name+'_train_lambda_value', arr=numpy.asarray(train_lambda_regularizer_list)) numpy.save(file=model_name+'_valid_sample_cost', arr=numpy.asarray(valid_sample_cost_list)) num_sec = 100 sampling_length = num_sec*sampling_rate/feature_size seed_input_data = valid_seed_data [generated_sequence, ] = sampling_function(seed_input_data, sampling_length) sample_data = numpy.swapaxes(generated_sequence, axis1=0, axis2=1) sample_data = sample_data.reshape((num_seeds, -1)) sample_data = sample_data*(1.15*2.**13) sample_data = sample_data.astype(numpy.int16) save_wavfile(sample_data, model_name+'_sample') if best_valid==valid_sample_cost_list[-1]: save_model_params(generator_model, model_name+'_model.pkl')
num_sec = 100 sampling_length = num_sec*sampling_rate/feature_size seed_input_data = valid_seed_data [generated_sequence, ] = sampling_function(seed_input_data, sampling_length) sample_data = numpy.swapaxes(generated_sequence, axis1=0, axis2=1) sample_data = sample_data.reshape((num_seeds, -1)) sample_data = sample_data*(1.15*2.**13) sample_data = sample_data.astype(numpy.int16) save_wavfile(sample_data, model_name+'_sample') if best_valid==valid_sample_cost_list[-1]: save_model_params(generator_model, model_name+'_model.pkl') if __name__=="__main__": feature_size = 1600 hidden_size = 800 model_name = 'LSTM_REGULARIZER_LAMBDA' \ + '_FEATURE{}'.format(int(feature_size)) \ + '_HIDDEN{}'.format(int(hidden_size)) \ # generator model generator_model = set_generator_model(input_size=feature_size, hidden_size=hidden_size)