예제 #1
0
def main(arg):
    device = torch.device(arg.device)

    num_epoch = arg.e
    batch_size = arg.batch

    SelectGraph.data_name = args.d
    data_set = SelectGraph('data/' + SelectGraph.data_name)
    input_size = data_set.num_features
    num_classes = data_set.num_classes
    shapes = list(map(int, arg.shapes.split(",")))
    train_set = DataLoader(data_set[arg.n_skip:arg.n_skip + arg.n_train],
                           batch_size=batch_size,
                           shuffle=False)
    test_set = DataLoader(data_set[arg.n_skip + arg.n_train:arg.n_skip +
                                   arg.n_train + arg.n_test],
                          batch_size=batch_size,
                          shuffle=False)

    if arg.m == "MIAGAE":
        from classification.Graph_AE import Net
        model = Net(input_size, arg.k, arg.depth, [arg.c_rate] * arg.depth,
                    shapes, device).to(device)
    elif arg.m == "UNet":
        from classification.UNet import Net
        model = Net(input_size, arg.depth, arg.c_rate, shapes,
                    device).to(device)
    elif arg.m == "Gpool":
        from classification.Gpool_model import Net
        model = Net(input_size, arg.depth, arg.c_rate, shapes,
                    device).to(device)
    elif arg.m == "SAGpool":
        from classification.SAG_model import Net
        model = Net(input_size, arg.depth, [arg.c_rate] * arg.depth, shapes,
                    device).to(device)
    else:
        print("model not found")
        return

    model.load_state_dict(torch.load(arg.model_dir + arg.m + ".ckpt"),
                          strict=True)
    group1, group2 = load_model_result(model, train_set, test_set, device)
    input_size2 = group1[0].shape[1]
    c_model = MLP(input_size2, arg.hidden, num_classes, arg.dropout).to(device)
    optimizer = torch.optim.Adam(c_model.parameters(), lr=arg.lr)
    train_cf(c_model, optimizer, device, train_set, test_set, num_epoch,
             group1, group2)
예제 #2
0
from torch_geometric.data import DataLoader
from utils.CustomDataSet import SelectGraph, SceneGraphs
from classification.Gpool_model import Net
from classification.Classifier import MLP
import utils.Display_Plot as dp
import torch

device = torch.device('cuda:1')

num_epoch = 100
batch_size = 200
comp_model = Net.get_instance().to(device)
cfy_model = MLP.get_instance().to(device)

SelectGraph.data_name = 'Shana7000'
data_set_Shana = SelectGraph('data/' + SelectGraph.data_name)
train_set = DataLoader(data_set_Shana[:5000], 500, shuffle=True)

SelectGraph.data_name = 'Shana7000'
data_set_Shana = SelectGraph('data/' + SelectGraph.data_name)
train_set2 = DataLoader(data_set_Shana[5000:6000], 1000, shuffle=False)
test_set = DataLoader(data_set_Shana[6000:7000], 1000, shuffle=False)

m_name = "Gpool_TRANSFER.ckpt"
data_list1, group1, group2 = comp_model.train_model(train_set, train_set2, test_set, num_epoch, m_name)
data_list2 = cfy_model.train_model(train_set2, test_set, int(num_epoch // 2), group1, group2)

title = "SAG TRANSFER"
labels = ['MSE Loss', 'Num Nodes', 'Total Loss', title]
dp.display(data_list1, num_epoch, labels, title)
labels = ['Train Loss', 'Train Acc', title, 'Test Acc']