Example #1
0
def main(args, exp_config, train_set, val_set, test_set):
    if args['featurizer_type'] != 'pre_train':
        exp_config['in_node_feats'] = args['node_featurizer'].feat_size()
        if args['edge_featurizer'] is not None:
            exp_config['in_edge_feats'] = args['edge_featurizer'].feat_size()
    exp_config.update({
        'n_tasks': args['n_tasks'],
        'model': args['model']
    })

    train_loader = DataLoader(dataset=train_set, batch_size=exp_config['batch_size'], shuffle=True,
                              collate_fn=collate_molgraphs, num_workers=args['num_workers'])
    val_loader = DataLoader(dataset=val_set, batch_size=exp_config['batch_size'],
                            collate_fn=collate_molgraphs, num_workers=args['num_workers'])
    test_loader = DataLoader(dataset=test_set, batch_size=exp_config['batch_size'],
                             collate_fn=collate_molgraphs, num_workers=args['num_workers'])

    if args['pretrain']:
        args['num_epochs'] = 0
        if args['featurizer_type'] == 'pre_train':
            model = load_pretrained('{}_{}'.format(
                args['model'], args['dataset'])).to(args['device'])
        else:
            model = load_pretrained('{}_{}_{}'.format(
                args['model'], args['featurizer_type'], args['dataset'])).to(args['device'])
    else:
        model = load_model(exp_config).to(args['device'])
        loss_criterion = nn.SmoothL1Loss(reduction='none')
        optimizer = Adam(model.parameters(), lr=exp_config['lr'],
                         weight_decay=exp_config['weight_decay'])
        stopper = EarlyStopping(patience=exp_config['patience'],
                                filename=args['result_path'] + '/model.pth',
                                metric=args['metric'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.format(
            epoch + 1, args['num_epochs'], args['metric'],
            val_score, args['metric'], stopper.best_score))

        if early_stop:
            break

    if not args['pretrain']:
        stopper.load_checkpoint(model)
    val_score = run_an_eval_epoch(args, model, val_loader)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('val {} {:.4f}'.format(args['metric'], val_score))
    print('test {} {:.4f}'.format(args['metric'], test_score))

    with open(args['result_path'] + '/eval.txt', 'w') as f:
        if not args['pretrain']:
            f.write('Best val {}: {}\n'.format(args['metric'], stopper.best_score))
        f.write('Val {}: {}\n'.format(args['metric'], val_score))
        f.write('Test {}: {}\n'.format(args['metric'], test_score))
Example #2
0
def test_dgmg():
    model = load_pretrained('DGMG_ZINC_canonical')
    run_dgmg_ZINC(model)
    model = load_pretrained('DGMG_ZINC_random')
    run_dgmg_ZINC(model)
    model = load_pretrained('DGMG_ChEMBL_canonical')
    run_dgmg_ChEMBL(model)
    model = load_pretrained('DGMG_ChEMBL_random')
    run_dgmg_ChEMBL(model)

    remove_file('DGMG_ChEMBL_canonical_pre_trained.pth')
    remove_file('DGMG_ChEMBL_random_pre_trained.pth')
    remove_file('DGMG_ZINC_canonical_pre_trained.pth')
    remove_file('DGMG_ZINC_random_pre_trained.pth')
Example #3
0
def prepare_for_evaluation(rank, args):
    worker_seed = args['seed'] + rank * 10000
    set_random_seed(worker_seed)
    torch.set_num_threads(1)

    # Setup dataset and data loader
    dataset = MoleculeDataset(args['dataset'],
                              subset_id=rank,
                              n_subsets=args['num_processes'])

    # Initialize model
    if not args['pretrained']:
        model = DGMG(atom_types=dataset.atom_types,
                     bond_types=dataset.bond_types,
                     node_hidden_size=args['node_hidden_size'],
                     num_prop_rounds=args['num_propagation_rounds'],
                     dropout=args['dropout'])
        model.load_state_dict(
            torch.load(args['model_path'])['model_state_dict'])
    else:
        model = load_pretrained('_'.join(
            ['DGMG', args['dataset'], args['order']]),
                                log=False)
    model.eval()

    worker_num_samples = args['num_samples'] // args['num_processes']
    if rank == args['num_processes'] - 1:
        worker_num_samples += args['num_samples'] % args['num_processes']

    worker_log_dir = os.path.join(args['log_dir'], str(rank))
    mkdir_p(worker_log_dir, log=False)
    generate_and_save(worker_log_dir, worker_num_samples,
                      args['max_num_steps'], model)
Example #4
0
def test_attentivefp_aromaticity():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    node_featurizer = BaseAtomFeaturizer(
        featurizer_funcs={'hv': ConcatFeaturizer([
            partial(atom_type_one_hot, allowable_set=[
                'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],
                    encode_unknown=True),
            partial(atom_degree_one_hot, allowable_set=list(range(6))),
            atom_formal_charge, atom_num_radical_electrons,
            partial(atom_hybridization_one_hot, encode_unknown=True),
            lambda atom: [0], # A placeholder for aromatic information,
            atom_total_num_H_one_hot, chirality
        ],
        )}
    )
    edge_featurizer = BaseBondFeaturizer({
        'he': lambda bond: [0 for _ in range(10)]
    })
    g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer,
                           edge_featurizer=edge_featurizer)
    g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer,
                           edge_featurizer=edge_featurizer)
    bg = dgl.batch([g1, g2])

    model = load_pretrained('AttentiveFP_Aromaticity').to(device)
    model(bg.to(device), bg.ndata.pop('hv').to(device), bg.edata.pop('he').to(device))
    model.eval()
    model(g1.to(device), g1.ndata.pop('hv').to(device), g1.edata.pop('he').to(device))

    remove_file('AttentiveFP_Aromaticity_pre_trained.pth')
Example #5
0
def main(args, dataset):
    data_loader = DataLoader(dataset,
                             batch_size=args['batch_size'],
                             collate_fn=collate,
                             shuffle=False)
    model = load_pretrained(args['model']).to(args['device'])
    model.eval()
    readout = AvgPooling()

    mol_emb = []
    for batch_id, bg in enumerate(data_loader):
        print('Processing batch {:d}/{:d}'.format(batch_id + 1,
                                                  len(data_loader)))
        bg = bg.to(args['device'])
        nfeats = [
            bg.ndata.pop('atomic_number').to(args['device']),
            bg.ndata.pop('chirality_type').to(args['device'])
        ]
        efeats = [
            bg.edata.pop('bond_type').to(args['device']),
            bg.edata.pop('bond_direction_type').to(args['device'])
        ]
        with torch.no_grad():
            node_repr = model(bg, nfeats, efeats)
        mol_emb.append(readout(bg, node_repr))
    mol_emb = torch.cat(mol_emb, dim=0).detach().cpu().numpy()
    np.save(args['out_dir'] + '/mol_emb.npy', mol_emb)
Example #6
0
def test_weave_tox21():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    node_featurizer = WeaveAtomFeaturizer()
    edge_featurizer = WeaveEdgeFeaturizer(max_distance=2)
    g1 = smiles_to_complete_graph('CO',
                                  node_featurizer=node_featurizer,
                                  edge_featurizer=edge_featurizer,
                                  add_self_loop=True)
    g2 = smiles_to_complete_graph('CCO',
                                  node_featurizer=node_featurizer,
                                  edge_featurizer=edge_featurizer,
                                  add_self_loop=True)
    bg = dgl.batch([g1, g2])

    model = load_pretrained('Weave_Tox21').to(device)
    model(bg.to(device),
          bg.ndata.pop('h').to(device),
          bg.edata.pop('e').to(device))
    model.eval()
    model(g1.to(device),
          g1.ndata.pop('h').to(device),
          g1.edata.pop('e').to(device))

    remove_file('Weave_Tox21_pre_trained.pth')
Example #7
0
def test_jtnn():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    model = load_pretrained('JTNN_ZINC_no_kl').to(device)
Example #8
0
def main(args):
    args['device'] = torch.device(
        "cuda: 0") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    dataset = PubChemBioAssayAromaticity(
        smiles_to_graph=args['smiles_to_graph'],
        node_featurizer=args.get('node_featurizer', None),
        edge_featurizer=args.get('edge_featurizer', None))
    train_set, val_set, test_set = RandomSplitter.train_val_test_split(
        dataset,
        frac_train=args['frac_train'],
        frac_val=args['frac_val'],
        frac_test=args['frac_test'],
        random_state=args['random_seed'])
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        model = load_model(args)
        loss_fn = nn.MSELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args['lr'],
                                     weight_decay=args['weight_decay'])
        stopper = EarlyStopping(mode=args['mode'], patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))

        if early_stop:
            break

    if not args['pre_trained']:
        stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric_name'], test_score))
Example #9
0
def test_jtnn():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    model = load_pretrained('JTNN_ZINC').to(device)

    remove_file('JTNN_ZINC_pre_trained.pth')
Example #10
0
	def __init__(self, predictor_dim=None):
		super(DGL_GIN_ContextPred, self).__init__()
		from dgllife.model import load_pretrained
		from dgl.nn.pytorch.glob import AvgPooling

		## this is fixed hyperparameters as it is a pretrained model
		self.gnn = load_pretrained('gin_supervised_contextpred')

		self.readout = AvgPooling()
		self.transform = nn.Linear(300, predictor_dim)
Example #11
0
def main(args):
    args['device'] = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    train_set, val_set, test_set = load_dataset_for_regression(args)
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args['batch_size'],
                              shuffle=True,
                              collate_fn=collate_molgraphs)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args['batch_size'],
                            shuffle=True,
                            collate_fn=collate_molgraphs)
    if test_set is not None:
        test_loader = DataLoader(dataset=test_set,
                                 batch_size=args['batch_size'],
                                 collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        model = load_model(args)
        loss_fn = nn.MSELoss(reduction='none')
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args['lr'],
                                     weight_decay=args['weight_decay'])
        stopper = EarlyStopping(mode='lower', patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))

        if early_stop:
            break

    if test_set is not None:
        if not args['pre_trained']:
            stopper.load_checkpoint(model)
        test_score = run_an_eval_epoch(args, model, test_loader)
        print('test {} {:.4f}'.format(args['metric_name'], test_score))
Example #12
0
def main(args):
    args['device'] = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    set_random_seed(args['random_seed'])

    # Interchangeable with other datasets
    dataset, train_set, val_set, test_set = load_dataset_for_classification(
        args)
    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate_molgraphs,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        args['n_tasks'] = dataset.n_tasks
        model = load_model(args)
        loss_criterion = BCEWithLogitsLoss(pos_weight=dataset.task_pos_weights(
            torch.tensor(train_set.indices)).to(args['device']),
                                           reduction='none')
        optimizer = Adam(model.parameters(), lr=args['lr'])
        stopper = EarlyStopping(patience=args['patience'])
    model.to(args['device'])

    for epoch in range(args['num_epochs']):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, loss_criterion,
                          optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, val_loader)
        early_stop = stopper.step(val_score, model)
        print(
            'epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}'.
            format(epoch + 1, args['num_epochs'], args['metric_name'],
                   val_score, args['metric_name'], stopper.best_score))
        if early_stop:
            break

    if not args['pre_trained']:
        stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print('test {} {:.4f}'.format(args['metric_name'], test_score))
Example #13
0
def test_gat_tox21():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    node_featurizer = CanonicalAtomFeaturizer()
    g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
    g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
    bg = dgl.batch([g1, g2])

    model = load_pretrained('GAT_Tox21').to(device)
    model(bg.to(device), bg.ndata.pop('h').to(device))
    model.eval()
    model(g1.to(device), g1.ndata.pop('h').to(device))

    remove_file('GAT_Tox21_pre_trained.pth')
Example #14
0
def main(args):
    set_seed()
    if torch.cuda.is_available():
        args['device'] = torch.device('cuda:0')
    else:
        args['device'] = torch.device('cpu')
    # Set current device
    torch.cuda.set_device(args['device'])

    if args['test_path'] is None:
        test_set = USPTOCenter('test',
                               num_processes=args['num_processes'],
                               load=args['load'])
    else:
        test_set = WLNCenterDataset(raw_file_path=args['test_path'],
                                    mol_graph_path=args['test_path'] + '.bin',
                                    num_processes=args['num_processes'],
                                    load=args['load'],
                                    reaction_validity_result_prefix='test')
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_center,
                             shuffle=False)

    if args['model_path'] is None:
        model = load_pretrained('wln_center_uspto')
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks'])
        model.load_state_dict(
            torch.load(args['model_path'],
                       map_location='cpu')['model_state_dict'])
    model = model.to(args['device'])

    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, args['top_ks_test'], model,
                                             test_loader, args['easy'])
    print(test_result)
    with open(args['result_path'] + '/test_eval.txt', 'w') as f:
        f.write(test_result)
Example #15
0
def main(args, path_to_candidate_bonds):
    if args['test_path'] is None:
        test_set = USPTORank(
            subset='test',
            candidate_bond_path=path_to_candidate_bonds['test'],
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_eval'],
            num_processes=args['num_processes'])
    else:
        test_set = WLNRankDataset(
            path_to_reaction_file=args['test_path'],
            candidate_bond_path=path_to_candidate_bonds['test'],
            mode='test',
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_eval'],
            num_processes=args['num_processes'])

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             collate_fn=collate_rank_eval,
                             shuffle=False,
                             num_workers=args['num_workers'])

    if args['num_workers'] > 1:
        torch.multiprocessing.set_sharing_strategy('file_system')

    if args['model_path'] is None:
        model = load_pretrained('wln_rank_uspto')
    else:
        model = WLNReactionRanking(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_hidden_feats=args['hidden_size'],
            num_encode_gnn_layers=args['num_encode_gnn_layers'])
        model.load_state_dict(
            torch.load(args['model_path'],
                       map_location='cpu')['model_state_dict'])
    model = model.to(args['device'])

    prediction_summary = candidate_ranking_eval(args, model, test_loader)
    with open(args['result_path'] + '/test_eval.txt', 'w') as f:
        f.write(prediction_summary)
Example #16
0
def main(args):
    lg = rdkit.RDLogger.logger()
    lg.setLevel(rdkit.RDLogger.CRITICAL)

    if args.use_cpu or not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:0')

    vocab = JTVAEVocab(file_path=args.train_path)
    if args.test_path is None:
        dataset = JTVAEZINC('test', vocab)
    else:
        dataset = JTVAEDataset(args.test_path, vocab, training=False)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            collate_fn=JTVAECollator(training=False))

    if args.model_path is None:
        model = load_pretrained('JTVAE_ZINC_no_kl')
    else:
        model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth)
        model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
    model = model.to(device)

    acc = 0.0
    for it, (tree, tree_graph, mol_graph) in enumerate(dataloader):
        tot = it + 1
        smiles = tree.smiles
        tree_graph = tree_graph.to(device)
        mol_graph = mol_graph.to(device)
        dec_smiles = model.reconstruct(tree_graph, mol_graph)
        if dec_smiles == smiles:
            acc += 1
        if tot % args.print_iter == 0:
            print('Iter {:d}/{:d} | Acc {:.4f}'.format(
                tot // args.print_iter, len(dataloader) // args.print_iter, acc / tot))
    print('Final acc: {:.4f}'.format(acc / tot))
Example #17
0
dataset = JTNNDataset(data=args.train, vocab=args.vocab, training=False)
vocab_file = dataset.vocab_file

hidden_size = int(args.hidden_size)
latent_size = int(args.latent_size)
depth = int(args.depth)

model = DGLJTNNVAE(vocab_file=vocab_file,
                   hidden_size=hidden_size,
                   latent_size=latent_size,
                   depth=depth)

if args.model_path is not None:
    model.load_state_dict(torch.load(args.model_path))
else:
    model = load_pretrained("JTNN_ZINC")

model = cuda(model)
model.eval()
print("Model #Params: %dK" %
      (sum([x.nelement() for x in model.parameters()]) / 1000,))

MAX_EPOCH = 100
PRINT_ITER = 20

def reconstruct():
    dataset.training = False
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
Example #18
0
def prepare_reaction_center(args, reaction_center_config):
    """Use a trained model for reaction center prediction to prepare candidate bonds.

    Parameters
    ----------
    args : dict
        Configuration for the experiment.
    reaction_center_config : dict
        Configuration for the experiment on reaction center prediction.

    Returns
    -------
    path_to_candidate_bonds : dict
        Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds.
    """
    if args['center_model_path'] is None:
        reaction_center_model = load_pretrained('wln_center_uspto').to(
            args['device'])
    else:
        reaction_center_model = WLNReactionCenter(
            node_in_feats=reaction_center_config['node_in_feats'],
            edge_in_feats=reaction_center_config['edge_in_feats'],
            node_pair_in_feats=reaction_center_config['node_pair_in_feats'],
            node_out_feats=reaction_center_config['node_out_feats'],
            n_layers=reaction_center_config['n_layers'],
            n_tasks=reaction_center_config['n_tasks'])
        reaction_center_model.load_state_dict(
            torch.load(args['center_model_path'])['model_state_dict'])
        reaction_center_model = reaction_center_model.to(args['device'])
    reaction_center_model.eval()

    path_to_candidate_bonds = dict()
    for subset in ['train', 'val', 'test']:
        if '{}_path'.format(subset) not in args:
            continue

        path_to_candidate_bonds[subset] = args['result_path'] + \
                                          '/{}_candidate_bonds.txt'.format(subset)
        if os.path.isfile(path_to_candidate_bonds[subset]):
            continue

        print('Processing subset {}...'.format(subset))
        print('Stage 1/3: Loading dataset...')
        if args['{}_path'.format(subset)] is None:
            dataset = USPTOCenter(subset, num_processes=args['num_processes'])
        else:
            dataset = WLNCenterDataset(
                raw_file_path=args['{}_path'.format(subset)],
                mol_graph_path='{}.bin'.format(subset),
                num_processes=args['num_processes'])

        dataloader = DataLoader(dataset,
                                batch_size=args['reaction_center_batch_size'],
                                collate_fn=collate_center,
                                shuffle=False)

        print('Stage 2/3: Performing model prediction...')
        output_strings = []
        for batch_id, batch_data in enumerate(dataloader):
            print('Computing candidate bonds for batch {:d}/{:d}'.format(
                batch_id + 1, len(dataloader)))
            batch_reactions, batch_graph_edits, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            with torch.no_grad():
                pred, biased_pred = reaction_center_prediction(
                    args['device'], reaction_center_model, batch_mol_graphs,
                    batch_complete_graphs)
            batch_size = len(batch_reactions)
            start = 0
            for i in range(batch_size):
                end = start + batch_complete_graphs.batch_num_edges[i]
                output_strings.append(
                    output_candidate_bonds_for_a_reaction(
                        (batch_reactions[i],
                         biased_pred[start:end, :].flatten(),
                         batch_complete_graphs.batch_num_nodes[i]),
                        reaction_center_config['max_k']))
                start = end

        print('Stage 3/3: Output candidate bonds...')
        with open(path_to_candidate_bonds[subset], 'w') as f:
            for candidate_string in output_strings:
                f.write(candidate_string)

        del dataset
        del dataloader

    del reaction_center_model

    return path_to_candidate_bonds
Example #19
0
def main(args):
    worker_init_fn(None)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.model_path is not None:
        model = DGLJTNNVAE(vocab_file=get_vocab_file(args.vocab),
                           hidden_size=args.hidden_size,
                           latent_size=args.latent_size,
                           depth=args.depth)
        model.load_state_dict(torch.load(args.model_path))
    else:
        model = load_pretrained("JTNN_ZINC")
    print("# model parameters: {:d}K".format(
        sum([x.nelement() for x in model.parameters()]) // 1000))

    dataset = JTVAEDataset(data=args.data, vocab=model.vocab, training=False)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=JTVAECollator(False),
        worker_init_fn=worker_init_fn)
    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    acc = 0.0
    tot = 0
    model = model.to(device)
    model.eval()
    for it, batch in enumerate(tqdm(dataloader)):
        gt_smiles = batch['mol_trees'][0].smiles
        batch = dataset.move_to_device(batch, device)
        try:
            _, tree_vec, mol_vec = model.encode(batch)

            tree_mean = model.T_mean(tree_vec)
            # Following Mueller et al.
            tree_log_var = -torch.abs(model.T_var(tree_vec))
            epsilon = torch.randn(1, model.latent_size // 2).to(device)
            tree_vec = tree_mean + torch.exp(tree_log_var // 2) * epsilon

            mol_mean = model.G_mean(mol_vec)
            # Following Mueller et al.
            mol_log_var = -torch.abs(model.G_var(mol_vec))
            epsilon = torch.randn(1, model.latent_size // 2).to(device)
            mol_vec = mol_mean + torch.exp(mol_log_var // 2) * epsilon

            dec_smiles = model.decode(tree_vec, mol_vec)

            if dec_smiles == gt_smiles:
                acc += 1
            tot += 1
        except Exception as e:
            print("Failed to encode: {}".format(gt_smiles))
            print(e)

        if it % 20 == 1:
            print("Progress {}/{}; Current Reconstruction Accuracy: {:.4f}".format(
                it, len(dataloader), acc / tot))

    print("Reconstruction Accuracy: {}".format(acc / tot))
Example #20
0
def load_model(exp_configure):
    if exp_configure['model'] == 'GCN':
        from dgllife.model import GCNPredictor
        model = GCNPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] *
            exp_configure['num_gnn_layers'],
            activation=[F.relu] * exp_configure['num_gnn_layers'],
            residual=[exp_configure['residual']] *
            exp_configure['num_gnn_layers'],
            batchnorm=[exp_configure['batchnorm']] *
            exp_configure['num_gnn_layers'],
            dropout=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'GAT':
        from dgllife.model import GATPredictor
        model = GATPredictor(
            in_feats=exp_configure['in_node_feats'],
            hidden_feats=[exp_configure['gnn_hidden_feats']] *
            exp_configure['num_gnn_layers'],
            num_heads=[exp_configure['num_heads']] *
            exp_configure['num_gnn_layers'],
            feat_drops=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            attn_drops=[exp_configure['dropout']] *
            exp_configure['num_gnn_layers'],
            alphas=[exp_configure['alpha']] * exp_configure['num_gnn_layers'],
            residuals=[exp_configure['residual']] *
            exp_configure['num_gnn_layers'],
            predictor_hidden_feats=exp_configure['predictor_hidden_feats'],
            predictor_dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'Weave':
        from dgllife.model import WeavePredictor
        model = WeavePredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            num_gnn_layers=exp_configure['num_gnn_layers'],
            gnn_hidden_feats=exp_configure['gnn_hidden_feats'],
            graph_feats=exp_configure['graph_feats'],
            gaussian_expand=exp_configure['gaussian_expand'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'MPNN':
        from dgllife.model import MPNNPredictor
        model = MPNNPredictor(
            node_in_feats=exp_configure['in_node_feats'],
            edge_in_feats=exp_configure['in_edge_feats'],
            node_out_feats=exp_configure['node_out_feats'],
            edge_hidden_feats=exp_configure['edge_hidden_feats'],
            num_step_message_passing=exp_configure['num_step_message_passing'],
            num_step_set2set=exp_configure['num_step_set2set'],
            num_layer_set2set=exp_configure['num_layer_set2set'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] == 'AttentiveFP':
        from dgllife.model import AttentiveFPPredictor
        model = AttentiveFPPredictor(
            node_feat_size=exp_configure['in_node_feats'],
            edge_feat_size=exp_configure['in_edge_feats'],
            num_layers=exp_configure['num_layers'],
            num_timesteps=exp_configure['num_timesteps'],
            graph_feat_size=exp_configure['graph_feat_size'],
            dropout=exp_configure['dropout'],
            n_tasks=exp_configure['n_tasks'])
    elif exp_configure['model'] in [
            'gin_supervised_contextpred', 'gin_supervised_infomax',
            'gin_supervised_edgepred', 'gin_supervised_masking'
    ]:
        from dgllife.model import GINPredictor
        from dgllife.model import load_pretrained
        model = GINPredictor(num_node_emb_list=[120, 3],
                             num_edge_emb_list=[6, 3],
                             num_layers=5,
                             emb_dim=300,
                             JK=exp_configure['jk'],
                             dropout=0.5,
                             readout=exp_configure['readout'],
                             n_tasks=exp_configure['n_tasks'])
        model.gnn = load_pretrained(exp_configure['model'])
        model.gnn.JK = exp_configure['jk']
    else:
        return ValueError(
            "Expect model to be from ['GCN', 'GAT', 'Weave', 'MPNN', 'AttentiveFP', "
            "'gin_supervised_contextpred', 'gin_supervised_infomax', "
            "'gin_supervised_edgepred', 'gin_supervised_masking'], "
            "got {}".format(exp_configure['model']))

    return model
Example #21
0
def main(args):
    setup(args)
    if args['train_path'] is None:
        train_set = USPTO('train')
    else:
        train_set = WLNReactionDataset(raw_file_path=args['train_path'],
                                       mol_graph_path='train.bin')
    if args['val_path'] is None:
        val_set = USPTO('val')
    else:
        val_set = WLNReactionDataset(raw_file_path=args['val_path'],
                                     mol_graph_path='val.bin')
    if args['test_path'] is None:
        test_set = USPTO('test')
    else:
        test_set = WLNReactionDataset(raw_file_path=args['test_path'],
                                      mol_graph_path='test.bin')
    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate,
                            shuffle=False)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate,
                             shuffle=False)

    if args['pre_trained']:
        model = load_pretrained('wln_center_uspto').to(args['device'])
        args['num_epochs'] = 0
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks']).to(args['device'])

        criterion = BCEWithLogitsLoss(reduction='sum')
        optimizer = Adam(model.parameters(), lr=args['lr'])
        scheduler = StepLR(optimizer,
                           step_size=args['decay_every'],
                           gamma=args['lr_decay_factor'])

        total_iter = 0
        grad_norm_sum = 0
        loss_sum = 0
        dur = []

    for epoch in range(args['num_epochs']):
        t0 = time.time()
        for batch_id, batch_data in enumerate(train_loader):
            total_iter += 1

            batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            labels = batch_atom_pair_labels.to(args['device'])
            pred, biased_pred = reaction_center_prediction(
                args['device'], model, batch_mol_graphs, batch_complete_graphs)
            loss = criterion(pred, labels) / len(batch_reactions)
            loss_sum += loss.cpu().detach().data.item()
            optimizer.zero_grad()
            loss.backward()
            grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
            grad_norm_sum += grad_norm
            optimizer.step()
            scheduler.step()

            if total_iter % args['print_every'] == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
                           'loss {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
                    (np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
                    grad_norm_sum / args['print_every'])
                grad_norm_sum = 0
                loss_sum = 0
                print(progress)

            if total_iter % args['decay_every'] == 0:
                torch.save(model.state_dict(),
                           args['result_path'] + '/model.pkl')

        dur.append(time.time() - t0)
        print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
              rough_eval_on_a_loader(args, model, val_loader))

    del train_loader
    del val_loader
    del train_set
    del val_set
    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, model, test_loader,
                                             args['easy'])
    print(test_result)
    with open(args['result_path'] + '/results.txt', 'w') as f:
        f.write(test_result)
Example #22
0
def main(args):

    torch.cuda.set_device(args['gpu'])
    set_random_seed(args['random_seed'])

    dataset, train_set, val_set, test_set = load_dataset_for_classification(
        args)  # 6264, 783, 784

    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate_molgraphs,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_molgraphs)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_molgraphs)

    if args['pre_trained']:
        args['num_epochs'] = 0
        model = load_pretrained(args['exp'])
    else:
        args['n_tasks'] = dataset.n_tasks
        if args['method'] == 'twp':
            model = load_mymodel(args)
            print(model)
        else:
            model = load_model(args)
            for name, parameters in model.named_parameters():
                print(name, ':', parameters.size())

        method = args['method']
        life_model = importlib.import_module(f'LifeModel.{method}_model')
        life_model_ins = life_model.NET(model, args)
        data_loader = DataLoader(train_set,
                                 batch_size=len(train_set),
                                 collate_fn=collate_molgraphs,
                                 shuffle=True)
        life_model_ins.data_loader = data_loader

        loss_criterion = BCEWithLogitsLoss(
            pos_weight=dataset.task_pos_weights.cuda(), reduction='none')

    model.cuda()
    score_mean = []
    score_matrix = np.zeros([args['n_tasks'], args['n_tasks']])

    prev_model = None
    for task_i in range(12):
        print('\n********' + str(task_i))
        stopper = EarlyStopping(patience=args['patience'])
        for epoch in range(args['num_epochs']):
            # Train
            if args['method'] == 'lwf':
                life_model_ins.observe(train_loader, loss_criterion, task_i,
                                       args, prev_model)
            else:
                life_model_ins.observe(train_loader, loss_criterion, task_i,
                                       args)

            # Validation and early stop
            val_score = run_an_eval_epoch(args, model, val_loader, task_i)
            early_stop = stopper.step(val_score, model)

            if early_stop:
                print(epoch)
                break

        if not args['pre_trained']:
            stopper.load_checkpoint(model)

        score_matrix[task_i] = run_eval_epoch(args, model, test_loader)
        prev_model = copy.deepcopy(life_model_ins).cuda()

    print('AP: ', round(np.mean(score_matrix[-1, :]), 4))
    backward = []
    for t in range(args['n_tasks'] - 1):
        b = score_matrix[args['n_tasks'] - 1][t] - score_matrix[t][t]
        backward.append(round(b, 4))
    mean_backward = round(np.mean(backward), 4)
    print('AF: ', mean_backward)
Example #23
0
def test_moleculenet():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    for dataset in [
            'BACE', 'BBBP', 'ClinTox', 'FreeSolv', 'HIV', 'MUV', 'SIDER',
            'ToxCast', 'PCBA', 'ESOL', 'Lipophilicity', 'Tox21'
    ]:
        for featurizer_type in ['canonical', 'attentivefp']:
            if featurizer_type == 'canonical':
                node_featurizer = CanonicalAtomFeaturizer(atom_data_field='hv')
                edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he',
                                                          self_loop=True)
            else:
                node_featurizer = AttentiveFPAtomFeaturizer(
                    atom_data_field='hv')
                edge_featurizer = AttentiveFPBondFeaturizer(
                    bond_data_field='he', self_loop=True)

            for model_type in ['GCN', 'GAT']:
                g1 = smiles_to_bigraph('CO', node_featurizer=node_featurizer)
                g2 = smiles_to_bigraph('CCO', node_featurizer=node_featurizer)
                bg = dgl.batch([g1, g2])

                model = load_pretrained('{}_{}_{}'.format(
                    model_type, featurizer_type, dataset)).to(device)
                with torch.no_grad():
                    model(bg.to(device), bg.ndata.pop('hv').to(device))
                    model.eval()
                    model(g1.to(device), g1.ndata.pop('hv').to(device))
                remove_file('{}_{}_{}_pre_trained.pth'.format(
                    model_type.lower(), featurizer_type, dataset))

            for model_type in ['Weave', 'MPNN', 'AttentiveFP']:
                g1 = smiles_to_bigraph('CO',
                                       add_self_loop=True,
                                       node_featurizer=node_featurizer,
                                       edge_featurizer=edge_featurizer)
                g2 = smiles_to_bigraph('CCO',
                                       add_self_loop=True,
                                       node_featurizer=node_featurizer,
                                       edge_featurizer=edge_featurizer)
                bg = dgl.batch([g1, g2])

                model = load_pretrained('{}_{}_{}'.format(
                    model_type, featurizer_type, dataset)).to(device)
                with torch.no_grad():
                    model(bg.to(device),
                          bg.ndata.pop('hv').to(device),
                          bg.edata.pop('he').to(device))
                    model.eval()
                    model(g1.to(device),
                          g1.ndata.pop('hv').to(device),
                          g1.edata.pop('he').to(device))
                remove_file('{}_{}_{}_pre_trained.pth'.format(
                    model_type.lower(), featurizer_type, dataset))

        if dataset == 'ClinTox':
            continue

        node_featurizer = PretrainAtomFeaturizer()
        edge_featurizer = PretrainBondFeaturizer()
        for model_type in [
                'gin_supervised_contextpred', 'gin_supervised_infomax',
                'gin_supervised_edgepred', 'gin_supervised_masking'
        ]:
            g1 = smiles_to_bigraph('CO',
                                   add_self_loop=True,
                                   node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer)
            g2 = smiles_to_bigraph('CCO',
                                   add_self_loop=True,
                                   node_featurizer=node_featurizer,
                                   edge_featurizer=edge_featurizer)
            bg = dgl.batch([g1, g2])

            model = load_pretrained('{}_{}'.format(model_type,
                                                   dataset)).to(device)
            with torch.no_grad():
                node_feats = [
                    bg.ndata.pop('atomic_number').to(device),
                    bg.ndata.pop('chirality_type').to(device)
                ]
                edge_feats = [
                    bg.edata.pop('bond_type').to(device),
                    bg.edata.pop('bond_direction_type').to(device)
                ]
                model(bg.to(device), node_feats, edge_feats)
                model.eval()
                node_feats = [
                    g1.ndata.pop('atomic_number').to(device),
                    g1.ndata.pop('chirality_type').to(device)
                ]
                edge_feats = [
                    g1.edata.pop('bond_type').to(device),
                    g1.edata.pop('bond_direction_type').to(device)
                ]
                model(g1.to(device), node_feats, edge_feats)
            remove_file('{}_{}_pre_trained.pth'.format(model_type.lower(),
                                                       dataset))