def train(): data_loader = TrainDataLoader() net = Net(student_n, exer_n, knowledge_n) net = net.to(device) optimizer = optim.Adam(net.parameters(), lr=0.002) print('training model...') loss_function = nn.NLLLoss() for epoch in range(epoch_n): data_loader.reset() running_loss = 0.0 batch_count = 0 while not data_loader.is_end(): batch_count += 1 input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch( ) input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to( device), input_exer_ids.to(device), input_knowledge_embs.to( device), labels.to(device) optimizer.zero_grad() output_1 = net.forward(input_stu_ids, input_exer_ids, input_knowledge_embs) output_0 = torch.ones(output_1.size()).to(device) - output_1 output = torch.cat((output_0, output_1), 1) # grad_penalty = 0 loss = loss_function(torch.log(output), labels) loss.backward() optimizer.step() net.apply_clipper() running_loss += loss.item() if batch_count % 200 == 199: print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_count + 1, running_loss / 200)) running_loss = 0.0 # validate and save current model every epoch rmse, auc = validate(net, epoch) save_snapshot(net, 'model/model_epoch' + str(epoch + 1))
def train(method): seed = 0 client_list = [0, 1] Nets = [] random.seed(seed) path = 'data/' for client in client_list: torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) data_loader = TrainDataLoader(client, path) val_loader = ValTestDataLoader(client, path) net = Net(data_loader.student_n, data_loader.exer_n, data_loader.knowledge_n) net = net.to(device) Nets.append([ client, data_loader, net, copy.deepcopy(net.state_dict()), val_loader ]) global_model1 = Nets[0][3] loss_function = nn.MSELoss(reduction='mean') Gauc = 0 for i in range(Epoch): AUC = [] ACC = [] for index in range(len(Nets)): school = Nets[index][0] net = Nets[index][2] data_loader = Nets[index][1] val_loader = Nets[index][4] optimizer = optim.Adam(net.parameters(), lr=0.001) print('training model...' + str(school)) best = 0 best_epoch = 0 best_knowauc = None best_indice = 0 for epoch in range(epoch_n): metric, _, _, know_auc, know_acc = validate( net, epoch, school, path, val_loader) auc = metric[1] rmse = metric[0] indice = metric[1] if auc > best: best = auc best_knowauc = know_auc best_indice = indice best_epoch = epoch best_knowacc = know_acc Nets[index][3] = copy.deepcopy(net.state_dict()) if epoch - best_epoch >= 5: break data_loader.reset() running_loss = 0.0 batch_count = 0 know_distribution = torch.zeros((data_loader.knowledge_n)) while not data_loader.is_end(): batch_count += 1 input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch( ) know_distribution += torch.sum(input_knowledge_embs, 0) input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to( device), input_exer_ids.to( device), input_knowledge_embs.to( device), labels.to(device) optimizer.zero_grad() output = net.forward(input_stu_ids, input_exer_ids, input_knowledge_embs) loss = loss_function(output, labels) loss.backward() optimizer.step() net.load_state_dict(Nets[index][3]) Nets[index][2] = net distribution = know_distribution * best_knowacc distribution[distribution == 0] = 0.001 Nets[index].append(distribution.unsqueeze(1).to(device)) print('Best AUC:', best) AUC.append([best_indice, best_knowacc]) ACC.append(best_indice) l_school = [item[0] for item in Nets] l_weights = [len(item[1].data) for item in Nets] l_know = [item[5] for item in Nets] l_net = [item[3] for item in Nets] metric0 = [] metric1 = [] metric2 = [] global_model2, student_group, question_group, _ = Fedknow( l_net, l_weights, l_know, AUC, method) print('global test ===========') for k in range(len(Nets)): metric2.append( validate(Nets[k][2], i, l_school[k], path, Nets[k][4])) globalauc = total(metric2) for k in range(len(Nets)): Apply(copy.deepcopy(global_model2), Nets[k][2], AUC[k], student_group, question_group, method)