Example #1
0
def main(args):

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)

    default_path = create_default_path()
    print('\n*** Set default saving/loading path to:', default_path)

    if args.dataset == AIFB or args.dataset == MUTAG:
        module = importlib.import_module(MODULE.format('dglrgcn'))
        data = module.load_dglrgcn(args.data_path)
        data = to_cuda(data) if cuda else data
        mode = NODE_CLASSIFICATION
    elif args.dataset == MUTAGENICITY or args.dataset == PTC_MR or args.dataset == PTC_MM or args.dataset == PTC_FR or args.dataset == PTC_FM:
        module = importlib.import_module(MODULE.format('dortmund'))
        data = module.load_dortmund(args.data_path)
        data = to_cuda(data) if cuda else data
        mode = GRAPH_CLASSIFICATION
    else:
        raise ValueError('Unable to load dataset', args.dataset)

    print_graph_stats(data[GRAPH])

    config_params = read_params(args.config_fpath, verbose=True)

    # create GNN model
    model = Model(g=data[GRAPH],
                  config_params=config_params[0],
                  n_classes=data[N_CLASSES],
                  n_rels=data[N_RELS] if N_RELS in data else None,
                  n_entities=data[N_ENTITIES] if N_ENTITIES in data else None,
                  is_cuda=cuda,
                  mode=mode)

    if cuda:
        model.cuda()

    # 1. Training
    app = App()
    learning_config = {
        'lr': args.lr,
        'n_epochs': args.n_epochs,
        'weight_decay': args.weight_decay,
        'batch_size': args.batch_size,
        'cuda': cuda
    }
    print('\n*** Start training ***\n')
    app.train(data, config_params[0], learning_config, default_path, mode=mode)

    # 2. Testing
    print('\n*** Start testing ***\n')
    app.test(data, default_path, mode=mode)

    # 3. Delete model
    remove_model(default_path)
 def _move_to_gpu(self, model: Model) -> Model:
     if self._cuda_device != -1:
         return model.cuda(self._cuda_device)
     else:
         return model
Example #3
0
class App:

    def __init__(self, early_stopping=True):
        if early_stopping:
            self.early_stopping = EarlyStopping(patience=100, verbose=True)

    def train(self, data, model_config, learning_config, save_path='', mode=NODE_CLASSIFICATION):

        loss_fcn = torch.nn.CrossEntropyLoss()

        labels = data[LABELS]
        # initialize graph
        if mode == NODE_CLASSIFICATION:
            train_mask = data[TRAIN_MASK]
            val_mask = data[VAL_MASK]
            dur = []

            # create GNN model
            self.model = Model(g=data[GRAPH],
                               config_params=model_config,
                               n_classes=data[N_CLASSES],
                               n_rels=data[N_RELS] if N_RELS in data else None,
                               n_entities=data[N_ENTITIES] if N_ENTITIES in data else None,
                               is_cuda=learning_config['cuda'],
                               mode=mode)

            optimizer = torch.optim.Adam(self.model.parameters(),
                                         lr=learning_config['lr'],
                                         weight_decay=learning_config['weight_decay'])

            for epoch in range(learning_config['n_epochs']):
                self.model.train()
                if epoch >= 3:
                    t0 = time.time()
                # forward
                logits = self.model(None)
                loss = loss_fcn(logits[train_mask], labels[train_mask])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if epoch >= 3:
                    dur.append(time.time() - t0)

                val_acc, val_loss = self.model.eval_node_classification(labels, val_mask)
                print("Epoch {:05d} | Time(s) {:.4f} | Train loss {:.4f} | Val accuracy {:.4f} | "
                      "Val loss {:.4f}".format(epoch,
                                               np.mean(dur),
                                               loss.item(),
                                               val_acc,
                                               val_loss))

                self.early_stopping(val_loss, self.model, save_path)

                if self.early_stopping.early_stop:
                    print("Early stopping")
                    break

        elif mode == GRAPH_CLASSIFICATION:
            self.accuracies = np.zeros(10)
            graphs = data[GRAPH]                 # load all the graphs

            # debug purposes: reshuffle all the data before the splitting
            random_indices = list(range(len(graphs)))
            random.shuffle(random_indices)
            graphs = [graphs[i] for i in random_indices]
            labels = labels[random_indices]

            K = 10
            for k in range(K):                  # K-fold cross validation

                # create GNN model
                self.model = Model(g=data[GRAPH],
                                   config_params=model_config,
                                   n_classes=data[N_CLASSES],
                                   n_rels=data[N_RELS] if N_RELS in data else None,
                                   n_entities=data[N_ENTITIES] if N_ENTITIES in data else None,
                                   is_cuda=learning_config['cuda'],
                                   mode=mode)

                optimizer = torch.optim.Adam(self.model.parameters(),
                                             lr=learning_config['lr'],
                                             weight_decay=learning_config['weight_decay'])

                if learning_config['cuda']:
                    self.model.cuda()

                print('\n\n\nProcess new k')
                start = int(len(graphs)/K) * k
                end = int(len(graphs)/K) * (k+1)

                # testing batch
                testing_graphs = graphs[start:end]
                self.testing_labels = labels[start:end]
                self.testing_batch = dgl.batch(testing_graphs)

                # training batch
                training_graphs = graphs[:start] + graphs[end:]
                training_labels = labels[list(range(0, start)) + list(range(end+1, len(graphs)))]
                training_samples = list(map(list, zip(training_graphs, training_labels)))
                training_batches = DataLoader(training_samples,
                                              batch_size=learning_config['batch_size'],
                                              shuffle=True,
                                              collate_fn=collate)

                dur = []
                for epoch in range(learning_config['n_epochs']):
                    self.model.train()
                    if epoch >= 3:
                        t0 = time.time()
                    losses = []
                    training_accuracies = []
                    for iter, (bg, label) in enumerate(training_batches):
                        logits = self.model(bg)
                        loss = loss_fcn(logits, label)
                        losses.append(loss.item())
                        _, indices = torch.max(logits, dim=1)
                        correct = torch.sum(indices == label)
                        training_accuracies.append(correct.item() * 1.0 / len(label))

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    if epoch >= 3:
                        dur.append(time.time() - t0)
                    val_acc, val_loss = self.model.eval_graph_classification(self.testing_labels, self.testing_batch)
                    print("Epoch {:05d} | Time(s) {:.4f} | Train acc {:.4f} | Train loss {:.4f} "
                          "| Val accuracy {:.4f} | Val loss {:.4f}".format(epoch,
                                                                           np.mean(dur) if dur else 0,
                                                                           np.mean(training_accuracies),
                                                                           np.mean(losses),
                                                                           val_acc,
                                                                           val_loss))

                    is_better = self.early_stopping(val_loss, self.model, save_path)
                    if is_better:
                        self.accuracies[k] = val_acc

                    if self.early_stopping.early_stop:
                        print("Early stopping")
                        break
                self.early_stopping.reset()
        else:
            raise RuntimeError

    def test(self, data, load_path='', mode=NODE_CLASSIFICATION):

        try:
            print('*** Load pre-trained model ***')
            self.model = load_checkpoint(self.model, load_path)
        except ValueError as e:
            print('Error while loading the model.', e)

        if mode == NODE_CLASSIFICATION:
            test_mask = data[TEST_MASK]
            labels = data[LABELS]
            acc, _ = self.model.eval_node_classification(labels, test_mask)
        else:
            acc = np.mean(self.accuracies)

        print("\nTest Accuracy {:.4f}".format(acc))

        return acc