Exemple #1
0
def test_valtest(sess, dcmf, epoch, step):
    val_epochs = Misc.get_epoch_files(dcmf.mfp.val_epochs)
    print 'Testing Perf: Val\t', epoch, step
    val_file = val_epochs[0]
    print 'Val file: ', val_file
    v_val_review = FromDisk.FromDisk(val_file)
    val_iter = v_val_review.BatchIter(dcmf.mfp.batch_size)

    oth_mse_val, full_mse_val = test(sess, dcmf, val_iter)

    print 'Testing MSE Other: Val\t', epoch, step, '\t', oth_mse_val
    print 'Testing MSE Full: Val\t', epoch, step, '\t', full_mse_val

    test_epochs = Misc.get_epoch_files(dcmf.mfp.test_epochs)
    print 'Testing Perf: Test\t', epoch, step
    test_file = test_epochs[0]
    print 'Test file: ', test_file
    v_test_review = FromDisk.FromDisk(test_file)
    test_iter = v_test_review.BatchIter(dcmf.mfp.batch_size)

    oth_mse_test, full_mse_test = test(sess, dcmf, test_iter)

    print 'Testing MSE Other: Test\t', epoch, step, '\t', oth_mse_test
    print 'Testing MSE Full: Test\t', epoch, step, '\t', full_mse_test

    return full_mse_val
def test_valtest(sess, dcmf, epoch, step):
    val_epochs = Misc.get_epoch_files(dcmf.mfp.val_epochs)
    print('Testing Perf: Val\t', epoch, step)
    val_file = val_epochs[0]
    print('Val file: ', val_file)
    v_val_review = FromDisk.FromDisk(val_file)
    val_iter = v_val_review.BatchIter(dcmf.mfp.batch_size)

    oth_mse_val, full_mse_val = test(sess, dcmf, val_iter)

    print('Testing MSE Other: Val\t', epoch, step, '\t', oth_mse_val)
    print('Testing MSE Full: Val\t', epoch, step, '\t', full_mse_val)

    test_epochs = Misc.get_epoch_files(dcmf.mfp.test_epochs)
    print('Testing Perf: Test\t', epoch, step)
    test_file = test_epochs[0]
    print('Test file: ', test_file)
    v_test_review = FromDisk.FromDisk(test_file)
    test_iter = v_test_review.BatchIter(dcmf.mfp.batch_size)

    oth_mse_test, full_mse_test = test(sess, dcmf, test_iter)

    print('Testing MSE Other: Test\t', epoch, step, '\t', oth_mse_test)
    print('Testing MSE Full: Test\t', epoch, step, '\t', full_mse_test)

    if full_mse_val < 1.7:
        #TODO: need to find the saving criteria from the previous MSE
        return True, full_mse_val  #save this model
    else:
        return False, full_mse_val
Exemple #3
0
    def BatchIter(self, batch_size):
        '''
        batch size = number of training u,b,r examples in the batch
        returns:
        uList = uA useres
        bList = iB items
        rList: rAB (float)
        user_revlist: the UText converted to int list
        item_revlist: the BText converted to int list
         
        '''
        while True:
            #one batch
            start = time.time()
            uList = []
            bList = []
            rList = []
            uTextList = []
            bTextList = []

            for line in self.fin:
                vals = line.split("\t")
                if len(vals) == 0:
                    continue

                u = vals[0]
                b = vals[1]
                r = float(vals[2])
                uText = vals[3]
                bText = vals[4]

                uList.append(u)
                bList.append(b)
                rList.append(r)
                uTextList.append(Misc.int_list(uText))
                bTextList.append(Misc.int_list(bText))

                if len(uList) >= batch_size:
                    break

            if len(uList) == 0:
                #end of data
                self._close()
                print('Total Batch gen time = ', (self.tot_batch / 60.0),
                      ' min')
                raise StopIteration

            end = time.time()

            bg = (end - start)

            print('Batch gen time = ', bg, ' sec')

            self.tot_batch += bg

            yield uList, bList, rList, uTextList, bTextList
Exemple #4
0
def train(dcmf, savedir):

    cfg = tf.ConfigProto(allow_soft_placement=True)
    cfg.gpu_options.allow_growth = True

    sess = tf.Session(config=cfg)

    min_MSE = float("inf")

    #read the embedding
    emb = pickle.load(open(dcmf.mfp.word_embedding_file, "rb"))

    dcmf.run_init_all(sess, emb)
    del emb

    step = 0

    #get the epoch files from the train dir
    train_epochs = Misc.get_epoch_files(dcmf.mfp.train_epochs)
    print 'Train Epochs: found ', len(train_epochs), ' files'

    #get the epoch files from the val dir
    val_epochs = Misc.get_epoch_files(dcmf.mfp.val_epochs)
    print 'Val Epochs: found ', len(val_epochs), ' files'

    #get the epoch files from the test dir
    test_epochs = Misc.get_epoch_files(dcmf.mfp.test_epochs)
    print 'Test Epochs: found ', len(test_epochs), ' files'

    #load the revAB from the train file
    dp_mgr = DataPairMgr.DataPairMgr(dcmf.mfp.train_data)

    names = True
    for epoch in range(dcmf.mfp.max_epoch):
        print 'Epoch: ', epoch

        train_time = 0

        train_file = train_epochs[epoch % len(train_epochs)]
        print 'Train file: ', train_file

        trainIter = FromDisk.FromDisk(train_file)
        batch_iter = trainIter.BatchIter(dcmf.mfp.batch_size)

        while True:
            step += 1
            rList, retUTextInt, retBTextInt, revABList = ([] for i in range(4))
            try:
                #read the values
                uList, bList, rList, retUTextInt, retBTextInt = batch_iter.next(
                )
                #get the revAB
                revABList = [
                    dp_mgr.get_int_review(u, b) for u, b in zip(uList, bList)
                ]

            except StopIteration:
                #end of this data epoch
                break

            start = time.time()
            act_rmse, oth_rmse, full_rmse = dcmf.run_train_step(
                sess, rList, retUTextInt, retBTextInt, revABList,
                dcmf.mfp.dropout_keep_prob)
            end = time.time()

            if names:
                names = False
                names_all, names_act, names_oth, names_full = dcmf.get_params()
                print 'Variables - all trainable: ', names_all
                print 'Variables trained in act: ', names_act
                print 'Variables trained in oth: ', names_oth
                print 'Variables trained in full: ', names_full

            tt = end - start
            print 'Train time ', (tt), ' sec'
            train_time += tt

            print 'Step ', step, ' act: ', act_rmse
            print 'Step ', step, ' oth: ', oth_rmse
            print 'Step ', step, ' full: ', full_rmse

            if step % 500 == 0:
                mse = test_valtest(sess, dcmf, epoch, step)

                if mse < min_MSE:
                    #save the current model
                    min_MSE = mse
                    save_model(sess, dcmf, mse, savedir)

        print 'End of Epoch Testing'
        mse = test_valtest(sess, dcmf, epoch, step)
        if mse < min_MSE:
            #save the current model
            save_model(sess, dcmf, mse, savedir)

        sys.stdout.flush()
def train(dcmf, savedir):

    cfg = tf.ConfigProto(allow_soft_placement=True )
    cfg.gpu_options.allow_growth = True

    sess = tf.Session(config=cfg)


    #read the embedding
    emb = pickle.load( open( dcmf.mfp.word_embedding_file, "rb" ),encoding="bytes" )

    dcmf.run_init_all(sess, emb)
    del emb

    step = 0

    #get the epoch files from the train dir
    train_epochs = Misc.get_epoch_files(dcmf.mfp.train_epochs)
    print ('Train Epochs: found ', len(train_epochs), ' files')

    #get the epoch files from the val dir
    val_epochs = Misc.get_epoch_files(dcmf.mfp.val_epochs)
    print ('Val Epochs: found ', len(val_epochs), ' files')


    #get the epoch files from the test dir
    test_epochs = Misc.get_epoch_files(dcmf.mfp.test_epochs)
    print ('Test Epochs: found ', len(test_epochs), ' files')

    #load the revAB from the train file NO DATA PAIr MANAGER
    # dp_mgr = DataPairMgr.DataPairMgr(dcmf.mfp.train_data)

    names = True
    for epoch in range(dcmf.mfp.max_epoch):
        print ('Epoch: ', epoch)

        train_time = 0

        train_file = train_epochs[epoch % len(train_epochs)]
        print ('Train file: ', train_file)

        trainIter = FromDisk.FromDisk(train_file,dcmf.mfp.review_delim,dcmf.mfp.review_emb)
        batch_iter = trainIter.BatchIter(dcmf.mfp.batch_size)

        while True:
            step += 1
            rList, retUTextInt, retBTextInt, revABList, revABList  = ([] for i in range(5))
            try:
                #read the values
                uList, bList, rList, retUTextInt, retBTextInt, revABList = next(batch_iter)
                #get the revAB
                # revABList = [ dp_mgr.get_int_review(u, b) for u,b in zip(uList, bList) ]

            except StopIteration:
                #end of this data epoch
                break

            start = time.time()
            act_rmse, oth_rmse, full_rmse = dcmf.run_train_step(sess, rList, retUTextInt, retBTextInt, revABList, dcmf.mfp.dropout_keep_prob)
            end = time.time()

            if names:
                names = False
                names_all, names_act, names_oth, names_full = dcmf.get_params()
                print ('Variables - all trainable: ', names_all)
                print ('Variables trained in act: ', names_act)
                print ('Variables trained in oth: ', names_oth)
                print ('Variables trained in full: ', names_full)


            tt= end - start
            #print   ('Train time ', (tt), ' sec')
            train_time += tt


            #print ('Step ', step, ' act: ', act_rmse)
            #print ('Step ', step, ' oth: ', oth_rmse)
            print ('Step ', step, ' full: ', full_rmse)

            if step % 1000 == 0:
                print('SAVING MODEL:')
                save_model(sess, dcmf, full_rmse, savedir)
                save, mse = test_valtest(sess, dcmf, epoch, step)

                #if save:
                    #print('SAVING MODEL:',mse)
                    #save the current model
                    #save_model(sess, dcmf, mse, savedir)



        print ('End of Epoch Testing')
        save, mse = test_valtest(sess, dcmf, epoch, step)
        #if save:
            #save the current model
        save_model(sess, dcmf, mse, savedir)

        sys.stdout.flush()