Beispiel #1
0
def classify(treeDic, x_test  , x_train,TDdroprate,BUdroprate,lr, weight_decay,patience,n_epochs,batchsize,dataname,iter, fold_count):

    unsup_model = Net(64, 3).to(device)

    for unsup_epoch in range(25):

        optimizer = th.optim.Adam(unsup_model.parameters(), lr=lr, weight_decay=weight_decay)
        unsup_model.train()
        traindata_list, _ = loadBiData(dataname, treeDic, x_train+x_test, x_test, 0.2, 0.2)
        train_loader = DataLoader(traindata_list, batch_size=batchsize, shuffle=True, num_workers=4)
        batch_idx = 0
        loss_all = 0
        tqdm_train_loader = tqdm(train_loader)
        for Batch_data in tqdm_train_loader:
            optimizer.zero_grad()
            Batch_data = Batch_data.to(device)
            loss = unsup_model(Batch_data)
            loss_all += loss.item() * (max(Batch_data.batch) + 1)

            loss.backward()
            optimizer.step()
            batch_idx = batch_idx + 1
        loss = loss_all / len(train_loader)
    name = "best_pre_"+dataname +"_4unsup" + ".pkl"
    th.save(unsup_model.state_dict(), name)
    print('Finished the unsuperivised training.', '  Loss:', loss)
    print("Start classify!!!")
    # unsup_model.eval()

    log_train = 'logs/' + datasetname + '/' + 'train' + 'iter_' + str(iter)
    writer_train = SummaryWriter(log_train)
    log_test = 'logs/' + datasetname + '/' + 'test' + 'iter_' + str(iter)
    writer_test = SummaryWriter(log_test)

    model = Classfier(64*3,64,4).to(device)
    opt = th.optim.Adam(model.parameters(), lr=0.0005, weight_decay=weight_decay)

    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    early_stopping = EarlyStopping(patience=10, verbose=True)
    for epoch in range(n_epochs):
        traindata_list, testdata_list = loadBiData(dataname, treeDic, x_train, x_test, TDdroprate,BUdroprate)
        train_loader = DataLoader(traindata_list, batch_size=batchsize, shuffle=True, num_workers=4)
        test_loader = DataLoader(testdata_list, batch_size=batchsize, shuffle=True, num_workers=4)
        avg_loss = []
        avg_acc = []
        batch_idx = 0
        tqdm_train_loader = tqdm(train_loader)
        model.train()
        unsup_model.train()
        for Batch_data in tqdm_train_loader:
            Batch_data.to(device)
            _, Batch_embed = unsup_model.encoder(Batch_data.x, Batch_data.edge_index, Batch_data.batch)
            out_labels= model(Batch_embed, Batch_data)
            finalloss=F.nll_loss(out_labels,Batch_data.y)
            loss=finalloss
            opt.zero_grad()
            loss.backward()
            avg_loss.append(loss.item())
            opt.step()
            _, pred = out_labels.max(dim=-1)
            correct = pred.eq(Batch_data.y).sum().item()
            train_acc = correct / len(Batch_data.y)
            avg_acc.append(train_acc)
            print("Iter {:03d} | Epoch {:05d} | Batch{:02d} | Train_Loss {:.4f}| Train_Accuracy {:.4f}".format(iter,epoch, batch_idx,
                                                                                                 loss.item(),
                                                                                                 train_acc))
            batch_idx = batch_idx + 1
            
        writer_train.add_scalar('train_loss', np.mean(avg_loss), global_step=epoch+1)
        writer_train.add_scalar('train_acc', np.mean(avg_acc), global_step=epoch+1)
        train_losses.append(np.mean(avg_loss))
        train_accs.append(np.mean(avg_acc))

        temp_val_losses = []
        temp_val_accs = []
        temp_val_Acc_all, temp_val_Acc1, temp_val_Prec1, temp_val_Recll1, temp_val_F1, \
        temp_val_Acc2, temp_val_Prec2, temp_val_Recll2, temp_val_F2, \
        temp_val_Acc3, temp_val_Prec3, temp_val_Recll3, temp_val_F3, \
        temp_val_Acc4, temp_val_Prec4, temp_val_Recll4, temp_val_F4 = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
        model.eval()
        unsup_model.eval()
        tqdm_test_loader = tqdm(test_loader)
        for Batch_data in tqdm_test_loader:
            Batch_data.to(device)
            Batch_embed = unsup_model.encoder.get_embeddings(Batch_data)
            val_out = model(Batch_embed, Batch_data)
            val_loss  = F.nll_loss(val_out, Batch_data.y)
            temp_val_losses.append(val_loss.item())
            _, val_pred = val_out.max(dim=1)
            correct = val_pred.eq(Batch_data.y).sum().item()
            val_acc = correct / len(Batch_data.y)
            Acc_all, Acc1, Prec1, Recll1, F1, Acc2, Prec2, Recll2, F2, Acc3, Prec3, Recll3, F3, Acc4, Prec4, Recll4, F4 = evaluation4class(
                val_pred, Batch_data.y)
            temp_val_Acc_all.append(Acc_all), temp_val_Acc1.append(Acc1), temp_val_Prec1.append(
                Prec1), temp_val_Recll1.append(Recll1), temp_val_F1.append(F1), \
            temp_val_Acc2.append(Acc2), temp_val_Prec2.append(Prec2), temp_val_Recll2.append(
                Recll2), temp_val_F2.append(F2), \
            temp_val_Acc3.append(Acc3), temp_val_Prec3.append(Prec3), temp_val_Recll3.append(
                Recll3), temp_val_F3.append(F3), \
            temp_val_Acc4.append(Acc4), temp_val_Prec4.append(Prec4), temp_val_Recll4.append(
                Recll4), temp_val_F4.append(F4)
            temp_val_accs.append(val_acc)
        writer_test.add_scalar('val_loss', np.mean(temp_val_losses), global_step=epoch+1)
        writer_test.add_scalar('val_accs', np.mean(temp_val_accs), global_step=epoch+1)
        val_losses.append(np.mean(temp_val_losses))
        val_accs.append(np.mean(temp_val_accs))
        print("Epoch {:05d} | Val_Loss {:.4f}| Val_Accuracy {:.4f}".format(epoch, np.mean(temp_val_losses),
                                                                           np.mean(temp_val_accs)))

        res = ['acc:{:.4f}'.format(np.mean(temp_val_Acc_all)),
               'C1:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc1), np.mean(temp_val_Prec1),
                                                       np.mean(temp_val_Recll1), np.mean(temp_val_F1)),
               'C2:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc2), np.mean(temp_val_Prec2),
                                                       np.mean(temp_val_Recll2), np.mean(temp_val_F2)),
               'C3:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc3), np.mean(temp_val_Prec3),
                                                       np.mean(temp_val_Recll3), np.mean(temp_val_F3)),
               'C4:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc4), np.mean(temp_val_Prec4),
                                                       np.mean(temp_val_Recll4), np.mean(temp_val_F4))]
        print('unsup_epoch:', (unsup_epoch+1) ,'   results:', res)
        early_stopping(np.mean(temp_val_losses), np.mean(temp_val_accs), np.mean(temp_val_F1), np.mean(temp_val_F2),
                       np.mean(temp_val_F3), np.mean(temp_val_F4), model, 'RDEA_'+str(fold_count)+'_', dataname)
        accs =np.mean(temp_val_accs)
        F1 = np.mean(temp_val_F1)
        F2 = np.mean(temp_val_F2)
        F3 = np.mean(temp_val_F3)
        F4 = np.mean(temp_val_F4)
        if epoch>=199:
            accs = early_stopping.accs
            F1 = early_stopping.F1
            F2 = early_stopping.F2
            F3 = early_stopping.F3
            F4 = early_stopping.F4
        if early_stopping.early_stop:
            print("Early stopping")
            accs=early_stopping.accs
            F1=early_stopping.F1
            F2 = early_stopping.F2
            F3 = early_stopping.F3
            F4 = early_stopping.F4
            break
    return train_losses , val_losses ,train_accs, val_accs,accs,F1,F2,F3,F4
def train_GCN(treeDic, x_test, x_train, TDdroprate, BUdroprate, lr, weight_decay, patience, n_epochs, batchsize, dataname, iter):
    model = Network(5000, 64, 64).to(device)
    # BU_params = list(map(id, model.BUrumorGCN.conv1.parameters()))
    # BU_params += list(map(id, model.BUrumorGCN.conv2.parameters()))
    # print(filter(lambda p: id(p) not in BU_params, model.parameters()))
    # BU_params += list(map(id, model.BUrumorGCN.conv2.parameters()))
    # base_params = filter(lambda p: id(p) not in BU_params, model.parameters())
    # optimizer = th.optim.Adam([
    #     {'params': base_params},
    #     {'params': model.BUrumorGCN.conv1.parameters(), 'lr': lr/5},
    #     {'params': model.BUrumorGCN.conv2.parameters(), 'lr': lr/5}
    # ], lr=lr, weight_decay=weight_decay)

    optimizer = th.optim.Adam([
        {'params': model.parameters()},
    ], lr=lr, weight_decay=weight_decay)
    model.train()
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    for epoch in range(n_epochs):
        # traindata_list, testdata_list = loadBiData(dataname, treeDic, x_train, x_test, TDdroprate,BUdroprate)
        traindata_list, testdata_list = loadSnapshotData(dataname, treeDic, x_train, x_test, TDdroprate, BUdroprate)
        train_loader = DataLoader(traindata_list, batch_size=batchsize, shuffle=True, num_workers=5)
        test_loader = DataLoader(testdata_list, batch_size=batchsize, shuffle=True, num_workers=5)
        avg_loss = []
        avg_acc = []
        batch_idx = 0
        # tqdm_train_loader = tqdm(train_loader)  # JIHO
        tqdm_train_loader = train_loader
        for Batch_data in tqdm_train_loader:
            # Batch_data.to(device)
            # out_labels= model(Batch_data)

            s0 = Batch_data[0].to(device)
            s1 = Batch_data[1].to(device)
            s2 = Batch_data[2].to(device)
            s3 = Batch_data[3].to(device)
            s4 = Batch_data[4].to(device)

            out_labels = model(s0, s1, s2, s3, s4)

            # finalloss = F.nll_loss(out_labels, Batch_data.y)
            finalloss = F.nll_loss(out_labels, Batch_data[0].y)
            loss = finalloss
            optimizer.zero_grad()
            loss.backward()
            avg_loss.append(loss.item())
            optimizer.step()
            _, pred = out_labels.max(dim=-1)
            # correct = pred.eq(Batch_data.y).sum().item()
            # train_acc = correct / len(Batch_data.y)
            correct = pred.eq(Batch_data[0].y).sum().item()
            train_acc = correct / len(Batch_data[0].y)
            avg_acc.append(train_acc)
            print("Iter {:03d} | Epoch {:05d} | Batch{:02d} | Train_Loss {:.4f} | Train_Accuracy {:.4f}".format(
                iter, epoch, batch_idx, loss.item(), train_acc))
            batch_idx = batch_idx + 1

        train_losses.append(np.mean(avg_loss))
        train_accs.append(np.mean(avg_acc))

        temp_val_losses = []
        temp_val_accs = []
        temp_val_Acc_all, temp_val_Acc1, temp_val_Prec1, temp_val_Recll1, temp_val_F1, \
            temp_val_Acc2, temp_val_Prec2, temp_val_Recll2, temp_val_F2, \
            temp_val_Acc3, temp_val_Prec3, temp_val_Recll3, temp_val_F3, \
            temp_val_Acc4, temp_val_Prec4, temp_val_Recll4, temp_val_F4 = [], [
            ], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
        model.eval()
        # tqdm_test_loader = tqdm(test_loader)  # JIHO
        tqdm_test_loader = test_loader
        for Batch_data in tqdm_test_loader:
            # Batch_data.to(device)
            # val_out = model(Batch_data)

            s0 = Batch_data[0].to(device)
            s1 = Batch_data[1].to(device)
            s2 = Batch_data[2].to(device)
            s3 = Batch_data[3].to(device)
            s4 = Batch_data[4].to(device)

            val_out = model(s0, s1, s2, s3, s4)

            val_loss = F.nll_loss(val_out, Batch_data[0].y)
            temp_val_losses.append(val_loss.item())
            _, val_pred = val_out.max(dim=1)
            correct = val_pred.eq(Batch_data[0].y).sum().item()
            val_acc = correct / len(Batch_data[0].y)
            Acc_all, Acc1, Prec1, Recll1, F1, \
                Acc2, Prec2, Recll2, F2, \
                Acc3, Prec3, Recll3, F3, \
                Acc4, Prec4, Recll4, F4 = evaluation4class(val_pred, Batch_data[0].y)

            temp_val_Acc_all.append(Acc_all)
            temp_val_Acc1.append(Acc1)
            temp_val_Prec1.append(Prec1)
            temp_val_Recll1.append(Recll1)
            temp_val_F1.append(F1)
            temp_val_Acc2.append(Acc2)
            temp_val_Prec2.append(Prec2)
            temp_val_Recll2.append(Recll2)
            temp_val_F2.append(F2)
            temp_val_Acc3.append(Acc3)
            temp_val_Prec3.append(Prec3)
            temp_val_Recll3.append(Recll3)
            temp_val_F3.append(F3)
            temp_val_Acc4.append(Acc4)
            temp_val_Prec4.append(Prec4)
            temp_val_Recll4.append(Recll4)
            temp_val_F4.append(F4)
            temp_val_accs.append(val_acc)
        val_losses.append(np.mean(temp_val_losses))
        val_accs.append(np.mean(temp_val_accs))
        print("Epoch {:05d} | Val_Loss {:.4f}| Val_Accuracy {:.4f}".format(
            epoch, np.mean(temp_val_losses), np.mean(temp_val_accs))
        )

        res = [
            'acc:{:.4f}'.format(np.mean(temp_val_Acc_all)),
            'C1:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc1), np.mean(temp_val_Prec1), np.mean(temp_val_Recll1), np.mean(temp_val_F1)),
            'C2:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc2), np.mean(temp_val_Prec2), np.mean(temp_val_Recll2), np.mean(temp_val_F2)),
            'C3:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc3), np.mean(temp_val_Prec3), np.mean(temp_val_Recll3), np.mean(temp_val_F3)),
            'C4:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc4), np.mean(temp_val_Prec4), np.mean(temp_val_Recll4), np.mean(temp_val_F4))
        ]

        print('results:', res)
        early_stopping(
            np.mean(temp_val_losses), np.mean(temp_val_accs), np.mean(temp_val_F1), np.mean(temp_val_F2),
            np.mean(temp_val_F3), np.mean(temp_val_F4), model, 'BiGCN', dataname
        )
        accs = np.mean(temp_val_accs)
        F1 = np.mean(temp_val_F1)
        F2 = np.mean(temp_val_F2)
        F3 = np.mean(temp_val_F3)
        F4 = np.mean(temp_val_F4)

        if early_stopping.early_stop:
            print("Early stopping")
            accs = early_stopping.accs
            F1 = early_stopping.F1
            F2 = early_stopping.F2
            F3 = early_stopping.F3
            F4 = early_stopping.F4
            break

    return train_losses, val_losses, train_accs, val_accs, accs, F1, F2, F3, F4
def train_model(treeDic, x_test, x_train, lr, weight_decay, patience, n_epochs,
                batchsize, dataname, modelname, iter):
    model = Net(5000, 64, 64).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    model.train()
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    score = float('inf')
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    for epoch in range(n_epochs):
        traindata_list, testdata_list = loadData(dataname, treeDic, x_train,
                                                 x_test, TDdroprate)
        train_loader = DataLoader(traindata_list,
                                  batch_size=batchsize,
                                  shuffle=False,
                                  num_workers=5)
        test_loader = DataLoader(testdata_list,
                                 batch_size=batchsize,
                                 shuffle=False,
                                 num_workers=5)
        avg_loss = []
        avg_acc = []
        batch_idx = 0
        tqdm_train_loader = tqdm(train_loader)
        for Batch_data in tqdm_train_loader:
            Batch_data.to(device)
            logits = model(Batch_data)
            if modelname in ['GAE', 'VGAE']:
                loss = F.nll_loss(logits, Batch_data.y) + model.loss()
            elif modelname == 'GCN':
                loss = F.nll_loss(logits, Batch_data.y)
            optimizer.zero_grad()
            loss.backward()
            avg_loss.append(loss.item())
            optimizer.step()
            _, pred = logits.max(dim=-1)
            correct = pred.eq(Batch_data.y).sum().item()
            train_acc = correct / len(Batch_data.y)
            avg_acc.append(train_acc)
            print(
                "Iter {:03d} | Epoch {:05d} | Batch{:02d} | Train_Loss {:.4f}| Train_Accuracy {:.4f}"
                .format(iter, epoch, batch_idx, loss.item(), train_acc))
            batch_idx = batch_idx + 1

        train_losses.append(np.mean(avg_loss))
        train_accs.append(np.mean(avg_acc))

        temp_val_losses = []
        temp_val_accs = []
        temp_val_Acc_all, temp_val_Acc1, temp_val_Prec1, temp_val_Recll1, temp_val_F1, \
        temp_val_Acc2, temp_val_Prec2, temp_val_Recll2, temp_val_F2, \
        temp_val_Acc3, temp_val_Prec3, temp_val_Recll3, temp_val_F3, \
        temp_val_Acc4, temp_val_Prec4, temp_val_Recll4, temp_val_F4 = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
        # eval
        model.eval()
        tqdm_test_loader = tqdm(test_loader)
        for Batch_data in tqdm_test_loader:
            Batch_data.to(device)
            val_out = model(Batch_data)
            if modelname in ['GAE', 'VGAE']:
                val_loss = F.nll_loss(val_out, Batch_data.y) + model.loss()
            elif modelname == 'GCN':
                val_loss = F.nll_loss(val_out, Batch_data.y)
            temp_val_losses.append(val_loss.item())
            _, val_pred = val_out.max(dim=1)
            correct = val_pred.eq(Batch_data.y).sum().item()
            val_acc = correct / len(Batch_data.y)
            Acc_all, Acc1, Prec1, Recll1, F1, Acc2, Prec2, Recll2, F2, Acc3, Prec3, Recll3, F3, Acc4, Prec4, Recll4, F4 = evaluation4class(
                val_pred, Batch_data.y)
            temp_val_Acc_all.append(Acc_all), temp_val_Acc1.append(Acc1), temp_val_Prec1.append(
                Prec1), temp_val_Recll1.append(Recll1), temp_val_F1.append(F1), \
            temp_val_Acc2.append(Acc2), temp_val_Prec2.append(Prec2), temp_val_Recll2.append(
                Recll2), temp_val_F2.append(F2), \
            temp_val_Acc3.append(Acc3), temp_val_Prec3.append(Prec3), temp_val_Recll3.append(
                Recll3), temp_val_F3.append(F3), \
            temp_val_Acc4.append(Acc4), temp_val_Prec4.append(Prec4), temp_val_Recll4.append(
                Recll4), temp_val_F4.append(F4)
            temp_val_accs.append(val_acc)
        val_losses.append(np.mean(temp_val_losses))
        val_accs.append(np.mean(temp_val_accs))
        print("Epoch {:05d} | Val_Loss {:.4f}| Val_Accuracy {:.4f}".format(
            epoch, np.mean(temp_val_losses), np.mean(temp_val_accs)))

        res = [
            'acc:{:.4f}'.format(np.mean(temp_val_Acc_all)),
            'C1:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc1),
                                                    np.mean(temp_val_Prec1),
                                                    np.mean(temp_val_Recll1),
                                                    np.mean(temp_val_F1)),
            'C2:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc2),
                                                    np.mean(temp_val_Prec2),
                                                    np.mean(temp_val_Recll2),
                                                    np.mean(temp_val_F2)),
            'C3:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc3),
                                                    np.mean(temp_val_Prec3),
                                                    np.mean(temp_val_Recll3),
                                                    np.mean(temp_val_F3)),
            'C4:{:.4f},{:.4f},{:.4f},{:.4f}'.format(np.mean(temp_val_Acc4),
                                                    np.mean(temp_val_Prec4),
                                                    np.mean(temp_val_Recll4),
                                                    np.mean(temp_val_F4))
        ]
        print('results:', res)
        early_stopping(np.mean(temp_val_losses), np.mean(temp_val_accs),
                       np.mean(temp_val_F1), np.mean(temp_val_F2),
                       np.mean(temp_val_F3), np.mean(temp_val_F4), model,
                       modelname, dataname)
        if np.mean(temp_val_losses) < score:
            score = np.mean(temp_val_losses)
            accs = np.mean(temp_val_accs)
            F1 = np.mean(temp_val_F1)
            F2 = np.mean(temp_val_F2)
            F3 = np.mean(temp_val_F3)
            F4 = np.mean(temp_val_F4)

        if early_stopping.early_stop:
            print("Early stopping")
            accs = early_stopping.accs
            F1 = early_stopping.F1
            F2 = early_stopping.F2
            F3 = early_stopping.F3
            F4 = early_stopping.F4
            break
    print('*****************************')
    print('Acc {:.4f} | N {:.4f} | F {:.4f} | T {:.4f} | U{:.4f} '.format(
        accs, F1, F2, F3, F4))
    return train_losses, val_losses, train_accs, val_accs, accs, F1, F2, F3, F4