def load_dataset_for_classification(args): """Load dataset for classification tasks. Parameters ---------- args : dict Configurations. Returns ------- dataset The whole dataset. train_set Subset for training. val_set Subset for validation. test_set Subset for test. """ assert args['dataset'] in ['Tox21'] if args['dataset'] == 'Tox21': from dgllife.data import Tox21 dataset = Tox21(smiles_to_bigraph, args['atom_featurizer']) train_set, val_set, test_set = RandomSplitter.train_val_test_split( dataset, frac_train=args['frac_train'], frac_val=args['frac_val'], frac_test=args['frac_test'], random_state=args['random_seed']) return dataset, train_set, val_set, test_set
def load_dataset_for_classification(args): """Load dataset for classification tasks. Parameters ---------- args : dict Configurations. Returns ------- dataset The whole dataset. train_set Subset for training. val_set Subset for validation. test_set Subset for test. """ assert args['dataset'] in ['Tox21'] if args['dataset'] == 'Tox21': from dgllife.data import Tox21 dataset = Tox21(smiles_to_graph=args['smiles_to_graph'], node_featurizer=args.get('node_featurizer', None), edge_featurizer=args.get('edge_featurizer', None), load=False, cache_file_path=args['exp']) train_set, val_set, test_set = RandomSplitter.train_val_test_split( dataset, frac_train=args['frac_train'], frac_val=args['frac_val'], frac_test=args['frac_test'], random_state=args['random_seed']) return dataset, train_set, val_set, test_set
from dgllife.data import ToxCast dataset = ToxCast(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), node_featurizer=args['node_featurizer'], edge_featurizer=args['edge_featurizer'], n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'HIV': from dgllife.data import HIV dataset = HIV(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), node_featurizer=args['node_featurizer'], edge_featurizer=args['edge_featurizer'], n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'PCBA': from dgllife.data import PCBA dataset = PCBA(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), node_featurizer=args['node_featurizer'], edge_featurizer=args['edge_featurizer'], n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) elif args['dataset'] == 'Tox21': from dgllife.data import Tox21 dataset = Tox21(smiles_to_graph=partial(smiles_to_bigraph, add_self_loop=True), node_featurizer=args['node_featurizer'], edge_featurizer=args['edge_featurizer'], n_jobs=1 if args['num_workers'] == 0 else args['num_workers']) else: raise ValueError('Unexpected dataset: {}'.format(args['dataset'])) args['n_tasks'] = dataset.n_tasks train_set, val_set, test_set = split_dataset(args, dataset) exp_config = get_configure(args['model'], args['featurizer_type'], args['dataset']) main(args, exp_config, train_set, val_set, test_set)