Ejemplo n.º 1
0
def load_dataset(args):
    if args['train_path'] is None:
        train_set = USPTOCenter('train', num_processes=args['num_processes'])
    else:
        train_set = WLNCenterDataset(raw_file_path=args['train_path'],
                                     mol_graph_path='train.bin',
                                     num_processes=args['num_processes'])
    if args['val_path'] is None:
        val_set = USPTOCenter('val', num_processes=args['num_processes'])
    else:
        val_set = WLNCenterDataset(raw_file_path=args['val_path'],
                                   mol_graph_path='val.bin',
                                   num_processes=args['num_processes'])

    return train_set, val_set
Ejemplo n.º 2
0
def load_dataset(args):
    if args['train_path'] is None:
        train_set = USPTOCenter('train', num_processes=args['num_processes'])
    else:
        train_set = WLNCenterDataset(raw_file_path=args['train_path'],
                                     mol_graph_path='./train.bin',
                                     num_processes=args['num_processes'],
                                     reaction_validity_result_prefix='train')
    if args['val_path'] is None:
        val_set = USPTOCenter('val', num_processes=args['num_processes'])
    else:
        val_set = WLNCenterDataset(raw_file_path=args['val_path'],
                                   mol_graph_path='./val.bin',
                                   num_processes=args['num_processes'],
                                   reaction_validity_result_prefix='val')

    return train_set, val_set
Ejemplo n.º 3
0
def main(args):
    set_seed()
    if torch.cuda.is_available():
        args['device'] = torch.device('cuda:0')
    else:
        args['device'] = torch.device('cpu')
    # Set current device
    torch.cuda.set_device(args['device'])

    if args['test_path'] is None:
        test_set = USPTOCenter('test',
                               num_processes=args['num_processes'],
                               load=args['load'])
    else:
        test_set = WLNCenterDataset(raw_file_path=args['test_path'],
                                    mol_graph_path=args['test_path'] + '.bin',
                                    num_processes=args['num_processes'],
                                    load=args['load'],
                                    reaction_validity_result_prefix='test')
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_center,
                             shuffle=False)

    if args['model_path'] is None:
        model = load_pretrained('wln_center_uspto')
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks'])
        model.load_state_dict(
            torch.load(args['model_path'],
                       map_location='cpu')['model_state_dict'])
    model = model.to(args['device'])

    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, args['top_ks_test'], model,
                                             test_loader, args['easy'])
    print(test_result)
    with open(args['result_path'] + '/test_eval.txt', 'w') as f:
        f.write(test_result)
Ejemplo n.º 4
0
def prepare_reaction_center(args, reaction_center_config):
    """Use a trained model for reaction center prediction to prepare candidate bonds.

    Parameters
    ----------
    args : dict
        Configuration for the experiment.
    reaction_center_config : dict
        Configuration for the experiment on reaction center prediction.

    Returns
    -------
    path_to_candidate_bonds : dict
        Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds.
    """
    if args['center_model_path'] is None:
        reaction_center_model = load_pretrained('wln_center_uspto').to(
            args['device'])
    else:
        reaction_center_model = WLNReactionCenter(
            node_in_feats=reaction_center_config['node_in_feats'],
            edge_in_feats=reaction_center_config['edge_in_feats'],
            node_pair_in_feats=reaction_center_config['node_pair_in_feats'],
            node_out_feats=reaction_center_config['node_out_feats'],
            n_layers=reaction_center_config['n_layers'],
            n_tasks=reaction_center_config['n_tasks'])
        reaction_center_model.load_state_dict(
            torch.load(args['center_model_path'])['model_state_dict'])
        reaction_center_model = reaction_center_model.to(args['device'])
    reaction_center_model.eval()

    path_to_candidate_bonds = dict()
    for subset in ['train', 'val', 'test']:
        if '{}_path'.format(subset) not in args:
            continue

        path_to_candidate_bonds[subset] = args['result_path'] + \
                                          '/{}_candidate_bonds.txt'.format(subset)
        if os.path.isfile(path_to_candidate_bonds[subset]):
            continue

        print('Processing subset {}...'.format(subset))
        print('Stage 1/3: Loading dataset...')
        if args['{}_path'.format(subset)] is None:
            dataset = USPTOCenter(subset, num_processes=args['num_processes'])
        else:
            dataset = WLNCenterDataset(
                raw_file_path=args['{}_path'.format(subset)],
                mol_graph_path='{}.bin'.format(subset),
                num_processes=args['num_processes'])

        dataloader = DataLoader(dataset,
                                batch_size=args['reaction_center_batch_size'],
                                collate_fn=collate_center,
                                shuffle=False)

        print('Stage 2/3: Performing model prediction...')
        output_strings = []
        for batch_id, batch_data in enumerate(dataloader):
            print('Computing candidate bonds for batch {:d}/{:d}'.format(
                batch_id + 1, len(dataloader)))
            batch_reactions, batch_graph_edits, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            with torch.no_grad():
                pred, biased_pred = reaction_center_prediction(
                    args['device'], reaction_center_model, batch_mol_graphs,
                    batch_complete_graphs)
            batch_size = len(batch_reactions)
            start = 0
            for i in range(batch_size):
                end = start + batch_complete_graphs.batch_num_edges[i]
                output_strings.append(
                    output_candidate_bonds_for_a_reaction(
                        (batch_reactions[i],
                         biased_pred[start:end, :].flatten(),
                         batch_complete_graphs.batch_num_nodes[i]),
                        reaction_center_config['max_k']))
                start = end

        print('Stage 3/3: Output candidate bonds...')
        with open(path_to_candidate_bonds[subset], 'w') as f:
            for candidate_string in output_strings:
                f.write(candidate_string)

        del dataset
        del dataloader

    del reaction_center_model

    return path_to_candidate_bonds