def load_dataset_for_classification(args): """Load dataset for classification tasks. Parameters ---------- args : dict Configurations. Returns ------- dataset The whole dataset. train_set Subset for training. val_set Subset for validation. test_set Subset for test. """ assert args['dataset'] in ['Tox21'] if args['dataset'] == 'Tox21': from dgl.data.chem import Tox21 dataset = Tox21(atom_featurizer=args['atom_featurizer']) train_set, val_set, test_set = split_dataset(dataset, args['train_val_test_split']) return dataset, train_set, val_set, test_set
def load_dataset_for_classification(args): """Load dataset for classification tasks. Parameters ---------- args : dict Configurations. Returns ------- dataset The whole dataset. train_set Subset for training. val_set Subset for validation. test_set Subset for test. """ assert args['dataset'] in ['Tox21'] if args['dataset'] == 'Tox21': from dgl.data.chem import Tox21 dataset = Tox21(smiles_to_bigraph, args['atom_featurizer']) train_set, val_set, test_set = RandomSplitter.train_val_test_split( dataset, frac_train=args['frac_train'], frac_val=args['frac_val'], frac_test=args['frac_test'], random_state=args['random_seed']) return dataset, train_set, val_set, test_set
def main(args): args['device'] = "cuda" if torch.cuda.is_available() else "cpu" set_random_seed() # Interchangeable with other datasets if args['dataset'] == 'Tox21': from dgl.data.chem import Tox21 dataset = Tox21() trainset, valset, testset = split_dataset(dataset, args['train_val_test_split']) train_loader = DataLoader(trainset, batch_size=args['batch_size'], collate_fn=collate_molgraphs_for_classification) val_loader = DataLoader(valset, batch_size=args['batch_size'], collate_fn=collate_molgraphs_for_classification) test_loader = DataLoader(testset, batch_size=args['batch_size'], collate_fn=collate_molgraphs_for_classification) if args['pre_trained']: args['num_epochs'] = 0 model = model_zoo.chem.load_pretrained(args['exp']) else: # Interchangeable with other models if args['model'] == 'GCN': model = model_zoo.chem.GCNClassifier(in_feats=args['in_feats'], gcn_hidden_feats=args['gcn_hidden_feats'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=dataset.n_tasks) elif args['model'] == 'GAT': model = model_zoo.chem.GATClassifier(in_feats=args['in_feats'], gat_hidden_feats=args['gat_hidden_feats'], num_heads=args['num_heads'], classifier_hidden_feats=args['classifier_hidden_feats'], n_tasks=dataset.n_tasks) loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights.to(args['device']), reduction='none') optimizer = Adam(model.parameters(), lr=args['lr']) stopper = EarlyStopping(patience=args['patience']) model.to(args['device']) for epoch in range(args['num_epochs']): # Train run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) # Validation and early stop val_roc_auc = run_an_eval_epoch(args, model, val_loader) early_stop = stopper.step(val_roc_auc, model) print('epoch {:d}/{:d}, validation roc-auc score {:.4f}, best validation roc-auc score {:.4f}'.format( epoch + 1, args['num_epochs'], val_roc_auc, stopper.best_score)) if early_stop: break if not args['pre_trained']: stopper.load_checkpoint(model) test_roc_auc = run_an_eval_epoch(args, model, test_loader) print('test roc-auc score {:.4f}'.format(test_roc_auc))
def main(args): args = setup(args) dataset = Tox21() train_set, val_set, test_set = split_dataset(dataset, shuffle=True) train_loader = DataLoader(train_set, batch_size=args['batch_size'], shuffle=True, collate_fn=collate_molgraphs) val_loader = DataLoader(val_set, batch_size=args['batch_size'], shuffle=True, collate_fn=collate_molgraphs) test_loader = DataLoader(test_set, batch_size=args['batch_size'], shuffle=True, collate_fn=collate_molgraphs) model = model_zoo.chem.GCNClassifier( in_feats=args['n_input'], gcn_hidden_feats=[args['n_hidden'] for _ in range(args['n_layers'])], n_tasks=dataset.n_tasks, classifier_hidden_feats=args['n_hidden']).to(args['device']) loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor( dataset.task_pos_weights).to(args['device']), reduction='none') optimizer = Adam(model.parameters(), lr=args['lr']) stopper = EarlyStopper(args['patience']) history = [] for epoch in range(args['n_epochs']): # Train train_score = run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader) history.append([train_score, val_score]) early_stop = stopper.step(val_score, model) print( 'epoch {:d}/{:d}, validation roc-auc {:.4f}, best validation roc-auc {:.4f}' .format(epoch + 1, args['n_epochs'], val_score, stopper.best_score)) torch.save(history, "./history.pt") if early_stop: break stopper.load_checkpoint(model) test_score = run_an_eval_epoch(args, model, test_loader) plot_save(history) print('Best validation score {:.4f}'.format(stopper.best_score)) print('Test score {:.4f}'.format(test_score))
def main(args): args = setup(args) dataset = Tox21() train_set, val_set, test_set = split_dataset(dataset, shuffle=True) train_loader = DataLoader(train_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_molgraphs) val_loader = DataLoader(val_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_molgraphs) test_loader = DataLoader(test_set, batch_size=args["batch_size"], shuffle=True, collate_fn=collate_molgraphs) model = model_zoo.chem.GCNClassifier( in_feats=args["n_input"], gcn_hidden_feats=[args["n_hidden"] for _ in range(args["n_layers"])], n_tasks=dataset.n_tasks, classifier_hidden_feats=args["n_hidden"], ).to(args["device"]) loss_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor( dataset.task_pos_weights).to(args["device"]), reduction="none") optimizer = Adam(model.parameters(), lr=args["lr"]) stopper = EarlyStopper(args["patience"]) for epoch in range(args["n_epochs"]): # Train run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader) early_stop = stopper.step(val_score, model) print( "epoch {:d}/{:d}, validation roc-auc {:.4f}, best validation roc-auc {:.4f}" .format(epoch + 1, args["n_epochs"], val_score, stopper.best_score)) if early_stop: break stopper.load_checkpoint(model) test_score = run_an_eval_epoch(args, model, test_loader) print("Best validation score {:.4f}".format(stopper.best_score)) print("Test score {:.4f}".format(test_score))