def main(args): if args.gpu < 0: cuda = False else: cuda = True torch.cuda.set_device(args.gpu) default_path = create_default_path() print('\n*** Set default saving/loading path to:', default_path) if args.dataset == AIFB or args.dataset == MUTAG: module = importlib.import_module(MODULE.format('dglrgcn')) data = module.load_dglrgcn(args.data_path) data = to_cuda(data) if cuda else data mode = NODE_CLASSIFICATION elif args.dataset == MUTAGENICITY or args.dataset == PTC_MR or args.dataset == PTC_MM or args.dataset == PTC_FR or args.dataset == PTC_FM: module = importlib.import_module(MODULE.format('dortmund')) data = module.load_dortmund(args.data_path) data = to_cuda(data) if cuda else data mode = GRAPH_CLASSIFICATION else: raise ValueError('Unable to load dataset', args.dataset) print_graph_stats(data[GRAPH]) config_params = read_params(args.config_fpath, verbose=True) # create GNN model model = Model(g=data[GRAPH], config_params=config_params[0], n_classes=data[N_CLASSES], n_rels=data[N_RELS] if N_RELS in data else None, n_entities=data[N_ENTITIES] if N_ENTITIES in data else None, is_cuda=cuda, mode=mode) if cuda: model.cuda() # 1. Training app = App() learning_config = { 'lr': args.lr, 'n_epochs': args.n_epochs, 'weight_decay': args.weight_decay, 'batch_size': args.batch_size, 'cuda': cuda } print('\n*** Start training ***\n') app.train(data, config_params[0], learning_config, default_path, mode=mode) # 2. Testing print('\n*** Start testing ***\n') app.test(data, default_path, mode=mode) # 3. Delete model remove_model(default_path)
def _move_to_gpu(self, model: Model) -> Model: if self._cuda_device != -1: return model.cuda(self._cuda_device) else: return model
class App: def __init__(self, early_stopping=True): if early_stopping: self.early_stopping = EarlyStopping(patience=100, verbose=True) def train(self, data, model_config, learning_config, save_path='', mode=NODE_CLASSIFICATION): loss_fcn = torch.nn.CrossEntropyLoss() labels = data[LABELS] # initialize graph if mode == NODE_CLASSIFICATION: train_mask = data[TRAIN_MASK] val_mask = data[VAL_MASK] dur = [] # create GNN model self.model = Model(g=data[GRAPH], config_params=model_config, n_classes=data[N_CLASSES], n_rels=data[N_RELS] if N_RELS in data else None, n_entities=data[N_ENTITIES] if N_ENTITIES in data else None, is_cuda=learning_config['cuda'], mode=mode) optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_config['lr'], weight_decay=learning_config['weight_decay']) for epoch in range(learning_config['n_epochs']): self.model.train() if epoch >= 3: t0 = time.time() # forward logits = self.model(None) loss = loss_fcn(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() if epoch >= 3: dur.append(time.time() - t0) val_acc, val_loss = self.model.eval_node_classification(labels, val_mask) print("Epoch {:05d} | Time(s) {:.4f} | Train loss {:.4f} | Val accuracy {:.4f} | " "Val loss {:.4f}".format(epoch, np.mean(dur), loss.item(), val_acc, val_loss)) self.early_stopping(val_loss, self.model, save_path) if self.early_stopping.early_stop: print("Early stopping") break elif mode == GRAPH_CLASSIFICATION: self.accuracies = np.zeros(10) graphs = data[GRAPH] # load all the graphs # debug purposes: reshuffle all the data before the splitting random_indices = list(range(len(graphs))) random.shuffle(random_indices) graphs = [graphs[i] for i in random_indices] labels = labels[random_indices] K = 10 for k in range(K): # K-fold cross validation # create GNN model self.model = Model(g=data[GRAPH], config_params=model_config, n_classes=data[N_CLASSES], n_rels=data[N_RELS] if N_RELS in data else None, n_entities=data[N_ENTITIES] if N_ENTITIES in data else None, is_cuda=learning_config['cuda'], mode=mode) optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_config['lr'], weight_decay=learning_config['weight_decay']) if learning_config['cuda']: self.model.cuda() print('\n\n\nProcess new k') start = int(len(graphs)/K) * k end = int(len(graphs)/K) * (k+1) # testing batch testing_graphs = graphs[start:end] self.testing_labels = labels[start:end] self.testing_batch = dgl.batch(testing_graphs) # training batch training_graphs = graphs[:start] + graphs[end:] training_labels = labels[list(range(0, start)) + list(range(end+1, len(graphs)))] training_samples = list(map(list, zip(training_graphs, training_labels))) training_batches = DataLoader(training_samples, batch_size=learning_config['batch_size'], shuffle=True, collate_fn=collate) dur = [] for epoch in range(learning_config['n_epochs']): self.model.train() if epoch >= 3: t0 = time.time() losses = [] training_accuracies = [] for iter, (bg, label) in enumerate(training_batches): logits = self.model(bg) loss = loss_fcn(logits, label) losses.append(loss.item()) _, indices = torch.max(logits, dim=1) correct = torch.sum(indices == label) training_accuracies.append(correct.item() * 1.0 / len(label)) optimizer.zero_grad() loss.backward() optimizer.step() if epoch >= 3: dur.append(time.time() - t0) val_acc, val_loss = self.model.eval_graph_classification(self.testing_labels, self.testing_batch) print("Epoch {:05d} | Time(s) {:.4f} | Train acc {:.4f} | Train loss {:.4f} " "| Val accuracy {:.4f} | Val loss {:.4f}".format(epoch, np.mean(dur) if dur else 0, np.mean(training_accuracies), np.mean(losses), val_acc, val_loss)) is_better = self.early_stopping(val_loss, self.model, save_path) if is_better: self.accuracies[k] = val_acc if self.early_stopping.early_stop: print("Early stopping") break self.early_stopping.reset() else: raise RuntimeError def test(self, data, load_path='', mode=NODE_CLASSIFICATION): try: print('*** Load pre-trained model ***') self.model = load_checkpoint(self.model, load_path) except ValueError as e: print('Error while loading the model.', e) if mode == NODE_CLASSIFICATION: test_mask = data[TEST_MASK] labels = data[LABELS] acc, _ = self.model.eval_node_classification(labels, test_mask) else: acc = np.mean(self.accuracies) print("\nTest Accuracy {:.4f}".format(acc)) return acc