def train(args): fnames = glob.glob('../mp3/*.mp3')[:1] traces = [util.loadf(fname) for fname in fnames] traces = np.hstack(traces) dirname = 'save-vrnn' if not os.path.exists(dirname): os.makedirs(dirname) with open(os.path.join(dirname, 'config.pkl'), 'w') as f: cPickle.dump(args, f) model = VRNN(args) # load previously trained model if applicable ckpt = tf.train.get_checkpoint_state(dirname) if ckpt: model.load_model(dirname) with tf.Session() as sess: summary_writer = tf.train.SummaryWriter('logs/'+datetime.now().isoformat().replace(':','-'), sess.graph) check = tf.add_check_numerics_ops() merged = tf.merge_all_summaries() tf.initialize_all_variables().run() saver = tf.train.Saver(tf.all_variables()) start = time.time() for e in xrange(args.num_epochs): sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) state = model.initial_state for b in xrange(100): #t0 = np.random.randn(args.batch_size,1,(args.chunk_samples)) #x = np.sin(2*np.pi*(np.arange(args.seq_length)[np.newaxis,:,np.newaxis]/30.+t0)) + np.random.randn(args.batch_size,args.seq_length,(args.chunk_samples))*0.1 #y = np.sin(2*np.pi*(np.arange(1,args.seq_length+1)[np.newaxis,:,np.newaxis]/30.+t0)) + np.random.randn(args.batch_size,args.seq_length,(args.chunk_samples))*0.1 if (e * 100 + b)%int(traces.shape[0]/(args.chunk_samples*args.batch_size)) == 0: data, _, _ = util.load_augment_data(traces,args.chunk_samples) print "Refreshed data" #x,y = next_batch(data,args) slopes = 10*np.random.random((1,1,2*args.chunk_samples))+1 x,y = (slopes*np.arange(args.seq_length)[np.newaxis,:,np.newaxis])-1,(slopes*np.arange(args.seq_length)[np.newaxis,:,np.newaxis]) y[:,:,args.chunk_samples:] = 0. x[:,:,args.chunk_samples:] = 0. feed = {model.input_data: x, model.target_data: y} train_loss, _, cr, summary, sigma = sess.run([model.cost, model.train_op, check, merged, model.sigma], feed) #train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed) summary_writer.add_summary(summary, e * 100 + b) if (e * 100 + b) % args.save_every == 0 and ((e * 100 + b) > 0): checkpoint_path = os.path.join('save', 'model.ckpt') saver.save(sess, checkpoint_path, global_step = e * 100 + b) print "model saved to {}".format(checkpoint_path) end = time.time() print "{}/{} (epoch {}), train_loss = {:.6f}, time/batch = {:.1f}, std = {:.3f}/{:.3f}" \ .format(e * 100 + b, args.num_epochs * 100, e, args.chunk_samples*train_loss, end - start, (sigma[:,200:]).mean(axis=0).mean(axis=0),(sigma[:,:200]).mean(axis=0).mean(axis=0)) start = time.time()
def main(): ''' Main function ''' # Laod the saved arguments with open(os.path.join('save-vrnn', 'config.pkl')) as f: saved_args = cPickle.load(f) # Initialize the model with the saved arguments in inference mode model = VRNN(saved_args, True) # Initialize the TensorFlow session sess = tf.InteractiveSession() # Initialize the saver saver = tf.train.Saver(tf.all_variables()) # Get model checkpoint ckpt = tf.train.get_checkpoint_state('save-vrnn') print "loading model: ", ckpt.model_checkpoint_path # Restore the model from the saved file saver.restore(sess, ckpt.model_checkpoint_path) # Sample the model sample_data, mus, sigmas = model.sample(sess, saved_args)
def train(args): dirname = 'save-vrnn' if not os.path.exists(dirname): os.makedirs(dirname) with open(os.path.join(dirname, 'config.pkl'), 'w') as f: cPickle.dump(args, f) model = VRNN(args) ckpt = tf.train.get_checkpoint_state(dirname) with tf.Session() as sess: summary_writer = tf.train.SummaryWriter('logs/' + datetime.now().isoformat().replace(':', '-'), sess.graph) check = tf.add_check_numerics_ops() merged = tf.merge_all_summaries() tf.initialize_all_variables().run() saver = tf.train.Saver(tf.all_variables()) if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) print "Loaded model" start = time.time() for e in xrange(args.num_epochs): sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) state = model.initial_state for b in xrange(100): x, y = next_batch(args) feed = {model.input_data: x, model.target_data: y} train_loss, _, cr, summary, sigma, mu, input, target= sess.run( [model.cost, model.train_op, check, merged, model.sigma, model.mu, model.flat_input, model.target], feed) summary_writer.add_summary(summary, e * 100 + b) if (e * 100 + b) % args.save_every == 0 and ((e * 100 + b) > 0): checkpoint_path = os.path.join(dirname, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=e * 100 + b) print "model saved to {}".format(checkpoint_path) end = time.time() print "{}/{} (epoch {}), train_loss = {:.6f}, time/batch = {:.1f}, std = {:.3f}" \ .format(e * 100 + b, args.num_epochs * 100, e, args.chunk_samples * train_loss, end - start, sigma.mean(axis=0).mean(axis=0)) start = time.time()
help='RNN sequence length') parser.add_argument('--num_epochs', type=int, default=100, help='number of epochs') parser.add_argument('--save_every', type=int, default=500, help='save frequency') parser.add_argument('--grad_clip', type=float, default=10., help='clip gradients at this value') parser.add_argument('--learning_rate', type=float, default=0.0005, help='learning rate') parser.add_argument('--decay_rate', type=float, default=1., help='decay of learning rate') parser.add_argument('--chunk_samples', type=int, default=1, help='number of samples per mdct chunk') args = parser.parse_args() model = VRNN(args) train(args, model)
import tensorflow as tf import os import cPickle from model_vrnn import VRNN import numpy as np from train_vrnn import next_batch with open(os.path.join('save-vrnn', 'config.pkl')) as f: saved_args = cPickle.load(f) model = VRNN(saved_args, True) sess = tf.InteractiveSession() saver = tf.train.Saver(tf.all_variables()) ckpt = tf.train.get_checkpoint_state('save-vrnn') print "loading model: ", ckpt.model_checkpoint_path saver.restore(sess, ckpt.model_checkpoint_path) sample_data, mus, sigmas = model.sample(sess, saved_args)
import tensorflow as tf import os import cPickle from model_vrnn import VRNN import numpy as np from train_vrnn import next_batch with open(os.path.join('save-vrnn', 'config.pkl')) as f: saved_args = cPickle.load(f) model = VRNN(saved_args, True) sess = tf.InteractiveSession() saver = tf.train.Saver(tf.all_variables()) ckpt = tf.train.get_checkpoint_state('save-vrnn') print "loading model: ",ckpt.model_checkpoint_path saver.restore(sess, ckpt.model_checkpoint_path) sample_data,mus,sigmas = model.sample(sess,saved_args)
def train(args): ''' The train function Params: args : Input arguments ''' # Initialize the model model = VRNN(args) # Initialize the data loader object dataloader = DataLoader(args) pdb.set_trace() # Directory to save the trained model dirname = 'save-vrnn' if not os.path.exists(dirname): os.makedirs(dirname) # Dump the input arguments into the file with open(os.path.join(dirname, 'config.pkl'), 'w') as f: cPickle.dump(args, f) # get checkpoint ckpt = tf.train.get_checkpoint_state(dirname) # Initialize TensorFlow session with tf.Session() as sess: # Initialize summary writer summary_writer = tf.train.SummaryWriter( 'logs/' + datetime.now().isoformat().replace(':', '-'), sess.graph) check = tf.add_check_numerics_ops() # Write all summaries merged = tf.merge_all_summaries() # Initialize all variables in the graph tf.initialize_all_variables().run() # Initialize a saver for all variables saver = tf.train.Saver(tf.all_variables()) # Load already saved model if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) print "Loaded model" # Initialize timer start = time.time() # For each epoch for e in xrange(args.num_epochs): # Set the learning rate for this epoch sess.run( tf.assign(model.lr, args.learning_rate * (args.decay_rate**e))) # state = model.initial_state_c, model.initial_state_h # For each minibatch for b in xrange(100): # Get the input and target data of the next minibatch x, y = dataloader.next_batch() # Create the feed dict feed = {model.input_data: x, model.target_data: y} # Run the session and get loss train_loss, _, cr, summary, sigma, mu, input, target = sess.run( [ model.cost, model.train_op, check, merged, model.sigma, model.mu, model.flat_input, model.target ], feed) # Write summary summary_writer.add_summary(summary, e * 100 + b) # Save model if (e * 100 + b) % args.save_every == 0 and ( (e * 100 + b) > 0): checkpoint_path = os.path.join(dirname, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=e * 100 + b) print "model saved to {}".format(checkpoint_path) # End timer end = time.time() # Print info print "{}/{} (epoch {}), train_loss = {:.6f}, time/batch = {:.1f}, std = {:.3f}" \ .format(e * 100 + b, args.num_epochs * 100, e, args.chunk_samples * train_loss, end - start, sigma.mean(axis=0).mean(axis=0)) start = time.time()