def main(): util = Util() dir = args.data_dir rows = args.rows ratings, to_read, books = util.read_data(dir) ratings = util.clean_subset(ratings, rows) num_vis = len(ratings) free_energy = args.free_energy train = util.preprocess(ratings) valid = None if free_energy: train, valid = util.split_data(train) H = args.num_hid user = args.user alpha = args.alpha w = np.random.normal(loc=0, scale=0.01, size=[num_vis, H]) rbm = RBM(alpha, H, num_vis) epochs = args.epochs batch_size = args.batch_size v = args.verbose reco, prv_w, prv_vb, prv_hb = rbm.training(train, valid, user, epochs, batch_size, free_energy, v) unread, read = rbm.calculate_scores(ratings, books, to_read, reco, user) rbm.export(unread, read)
def get_recc(att_df, cat_rating): util = Util() epochs = 50 rows = 40000 alpha = 0.01 H = 128 batch_size = 16 dir = 'etl/' ratings, attractions = util.read_data(dir) ratings = util.clean_subset(ratings, rows) rbm_att, train = util.preprocess(ratings) num_vis = len(ratings) rbm = RBM(alpha, H, num_vis) joined = ratings.set_index('attraction_id').join(attractions[[ "attraction_id", "category" ]].set_index("attraction_id")).reset_index('attraction_id') grouped = joined.groupby('user_id') category_df = grouped['category'].apply(list).reset_index() rating_df = grouped['rating'].apply(list).reset_index() cat_rat_df = category_df.set_index('user_id').join( rating_df.set_index('user_id')) cat_rat_df['cat_rat'] = cat_rat_df.apply(f, axis=1) cat_rat_df = cat_rat_df.reset_index()[['user_id', 'cat_rat']] cat_rat_df['user_data'] = [cat_rating for i in range(len(cat_rat_df))] cat_rat_df['sim_score'] = cat_rat_df.apply(sim_score, axis=1) user = cat_rat_df.sort_values(['sim_score']).values[0][0] print("Similar User: {u}".format(u=user)) filename = "e" + str(epochs) + "_r" + str(rows) + "_lr" + str( alpha) + "_hu" + str(H) + "_bs" + str(batch_size) reco, weights, vb, hb = rbm.load_predict(filename, train, user) unseen, seen = rbm.calculate_scores(ratings, attractions, reco, user) rbm.export(unseen, seen, 'rbm_models/' + filename, str(user)) return filename, user, rbm_att