Exemplo n.º 1
0
def test_valtest(sess, dcmf, epoch, step):  
    val_epochs = get_epoch_files(dcmf.mfp.val_epochs)      
    print 'Testing Perf: Val\t', epoch, step
    val_file = val_epochs[0]
    print 'Val file: ', val_file
    v_val_review = FromDisk.FromDisk(val_file)
    val_iter = v_val_review.BatchIter(dcmf.mfp.batch_size)    
    
    oth_mse_val, full_mse_val = test(sess, dcmf, val_iter)

    print 'Testing MSE Other: Val\t', epoch, step, '\t', oth_mse_val
    print 'Testing MSE Full: Val\t', epoch, step, '\t', full_mse_val
    
    test_epochs = get_epoch_files(dcmf.mfp.test_epochs)
    print 'Testing Perf: Test\t', epoch, step
    test_file = test_epochs[0]
    print 'Test file: ', test_file
    v_test_review = FromDisk.FromDisk(test_file)
    test_iter = v_test_review.BatchIter(dcmf.mfp.batch_size)    

    oth_mse_test, full_mse_test = test(sess, dcmf, test_iter)
    
    print 'Testing MSE Other: Test\t', epoch, step, '\t', oth_mse_test
    print 'Testing MSE Full: Test\t', epoch, step, '\t', full_mse_test
Exemplo n.º 2
0
def train(dcmf):
    
    cfg = tf.ConfigProto(allow_soft_placement=True )
    cfg.gpu_options.allow_growth = True
    
    sess = tf.Session(config=cfg)


    #read the embedding
    emb = pickle.load( open( dcmf.mfp.word_embedding_file, "rb" ) )
    
    dcmf.run_init_all(sess, emb)
    del emb
    
    step = 0
    
    #get the epoch files from the train dir
    train_epochs = get_epoch_files(dcmf.mfp.train_epochs)
    print 'Train Epochs: found ', len(train_epochs), ' files'
        
    #get the epoch files from the val dir
    val_epochs = get_epoch_files(dcmf.mfp.val_epochs)
    print 'Val Epochs: found ', len(val_epochs), ' files'
    

    #get the epoch files from the test dir
    test_epochs = get_epoch_files(dcmf.mfp.test_epochs)
    print 'Test Epochs: found ', len(test_epochs), ' files'
    
    #load the revAB from the train file
    dp_mgr = DataPairMgr.DataPairMgr(dcmf.mfp.train_data)
    
    names = True
    for epoch in range(dcmf.mfp.max_epoch):
        print 'Epoch: ', epoch
        
        train_time = 0
        
        train_file = train_epochs[epoch % len(train_epochs)]
        print 'Train file: ', train_file        
        
        trainIter = FromDisk.FromDisk(train_file)
        batch_iter = trainIter.BatchIter(dcmf.mfp.batch_size)
        
        while True:
            step += 1
            uList, bList, rList, retUTextInt, retBTextInt, revABList  = ([] for i in range(6))
            try:
                #read the values
                uList, bList, rList, retUTextInt, retBTextInt  = batch_iter.next()
                #get the revAB 
                revABList = [ dp_mgr.get_int_review(u, b) for u,b in zip(uList, bList) ]
                
            except StopIteration:
                #end of this data epoch
                break
            
            start = time.time()
            act_rmse, oth_rmse, full_rmse = dcmf.run_train_step(sess, uList, bList, rList, retUTextInt, retBTextInt, revABList, dcmf.mfp.dropout_keep_prob)
            end = time.time()
            
            if names:
                names = False
                names_all, names_act, names_oth, names_full = dcmf.get_params()
                print 'Variables - all trainable: ', names_all
                print 'Variables trained in act: ', names_act
                print 'Variables trained in oth: ', names_oth
                print 'Variables trained in full: ', names_full
            
            
            tt= end - start
            print 'Train time ', (tt), ' sec'
            train_time += tt
            
            
            print 'Step ', step, ' act: ', act_rmse
            print 'Step ', step, ' oth: ', oth_rmse
            print 'Step ', step, ' full: ', full_rmse
            
            if step % 1000 == 0:
                test_valtest(sess, dcmf, epoch, step)
            
        
        print 'End of Epoch Testing'
        test_valtest(sess, dcmf, epoch, step)
        
        sys.stdout.flush()