Beispiel #1
0
def split_dataset(args, dataset):
    """Split the dataset

    Parameters
    ----------
    args : dict
        Settings
    dataset
        Dataset instance

    Returns
    -------
    train_set
        Training subset
    val_set
        Validation subset
    test_set
        Test subset
    """
    train_ratio, val_ratio, test_ratio = map(float, args['split_ratio'].split(','))
    if args['split'] == 'scaffold':
        train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio,
            scaffold_func='smiles')
    elif args['split'] == 'random':
        train_set, val_set, test_set = RandomSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio)
    else:
        return ValueError("Expect the splitting method to be 'scaffold', got {}".format(args['split']))

    return train_set, val_set, test_set
Beispiel #2
0
def main(args):
    args['device'] = torch.device(
        "cuda: 0") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    dataset = PubChemBioAssayAromaticity(
        smiles_to_graph=args['smiles_to_graph'],
        node_featurizer=args.get('node_featurizer', None),
        edge_featurizer=args.get('edge_featurizer', None))
    train_set, val_set, test_set = RandomSplitter.train_val_test_split(
        dataset,
        frac_train=args['frac_train'],
        frac_val=args['frac_val'],
        frac_test=args['frac_test'],
        random_state=args['random_seed'])
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        model = load_model(args)
        loss_fn = nn.MSELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args['lr'],
                                     weight_decay=args['weight_decay'])
        stopper = EarlyStopping(mode=args['mode'], patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))

        if early_stop:
            break

    if not args['pre_trained']:
        stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric_name'], test_score))
Beispiel #3
0
        'all the columns except for the smiles_column in the CSV file. '
        '(default: None)')
    args = parser.parse_args().__dict__

    args['exp_name'] = '_'.join([args['model'], args['mode']])
    if args['tasks'] is not None:
        args['tasks'] = args['tasks'].split(',')
    args.update(configs[args['exp_name']])

    # Setup for experiments
    mkdir_p(args['result_path'])

    node_featurizer = atom_featurizer
    edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he',
                                              self_loop=True)
    df = pd.read_csv(args['csv_path'])
    dataset = MoleculeCSVDataset(
        df,
        partial(smiles_to_bigraph, add_self_loop=True),
        node_featurizer=node_featurizer,
        edge_featurizer=edge_featurizer,
        smiles_column=args['smiles_column'],
        cache_file_path=args['result_path'] + '/graph.bin',
        task_names=args['tasks'])
    args['tasks'] = dataset.task_names
    args = setup(args)
    train_set, val_set, test_set = RandomSplitter.train_val_test_split(
        dataset, frac_train=0.8, frac_val=0.1, frac_test=0.1, random_state=0)

    main(args, node_featurizer, edge_featurizer, train_set, val_set, test_set)
Beispiel #4
0
def load_dataset(args):
    """Load the dataset.
    Parameters
    ----------
    args : dict
        Input arguments.
    Returns
    -------
    dataset
        Full dataset.
    train_set
        Train subset of the dataset.
    val_set
        Validation subset of the dataset.
    """
    assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(args['dataset'])
    if args['dataset'] == 'PDBBind':
        if args['model'] == 'PotentialNet': 
            from functools import partial
            from dgllife.utils import potentialNet_graph_construction_featurization
            dataset = PDBBind(subset=args['subset'], pdb_version=args['version'],
                    load_binding_pocket=args['load_binding_pocket'], 
                    construct_graph_and_featurize = partial(potentialNet_graph_construction_featurization, 
                        distance_bins=args['distance_bins'],
                        max_num_neighbors=args['max_num_neighbors'])
                        )
        elif args['model'] =='ACNN':
            dataset = PDBBind(subset=args['subset'], pdb_version=args['version'],
                          load_binding_pocket=args['load_binding_pocket'],
                          )

        if args['split'] == 'random':
            train_set, val_set, test_set = RandomSplitter.train_val_test_split(
                dataset,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'scaffold':
            train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
                dataset,
                mols=dataset.ligand_mols,
                sanitize=False,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'])

        elif args['split'] == 'stratified':
            train_set, val_set, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
                dataset,
                labels=dataset.labels,
                task_id=0,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'temporal':
            years = dataset.df['release_year'].values.astype(np.float32)
            indices = np.argsort(years).tolist()
            frac_list = np.array([args['frac_train'], args['frac_val'], args['frac_test']])
            num_data = len(dataset)
            lengths = (num_data * frac_list).astype(int)
            lengths[-1] = num_data - np.sum(lengths[:-1])
            train_set, val_set, test_set = [
                Subset(dataset, list(indices[offset - length:offset]))
                for offset, length in zip(accumulate(lengths), lengths)]

        else:
            raise ValueError('Expect the splitting method '
                             'to be "random", "scaffold", "stratified" or "temporal", got {}'.format(args['split']))
        if args['frac_train'] > 0:
            train_labels = torch.stack([train_set.dataset.labels[i] for i in train_set.indices])
            train_set.labels_mean = train_labels.mean(dim=0)
            train_set.labels_std = train_labels.std(dim=0)

    return dataset, train_set, val_set, test_set
Beispiel #5
0
def load_dataset(args):
    """Load the dataset.
    Parameters
    ----------
    args : dict
        Input arguments.
    Returns
    -------
    dataset
        Full dataset.
    train_set
        Train subset of the dataset.
    val_set
        Validation subset of the dataset.
    """
    assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(
        args['dataset'])
    if args['dataset'] == 'PDBBind':
        dataset = PDBBind(subset=args['subset'],
                          load_binding_pocket=args['load_binding_pocket'],
                          zero_padding=True)
        # No validation set is used and frac_val = 0.
        if args['split'] == 'random':
            train_set, _, test_set = RandomSplitter.train_val_test_split(
                dataset,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'scaffold':
            train_set, _, test_set = ScaffoldSplitter.train_val_test_split(
                dataset,
                mols=dataset.ligand_mols,
                sanitize=False,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'])

        elif args['split'] == 'stratified':
            train_set, _, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
                dataset,
                labels=dataset.labels,
                task_id=0,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'temporal':
            years = dataset.df['release_year'].values.astype(np.float32)
            indices = np.argsort(years).tolist()
            frac_list = np.array(
                [args['frac_train'], args['frac_val'], args['frac_test']])
            num_data = len(dataset)
            lengths = (num_data * frac_list).astype(int)
            lengths[-1] = num_data - np.sum(lengths[:-1])
            train_set, val_set, test_set = [
                Subset(dataset, list(indices[offset - length:offset]))
                for offset, length in zip(accumulate(lengths), lengths)
            ]

        else:
            raise ValueError('Expect the splitting method '
                             'to be "random" or "scaffold", got {}'.format(
                                 args['split']))
        train_labels = torch.stack(
            [train_set.dataset.labels[i] for i in train_set.indices])
        train_set.labels_mean = train_labels.mean(dim=0)
        train_set.labels_std = train_labels.std(dim=0)

    return dataset, train_set, test_set