def main(arg): device = torch.device(arg.device) num_epoch = arg.e batch_size = arg.batch SelectGraph.data_name = args.d data_set = SelectGraph('data/' + SelectGraph.data_name) input_size = data_set.num_features num_classes = data_set.num_classes shapes = list(map(int, arg.shapes.split(","))) train_set = DataLoader(data_set[arg.n_skip:arg.n_skip + arg.n_train], batch_size=batch_size, shuffle=False) test_set = DataLoader(data_set[arg.n_skip + arg.n_train:arg.n_skip + arg.n_train + arg.n_test], batch_size=batch_size, shuffle=False) if arg.m == "MIAGAE": from classification.Graph_AE import Net model = Net(input_size, arg.k, arg.depth, [arg.c_rate] * arg.depth, shapes, device).to(device) elif arg.m == "UNet": from classification.UNet import Net model = Net(input_size, arg.depth, arg.c_rate, shapes, device).to(device) elif arg.m == "Gpool": from classification.Gpool_model import Net model = Net(input_size, arg.depth, arg.c_rate, shapes, device).to(device) elif arg.m == "SAGpool": from classification.SAG_model import Net model = Net(input_size, arg.depth, [arg.c_rate] * arg.depth, shapes, device).to(device) else: print("model not found") return model.load_state_dict(torch.load(arg.model_dir + arg.m + ".ckpt"), strict=True) group1, group2 = load_model_result(model, train_set, test_set, device) input_size2 = group1[0].shape[1] c_model = MLP(input_size2, arg.hidden, num_classes, arg.dropout).to(device) optimizer = torch.optim.Adam(c_model.parameters(), lr=arg.lr) train_cf(c_model, optimizer, device, train_set, test_set, num_epoch, group1, group2)
from torch_geometric.data import DataLoader from utils.CustomDataSet import SelectGraph, SceneGraphs from classification.Gpool_model import Net from classification.Classifier import MLP import utils.Display_Plot as dp import torch device = torch.device('cuda:1') num_epoch = 100 batch_size = 200 comp_model = Net.get_instance().to(device) cfy_model = MLP.get_instance().to(device) SelectGraph.data_name = 'Shana7000' data_set_Shana = SelectGraph('data/' + SelectGraph.data_name) train_set = DataLoader(data_set_Shana[:5000], 500, shuffle=True) SelectGraph.data_name = 'Shana7000' data_set_Shana = SelectGraph('data/' + SelectGraph.data_name) train_set2 = DataLoader(data_set_Shana[5000:6000], 1000, shuffle=False) test_set = DataLoader(data_set_Shana[6000:7000], 1000, shuffle=False) m_name = "Gpool_TRANSFER.ckpt" data_list1, group1, group2 = comp_model.train_model(train_set, train_set2, test_set, num_epoch, m_name) data_list2 = cfy_model.train_model(train_set2, test_set, int(num_epoch // 2), group1, group2) title = "SAG TRANSFER" labels = ['MSE Loss', 'Num Nodes', 'Total Loss', title] dp.display(data_list1, num_epoch, labels, title) labels = ['Train Loss', 'Train Acc', title, 'Test Acc']