Exemplo n.º 1
0
def run_GCN(args,
            gpu_id=None,
            exp_name=None,
            number=0,
            return_model=False,
            return_time_series=False):
    random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    final_acc = 0
    best_acc = 0
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    running_device = "cpu" if gpu_id is None \
        else torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')
    dataset_kwargs = {}

    train_d, adj_list, x_list = get_dataset(args, dataset_kwargs)

    lable = train_d.data.y
    A_I = adj_list[0]
    A_I_nomal = adj_list[1]

    nb_edges = train_d.data.num_edges
    nb_nodes = train_d.data.num_nodes
    nb_feature = train_d.data.num_features
    nb_classes = int(lable.max() - lable.min()) + 1

    lable_matrix = (lable.view(nb_nodes, 1).repeat(1, nb_nodes) == lable.view(
        1, nb_nodes).repeat(nb_nodes, 1)) + 0
    I = (torch.eye(nb_nodes).to(lable_matrix.device) == 1)
    lable_matrix[I] = 0
    zero_vec = 0.0 * torch.ones_like(A_I_nomal)
    if args.dataset_name in [
            'Photo', 'DBLP', 'Crocodile', 'CoraFull', 'WikiCS'
    ]:
        useA = True
    else:
        useA = False
    model = UGRL_GCN_test(nb_nodes,
                          nb_feature,
                          args.dim,
                          dim_x=args.dim_x,
                          useact=args.usingact,
                          liner=args.UsingLiner,
                          dropout=args.dropout,
                          useA=useA)

    optimiser = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=0.0001)

    model.to(running_device)
    lable = lable.to(running_device)
    if args.dataset_name == 'WikiCS':
        train_lbls = lable[train_d.data.train_mask[:, args.NewATop]]  # capture
        test_lbls = lable[train_d.data.test_mask]
    elif args.dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
        train_lbls = lable[train_d.data.train_mask]
        test_lbls = lable[train_d.data.test_mask]
    elif args.dataset_name in ['Photo', 'DBLP', 'Crocodile', 'CoraFull']:
        train_index = []
        test_index = []
        for j in range(lable.max().item() + 1):
            num = ((lable == j) + 0).sum().item()
            index = torch.range(0, len(lable) - 1)[(lable == j)]
            x_list0 = random.sample(list(index), int(len(index) * 0.1))
            for x in x_list0:
                train_index.append(int(x))
        for c in range(len(lable)):
            if int(c) not in train_index:
                test_index.append(int(c))
        train_lbls = lable[train_index]
        test_lbls = lable[test_index]
        val_lbls = lable[train_index]

    A_I_nomal_dense = A_I_nomal
    I_input = torch.eye(A_I_nomal.shape[1])  # .to(A_I_nomal.device)
    if args.dataset_name in ['PubMed', 'CoraFull', 'DBLP']:
        pass
    elif args.dataset_name in ['Crocodile', 'Photo', 'WikiCS']:
        A_I_nomal_dense = A_I_nomal_dense.to(running_device)
    ######################sparse################
    if args.dataset_name in [
            'PubMed', 'Crocodile', 'CoraFull', 'DBLP', 'Photo', 'WikiCS'
    ]:
        A_I_nomal = A_I_nomal.to_sparse()
        model.sparse = True
        I_input = I_input.to_sparse()
    ######################sparse################
    A_I_nomal = A_I_nomal.to(running_device)
    I_input = I_input.to(A_I_nomal.device)
    mask_I = I.to(running_device)
    zero_vec = zero_vec.to(running_device)
    my_margin = args.margin1
    my_margin_2 = my_margin + args.margin2
    margin_loss = torch.nn.MarginRankingLoss(margin=my_margin, reduce=False)
    num_neg = args.NN
    for current_iter, epoch in enumerate(
            tqdm(range(args.start_epoch, args.start_epoch + args.epochs + 1))):
        model.train()
        optimiser.zero_grad()
        idx = np.random.permutation(nb_nodes)
        feature_X = x_list[0].to(running_device)
        lbl_z = torch.tensor([0.]).to(running_device)
        feature_a = feature_X
        feature_p = feature_X
        feature_n = []
        idx_list = []
        idx_lable = []
        for i in range(num_neg):
            idx_0 = np.random.permutation(nb_nodes)
            idx_list.append(idx_0)
            idx_lable.append(lable[idx_0])
            feature_temp = feature_X[idx_0]
            feature_n.append(feature_temp)
        h_a, h_p, h_n_lsit, h_a_0, h_p_0, h_n_0_list = model(feature_a,
                                                             feature_p,
                                                             feature_n,
                                                             A_I_nomal,
                                                             I=I_input)
        s_p = F.pairwise_distance(h_a, h_p)
        cos_0_list = []
        for h_n_0 in h_n_0_list:
            cos_0 = F.pairwise_distance(h_a_0, h_n_0)
            cos_0_list.append(cos_0)
        cos_0_stack = torch.stack(cos_0_list).detach()
        cos_0_min = cos_0_stack.min(dim=0)[0]
        cos_0_max = cos_0_stack.max(dim=0)[0]
        gap = cos_0_max - cos_0_min
        weight_list = []
        for i in range(cos_0_stack.size()[0]):
            weight_list.append((cos_0_stack[i] - cos_0_min) / gap)
        s_n_list = []
        s_n_cosin_list = []
        for h_n in h_n_lsit:
            if args.dataset_name in ['Cora', 'CiteSeer']:
                s_n_cosin_list.append(cosine_dist(h_a, h_n)[mask_I].detach())
            s_n = F.pairwise_distance(h_a, h_n)
            s_n_list.append(s_n)
        margin_label = -1 * torch.ones_like(s_p)
        loss_mar = 0
        mask_margin_N = 0
        i = 0
        for s_n in s_n_list:
            loss_mar += (margin_loss(s_p, s_n, margin_label) *
                         weight_list[i]).mean()
            mask_margin_N += torch.max((s_n - s_p.detach() - my_margin_2),
                                       lbl_z).sum()
            i += 1
        mask_margin_N = mask_margin_N / num_neg
        string_1 = " loss_1: {:.3f}||loss_2: {:.3f}||".format(
            loss_mar.item(), mask_margin_N.item())
        loss = loss_mar * args.w_loss1 + mask_margin_N * args.w_loss2 / nb_nodes
        if args.dataset_name in ['Cora']:
            loss = loss_mar * args.w_loss1 + mask_margin_N * args.w_loss2
        loss.backward()
        optimiser.step()

        model.eval()
        if args.dataset_name in ['Crocodile', 'WikiCS', 'Photo']:
            h_p_d = h_p.detach()
            S_new = cosine_dist(h_p_d, h_p_d)
            model.A = normalize_graph(torch.mul(S_new,
                                                A_I_nomal_dense)).to_sparse()
        elif args.dataset_name in ['Cora', 'CiteSeer']:
            h_a, h_p = model.embed(feature_a,
                                   feature_p,
                                   feature_n,
                                   A_I_nomal,
                                   I=I_input)
            s_a = cosine_dist(h_a, h_a).detach()
            S = (torch.stack(s_n_cosin_list).mean(dim=0).expand_as(A_I) -
                 s_a).detach()
            # zero_vec = -9e15 * torch.ones_like(S)
            one_vec = torch.ones_like(S)
            s_a = torch.where(A_I_nomal > 0, one_vec, zero_vec)
            attention = torch.where(S < 0, s_a, zero_vec)
            attention_N = normalize_graph(attention)
            attention[I] = 0
            model.A = attention_N

        if epoch % 50 == 0:
            model.eval()
            h_a, h_p = model.embed(feature_a,
                                   feature_p,
                                   feature_n,
                                   A_I_nomal,
                                   I=I_input)
            if args.useNewA:
                embs = h_p  #torch.cat((h_a,h_p),dim=1)
            else:
                embs = h_a
            if args.dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
                embs = embs / embs.norm(dim=1)[:, None]

            if args.dataset_name == 'WikiCS':
                train_embs = embs[train_d.data.train_mask[:, args.NewATop]]
                test_embs = embs[train_d.data.test_mask]
            elif args.dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
                train_embs = embs[train_d.data.train_mask]
                test_embs = embs[train_d.data.test_mask]


#               #val_embs = embs[train_d.data.val_mask]
            elif args.dataset_name in [
                    'Photo', 'DBLP', 'Crocodile', 'CoraFull'
            ]:
                train_embs = embs[train_index]
                test_embs = embs[test_index]
                #val_embs = embs[train_index]

            accs = []
            accs_small = []
            xent = nn.CrossEntropyLoss()
            for _ in range(2):
                log = LogReg(args.dim, nb_classes)
                opt = torch.optim.Adam(log.parameters(),
                                       lr=1e-2,
                                       weight_decay=args.wd)
                log.to(running_device)
                for _ in range(args.num1):
                    log.train()
                    opt.zero_grad()
                    logits = log(train_embs)
                    loss = xent(logits, train_lbls)
                    loss.backward()
                    opt.step()
                logits = log(test_embs)
                preds = torch.argmax(logits, dim=1)
                acc = torch.sum(
                    preds == test_lbls).float() / test_lbls.shape[0]
                accs.append(acc * 100)
                ac = []
                for i in range(nb_classes):
                    acc_small = torch.sum(
                        preds[test_lbls == i] == test_lbls[test_lbls == i]
                    ).float() / test_lbls[test_lbls == i].shape[0]
                    ac.append(acc_small * 100)
                accs_small = ac
            accs = torch.stack(accs)
            string_3 = ""
            for i in range(nb_classes):
                string_3 = string_3 + "|{:.1f}".format(accs_small[i].item())
            string_2 = Fore.GREEN + " epoch: {},accs: {:.1f},std: {:.2f} ".format(
                epoch,
                accs.mean().item(),
                accs.std().item())
            tqdm.write(string_1 + string_2 + string_3)
            final_acc = accs.mean().item()
            best_acc = max(best_acc, final_acc)
    return final_acc, best_acc
Exemplo n.º 2
0
def evaluate(embeds, idx_train, idx_val, idx_test, labels, device, isTest=True):
    hid_units = embeds.shape[2]
    nb_classes = labels.shape[2]
    xent = nn.CrossEntropyLoss()
    train_embs = embeds[0, idx_train]
    val_embs = embeds[0, idx_val]
    test_embs = embeds[0, idx_test]

    train_lbls = torch.argmax(labels[0, idx_train], dim=1)
    val_lbls = torch.argmax(labels[0, idx_val], dim=1)
    test_lbls = torch.argmax(labels[0, idx_test], dim=1)

    accs = []
    micro_f1s = []
    macro_f1s = []
    macro_f1s_val = [] ##
    for _ in range(50):
        log = LogReg(hid_units, nb_classes)
        opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
        log.to(device)

        val_accs = []; test_accs = []
        val_micro_f1s = []; test_micro_f1s = []
        val_macro_f1s = []; test_macro_f1s = []
        for iter_ in range(50):
            # train
            log.train()
            opt.zero_grad()

            logits = log(train_embs)
            loss = xent(logits, train_lbls)

            loss.backward()
            opt.step()

            # val
            logits = log(val_embs)
            preds = torch.argmax(logits, dim=1)

            val_acc = torch.sum(preds == val_lbls).float() / val_lbls.shape[0]
            val_f1_macro = f1_score(val_lbls.cpu(), preds.cpu(), average='macro')
            val_f1_micro = f1_score(val_lbls.cpu(), preds.cpu(), average='micro')

            val_accs.append(val_acc.item())
            val_macro_f1s.append(val_f1_macro)
            val_micro_f1s.append(val_f1_micro)

            # test
            logits = log(test_embs)
            preds = torch.argmax(logits, dim=1)

            test_acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
            test_f1_macro = f1_score(test_lbls.cpu(), preds.cpu(), average='macro')
            test_f1_micro = f1_score(test_lbls.cpu(), preds.cpu(), average='micro')

            test_accs.append(test_acc.item())
            test_macro_f1s.append(test_f1_macro)
            test_micro_f1s.append(test_f1_micro)


        max_iter = val_accs.index(max(val_accs))
        accs.append(test_accs[max_iter])

        max_iter = val_macro_f1s.index(max(val_macro_f1s))
        macro_f1s.append(test_macro_f1s[max_iter])
        macro_f1s_val.append(val_macro_f1s[max_iter]) ###

        max_iter = val_micro_f1s.index(max(val_micro_f1s))
        micro_f1s.append(test_micro_f1s[max_iter])

    if isTest:
        print("\t[Classification] Macro-F1: {:.4f} ({:.4f}) | Micro-F1: {:.4f} ({:.4f})".format(np.mean(macro_f1s),
                                                                                                np.std(macro_f1s),
                                                                                                np.mean(micro_f1s),
                                                                                                np.std(micro_f1s)))
    else:
        return np.mean(macro_f1s_val), np.mean(macro_f1s)

    test_embs = np.array(test_embs.cpu())
    test_lbls = np.array(test_lbls.cpu())

    run_kmeans(test_embs, test_lbls, nb_classes)
    run_similarity_search(test_embs, test_lbls)