-
Notifications
You must be signed in to change notification settings - Fork 0
/
als.py
79 lines (56 loc) · 2.26 KB
/
als.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from implicit.als import AlternatingLeastSquares as als
from scipy.sparse import csr_matrix
import scipy.sparse as sp
from loguru import logger
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from validation import Validator
from preprocess import Preprocessor
class ALS_helper:
def __init__(self,factors=10,train_df_path='train_df.csv',test_df_path='test_df.csv'):
self.model = als(factors=factors)
self.train_df = pd.read_csv(train_df_path)
self.test_df = pd.read_csv(test_df_path)
self.ratings_matrix = self.__get_matrix()
self.user_items = self.ratings_matrix.T.tocsr()
self.validator = Validator(test_df_path)
os.environ["OPENBLAS_NUM_THREADS"] = '1'
logger.info('The heplper is initialized succesfully')
def train(self):
logger.info('Model is training now')
self.model.fit(self.ratings_matrix)
logger.info('Model is trained')
def __recommend(self):
recom_list = self.model.recommend_all(self.user_items,filter_already_liked_items=True)
return recom_list
def validate(self):
users = np.unique(self.test_df.userId.values)
recom_list = self.__recommend()
ndc1_als = []
ndc10_als = []
for user in tqdm(users):
n1,n10 = self.validator.valid(user,recom_list[user])
ndc1_als.append(n1)
ndc10_als.append(n10)
return np.mean(ndc1_als),np.mean(ndc10_als)
def __get_matrix(self):
self.train_df['userId'] = self.train_df['userId'].astype('category')
self.train_df['movieId'] = self.train_df['movieId'].astype('category')
ratings_matrix = sp.coo_matrix(
(self.train_df['rating'].astype(np.float32) ,
(
self.train_df['movieId'].cat.codes.copy(),
self.train_df['userId'].cat.codes.copy()
)
)
)
ratings_matrix = ratings_matrix.tocsr()
return ratings_matrix
prep = Preprocessor('ratings.csv')
prep.process(0.4)
mdl = ALS_helper()
mdl.train()
a,b = mdl.validate()
logger.info('ndcg@1 = {}, ndcg@10 = {}'.format(a,b))