def doEva(net, d, G): net.eval() d = torch.LongTensor(d) u, i, r = d[:, 0], d[:, 1], d[:, 2] i_index = i.detach().numpy() adj_lists = dataloader4graph.graphSage4RecAdjType(G, i_index) out = net(u, adj_lists) y_pred = np.array([1 if i >= 0.5 else 0 for i in out]) y_true = r.detach().numpy() p = precision_score(y_true, y_pred) r = recall_score(y_true, y_pred) acc = accuracy_score(y_true, y_pred) return p, r, acc
def train(epoch=20, batchSize=1024, dim=128, tdim=64, lr=0.002, eva_per_epochs=1, atten_way='base'): user_set, item_set, train_set, test_set = dataloader4graph.readRecData() entitys, pairs = dataloader4graph.readGraphData() G = dataloader4graph.get_graph(pairs) net = GAFM(max(user_set) + 1, max(entitys) + 1, dim, tdim, atten_way) criterion = torch.nn.BCELoss() optimizer = torch.optim.AdamW(net.parameters(), lr=lr) for e in range(epoch): net.train() all_lose = 0 for u, i, r in tqdm( DataLoader(train_set, batch_size=batchSize, shuffle=True)): r = torch.FloatTensor(r.detach().numpy()) optimizer.zero_grad() i_index = i.detach().numpy() adj_lists = dataloader4graph.graphSage4RecAdjType(G, i_index) logits = net(u, adj_lists) loss = criterion(logits, r) all_lose += loss loss.backward() optimizer.step() print('epoch {}, avg_loss = {:.4f}'.format( e, all_lose / (len(train_set) // batchSize))) # 评估模型 if e % eva_per_epochs == 0: p, r, acc = doEva(net, train_set, G) print('train: Precision {:.4f} | Recall {:.4f} | accuracy {:.4f}'. format(p, r, acc)) p, r, acc = doEva(net, test_set, G) print('test: Precision {:.4f} | Recall {:.4f} | accuracy {:.4f}'. format(p, r, acc)) return net