def gen_wav(savename): mdl = cPickle.load(open(savename))[0] w = mdl.layers[0].W.get_value() w = np.squeeze(w) grid_plot(w) plt.show()
def fine_tune(savename): osrtdnn = cPickle.load(open(savename))[0] learning_rate = 0.1 n_epochs = 10000000 dataset = 'timit' image_w = 2048 batch_size = 100 gridx = 10 gridy = 10 start = 1 stop = start + batch_size channel = 1 wavtype = 'timit' learning_rule = 'ada' slice = 'N' mom = 0.96 postfix = '_z'+str(image_w) ind = 1 savename = os.path.splitext(savename)[0]+postfix train_set_x = load_data_timit_seq('test', start, stop, image_w, wavtype,slice) # compute number of minibatches for training, validation and testing n_train_batches0 = train_set_x.get_value(borrow=True).shape[0] n_train_batches = n_train_batches0 / batch_size ###################### # BUILD ACTUAL MODEL # ###################### print '... building the model' # allocate symbolic variables for the data index = T.lscalar() # index to a [mini]batch x = T.matrix('x') # the data is presented as rasterized images # Z initialization osrtdnn.set_batch_size(n_train_batches0) osrtdnn.set_image_w(image_w) x_tot_shape = x.reshape((n_train_batches0, channel, 1, image_w)) z_val = osrtdnn.encode(x_tot_shape) z_init = theano.function([x],z_val) z_tot = theano.shared(value=z_init(train_set_x.get_value()), borrow=True) x_re = x.reshape((batch_size, channel, 1, osrtdnn.layers[0].input_shape[3])) cost,cost_dec,cost_rec = osrtdnn.cost(x_re,z_tot[index*batch_size:(index+1)*batch_size]) zgrads = T.grad(cost, z_tot) zgradsdic = {z_tot:zgrads} if learning_rule == 'ada': ad = AdaDelta() zupdates = ad.get_updates(learning_rate, zgradsdic) elif learning_rule == 'con': zupdates = [] for param_i, grad_i in zip(zparams, zgrads): zupdates.append((param_i, param_i - learning_rate * grad_i)) elif learning_rule == 'mom': momentum = mom mm = Momentum(momentum) zupdates = mm.get_updates(learning_rate, zgradsdic) else: raise ValueError('invalid learning_rule') train_z_model = theano.function( inputs = [index], outputs = [cost, cost_dec,cost_rec], updates = zupdates, givens = {x: train_set_x[index * batch_size: (index + 1) * batch_size]}) z_in = T.tensor4() decode_out = theano.function([z_in], osrtdnn.decode(z_in)) ############### # TRAIN MODEL # ############### print '... training' # early-stopping parameters pat_time = np.inf first_lr = learning_rate st_an = 200 en_an = 2000 best_params = None best_validation_loss = np.inf test_score = 0. start_time = time.clock() epoch_start_time=0 score_cum=[] score_dec_cum=[] score_rec_cum=[] epoch = 0 done_looping = False while (epoch < n_epochs) and (not done_looping): epoch_start_time = time.clock() epoch = epoch + 1 if epoch > st_an and learning_rule in ['con','mom']: learning_rate = first_lr/(epoch-st_an) cost_ij=0 cost_dec_ij=0 cost_rec_ij=0 for minibatch_index in xrange(n_train_batches): cost_ij += train_z_model(minibatch_index)[0] cost_dec_ij += train_z_model(minibatch_index)[1] cost_rec_ij += train_z_model(minibatch_index)[2] cost_ij /= (2*(n_train_batches)) cost_dec_ij /= (2*(n_train_batches)) cost_rec_ij /= (2*(n_train_batches)) score_cum.append(cost_ij) score_dec_cum.append(cost_dec_ij) score_rec_cum.append(cost_rec_ij) # compute loss on validation set print('%3i, training error %.2f, %.2f, %.2f, %.2fs, %s ' % \ (epoch, cost_ij, cost_dec_ij, cost_rec_ij, (time.clock() - epoch_start_time), savename)) # if we got the best validation score until now if (epoch%50==0 and cost_ij < best_validation_loss) or time.clock()-start_time > pat_time: best_validation_loss = cost_ij z_dec = decode_out(z_tot.get_value()) grid_plot.grid_plot((train_set_x.get_value(), z_dec)) #plt.legend('test','decoded') plt.savefig(savename+'.png') plt.close() with open(savename+'.pkl','wb') as f: cPickle.dump([osrtdnn, z_tot, [score_cum, score_dec_cum, score_rec_cum]],f) ''' for i,save_wav in enumerate([z_dec[ind], train_set_x.get_value()[ind]]): x_dec_sav = save_wav*_std+_mean x_dec_sav = np.asarray(x_dec_sav, dtype=np.int16) wavfile.write(os.path.splitext(savename)[0]+'_'+str(ind)+'_'+str(i)+'.wav',16000, x_dec_sav) ''' end_time = time.clock() print('Optimization complete.') print savename