Exemplo n.º 1
0
def split_dataset(args, dataset):
    """Split the dataset

    Parameters
    ----------
    args : dict
        Settings
    dataset
        Dataset instance

    Returns
    -------
    train_set
        Training subset
    val_set
        Validation subset
    test_set
        Test subset
    """
    train_ratio, val_ratio, test_ratio = map(float, args['split_ratio'].split(','))
    if args['split'] == 'scaffold':
        train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio,
            scaffold_func='smiles')
    elif args['split'] == 'random':
        train_set, val_set, test_set = RandomSplitter.train_val_test_split(
            dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio)
    else:
        return ValueError("Expect the splitting method to be 'scaffold', got {}".format(args['split']))

    return train_set, val_set, test_set
Exemplo n.º 2
0
def load_dataset(args):
    """Load the dataset.
    Parameters
    ----------
    args : dict
        Input arguments.
    Returns
    -------
    dataset
        Full dataset.
    train_set
        Train subset of the dataset.
    val_set
        Validation subset of the dataset.
    """
    assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(args['dataset'])
    if args['dataset'] == 'PDBBind':
        if args['model'] == 'PotentialNet': 
            from functools import partial
            from dgllife.utils import potentialNet_graph_construction_featurization
            dataset = PDBBind(subset=args['subset'], pdb_version=args['version'],
                    load_binding_pocket=args['load_binding_pocket'], 
                    construct_graph_and_featurize = partial(potentialNet_graph_construction_featurization, 
                        distance_bins=args['distance_bins'],
                        max_num_neighbors=args['max_num_neighbors'])
                        )
        elif args['model'] =='ACNN':
            dataset = PDBBind(subset=args['subset'], pdb_version=args['version'],
                          load_binding_pocket=args['load_binding_pocket'],
                          )

        if args['split'] == 'random':
            train_set, val_set, test_set = RandomSplitter.train_val_test_split(
                dataset,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'scaffold':
            train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
                dataset,
                mols=dataset.ligand_mols,
                sanitize=False,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'])

        elif args['split'] == 'stratified':
            train_set, val_set, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
                dataset,
                labels=dataset.labels,
                task_id=0,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'temporal':
            years = dataset.df['release_year'].values.astype(np.float32)
            indices = np.argsort(years).tolist()
            frac_list = np.array([args['frac_train'], args['frac_val'], args['frac_test']])
            num_data = len(dataset)
            lengths = (num_data * frac_list).astype(int)
            lengths[-1] = num_data - np.sum(lengths[:-1])
            train_set, val_set, test_set = [
                Subset(dataset, list(indices[offset - length:offset]))
                for offset, length in zip(accumulate(lengths), lengths)]

        else:
            raise ValueError('Expect the splitting method '
                             'to be "random", "scaffold", "stratified" or "temporal", got {}'.format(args['split']))
        if args['frac_train'] > 0:
            train_labels = torch.stack([train_set.dataset.labels[i] for i in train_set.indices])
            train_set.labels_mean = train_labels.mean(dim=0)
            train_set.labels_std = train_labels.std(dim=0)

    return dataset, train_set, val_set, test_set
Exemplo n.º 3
0
def load_dataset(args):
    """Load the dataset.
    Parameters
    ----------
    args : dict
        Input arguments.
    Returns
    -------
    dataset
        Full dataset.
    train_set
        Train subset of the dataset.
    val_set
        Validation subset of the dataset.
    """
    assert args['dataset'] in ['PDBBind'], 'Unexpected dataset {}'.format(
        args['dataset'])
    if args['dataset'] == 'PDBBind':
        dataset = PDBBind(subset=args['subset'],
                          load_binding_pocket=args['load_binding_pocket'],
                          zero_padding=True)
        # No validation set is used and frac_val = 0.
        if args['split'] == 'random':
            train_set, _, test_set = RandomSplitter.train_val_test_split(
                dataset,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'scaffold':
            train_set, _, test_set = ScaffoldSplitter.train_val_test_split(
                dataset,
                mols=dataset.ligand_mols,
                sanitize=False,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'])

        elif args['split'] == 'stratified':
            train_set, _, test_set = SingleTaskStratifiedSplitter.train_val_test_split(
                dataset,
                labels=dataset.labels,
                task_id=0,
                frac_train=args['frac_train'],
                frac_val=args['frac_val'],
                frac_test=args['frac_test'],
                random_state=args['random_seed'])

        elif args['split'] == 'temporal':
            years = dataset.df['release_year'].values.astype(np.float32)
            indices = np.argsort(years).tolist()
            frac_list = np.array(
                [args['frac_train'], args['frac_val'], args['frac_test']])
            num_data = len(dataset)
            lengths = (num_data * frac_list).astype(int)
            lengths[-1] = num_data - np.sum(lengths[:-1])
            train_set, val_set, test_set = [
                Subset(dataset, list(indices[offset - length:offset]))
                for offset, length in zip(accumulate(lengths), lengths)
            ]

        else:
            raise ValueError('Expect the splitting method '
                             'to be "random" or "scaffold", got {}'.format(
                                 args['split']))
        train_labels = torch.stack(
            [train_set.dataset.labels[i] for i in train_set.indices])
        train_set.labels_mean = train_labels.mean(dim=0)
        train_set.labels_std = train_labels.std(dim=0)

    return dataset, train_set, test_set