コード例 #1
0
ファイル: train.py プロジェクト: Aditya239233/GNNExplainer
def syn_task1(args, size, writer=None):
    # data
    G, labels, name = gengraph.gen_syn1(
        width_basis=size,
        feature_generator=featgen.ConstFeatureGen(
            np.ones(args.input_dim, dtype=float)))
    num_classes = max(labels) + 1

    if args.method == "att":
        print("Method: att")
        model = models.GcnEncoderNode(
            args.input_dim,
            args.hidden_dim,
            args.output_dim,
            num_classes,
            args.num_gc_layers,
            bn=args.bn,
            args=args,
        )
    else:
        print("Method:", args.method)
        model = models.GcnEncoderNode(
            args.input_dim,
            args.hidden_dim,
            args.output_dim,
            num_classes,
            args.num_gc_layers,
            bn=args.bn,
            args=args,
        )
    if args.gpu:
        model = model.cuda()

    train_node_classifier(G, labels, model, args, writer=writer)
コード例 #2
0
def fan(args, feat=None):
    with open(os.path.join(args.datadir, args.pkl_fname), "rb") as pkl_file:
        data = pickle.load(pkl_file)
    graphs = data[0]
    labels = data[1]
    test_graphs = data[2]
    test_labels = data[3]

    for i in range(len(graphs)):
        graphs[i].graph["label"] = labels[i]
    for i in range(len(test_graphs)):
        test_graphs[i].graph["label"] = test_labels[i]

    if feat is None:
        featgen_const = featgen.ConstFeatureGen(
            np.ones(args.input_dim, dtype=float))
        for G in graphs:
            featgen_const.gen_node_features(G)
        for G in test_graphs:
            featgen_const.gen_node_features(G)

    train_dataset, test_dataset, max_num_nodes = prepare_data(
        graphs, args, test_graphs=test_graphs)
    model = models.GcnEncoderGraph(
        args.input_dim,
        args.hidden_dim,
        args.output_dim,
        args.num_classes,
        args.num_gc_layers,
        bn=args.bn,
    ).cuda()
    train(train_dataset, model, args, test_dataset=test_dataset)
    evaluate(test_dataset, model, args, "Validation")
コード例 #3
0
ファイル: gengraph.py プロジェクト: Aditya239233/GNNExplainer
def gen_syn3(nb_shapes=80, width_basis=300, feature_generator=None, m=5):
    """ Synthetic Graph #3:

    Start with Barabasi-Albert graph and attach grid-shaped subgraphs.

    Args:
        nb_shapes         :  The number of shapes (here 'grid') that should be added to the base graph.
        width_basis       :  The width of the basis graph (here 'Barabasi-Albert' random graph).
        feature_generator :  A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
        m                 :  number of edges to attach to existing node (for BA graph)

    Returns:
        G                 :  A networkx graph
        role_id           :  Role ID for each node in synthetic graph.
        name              :  A graph identifier
    """
    basis_type = "ba"
    list_shapes = [["grid", 3]] * nb_shapes

    plt.figure(figsize=(8, 6), dpi=300)

    G, role_id, _ = synthetic_structsim.build_graph(width_basis,
                                                    basis_type,
                                                    list_shapes,
                                                    start=0,
                                                    m=5)
    G = perturb([G], 0.01)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)
    return G, role_id, name
コード例 #4
0
def gen_syn5(nb_shapes=80, width_basis=8, feature_generator=None, m=3):
    """ Synthetic Graph #5:
    Start with a tree and attach grid-shaped subgraphs.
    Args:
        nb_shapes         :  The number of shapes (here 'houses') that should be added to the base graph.
        width_basis       :  The width of the basis graph (here a random 'grid').
        feature_generator :  A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
        m                 :  The tree depth.
    Returns:
        G                 :  A networkx graph
        role_id           :  Role ID for each node in synthetic graph
        name              :  A graph identifier
    """
    basis_type = "tree"
    list_shapes = [["grid", m]] * nb_shapes

    plt.figure(figsize=(8, 6), dpi=300)

    G, role_id, _ = synthetic_structsim.build_graph(width_basis,
                                                    basis_type,
                                                    list_shapes,
                                                    start=0)
    G = perturb([G], 0.1)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)

    path = os.path.join("log/syn5_base_h20_o20")
    writer = SummaryWriter(path)

    return G, role_id, name
コード例 #5
0
def syn_task5(args, writer=None):
    # data
    G, labels, name = gengraph.gen_syn5(
        feature_generator=featgen.ConstFeatureGen(
            np.ones(args.input_dim, dtype=float)))
    print(labels)
    print("Number of nodes: ", G.number_of_nodes())
    num_classes = max(labels) + 1

    if args.method == "attn":
        print("Method: attn")
    else:
        print("Method: base")
        model = models.GcnEncoderNode(
            args.input_dim,
            args.hidden_dim,
            args.output_dim,
            num_classes,
            args.num_gc_layers,
            bn=args.bn,
            args=args,
        )

        if args.gpu:
            model = model.cuda()

    train_node_classifier(G, labels, model, args, writer=writer)
コード例 #6
0
ファイル: gengraph.py プロジェクト: lizzij/LCProj
def gen_sat1(nb_shapes=80, width_basis=300, feature_generator=None, m=5):
    """ Sat Graph #1:

    Start with Barabasi-Albert graph and attach house-shaped subgraphs.

    Args:
        nb_shapes         :  The number of shapes (here 'houses') that should be added to the base graph.
        width_basis       :  The width of the basis graph (here 'Barabasi-Albert' random graph).
        feature_generator :  A `FeatureGenerator` for node features. If `None`, add constant features to nodes.
        m                 :  number of edges to attach to existing node (for BA graph)

    Returns:
        G                 :  A networkx graph
        role_id           :  A list with length equal to number of nodes in the entire graph (basis
                          :  + shapes). role_id[i] is the ID of the role of node i. It is the label.
        name              :  A graph identifier
    """
    basis_type = "ba"
    list_shapes = [["house"]] * nb_shapes

    plt.figure(figsize=(8, 6), dpi=300)

    G, role_id, _ = synthetic_structsim.build_graph(width_basis,
                                                    basis_type,
                                                    list_shapes,
                                                    start=0,
                                                    m=5)
    G = perturb([G], 0.01)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + "_" + str(width_basis) + "_" + str(nb_shapes)
    return G, role_id, name
コード例 #7
0
def synthetic_data(dataset, dirname, train_ratio=0.8, input_dim=10):
    """
    Create synthetic data, similarly to what was done in GNNExplainer
    Pipeline was adapted so as to fit ours. 
    """
    # Define path where dataset should be saved
    data_path = "data/{}.pth".format(dataset)

    # If already created, do not recreate
    if os.path.exists(data_path):
        data = torch.load(data_path)

    else:
        # Construct graph
        if dataset == 'syn1':
            G, labels, name = gengraph.gen_syn1(
                feature_generator=featgen.ConstFeatureGen(np.ones(input_dim)))
        elif dataset == 'syn4':
            G, labels, name = gengraph.gen_syn4(
                feature_generator=featgen.ConstFeatureGen(
                    np.ones(input_dim, dtype=float)))
        elif dataset == 'syn5':
            G, labels, name = gengraph.gen_syn5(
                feature_generator=featgen.ConstFeatureGen(
                    np.ones(input_dim, dtype=float)))
        elif dataset == 'syn2':
            G, labels, name = gengraph.gen_syn2()
            input_dim = len(G.nodes[0]["feat"])

        # Create dataset
        data = SimpleNamespace()
        data.x, data.edge_index, data.y = gengraph.preprocess_input_graph(
            G, labels)
        data.x = data.x.type(torch.FloatTensor)
        data.num_classes = max(labels) + 1
        data.num_features = input_dim
        data.num_nodes = G.number_of_nodes()
        data.name = dataset

        # Train/test split only for nodes
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy(), train_ratio)

        # Save data
        torch.save(data, data_path)

    return data
コード例 #8
0
def benchmark_task_val(args, writer=None, feat="node-label"):
    all_vals = []
    graphs = io_utils.read_graphfile(args.datadir,
                                     args.bmname,
                                     max_nodes=args.max_nodes)

    if feat == "node-feat" and "feat_dim" in graphs[0].graph:
        print("Using node features")
        input_dim = graphs[0].graph["feat_dim"]
    elif feat == "node-label" and "label" in graphs[0].nodes[0]:
        print("Using node labels")
        for G in graphs:
            for u in G.nodes():
                G.nodes[u]["feat"] = np.array(G.nodes[u]["label"])
    else:
        print("Using constant labels")
        featgen_const = featgen.ConstFeatureGen(
            np.ones(args.input_dim, dtype=float))
        for G in graphs:
            featgen_const.gen_node_features(G)

    # 10 splits
    for i in range(10):
        train_dataset, val_dataset, max_num_nodes, input_dim, assign_input_dim = cross_val.prepare_val_data(
            graphs, args, i, max_nodes=args.max_nodes)
        print("Method: base")
        model = models.GcnEncoderGraph(
            input_dim,
            args.hidden_dim,
            args.output_dim,
            args.num_classes,
            args.num_gc_layers,
            bn=args.bn,
            dropout=args.dropout,
            args=args,
        ).cuda()

        _, val_accs = train(
            train_dataset,
            model,
            args,
            val_dataset=val_dataset,
            test_dataset=None,
            writer=writer,
        )
        all_vals.append(np.array(val_accs))
    all_vals = np.vstack(all_vals)
    all_vals = np.mean(all_vals, axis=0)
    print(all_vals)
    print(np.max(all_vals))
    print(np.argmax(all_vals))
コード例 #9
0
def reveal(args, idx=None, writer=None):
    labels_dict = {
        "None": 5,
        "Employee": 0,
        "Vice President": 1,
        "Manager": 2,
        "Trader": 3,
        "CEO+Managing Director+Director+President": 4,
    }
    max_enron_id = 183
    if idx is None:
        G_list = []
        labels_list = []
        for i in range(10):
            net = pickle.load(
                open("data/gnn-explainer-enron/enron_slice_{}.pkl".format(i),
                     "rb"))
            # net.add_nodes_from(range(max_enron_id))
            # labels=[n[1].get('role', 'None') for n in net.nodes(data=True)]
            # labels_num = [labels_dict[l] for l in labels]
            featgen_const = featgen.ConstFeatureGen(
                np.ones(args.input_dim, dtype=float))
            featgen_const.gen_node_features(net)
            G_list.append(net)
            print(net.number_of_nodes())
            # labels_list.append(labels_num)

        G = nx.disjoint_union_all(G_list)
        model = models.GcnEncoderNode(
            args.input_dim,
            args.hidden_dim,
            args.output_dim,
            len(labels_dict),
            args.num_gc_layers,
            bn=args.bn,
            args=args,
        )
        labels = [n[1].get("role", "None") for n in G.nodes(data=True)]
        labels_num = [labels_dict[l] for l in labels]
        for i in range(5):
            print("Label ", i, ": ", labels_num.count(i))

        print("Total num nodes: ", len(labels_num))
        print(labels_num)

        if args.gpu:
            model = model.cuda()
        train_node_classifier(G, labels_num, model, args, writer=writer)
    else:
        print("Running Enron full task")
コード例 #10
0
def enron_task_multigraph(args, idx=None, writer=None):
    labels_dict = {
        "None": 5,
        "Employee": 0,
        "Vice President": 1,
        "Manager": 2,
        "Trader": 3,
        "CEO+Managing Director+Director+President": 4,
    }
    max_enron_id = 183
    if idx is None:
        G_list = []
        labels_list = []
        for i in range(10):
            net = pickle.load(
                open("data/gnn-explainer-enron/enron_slice_{}.pkl".format(i),
                     "rb"))
            net.add_nodes_from(range(max_enron_id))
            labels = [n[1].get("role", "None") for n in net.nodes(data=True)]
            labels_num = [labels_dict[l] for l in labels]
            featgen_const = featgen.ConstFeatureGen(
                np.ones(args.input_dim, dtype=float))
            featgen_const.gen_node_features(net)
            G_list.append(net)
            labels_list.append(labels_num)
        # train_dataset, test_dataset, max_num_nodes = prepare_data(G_list, args)
        model = models.GcnEncoderNode(
            args.input_dim,
            args.hidden_dim,
            args.output_dim,
            args.num_classes,
            args.num_gc_layers,
            bn=args.bn,
            args=args,
        )
        if args.gpu:
            model = model.cuda()
        print(labels_num)
        train_node_classifier_multigraph(G_list,
                                         labels_list,
                                         model,
                                         args,
                                         writer=writer)
    else:
        print("Running Enron full task")
コード例 #11
0
def gen_syn1(nb_shapes=80, width_basis=300, feature_generator=None, m=5):
    basis_type = 'ba'
    list_shapes = [['house']] * nb_shapes

    fig = plt.figure(figsize=(8, 6), dpi=300)

    G, role_id, plugins = synthetic_structsim.build_graph(width_basis,
                                                          basis_type,
                                                          list_shapes,
                                                          start=0,
                                                          m=5)
    G = perturb_new([G], 0.01)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + '_' + str(width_basis) + '_' + str(nb_shapes)
    return G, role_id, name
コード例 #12
0
def gen_syn5(nb_shapes=80, width_basis=8, feature_generator=None, m=3):
    basis_type = 'tree'
    list_shapes = [['grid', m]] * nb_shapes
    fig = plt.figure(figsize=(8, 6), dpi=300)

    G, role_id, plugins = synthetic_structsim.build_graph(width_basis,
                                                          basis_type,
                                                          list_shapes,
                                                          start=0)
    G = perturb_new([G], 0.1)[0]

    if feature_generator is None:
        feature_generator = featgen.ConstFeatureGen(1)
    feature_generator.gen_node_features(G)

    name = basis_type + '_' + str(width_basis) + '_' + str(nb_shapes)

    path = os.path.join('log/syn5_base_h20_o20')
    writer = SummaryWriter(path)
    #io_utils.log_graph(writer, G, 'graph/full')

    return G, role_id, name
コード例 #13
0
def read_biosnap(datadir,
                 edgelist_file,
                 label_file,
                 feat_file=None,
                 concat=True):
    """ Read data from BioSnap

    Returns:
        List of networkx objects with graph and node labels
    """
    G = nx.Graph()
    delimiter = "\t" if "tsv" in edgelist_file else ","
    print(delimiter)
    df = pd.read_csv(os.path.join(datadir, edgelist_file),
                     delimiter=delimiter,
                     header=None)
    data = list(map(tuple, df.values.tolist()))
    G.add_edges_from(data)
    print("Total nodes: ", G.number_of_nodes())

    G = max(nx.connected_component_subgraphs(G), key=len)
    print("Total nodes in largest connected component: ", G.number_of_nodes())

    df = pd.read_csv(os.path.join(datadir, label_file),
                     delimiter="\t",
                     usecols=[0, 1])
    data = list(map(tuple, df.values.tolist()))

    missing_node = 0
    for line in data:
        if int(line[0]) not in G:
            missing_node += 1
        else:
            G.nodes[int(line[0])]["label"] = int(line[1] == "Essential")

    print("missing node: ", missing_node)

    missing_label = 0
    remove_nodes = []
    for u in G.nodes():
        if "label" not in G.nodes[u]:
            missing_label += 1
            remove_nodes.append(u)
    G.remove_nodes_from(remove_nodes)
    print("missing_label: ", missing_label)

    if feat_file is None:
        feature_generator = featgen.ConstFeatureGen(np.ones(10, dtype=float))
        feature_generator.gen_node_features(G)
    else:
        df = pd.read_csv(os.path.join(datadir, feat_file), delimiter=",")
        data = np.array(df.values)
        print("Feat shape: ", data.shape)

        for row in data:
            if int(row[0]) in G:
                if concat:
                    node = int(row[0])
                    onehot = np.zeros(10)
                    onehot[min(G.degree[node], 10) - 1] = 1.0
                    G.nodes[node]["feat"] = np.hstack(
                        (np.log(row[1:] + 0.1), [1.0], onehot))
                else:
                    G.nodes[int(row[0])]["feat"] = np.log(row[1:] + 0.1)

        missing_feat = 0
        remove_nodes = []
        for u in G.nodes():
            if "feat" not in G.nodes[u]:
                missing_feat += 1
                remove_nodes.append(u)
        G.remove_nodes_from(remove_nodes)
        print("missing feat: ", missing_feat)

    return G
コード例 #14
0
def syn_task1(args, writer=None):
    print('Generating graph.')
    feature_generator = featgen.ConstFeatureGen(
        np.ones(args.input_dim, dtype=float))
    if args.dataset == 'syn1':
        gen_fn = gengraph.gen_syn1
    elif args.dataset == 'syn2':
        gen_fn = gengraph.gen_syn2
        feature_generator = None
    elif args.dataset == 'syn3':
        gen_fn = gengraph.gen_syn3
    elif args.dataset == 'syn4':
        gen_fn = gengraph.gen_syn4
    elif args.dataset == 'syn5':
        gen_fn = gengraph.gen_syn5
    G, labels, name = gen_fn(feature_generator=feature_generator)
    pyg_G = NxDataset([G],
                      device=torch.device('gpu' if args.gpu else 'cpu'))[0]
    num_classes = max(labels) + 1
    labels = torch.LongTensor(labels)
    print('Done generating graph.')

    model = GCNNet(args.input_dim,
                   args.hidden_dim,
                   args.output_dim,
                   num_classes,
                   args.num_gc_layers,
                   args=args)

    if args.gpu:
        model = model.cuda()

    train_ratio = args.train_ratio
    num_train = int(train_ratio * G.number_of_nodes())
    num_test = G.number_of_nodes() - num_train

    idx = [i for i in range(G.number_of_nodes())]

    np.random.shuffle(idx)
    train_mask = idx[:num_train]
    test_mask = idx[num_train:]

    loader = torch_geometric.data.DataLoader([pyg_G], batch_size=1)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler, opt = train_utils.build_optimizer(
        args, model.parameters(), weight_decay=args.weight_decay)
    for epoch in range(args.num_epochs):
        model.train()
        total_loss = 0
        for batch in loader:
            opt.zero_grad()
            pred = model(batch)

            pred = pred[train_mask]
            label = labels[train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            opt.step()
            total_loss += loss.item() * 1
        writer.add_scalar("loss", total_loss, epoch)

        if epoch % 10 == 0:
            test_acc = test(loader, model, args, labels, test_mask)
            print("{} {:.4f} {:.4f}".format(epoch, total_loss, test_acc))
            writer.add_scalar("test", test_acc, epoch)

    print("{} {:.4f} {:.4f}".format(epoch, total_loss, test_acc))
    data = gengraph.preprocess_input_graph(G, labels)
    adj = torch.tensor(data['adj'], dtype=torch.float)
    x = torch.tensor(data['feat'], requires_grad=True, dtype=torch.float)

    model.eval()
    ypred = model(batch)
コード例 #15
0
def FFMpeg(args, writer=None, feat="node-label"):
    graphs = io_utils.read_graphfile(args.datadir,
                                     args.bmname,
                                     max_nodes=args.max_nodes)
    print(max([G.graph["label"] for G in graphs]))

    if feat == "node-feat" and "feat_dim" in graphs[0].graph:
        print("Using node features")
        input_dim = graphs[0].graph["feat_dim"]
    elif feat == "node-label" and "label" in graphs[0].nodes[0]:
        print("Using node labels")
        for G in graphs:
            for u in G.nodes():
                G.nodes[u]["feat"] = np.array(G.nodes[u]["label"])
                # make it -1/1 instead of 0/1
                # feat = np.array(G.nodes[u]['label'])
                # G.nodes[u]['feat'] = feat * 2 - 1
    else:
        print("Using constant labels")
        featgen_const = featgen.ConstFeatureGen(
            np.ones(args.input_dim, dtype=float))
        for G in graphs:
            featgen_const.gen_node_features(G)

    train_dataset, val_dataset, test_dataset, max_num_nodes, input_dim, assign_input_dim = prepare_data(
        graphs, args, max_nodes=args.max_nodes)
    if args.method == "soft-assign":
        print("Method: soft-assign")
        model = models.SoftPoolingGcnEncoder(
            max_num_nodes,
            input_dim,
            args.hidden_dim,
            args.output_dim,
            args.num_classes,
            args.num_gc_layers,
            args.hidden_dim,
            assign_ratio=args.assign_ratio,
            num_pooling=args.num_pool,
            bn=args.bn,
            dropout=args.dropout,
            linkpred=args.linkpred,
            args=args,
            assign_input_dim=assign_input_dim,
        ).cuda()
    else:
        print("Method: base")
        model = models.GcnEncoderGraph(
            input_dim,
            args.hidden_dim,
            args.output_dim,
            args.num_classes,
            args.num_gc_layers,
            bn=args.bn,
            dropout=args.dropout,
            args=args,
        ).cuda()

    train(
        train_dataset,
        model,
        args,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        writer=writer,
    )
    evaluate(test_dataset, model, args, "Validation")
コード例 #16
0
import networkx as nx
from synthetic import gen_syn4
import pickle
import numpy as np
from utils import featgen

G, labels, _ = gen_syn4(
    feature_generator=featgen.ConstFeatureGen(np.ones(10, dtype=float)))
nx.write_gpickle(G, 'data/syn4_G.pickle')

with open('data/syn4_lab.pickle', 'wb') as f:
    pickle.dump(labels, f)
コード例 #17
0
def syn_task1(args, writer=None):
    # data
    print('Generating graph.')
    G, labels, name = gengraph.gen_syn1(
        feature_generator=featgen.ConstFeatureGen(
            np.ones(args.input_dim, dtype=float)))
    # print ('G.node[0]:', G.node[0]['feat'].dtype)
    # print ('Original labels:', labels)
    pyg_G = from_networkx(G)
    num_classes = max(labels) + 1
    labels = torch.LongTensor(labels)
    print('Done generating graph.')

    # if args.method == 'att':
    # print('Method: att')
    # model = models.GcnEncoderNode(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
    # args.num_gc_layers, bn=args.bn, args=args)

    # else:
    # print('Method:', args.method)
    # model = models.GcnEncoderNode(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
    # args.num_gc_layers, bn=args.bn, args=args)

    model = GCNNet(args.input_dim,
                   args.hidden_dim,
                   num_classes,
                   args.num_gc_layers,
                   args=args)

    if args.gpu:
        model = model.cuda()

    train_ratio = args.train_ratio
    num_train = int(train_ratio * G.number_of_nodes())
    num_test = G.number_of_nodes() - num_train
    shuffle_indices = list(range(G.number_of_nodes()))
    shuffle_indices = np.random.permutation(shuffle_indices)

    train_mask = num_train * [True] + num_test * [False]
    train_mask = torch.BoolTensor([train_mask[i] for i in shuffle_indices])
    test_mask = num_train * [False] + num_test * [True]
    test_mask = torch.BoolTensor([test_mask[i] for i in shuffle_indices])

    loader = torch_geometric.data.DataLoader([pyg_G], batch_size=1)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    for epoch in range(args.num_epochs):
        total_loss = 0
        model.train()
        for batch in loader:
            # print ('batch:', batch.feat)
            opt.zero_grad()
            pred = model(batch)

            pred = pred[train_mask]
            # print ('pred:', pred)
            label = labels[train_mask]
            # print ('label:', label)
            loss = model.loss(pred, label)
            print('loss:', loss)
            loss.backward()
            opt.step()
            total_loss += loss.item() * 1
        total_loss /= num_train
        writer.add_scalar("loss", total_loss, epoch)

        if epoch % 10 == 0:
            test_acc = test(loader, model, args, labels, test_mask)
            print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format(
                epoch, total_loss, test_acc))
            writer.add_scalar("test accuracy", test_acc, epoch)