def test_bpr_train_stores_data(): bpr = BPR(1, 2, 3) bpr.train([ (0, 1), (0, 2), (1, 0), (1, 2), ], batch_size=4) assert_equal(bpr._train_users, set([0, 1])) assert_equal(bpr._train_items, set([0, 1, 2])) assert_equal(bpr._train_dict, { 0: [ 1, 2 ], 1: [ 0, 2 ], })
def test_bpr_predictions(): bpr = BPR(10, 100, 50) train_data = zip(randint(100, size=1000), randint(50, size=1000)) bpr.train(train_data, epochs=1) assert_equal(bpr.predictions(0).shape, (50,)) assert_equal(bpr.prediction(0,0), bpr.predictions(0)[0]) assert_equal(len(bpr.top_predictions(0, topn=20)), 20)
def test_bpr_train_and_test(): bpr = BPR(10, 200, 50) train_data = zip(randint(100, size=1000), randint(50, size=1000)) bpr.train(train_data, batch_size=50) assert(bpr.test(train_data) > 0.8) test_data = zip(randint(100, size=1000), randint(50, size=1000)) assert(bpr.test(test_data) > 0.4 and bpr.test(test_data) < 0.6)
def test_bpr_train_no_epochs(): bpr = BPR(10, 100, 50) train_data = zip(randint(100, size=1000), randint(50, size=1000)) bpr.train(train_data, epochs=0) assert(bpr.test(train_data) > 0.4 and bpr.test(train_data) < 0.6)
session = sessions[s] ssl = usl[s] + 1 for i in range(ssl): a.append((user, session[i][1])) return a print("converting training data to array") a = convert_data(train, train_sl) training_array, uti, iti = load_data_from_array(a) a = convert_data(test, test_sl) testing_array, uti, iti = load_data_from_array(a, uti, iti) print("creating BPR model") bpr = BPR(embedding_size, len(list(uti.keys())), len(list(iti.keys())), learning_rate=learning_rate) print("training model") split2 = int(len(training_array) / 2) split1 = int(split2 / 2) split3 = split1 + split2 users = train.keys() session_batch = [] sl = [] for user in users: session_batch.append(test[user][0]) sl.append(test_sl[user][0]) session_batch = [[event[1] for event in session] for session in session_batch] for s in range(len(session_batch)):
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from theano_bpr.utils import load_data_from_csv from theano_bpr import BPR import sys if len(sys.argv) != 3: print "Usage: ./example.py training_data.csv testing_data.csv" sys.exit(1) # Loading train data train_data, users_to_index, items_to_index = load_data_from_csv(sys.argv[1]) # Loading test data test_data, users_to_index, items_to_index = load_data_from_csv( sys.argv[2], users_to_index, items_to_index) # Initialising BPR model, 10 latent factors bpr = BPR(10, len(users_to_index.keys()), len(items_to_index.keys())) # Training model, 30 epochs bpr.train(train_data, epochs=30) # Testing model print bpr.test(test_data)
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from theano_bpr.utils import load_data_from_csv from theano_bpr import BPR import sys if len(sys.argv) != 3: print "Usage: ./example.py training_data.csv testing_data.csv" sys.exit(1) # Loading train data train_data, users_to_index, items_to_index = load_data_from_csv(sys.argv[1]) # Loading test data test_data, users_to_index, items_to_index = load_data_from_csv(sys.argv[2], users_to_index, items_to_index) # Initialising BPR model, 10 latent factors bpr = BPR(10, len(users_to_index.keys()), len(items_to_index.keys())) # Training model, 30 epochs bpr.train(train_data, epochs=30) # Testing model print bpr.test(test_data)