def test_valtest(sess, dcmf, epoch, step): val_epochs = get_epoch_files(dcmf.mfp.val_epochs) print 'Testing Perf: Val\t', epoch, step val_file = val_epochs[0] print 'Val file: ', val_file v_val_review = FromDisk.FromDisk(val_file) val_iter = v_val_review.BatchIter(dcmf.mfp.batch_size) oth_mse_val, full_mse_val = test(sess, dcmf, val_iter) print 'Testing MSE Other: Val\t', epoch, step, '\t', oth_mse_val print 'Testing MSE Full: Val\t', epoch, step, '\t', full_mse_val test_epochs = get_epoch_files(dcmf.mfp.test_epochs) print 'Testing Perf: Test\t', epoch, step test_file = test_epochs[0] print 'Test file: ', test_file v_test_review = FromDisk.FromDisk(test_file) test_iter = v_test_review.BatchIter(dcmf.mfp.batch_size) oth_mse_test, full_mse_test = test(sess, dcmf, test_iter) print 'Testing MSE Other: Test\t', epoch, step, '\t', oth_mse_test print 'Testing MSE Full: Test\t', epoch, step, '\t', full_mse_test
def train(dcmf): cfg = tf.ConfigProto(allow_soft_placement=True ) cfg.gpu_options.allow_growth = True sess = tf.Session(config=cfg) #read the embedding emb = pickle.load( open( dcmf.mfp.word_embedding_file, "rb" ) ) dcmf.run_init_all(sess, emb) del emb step = 0 #get the epoch files from the train dir train_epochs = get_epoch_files(dcmf.mfp.train_epochs) print 'Train Epochs: found ', len(train_epochs), ' files' #get the epoch files from the val dir val_epochs = get_epoch_files(dcmf.mfp.val_epochs) print 'Val Epochs: found ', len(val_epochs), ' files' #get the epoch files from the test dir test_epochs = get_epoch_files(dcmf.mfp.test_epochs) print 'Test Epochs: found ', len(test_epochs), ' files' #load the revAB from the train file dp_mgr = DataPairMgr.DataPairMgr(dcmf.mfp.train_data) names = True for epoch in range(dcmf.mfp.max_epoch): print 'Epoch: ', epoch train_time = 0 train_file = train_epochs[epoch % len(train_epochs)] print 'Train file: ', train_file trainIter = FromDisk.FromDisk(train_file) batch_iter = trainIter.BatchIter(dcmf.mfp.batch_size) while True: step += 1 uList, bList, rList, retUTextInt, retBTextInt, revABList = ([] for i in range(6)) try: #read the values uList, bList, rList, retUTextInt, retBTextInt = batch_iter.next() #get the revAB revABList = [ dp_mgr.get_int_review(u, b) for u,b in zip(uList, bList) ] except StopIteration: #end of this data epoch break start = time.time() act_rmse, oth_rmse, full_rmse = dcmf.run_train_step(sess, uList, bList, rList, retUTextInt, retBTextInt, revABList, dcmf.mfp.dropout_keep_prob) end = time.time() if names: names = False names_all, names_act, names_oth, names_full = dcmf.get_params() print 'Variables - all trainable: ', names_all print 'Variables trained in act: ', names_act print 'Variables trained in oth: ', names_oth print 'Variables trained in full: ', names_full tt= end - start print 'Train time ', (tt), ' sec' train_time += tt print 'Step ', step, ' act: ', act_rmse print 'Step ', step, ' oth: ', oth_rmse print 'Step ', step, ' full: ', full_rmse if step % 1000 == 0: test_valtest(sess, dcmf, epoch, step) print 'End of Epoch Testing' test_valtest(sess, dcmf, epoch, step) sys.stdout.flush()