Ejemplo n.º 1
0
def load_dataset_for_classification(args):
    """Load dataset for classification tasks.

    Parameters
    ----------
    args : dict
        Configurations.

    Returns
    -------
    dataset
        The whole dataset.
    train_set
        Subset for training.
    val_set
        Subset for validation.
    test_set
        Subset for test.
    """
    assert args['dataset'] in ['Tox21']
    if args['dataset'] == 'Tox21':
        from dgl.data.chem import Tox21
        dataset = Tox21(atom_featurizer=args['atom_featurizer'])
        train_set, val_set, test_set = split_dataset(dataset, args['train_val_test_split'])

    return dataset, train_set, val_set, test_set
Ejemplo n.º 2
0
def load_dataset_for_classification(args):
    """Load dataset for classification tasks.

    Parameters
    ----------
    args : dict
        Configurations.

    Returns
    -------
    dataset
        The whole dataset.
    train_set
        Subset for training.
    val_set
        Subset for validation.
    test_set
        Subset for test.
    """
    assert args['dataset'] in ['Tox21']
    if args['dataset'] == 'Tox21':
        from dgl.data.chem import Tox21
        dataset = Tox21(smiles_to_bigraph, args['atom_featurizer'])
        train_set, val_set, test_set = RandomSplitter.train_val_test_split(
            dataset,
            frac_train=args['frac_train'],
            frac_val=args['frac_val'],
            frac_test=args['frac_test'],
            random_state=args['random_seed'])

    return dataset, train_set, val_set, test_set
Ejemplo n.º 3
0
def main(args):
    args['device'] = "cuda" if torch.cuda.is_available() else "cpu"
    set_random_seed()

    # Interchangeable with other datasets
    if args['dataset'] == 'Tox21':
        from dgl.data.chem import Tox21
        dataset = Tox21()

    trainset, valset, testset = split_dataset(dataset, args['train_val_test_split'])
    train_loader = DataLoader(trainset, batch_size=args['batch_size'],
                              collate_fn=collate_molgraphs_for_classification)
    val_loader = DataLoader(valset, batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs_for_classification)
    test_loader = DataLoader(testset, batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs_for_classification)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = model_zoo.chem.load_pretrained(args['exp'])
    else:
        # Interchangeable with other models
        if args['model'] == 'GCN':
            model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'],
                                                 gcn_hidden_feats=args['gcn_hidden_feats'],
                                                 classifier_hidden_feats=args['classifier_hidden_feats'],
                                                 n_tasks=dataset.n_tasks)
        elif args['model'] == 'GAT':
            model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'],
                                                 gat_hidden_feats=args['gat_hidden_feats'],
                                                 num_heads=args['num_heads'],
                                                 classifier_hidden_feats=args['classifier_hidden_feats'],
                                                 n_tasks=dataset.n_tasks)

        loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']),
                                           reduction='none')
        optimizer = Adam(model.parameters(), lr=args['lr'])
        stopper = EarlyStopping(patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)

        # Validation and early stop
        val_roc_auc = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_roc_auc, model)
        print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format(
            epoch + 1, args['num_epochs'], val_roc_auc, stopper.best_score))
        if early_stop:
            break

    if not args['pre_trained']:
        stopper.load_checkpoint(model)
    test_roc_auc = run_an_eval_epoch(args, model, test_loader)
    print('test roc-auc score {:.4f}'.format(test_roc_auc))
Ejemplo n.º 4
0
def main(args):
    args = setup(args)

    dataset = Tox21()
    train_set, val_set, test_set = split_dataset(dataset, shuffle=True)
    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            shuffle=True,
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             shuffle=True,
                             collate_fn=collate_molgraphs)

    model = model_zoo.chem.GCNClassifier(
        in_feats=args['n_input'],
        gcn_hidden_feats=[args['n_hidden'] for _ in range(args['n_layers'])],
        n_tasks=dataset.n_tasks,
        classifier_hidden_feats=args['n_hidden']).to(args['device'])
    loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
        dataset.task_pos_weights).to(args['device']),
                                       reduction='none')
    optimizer = Adam(model.parameters(), lr=args['lr'])
    stopper = EarlyStopper(args['patience'])
    history = []
    for epoch in range(args['n_epochs']):
        # Train
        train_score = run_a_train_epoch(args, epoch, model, train_loader,
                                        loss_criterion, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        history.append([train_score, val_score])
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation roc-auc {:.4f}, best validation roc-auc {:.4f}'
            .format(epoch + 1, args['n_epochs'], val_score,
                    stopper.best_score))
        torch.save(history, "./history.pt")
        if early_stop:
            break

    stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    plot_save(history)
    print('Best validation score {:.4f}'.format(stopper.best_score))
    print('Test score {:.4f}'.format(test_score))
Ejemplo n.º 5
0
def main(args):
    args = setup(args)

    dataset = Tox21()
    train_set, val_set, test_set = split_dataset(dataset, shuffle=True)
    train_loader = DataLoader(train_set,
                              batch_size=args["batch_size"],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(val_set,
                            batch_size=args["batch_size"],
                            shuffle=True,
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(test_set,
                             batch_size=args["batch_size"],
                             shuffle=True,
                             collate_fn=collate_molgraphs)

    model = model_zoo.chem.GCNClassifier(
        in_feats=args["n_input"],
        gcn_hidden_feats=[args["n_hidden"] for _ in range(args["n_layers"])],
        n_tasks=dataset.n_tasks,
        classifier_hidden_feats=args["n_hidden"],
    ).to(args["device"])
    loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(
        dataset.task_pos_weights).to(args["device"]),
                                       reduction="none")
    optimizer = Adam(model.parameters(), lr=args["lr"])
    stopper = EarlyStopper(args["patience"])

    for epoch in range(args["n_epochs"]):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion,
                          optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            "epoch {:d}/{:d}, validation roc-auc {:.4f}, best validation roc-auc {:.4f}"
            .format(epoch + 1, args["n_epochs"], val_score,
                    stopper.best_score))
        if early_stop:
            break

    stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print("Best validation score {:.4f}".format(stopper.best_score))
    print("Test score {:.4f}".format(test_score))