-
Notifications
You must be signed in to change notification settings - Fork 0
/
RBMAlgorithm.py
59 lines (47 loc) · 2.28 KB
/
RBMAlgorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 15 11:54:29 2020
@author: Rakin Shahriar
"""
from surprise import AlgoBase
from surprise import PredictionImpossible
import numpy as np
from RBM import RBM
class RBMAlgorithm(AlgoBase):
def __init__(self, epochs = 40, hiddenDim = 100, learningRate = 0.001, batchSize = 100, sim_options ={}):
AlgoBase.__init__(self)
self.epochs = epochs
self.hiddenDim = hiddenDim
self.learningRate = learningRate
self.batchSize = batchSize
def softmax(self, x):
return np.exp(x) / np.sum(np.exp(x), axis = 0)
def fit(self, trainset):
AlgoBase.fit(self, trainset)
numUsers = trainset.n_users
numItems = trainset.n_items
trainingMatrix = np.zeros([numUsers, numItems, 10], dtype = np.float32)
for (uid, iid, rating) in trainset.all_ratings():
adjustedRating = int(float(rating)*2.0) - 1
trainingMatrix[int(uid), int(iid), adjustedRating] = 1
trainingMatrix = np.reshape(trainingMatrix,[trainingMatrix.shape[0], - 1])
rbm = RBM(trainingMatrix.shape[1], hiddenDimensions = self.hiddenDim, learningRate = self.learningRate, batchSize = self.batchSize)
rbm.Train(trainingMatrix)
self.predictedRatings = np.zeros([numUsers, numItems], dtype = np.float32)
for uiid in range(trainset.n_users):
if(uiid % 50 == 0):
print("Procissing user ", uiid)
recs = rbm.GetRecommendations([trainingMatrix[uiid]])
recs = np.reshape(recs, [numItems, 10])
for itemID, rec in enumerate(recs):
normalized = self.softmax(rec)
rating = np.average(np.arange(10), weights = normalized)
self.predictedRatings[uiid,itemID] = (rating + 1)* 0.5
return self
def estimate(self, u, i):
if not (self.trainset.knows_user(u) and self.trainset.knows_item(i)):
raise PredictionImpossible('User and/or item is unknown. ')
rating = self.predictedRatings[u,i]
if(rating < 0.001):
raise PredictionImpossible('No valid prediction exists. ')
return rating