def main(args, exp_config, train_set, val_set, test_set): if args['featurizer_type'] != 'pre_train': exp_config['in_node_feats'] = args['node_featurizer'].feat_size() if args['edge_featurizer'] is not None: exp_config['in_edge_feats'] = args['edge_featurizer'].feat_size() exp_config.update({ 'n_tasks': args['n_tasks'], 'model': args['model'] }) train_loader = DataLoader(dataset=train_set, batch_size=exp_config['batch_size'], shuffle=True, collate_fn=collate_molgraphs, num_workers=args['num_workers']) val_loader = DataLoader(dataset=val_set, batch_size=exp_config['batch_size'], collate_fn=collate_molgraphs, num_workers=args['num_workers']) test_loader = DataLoader(dataset=test_set, batch_size=exp_config['batch_size'], collate_fn=collate_molgraphs, num_workers=args['num_workers']) if args['pretrain']: args['num_epochs'] = 0 if args['featurizer_type'] == 'pre_train': model = load_pretrained('{}_{}'.format( args['model'], args['dataset'])).to(args['device']) else: model = load_pretrained('{}_{}_{}'.format( args['model'], args['featurizer_type'], args['dataset'])).to(args['device']) else: model = load_model(exp_config).to(args['device']) loss_criterion = nn.SmoothL1Loss(reduction='none') optimizer = Adam(model.parameters(), lr=exp_config['lr'], weight_decay=exp_config['weight_decay']) stopper = EarlyStopping(patience=exp_config['patience'], filename=args['result_path'] + '/model.pth', metric=args['metric']) for epoch in range(args['num_epochs']): # Train run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader) early_stop = stopper.step(val_score, model) print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format( epoch + 1, args['num_epochs'], args['metric'], val_score, args['metric'], stopper.best_score)) if early_stop: break if not args['pretrain']: stopper.load_checkpoint(model) val_score = run_an_eval_epoch(args, model, val_loader) test_score = run_an_eval_epoch(args, model, test_loader) print('val {} {:.4f}'.format(args['metric'], val_score)) print('test {} {:.4f}'.format(args['metric'], test_score)) with open(args['result_path'] + '/eval.txt', 'w') as f: if not args['pretrain']: f.write('Best val {}: {}\n'.format(args['metric'], stopper.best_score)) f.write('Val {}: {}\n'.format(args['metric'], val_score)) f.write('Test {}: {}\n'.format(args['metric'], test_score))
def test_dgmg(): model = load_pretrained('DGMG_ZINC_canonical') run_dgmg_ZINC(model) model = load_pretrained('DGMG_ZINC_random') run_dgmg_ZINC(model) model = load_pretrained('DGMG_ChEMBL_canonical') run_dgmg_ChEMBL(model) model = load_pretrained('DGMG_ChEMBL_random') run_dgmg_ChEMBL(model) remove_file('DGMG_ChEMBL_canonical_pre_trained.pth') remove_file('DGMG_ChEMBL_random_pre_trained.pth') remove_file('DGMG_ZINC_canonical_pre_trained.pth') remove_file('DGMG_ZINC_random_pre_trained.pth')
def prepare_for_evaluation(rank, args): worker_seed = args['seed'] + rank * 10000 set_random_seed(worker_seed) torch.set_num_threads(1) # Setup dataset and data loader dataset = MoleculeDataset(args['dataset'], subset_id=rank, n_subsets=args['num_processes']) # Initialize model if not args['pretrained']: model = DGMG(atom_types=dataset.atom_types, bond_types=dataset.bond_types, node_hidden_size=args['node_hidden_size'], num_prop_rounds=args['num_propagation_rounds'], dropout=args['dropout']) model.load_state_dict( torch.load(args['model_path'])['model_state_dict']) else: model = load_pretrained('_'.join( ['DGMG', args['dataset'], args['order']]), log=False) model.eval() worker_num_samples = args['num_samples'] // args['num_processes'] if rank == args['num_processes'] - 1: worker_num_samples += args['num_samples'] % args['num_processes'] worker_log_dir = os.path.join(args['log_dir'], str(rank)) mkdir_p(worker_log_dir, log=False) generate_and_save(worker_log_dir, worker_num_samples, args['max_num_steps'], model)
def test_attentivefp_aromaticity(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') node_featurizer = BaseAtomFeaturizer( featurizer_funcs={'hv': ConcatFeaturizer([ partial(atom_type_one_hot, allowable_set=[ 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'], encode_unknown=True), partial(atom_degree_one_hot, allowable_set=list(range(6))), atom_formal_charge, atom_num_radical_electrons, partial(atom_hybridization_one_hot, encode_unknown=True), lambda atom: [0], # A placeholder for aromatic information, atom_total_num_H_one_hot, chirality ], )} ) edge_featurizer = BaseBondFeaturizer({ 'he': lambda bond: [0 for _ in range(10)] }) g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('AttentiveFP_Aromaticity').to(device) model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device)) remove_file('AttentiveFP_Aromaticity_pre_trained.pth')
def main(args, dataset): data_loader = DataLoader(dataset, batch_size=args['batch_size'], collate_fn=collate, shuffle=False) model = load_pretrained(args['model']).to(args['device']) model.eval() readout = AvgPooling() mol_emb = [] for batch_id, bg in enumerate(data_loader): print('Processing batch {:d}/{:d}'.format(batch_id + 1, len(data_loader))) bg = bg.to(args['device']) nfeats = [ bg.ndata.pop('atomic_number').to(args['device']), bg.ndata.pop('chirality_type').to(args['device']) ] efeats = [ bg.edata.pop('bond_type').to(args['device']), bg.edata.pop('bond_direction_type').to(args['device']) ] with torch.no_grad(): node_repr = model(bg, nfeats, efeats) mol_emb.append(readout(bg, node_repr)) mol_emb = torch.cat(mol_emb, dim=0).detach().cpu().numpy() np.save(args['out_dir'] + '/mol_emb.npy', mol_emb)
def test_weave_tox21(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') node_featurizer = WeaveAtomFeaturizer() edge_featurizer = WeaveEdgeFeaturizer(max_distance=2) g1 = smiles_to_complete_graph('CO', node_featurizer=node_featurizer, edge_featurizer=edge_featurizer, add_self_loop=True) g2 = smiles_to_complete_graph('CCO', node_featurizer=node_featurizer, edge_featurizer=edge_featurizer, add_self_loop=True) bg = dgl.batch([g1, g2]) model = load_pretrained('Weave_Tox21').to(device) model(bg.to(device), bg.ndata.pop('h').to(device), bg.edata.pop('e').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('h').to(device), g1.edata.pop('e').to(device)) remove_file('Weave_Tox21_pre_trained.pth')
def test_jtnn(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') model = load_pretrained('JTNN_ZINC_no_kl').to(device)
def main(args): args['device'] = torch.device( "cuda: 0") if torch.cuda.is_available() else torch.device("cpu") set_random_seed(args['random_seed']) dataset = PubChemBioAssayAromaticity( smiles_to_graph=args['smiles_to_graph'], node_featurizer=args.get('node_featurizer', None), edge_featurizer=args.get('edge_featurizer', None)) train_set, val_set, test_set = RandomSplitter.train_val_test_split( dataset, frac_train=args['frac_train'], frac_val=args['frac_val'], frac_test=args['frac_test'], random_state=args['random_seed']) train_loader = DataLoader(dataset=train_set, batch_size=args['batch_size'], shuffle=True, collate_fn=collate_molgraphs) val_loader = DataLoader(dataset=val_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) if args['pre_trained']: args['num_epochs'] = 0 model = load_pretrained(args['exp']) else: model = load_model(args) loss_fn = nn.MSELoss(reduction='none') optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) stopper = EarlyStopping(mode=args['mode'], patience=args['patience']) model.to(args['device']) for epoch in range(args['num_epochs']): # Train run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader) early_stop = stopper.step(val_score, model) print( 'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'. format(epoch + 1, args['num_epochs'], args['metric_name'], val_score, args['metric_name'], stopper.best_score)) if early_stop: break if not args['pre_trained']: stopper.load_checkpoint(model) test_score = run_an_eval_epoch(args, model, test_loader) print('test {} {:.4f}'.format(args['metric_name'], test_score))
def test_jtnn(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') model = load_pretrained('JTNN_ZINC').to(device) remove_file('JTNN_ZINC_pre_trained.pth')
def __init__(self, predictor_dim=None): super(DGL_GIN_ContextPred, self).__init__() from dgllife.model import load_pretrained from dgl.nn.pytorch.glob import AvgPooling ## this is fixed hyperparameters as it is a pretrained model self.gnn = load_pretrained('gin_supervised_contextpred') self.readout = AvgPooling() self.transform = nn.Linear(300, predictor_dim)
def main(args): args['device'] = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") set_random_seed(args['random_seed']) train_set, val_set, test_set = load_dataset_for_regression(args) train_loader = DataLoader(dataset=train_set, batch_size=args['batch_size'], shuffle=True, collate_fn=collate_molgraphs) val_loader = DataLoader(dataset=val_set, batch_size=args['batch_size'], shuffle=True, collate_fn=collate_molgraphs) if test_set is not None: test_loader = DataLoader(dataset=test_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) if args['pre_trained']: args['num_epochs'] = 0 model = load_pretrained(args['exp']) else: model = load_model(args) loss_fn = nn.MSELoss(reduction='none') optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) stopper = EarlyStopping(mode='lower', patience=args['patience']) model.to(args['device']) for epoch in range(args['num_epochs']): # Train run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader) early_stop = stopper.step(val_score, model) print( 'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'. format(epoch + 1, args['num_epochs'], args['metric_name'], val_score, args['metric_name'], stopper.best_score)) if early_stop: break if test_set is not None: if not args['pre_trained']: stopper.load_checkpoint(model) test_score = run_an_eval_epoch(args, model, test_loader) print('test {} {:.4f}'.format(args['metric_name'], test_score))
def main(args): args['device'] = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") set_random_seed(args['random_seed']) # Interchangeable with other datasets dataset, train_set, val_set, test_set = load_dataset_for_classification( args) train_loader = DataLoader(train_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs, shuffle=True) val_loader = DataLoader(val_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) if args['pre_trained']: args['num_epochs'] = 0 model = load_pretrained(args['exp']) else: args['n_tasks'] = dataset.n_tasks model = load_model(args) loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights( torch.tensor(train_set.indices)).to(args['device']), reduction='none') optimizer = Adam(model.parameters(), lr=args['lr']) stopper = EarlyStopping(patience=args['patience']) model.to(args['device']) for epoch in range(args['num_epochs']): # Train run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader) early_stop = stopper.step(val_score, model) print( 'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'. format(epoch + 1, args['num_epochs'], args['metric_name'], val_score, args['metric_name'], stopper.best_score)) if early_stop: break if not args['pre_trained']: stopper.load_checkpoint(model) test_score = run_an_eval_epoch(args, model, test_loader) print('test {} {:.4f}'.format(args['metric_name'], test_score))
def test_gat_tox21(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') node_featurizer = CanonicalAtomFeaturizer() g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer) g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('GAT_Tox21').to(device) model(bg.to(device), bg.ndata.pop('h').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('h').to(device)) remove_file('GAT_Tox21_pre_trained.pth')
def main(args): set_seed() if torch.cuda.is_available(): args['device'] = torch.device('cuda:0') else: args['device'] = torch.device('cpu') # Set current device torch.cuda.set_device(args['device']) if args['test_path'] is None: test_set = USPTOCenter('test', num_processes=args['num_processes'], load=args['load']) else: test_set = WLNCenterDataset(raw_file_path=args['test_path'], mol_graph_path=args['test_path'] + '.bin', num_processes=args['num_processes'], load=args['load'], reaction_validity_result_prefix='test') test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate_center, shuffle=False) if args['model_path'] is None: model = load_pretrained('wln_center_uspto') else: model = WLNReactionCenter( node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_pair_in_feats=args['node_pair_in_feats'], node_out_feats=args['node_out_feats'], n_layers=args['n_layers'], n_tasks=args['n_tasks']) model.load_state_dict( torch.load(args['model_path'], map_location='cpu')['model_state_dict']) model = model.to(args['device']) print('Evaluation on the test set.') test_result = reaction_center_final_eval(args, args['top_ks_test'], model, test_loader, args['easy']) print(test_result) with open(args['result_path'] + '/test_eval.txt', 'w') as f: f.write(test_result)
def main(args, path_to_candidate_bonds): if args['test_path'] is None: test_set = USPTORank( subset='test', candidate_bond_path=path_to_candidate_bonds['test'], max_num_change_combos_per_reaction=args[ 'max_num_change_combos_per_reaction_eval'], num_processes=args['num_processes']) else: test_set = WLNRankDataset( path_to_reaction_file=args['test_path'], candidate_bond_path=path_to_candidate_bonds['test'], mode='test', max_num_change_combos_per_reaction=args[ 'max_num_change_combos_per_reaction_eval'], num_processes=args['num_processes']) test_loader = DataLoader(test_set, batch_size=1, collate_fn=collate_rank_eval, shuffle=False, num_workers=args['num_workers']) if args['num_workers'] > 1: torch.multiprocessing.set_sharing_strategy('file_system') if args['model_path'] is None: model = load_pretrained('wln_rank_uspto') else: model = WLNReactionRanking( node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_hidden_feats=args['hidden_size'], num_encode_gnn_layers=args['num_encode_gnn_layers']) model.load_state_dict( torch.load(args['model_path'], map_location='cpu')['model_state_dict']) model = model.to(args['device']) prediction_summary = candidate_ranking_eval(args, model, test_loader) with open(args['result_path'] + '/test_eval.txt', 'w') as f: f.write(prediction_summary)
def main(args): lg = rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL) if args.use_cpu or not torch.cuda.is_available(): device = torch.device('cpu') else: device = torch.device('cuda:0') vocab = JTVAEVocab(file_path=args.train_path) if args.test_path is None: dataset = JTVAEZINC('test', vocab) else: dataset = JTVAEDataset(args.test_path, vocab, training=False) dataloader = DataLoader(dataset, batch_size=1, collate_fn=JTVAECollator(training=False)) if args.model_path is None: model = load_pretrained('JTVAE_ZINC_no_kl') else: model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth) model.load_state_dict(torch.load(args.model_path, map_location='cpu')) model = model.to(device) acc = 0.0 for it, (tree, tree_graph, mol_graph) in enumerate(dataloader): tot = it + 1 smiles = tree.smiles tree_graph = tree_graph.to(device) mol_graph = mol_graph.to(device) dec_smiles = model.reconstruct(tree_graph, mol_graph) if dec_smiles == smiles: acc += 1 if tot % args.print_iter == 0: print('Iter {:d}/{:d} | Acc {:.4f}'.format( tot // args.print_iter, len(dataloader) // args.print_iter, acc / tot)) print('Final acc: {:.4f}'.format(acc / tot))
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False) vocab_file = dataset.vocab_file hidden_size = int(args.hidden_size) latent_size = int(args.latent_size) depth = int(args.depth) model = DGLJTNNVAE(vocab_file=vocab_file, hidden_size=hidden_size, latent_size=latent_size, depth=depth) if args.model_path is not None: model.load_state_dict(torch.load(args.model_path)) else: model = load_pretrained("JTNN_ZINC") model = cuda(model) model.eval() print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) MAX_EPOCH = 100 PRINT_ITER = 20 def reconstruct(): dataset.training = False dataloader = DataLoader( dataset, batch_size=1, shuffle=False,
def prepare_reaction_center(args, reaction_center_config): """Use a trained model for reaction center prediction to prepare candidate bonds. Parameters ---------- args : dict Configuration for the experiment. reaction_center_config : dict Configuration for the experiment on reaction center prediction. Returns ------- path_to_candidate_bonds : dict Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds. """ if args['center_model_path'] is None: reaction_center_model = load_pretrained('wln_center_uspto').to( args['device']) else: reaction_center_model = WLNReactionCenter( node_in_feats=reaction_center_config['node_in_feats'], edge_in_feats=reaction_center_config['edge_in_feats'], node_pair_in_feats=reaction_center_config['node_pair_in_feats'], node_out_feats=reaction_center_config['node_out_feats'], n_layers=reaction_center_config['n_layers'], n_tasks=reaction_center_config['n_tasks']) reaction_center_model.load_state_dict( torch.load(args['center_model_path'])['model_state_dict']) reaction_center_model = reaction_center_model.to(args['device']) reaction_center_model.eval() path_to_candidate_bonds = dict() for subset in ['train', 'val', 'test']: if '{}_path'.format(subset) not in args: continue path_to_candidate_bonds[subset] = args['result_path'] + \ '/{}_candidate_bonds.txt'.format(subset) if os.path.isfile(path_to_candidate_bonds[subset]): continue print('Processing subset {}...'.format(subset)) print('Stage 1/3: Loading dataset...') if args['{}_path'.format(subset)] is None: dataset = USPTOCenter(subset, num_processes=args['num_processes']) else: dataset = WLNCenterDataset( raw_file_path=args['{}_path'.format(subset)], mol_graph_path='{}.bin'.format(subset), num_processes=args['num_processes']) dataloader = DataLoader(dataset, batch_size=args['reaction_center_batch_size'], collate_fn=collate_center, shuffle=False) print('Stage 2/3: Performing model prediction...') output_strings = [] for batch_id, batch_data in enumerate(dataloader): print('Computing candidate bonds for batch {:d}/{:d}'.format( batch_id + 1, len(dataloader))) batch_reactions, batch_graph_edits, batch_mol_graphs, \ batch_complete_graphs, batch_atom_pair_labels = batch_data with torch.no_grad(): pred, biased_pred = reaction_center_prediction( args['device'], reaction_center_model, batch_mol_graphs, batch_complete_graphs) batch_size = len(batch_reactions) start = 0 for i in range(batch_size): end = start + batch_complete_graphs.batch_num_edges[i] output_strings.append( output_candidate_bonds_for_a_reaction( (batch_reactions[i], biased_pred[start:end, :].flatten(), batch_complete_graphs.batch_num_nodes[i]), reaction_center_config['max_k'])) start = end print('Stage 3/3: Output candidate bonds...') with open(path_to_candidate_bonds[subset], 'w') as f: for candidate_string in output_strings: f.write(candidate_string) del dataset del dataloader del reaction_center_model return path_to_candidate_bonds
def main(args): worker_init_fn(None) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') if args.model_path is not None: model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab), hidden_size=args.hidden_size, latent_size=args.latent_size, depth=args.depth) model.load_state_dict(torch.load(args.model_path)) else: model = load_pretrained("JTNN_ZINC") print("# model parameters: {:d}K".format( sum([x.nelement() for x in model.parameters()]) // 1000)) dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=False) dataloader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=JTVAECollator(False), worker_init_fn=worker_init_fn) # Just an example of molecule decoding; in reality you may want to sample # tree and molecule vectors. acc = 0.0 tot = 0 model = model.to(device) model.eval() for it, batch in enumerate(tqdm(dataloader)): gt_smiles = batch['mol_trees'][0].smiles batch = dataset.move_to_device(batch, device) try: _, tree_vec, mol_vec = model.encode(batch) tree_mean = model.T_mean(tree_vec) # Following Mueller et al. tree_log_var = -torch.abs(model.T_var(tree_vec)) epsilon = torch.randn(1, model.latent_size // 2).to(device) tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon mol_mean = model.G_mean(mol_vec) # Following Mueller et al. mol_log_var = -torch.abs(model.G_var(mol_vec)) epsilon = torch.randn(1, model.latent_size // 2).to(device) mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon dec_smiles = model.decode(tree_vec, mol_vec) if dec_smiles == gt_smiles: acc += 1 tot += 1 except Exception as e: print("Failed to encode: {}".format(gt_smiles)) print(e) if it % 20 == 1: print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format( it, len(dataloader), acc / tot)) print("Reconstruction Accuracy: {}".format(acc / tot))
def load_model(exp_configure): if exp_configure['model'] == 'GCN': from dgllife.model import GCNPredictor model = GCNPredictor( in_feats=exp_configure['in_node_feats'], hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'], activation=[F.relu] * exp_configure['num_gnn_layers'], residual=[exp_configure['residual']] * exp_configure['num_gnn_layers'], batchnorm=[exp_configure['batchnorm']] * exp_configure['num_gnn_layers'], dropout=[exp_configure['dropout']] * exp_configure['num_gnn_layers'], predictor_hidden_feats=exp_configure['predictor_hidden_feats'], predictor_dropout=exp_configure['dropout'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'GAT': from dgllife.model import GATPredictor model = GATPredictor( in_feats=exp_configure['in_node_feats'], hidden_feats=[exp_configure['gnn_hidden_feats']] * exp_configure['num_gnn_layers'], num_heads=[exp_configure['num_heads']] * exp_configure['num_gnn_layers'], feat_drops=[exp_configure['dropout']] * exp_configure['num_gnn_layers'], attn_drops=[exp_configure['dropout']] * exp_configure['num_gnn_layers'], alphas=[exp_configure['alpha']] * exp_configure['num_gnn_layers'], residuals=[exp_configure['residual']] * exp_configure['num_gnn_layers'], predictor_hidden_feats=exp_configure['predictor_hidden_feats'], predictor_dropout=exp_configure['dropout'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'Weave': from dgllife.model import WeavePredictor model = WeavePredictor( node_in_feats=exp_configure['in_node_feats'], edge_in_feats=exp_configure['in_edge_feats'], num_gnn_layers=exp_configure['num_gnn_layers'], gnn_hidden_feats=exp_configure['gnn_hidden_feats'], graph_feats=exp_configure['graph_feats'], gaussian_expand=exp_configure['gaussian_expand'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'MPNN': from dgllife.model import MPNNPredictor model = MPNNPredictor( node_in_feats=exp_configure['in_node_feats'], edge_in_feats=exp_configure['in_edge_feats'], node_out_feats=exp_configure['node_out_feats'], edge_hidden_feats=exp_configure['edge_hidden_feats'], num_step_message_passing=exp_configure['num_step_message_passing'], num_step_set2set=exp_configure['num_step_set2set'], num_layer_set2set=exp_configure['num_layer_set2set'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] == 'AttentiveFP': from dgllife.model import AttentiveFPPredictor model = AttentiveFPPredictor( node_feat_size=exp_configure['in_node_feats'], edge_feat_size=exp_configure['in_edge_feats'], num_layers=exp_configure['num_layers'], num_timesteps=exp_configure['num_timesteps'], graph_feat_size=exp_configure['graph_feat_size'], dropout=exp_configure['dropout'], n_tasks=exp_configure['n_tasks']) elif exp_configure['model'] in [ 'gin_supervised_contextpred', 'gin_supervised_infomax', 'gin_supervised_edgepred', 'gin_supervised_masking' ]: from dgllife.model import GINPredictor from dgllife.model import load_pretrained model = GINPredictor(num_node_emb_list=[120, 3], num_edge_emb_list=[6, 3], num_layers=5, emb_dim=300, JK=exp_configure['jk'], dropout=0.5, readout=exp_configure['readout'], n_tasks=exp_configure['n_tasks']) model.gnn = load_pretrained(exp_configure['model']) model.gnn.JK = exp_configure['jk'] else: return ValueError( "Expect model to be from ['GCN', 'GAT', 'Weave', 'MPNN', 'AttentiveFP', " "'gin_supervised_contextpred', 'gin_supervised_infomax', " "'gin_supervised_edgepred', 'gin_supervised_masking'], " "got {}".format(exp_configure['model'])) return model
def main(args): setup(args) if args['train_path'] is None: train_set = USPTO('train') else: train_set = WLNReactionDataset(raw_file_path=args['train_path'], mol_graph_path='train.bin') if args['val_path'] is None: val_set = USPTO('val') else: val_set = WLNReactionDataset(raw_file_path=args['val_path'], mol_graph_path='val.bin') if args['test_path'] is None: test_set = USPTO('test') else: test_set = WLNReactionDataset(raw_file_path=args['test_path'], mol_graph_path='test.bin') train_loader = DataLoader(train_set, batch_size=args['batch_size'], collate_fn=collate, shuffle=True) val_loader = DataLoader(val_set, batch_size=args['batch_size'], collate_fn=collate, shuffle=False) test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate, shuffle=False) if args['pre_trained']: model = load_pretrained('wln_center_uspto').to(args['device']) args['num_epochs'] = 0 else: model = WLNReactionCenter( node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_pair_in_feats=args['node_pair_in_feats'], node_out_feats=args['node_out_feats'], n_layers=args['n_layers'], n_tasks=args['n_tasks']).to(args['device']) criterion = BCEWithLogitsLoss(reduction='sum') optimizer = Adam(model.parameters(), lr=args['lr']) scheduler = StepLR(optimizer, step_size=args['decay_every'], gamma=args['lr_decay_factor']) total_iter = 0 grad_norm_sum = 0 loss_sum = 0 dur = [] for epoch in range(args['num_epochs']): t0 = time.time() for batch_id, batch_data in enumerate(train_loader): total_iter += 1 batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \ batch_complete_graphs, batch_atom_pair_labels = batch_data labels = batch_atom_pair_labels.to(args['device']) pred, biased_pred = reaction_center_prediction( args['device'], model, batch_mol_graphs, batch_complete_graphs) loss = criterion(pred, labels) / len(batch_reactions) loss_sum += loss.cpu().detach().data.item() optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), args['max_norm']) grad_norm_sum += grad_norm optimizer.step() scheduler.step() if total_iter % args['print_every'] == 0: progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \ 'loss {:.4f} | grad norm {:.4f}'.format( epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader), (np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'], grad_norm_sum / args['print_every']) grad_norm_sum = 0 loss_sum = 0 print(progress) if total_iter % args['decay_every'] == 0: torch.save(model.state_dict(), args['result_path'] + '/model.pkl') dur.append(time.time() - t0) print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \ rough_eval_on_a_loader(args, model, val_loader)) del train_loader del val_loader del train_set del val_set print('Evaluation on the test set.') test_result = reaction_center_final_eval(args, model, test_loader, args['easy']) print(test_result) with open(args['result_path'] + '/results.txt', 'w') as f: f.write(test_result)
def main(args): torch.cuda.set_device(args['gpu']) set_random_seed(args['random_seed']) dataset, train_set, val_set, test_set = load_dataset_for_classification( args) # 6264, 783, 784 train_loader = DataLoader(train_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs, shuffle=True) val_loader = DataLoader(val_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate_molgraphs) if args['pre_trained']: args['num_epochs'] = 0 model = load_pretrained(args['exp']) else: args['n_tasks'] = dataset.n_tasks if args['method'] == 'twp': model = load_mymodel(args) print(model) else: model = load_model(args) for name, parameters in model.named_parameters(): print(name, ':', parameters.size()) method = args['method'] life_model = importlib.import_module(f'LifeModel.{method}_model') life_model_ins = life_model.NET(model, args) data_loader = DataLoader(train_set, batch_size=len(train_set), collate_fn=collate_molgraphs, shuffle=True) life_model_ins.data_loader = data_loader loss_criterion = BCEWithLogitsLoss( pos_weight=dataset.task_pos_weights.cuda(), reduction='none') model.cuda() score_mean = [] score_matrix = np.zeros([args['n_tasks'], args['n_tasks']]) prev_model = None for task_i in range(12): print('\n********' + str(task_i)) stopper = EarlyStopping(patience=args['patience']) for epoch in range(args['num_epochs']): # Train if args['method'] == 'lwf': life_model_ins.observe(train_loader, loss_criterion, task_i, args, prev_model) else: life_model_ins.observe(train_loader, loss_criterion, task_i, args) # Validation and early stop val_score = run_an_eval_epoch(args, model, val_loader, task_i) early_stop = stopper.step(val_score, model) if early_stop: print(epoch) break if not args['pre_trained']: stopper.load_checkpoint(model) score_matrix[task_i] = run_eval_epoch(args, model, test_loader) prev_model = copy.deepcopy(life_model_ins).cuda() print('AP: ', round(np.mean(score_matrix[-1, :]), 4)) backward = [] for t in range(args['n_tasks'] - 1): b = score_matrix[args['n_tasks'] - 1][t] - score_matrix[t][t] backward.append(round(b, 4)) mean_backward = round(np.mean(backward), 4) print('AF: ', mean_backward)
def test_moleculenet(): if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') for dataset in [ 'BACE', 'BBBP', 'ClinTox', 'FreeSolv', 'HIV', 'MUV', 'SIDER', 'ToxCast', 'PCBA', 'ESOL', 'Lipophilicity', 'Tox21' ]: for featurizer_type in ['canonical', 'attentivefp']: if featurizer_type == 'canonical': node_featurizer = CanonicalAtomFeaturizer(atom_data_field='hv') edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True) else: node_featurizer = AttentiveFPAtomFeaturizer( atom_data_field='hv') edge_featurizer = AttentiveFPBondFeaturizer( bond_data_field='he', self_loop=True) for model_type in ['GCN', 'GAT']: g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer) g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('{}_{}_{}'.format( model_type, featurizer_type, dataset)).to(device) with torch.no_grad(): model(bg.to(device), bg.ndata.pop('hv').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('hv').to(device)) remove_file('{}_{}_{}_pre_trained.pth'.format( model_type.lower(), featurizer_type, dataset)) for model_type in ['Weave', 'MPNN', 'AttentiveFP']: g1 = smiles_to_bigraph('CO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) g2 = smiles_to_bigraph('CCO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('{}_{}_{}'.format( model_type, featurizer_type, dataset)).to(device) with torch.no_grad(): model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device)) model.eval() model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device)) remove_file('{}_{}_{}_pre_trained.pth'.format( model_type.lower(), featurizer_type, dataset)) if dataset == 'ClinTox': continue node_featurizer = PretrainAtomFeaturizer() edge_featurizer = PretrainBondFeaturizer() for model_type in [ 'gin_supervised_contextpred', 'gin_supervised_infomax', 'gin_supervised_edgepred', 'gin_supervised_masking' ]: g1 = smiles_to_bigraph('CO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) g2 = smiles_to_bigraph('CCO', add_self_loop=True, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer) bg = dgl.batch([g1, g2]) model = load_pretrained('{}_{}'.format(model_type, dataset)).to(device) with torch.no_grad(): node_feats = [ bg.ndata.pop('atomic_number').to(device), bg.ndata.pop('chirality_type').to(device) ] edge_feats = [ bg.edata.pop('bond_type').to(device), bg.edata.pop('bond_direction_type').to(device) ] model(bg.to(device), node_feats, edge_feats) model.eval() node_feats = [ g1.ndata.pop('atomic_number').to(device), g1.ndata.pop('chirality_type').to(device) ] edge_feats = [ g1.edata.pop('bond_type').to(device), g1.edata.pop('bond_direction_type').to(device) ] model(g1.to(device), node_feats, edge_feats) remove_file('{}_{}_pre_trained.pth'.format(model_type.lower(), dataset))