def train(network_architecture, minibatches, type, learning_rate=0.001, batch_size=200, training_epochs=100, display_step=5): tf.reset_default_graph() print("type = ", type) vae = '' if type == 'prodlda': vae = prodlda.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) elif type == 'nvlda': vae = nvlda.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) elif type == 'prodlda2': vae = prodlda2.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) emb = 0 # Training cycle for epoch in range(training_epochs): avg_cost = 0. total_batch = int(n_samples_tr / batch_size) # Loop over all batches for i in range(total_batch): batch_xs = minibatches.next() # Fit training using batch data cost, emb = vae.partial_fit(batch_xs) # Compute average loss avg_cost += cost / n_samples_tr * batch_size if np.isnan(avg_cost): print epoch, i, np.sum(batch_xs, 1).astype(np.int), batch_xs.shape print 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' # return vae,emb sys.exit() # Display logs per epoch step if epoch % display_step == 0: print "Epoch:", '%04d' % (epoch+1), \ "cost=", "{:.9f}".format(avg_cost) return vae, emb
def train(config): tf.reset_default_graph() vocab, docs_tr = load_data('train') vocab, docs_va_h1, docs_va_h2 = load_data('valid') #if type=='prodlda': #vae = prodlda.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) network_architecture = { 'n_hidden_recog_1': config['layer1'], 'n_hidden_recog_2': config['layer2'], 'n_hidden_gener_1': docs_tr.shape[1], 'n_input': docs_tr.shape[1], 'n_z': int(t) } vae = prodlda.VAE(network_architecture, learning_rate=config['learning_rate'], batch_size=config['batch_size'], keep_prob=config['keep_prob']) saver = tf.train.Saver() vae.sess.graph.finalize() #elif type=='nvlda': # vae = nvlda.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) #summaries = get_summaries(vae.sess) #writer = tf.summary.FileWriter(ckpt + '/logs/', vae.sess.graph) for epoch in range(config['epochs']): avg_cost = 0. #total_batch = int(n_samples_tr / batch_size) indices = np.random.permutation(docs_tr.shape[0]) for base in range(0, docs_tr.shape[0], config['batch_size']): batch_xs = docs_tr[ indices[base:min(base + config['batch_size'], docs_tr.shape[0])]] cost = vae.partial_fit(batch_xs) # Compute average loss avg_cost += cost / docs_tr.shape[0] * config['batch_size'] if np.isnan(avg_cost): #print(epoch,i,np.sum(batch_xs,1).astype(np.int),batch_xs.shape) #print('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.') return #sys.exit() #parameters = {v.name: v for v in tf.trainable_variables() if v.name == 'beta/kernel:0' or v.name == 'beta/bias:0'} #emb = parameters['beta/kernel:0'].eval(vae.sess) + parameters['beta/bias:0'].eval(vae.sess) emb = [ v for v in tf.trainable_variables() if v.name == 'beta/kernel:0' ][0].eval(vae.sess) perplexity = evaluate(vae, emb, docs_va_h1, docs_va_h2, 'valid', config) #perplexity = evaluate(vae, emb, 'valid', summaries, writer, saver, vae.sess, epoch) # Display logs per epoch step #if epoch % display_step == 0: # print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)) saver.save(vae.sess, os.path.join(save_path, 'model.ckpt')) tune.track.log(validation_perplexity=perplexity) return
vocab, docs_tr = load_data('train') vocab, docs_te_h1, docs_te_h2 = load_data('test') #config = tf.ConfigProto() #config.gpu_options.allow_growth=True #sess = tf.Session(config=config) network_architecture = { 'n_hidden_recog_1': config['layer1'], 'n_hidden_recog_2': config['layer2'], 'n_hidden_gener_1': docs_tr.shape[1], 'n_input': docs_tr.shape[1], 'n_z': int(t) } vae = prodlda.VAE(network_architecture, learning_rate=config['learning_rate'], batch_size=config['batch_size'], keep_prob=config['keep_prob']) #if m=='prodlda': # vae = prodlda.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) #else: # vae = nvlda.VAE(network_architecture, learning_rate=learning_rate, batch_size=batch_size) saver = tf.train.Saver() saver.restore(vae.sess, load_from) print("Model restored.") #emb = sess.run(vae.network_weights['weights_gener']['h2']) emb = [ v for v in tf.trainable_variables() if v.name == 'beta/kernel:0' ][0].eval(vae.sess) perplexity = evaluate(vae, emb, docs_te_h1, docs_te_h2, 'test', config,