def load_model(dataset_name, algo_name, dataset): if algo_name == "vbpr": model = VBPR(dataset.n_users, dataset.n_items, dataset.corpus.image_features) elif algo_name == "deepstyle": model = DeepStyle( dataset.n_users, dataset.n_items, dataset.n_categories, dataset.corpus.image_features, dataset.corpus.item_category ) model.load(f'../data/dataset/{dataset_name}/models/{algo_name}_resnet50.pth') return model
from dataset import RecSysDataset from train import Trainer from models import VBPR import torch if __name__ == '__main__': k=10 k2=20 batch_size=128 n_epochs=20 dataset = RecSysDataset() vbpr = VBPR( dataset.n_users, dataset.n_items, dataset.corpus.image_features, k, k2) tr = Trainer(vbpr, dataset) tr.train(n_epochs, batch_size) torch.save(vbpr, 'vbpr_resnet50_v1.pth')
parser.add_argument('--lambda_e', type=float, default=0.0001) parser.add_argument('--algorithm', type=str, default='deepstyle') # vbpr, deepstyle parser.add_argument('--dataset', type=str, default='Electronics') args = parser.parse_args() print(args) np.random.seed(args.seed) torch.manual_seed(args.seed) dataset = RecSysDataset(args.dataset) if args.algorithm == "vbpr": model = VBPR(dataset.n_users, dataset.n_items, dataset.corpus.image_features, args.k, args.k2, args.lambda_w, args.lambda_b, args.lambda_e) elif args.algorithm == "deepstyle": model = DeepStyle(dataset.n_users, dataset.n_items, dataset.n_categories, dataset.corpus.image_features, dataset.corpus.item_category, args.k, args.lambda_w, args.lambda_e) elif args.algorithm == "bpr": model = BPR(dataset.n_users, dataset.n_items, args.k, args.lambda_w, args.lambda_b) if torch.cuda.is_available(): model = model.cuda()
if __name__ == '__main__': parser = ArgumentParser(description="Experiments") parser.add_argument('--k', type=int, default=10) parser.add_argument('--k2', type=int, default=10) parser.add_argument('--algorithm', type=str, default='deepstyle') # bpr, vbpr, vbprc, deepstyle parser.add_argument('--dataset', type=str, default='Electronics') args = parser.parse_args() print(args) dataset = RecSysDataset(args.dataset) if args.algorithm == "vbpr": model = VBPR(dataset.n_users, dataset.n_items, dataset.corpus.image_features, args.k, args.k2) elif args.algorithm == "vbprc": model = VBPRC(dataset.n_users, dataset.n_items, dataset.n_categories, dataset.corpus.image_features, dataset.corpus.item_category, args.k, args.k2) elif args.algorithm == "deepstyle": model = DeepStyle(dataset.n_users, dataset.n_items, dataset.n_categories, dataset.corpus.image_features, dataset.corpus.item_category, args.k) elif args.algorithm == "bpr": model = BPR(dataset.n_users, dataset.n_items, args.k) model.load(