Ejemplo n.º 1
0
def train():
    data_loader = TrainDataLoader()
    net = Net(student_n, exer_n, knowledge_n)

    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.002)
    print('training model...')

    loss_function = nn.NLLLoss()
    for epoch in range(epoch_n):
        data_loader.reset()
        running_loss = 0.0
        batch_count = 0
        while not data_loader.is_end():
            batch_count += 1
            input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch(
            )
            input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to(
                device), input_exer_ids.to(device), input_knowledge_embs.to(
                    device), labels.to(device)
            optimizer.zero_grad()
            output_1 = net.forward(input_stu_ids, input_exer_ids,
                                   input_knowledge_embs)
            output_0 = torch.ones(output_1.size()).to(device) - output_1
            output = torch.cat((output_0, output_1), 1)

            # grad_penalty = 0
            loss = loss_function(torch.log(output), labels)
            loss.backward()
            optimizer.step()
            net.apply_clipper()

            running_loss += loss.item()
            if batch_count % 200 == 199:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, batch_count + 1, running_loss / 200))
                running_loss = 0.0

        # validate and save current model every epoch
        rmse, auc = validate(net, epoch)
        save_snapshot(net, 'model/model_epoch' + str(epoch + 1))
Ejemplo n.º 2
0
def train(method):
    seed = 0
    client_list = [0, 1]

    Nets = []
    random.seed(seed)
    path = 'data/'
    for client in client_list:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        data_loader = TrainDataLoader(client, path)
        val_loader = ValTestDataLoader(client, path)
        net = Net(data_loader.student_n, data_loader.exer_n,
                  data_loader.knowledge_n)
        net = net.to(device)
        Nets.append([
            client, data_loader, net,
            copy.deepcopy(net.state_dict()), val_loader
        ])

    global_model1 = Nets[0][3]
    loss_function = nn.MSELoss(reduction='mean')

    Gauc = 0
    for i in range(Epoch):
        AUC = []
        ACC = []
        for index in range(len(Nets)):
            school = Nets[index][0]
            net = Nets[index][2]
            data_loader = Nets[index][1]
            val_loader = Nets[index][4]
            optimizer = optim.Adam(net.parameters(), lr=0.001)
            print('training model...' + str(school))
            best = 0
            best_epoch = 0
            best_knowauc = None
            best_indice = 0
            for epoch in range(epoch_n):
                metric, _, _, know_auc, know_acc = validate(
                    net, epoch, school, path, val_loader)
                auc = metric[1]
                rmse = metric[0]
                indice = metric[1]
                if auc > best:
                    best = auc
                    best_knowauc = know_auc
                    best_indice = indice
                    best_epoch = epoch
                    best_knowacc = know_acc
                    Nets[index][3] = copy.deepcopy(net.state_dict())
                if epoch - best_epoch >= 5:
                    break

                data_loader.reset()
                running_loss = 0.0
                batch_count = 0
                know_distribution = torch.zeros((data_loader.knowledge_n))
                while not data_loader.is_end():

                    batch_count += 1
                    input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch(
                    )

                    know_distribution += torch.sum(input_knowledge_embs, 0)
                    input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to(
                        device), input_exer_ids.to(
                            device), input_knowledge_embs.to(
                                device), labels.to(device)
                    optimizer.zero_grad()
                    output = net.forward(input_stu_ids, input_exer_ids,
                                         input_knowledge_embs)
                    loss = loss_function(output, labels)
                    loss.backward()
                    optimizer.step()

            net.load_state_dict(Nets[index][3])
            Nets[index][2] = net
            distribution = know_distribution * best_knowacc
            distribution[distribution == 0] = 0.001
            Nets[index].append(distribution.unsqueeze(1).to(device))
            print('Best AUC:', best)
            AUC.append([best_indice, best_knowacc])
            ACC.append(best_indice)

        l_school = [item[0] for item in Nets]
        l_weights = [len(item[1].data) for item in Nets]
        l_know = [item[5] for item in Nets]
        l_net = [item[3] for item in Nets]
        metric0 = []
        metric1 = []
        metric2 = []
        global_model2, student_group, question_group, _ = Fedknow(
            l_net, l_weights, l_know, AUC, method)
        print('global test ===========')
        for k in range(len(Nets)):
            metric2.append(
                validate(Nets[k][2], i, l_school[k], path, Nets[k][4]))
        globalauc = total(metric2)

        for k in range(len(Nets)):
            Apply(copy.deepcopy(global_model2), Nets[k][2], AUC[k],
                  student_group, question_group, method)