Exemplo n.º 1
0
def main(args):
    set_seed()
    if torch.cuda.is_available():
        args['device'] = torch.device('cuda:0')
    else:
        args['device'] = torch.device('cpu')
    # Set current device
    torch.cuda.set_device(args['device'])

    if args['test_path'] is None:
        test_set = USPTOCenter('test',
                               num_processes=args['num_processes'],
                               load=args['load'])
    else:
        test_set = WLNCenterDataset(raw_file_path=args['test_path'],
                                    mol_graph_path=args['test_path'] + '.bin',
                                    num_processes=args['num_processes'],
                                    load=args['load'],
                                    reaction_validity_result_prefix='test')
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_center,
                             shuffle=False)

    if args['model_path'] is None:
        model = load_pretrained('wln_center_uspto')
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks'])
        model.load_state_dict(
            torch.load(args['model_path'],
                       map_location='cpu')['model_state_dict'])
    model = model.to(args['device'])

    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, args['top_ks_test'], model,
                                             test_loader, args['easy'])
    print(test_result)
    with open(args['result_path'] + '/test_eval.txt', 'w') as f:
        f.write(test_result)
Exemplo n.º 2
0
def main(args):
    setup(args)
    if args['train_path'] is None:
        train_set = USPTO('train')
    else:
        train_set = WLNReactionDataset(raw_file_path=args['train_path'],
                                       mol_graph_path='train.bin')
    if args['val_path'] is None:
        val_set = USPTO('val')
    else:
        val_set = WLNReactionDataset(raw_file_path=args['val_path'],
                                     mol_graph_path='val.bin')
    if args['test_path'] is None:
        test_set = USPTO('test')
    else:
        test_set = WLNReactionDataset(raw_file_path=args['test_path'],
                                      mol_graph_path='test.bin')
    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate,
                            shuffle=False)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate,
                             shuffle=False)

    if args['pre_trained']:
        model = load_pretrained('wln_center_uspto').to(args['device'])
        args['num_epochs'] = 0
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks']).to(args['device'])

        criterion = BCEWithLogitsLoss(reduction='sum')
        optimizer = Adam(model.parameters(), lr=args['lr'])
        scheduler = StepLR(optimizer,
                           step_size=args['decay_every'],
                           gamma=args['lr_decay_factor'])

        total_iter = 0
        grad_norm_sum = 0
        loss_sum = 0
        dur = []

    for epoch in range(args['num_epochs']):
        t0 = time.time()
        for batch_id, batch_data in enumerate(train_loader):
            total_iter += 1

            batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            labels = batch_atom_pair_labels.to(args['device'])
            pred, biased_pred = reaction_center_prediction(
                args['device'], model, batch_mol_graphs, batch_complete_graphs)
            loss = criterion(pred, labels) / len(batch_reactions)
            loss_sum += loss.cpu().detach().data.item()
            optimizer.zero_grad()
            loss.backward()
            grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
            grad_norm_sum += grad_norm
            optimizer.step()
            scheduler.step()

            if total_iter % args['print_every'] == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
                           'loss {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
                    (np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
                    grad_norm_sum / args['print_every'])
                grad_norm_sum = 0
                loss_sum = 0
                print(progress)

            if total_iter % args['decay_every'] == 0:
                torch.save(model.state_dict(),
                           args['result_path'] + '/model.pkl')

        dur.append(time.time() - t0)
        print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
              rough_eval_on_a_loader(args, model, val_loader))

    del train_loader
    del val_loader
    del train_set
    del val_set
    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, model, test_loader,
                                             args['easy'])
    print(test_result)
    with open(args['result_path'] + '/results.txt', 'w') as f:
        f.write(test_result)
Exemplo n.º 3
0
def main(rank, dev_id, args):
    set_seed()
    # Remove the line below will result in problems for multiprocess
    if args['num_devices'] > 1:
        torch.set_num_threads(1)
    if dev_id == -1:
        args['device'] = torch.device('cpu')
    else:
        args['device'] = torch.device('cuda:{}'.format(dev_id))
        # Set current device
        torch.cuda.set_device(args['device'])

    train_set, val_set = load_dataset(args)
    get_center_subset(train_set, rank, args['num_devices'])
    train_loader = DataLoader(train_set, batch_size=args['batch_size'],
                              collate_fn=collate_center, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=args['batch_size'],
                            collate_fn=collate_center, shuffle=False)

    model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
                              edge_in_feats=args['edge_in_feats'],
                              node_pair_in_feats=args['node_pair_in_feats'],
                              node_out_feats=args['node_out_feats'],
                              n_layers=args['n_layers'],
                              n_tasks=args['n_tasks']).to(args['device'])
    model.train()
    if rank == 0:
        print('# trainable parameters in the model: ', count_parameters(model))

    criterion = BCEWithLogitsLoss(reduction='sum')
    optimizer = Adam(model.parameters(), lr=args['lr'])
    if args['num_devices'] <= 1:
        from utils import Optimizer
        optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
    else:
        from utils import MultiProcessOptimizer
        optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'],
                                          optimizer, max_grad_norm=args['max_norm'])

    total_iter = 0
    rank_iter = 0
    grad_norm_sum = 0
    loss_sum = 0
    dur = []

    for epoch in range(args['num_epochs']):
        t0 = time.time()
        for batch_id, batch_data in enumerate(train_loader):
            total_iter += args['num_devices']
            rank_iter += 1

            batch_reactions, batch_graph_edits, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            labels = batch_atom_pair_labels.to(args['device'])
            pred, biased_pred = reaction_center_prediction(
                args['device'], model, batch_mol_graphs, batch_complete_graphs)
            loss = criterion(pred, labels) / len(batch_reactions)
            loss_sum += loss.cpu().detach().data.item()
            grad_norm_sum += optimizer.backward_and_step(loss)

            if rank_iter % args['print_every'] == 0 and rank == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \
                           'loss {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
                    loss_sum / args['print_every'], grad_norm_sum / args['print_every'])
                print(progress)
                grad_norm_sum = 0
                loss_sum = 0

            if total_iter % args['decay_every'] == 0:
                optimizer.decay_lr(args['lr_decay_factor'])
            if total_iter % args['decay_every'] == 0 and rank == 0:
                if epoch >= 1:
                    dur.append(time.time() - t0)
                    print('Training time per {:d} iterations: {:.4f}'.format(
                        rank_iter, np.mean(dur)))
                total_samples = total_iter * args['batch_size']
                prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
                    total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
                      reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
                print(prediction_summary)
                with open(args['result_path'] + '/val_eval.txt', 'a') as f:
                    f.write(prediction_summary)
                torch.save({'model_state_dict': model.state_dict()},
                           args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
                t0 = time.time()
                model.train()
        synchronize(args['num_devices'])
Exemplo n.º 4
0
def prepare_reaction_center(args, reaction_center_config):
    """Use a trained model for reaction center prediction to prepare candidate bonds.

    Parameters
    ----------
    args : dict
        Configuration for the experiment.
    reaction_center_config : dict
        Configuration for the experiment on reaction center prediction.

    Returns
    -------
    path_to_candidate_bonds : dict
        Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds.
    """
    if args['center_model_path'] is None:
        reaction_center_model = load_pretrained('wln_center_uspto').to(
            args['device'])
    else:
        reaction_center_model = WLNReactionCenter(
            node_in_feats=reaction_center_config['node_in_feats'],
            edge_in_feats=reaction_center_config['edge_in_feats'],
            node_pair_in_feats=reaction_center_config['node_pair_in_feats'],
            node_out_feats=reaction_center_config['node_out_feats'],
            n_layers=reaction_center_config['n_layers'],
            n_tasks=reaction_center_config['n_tasks'])
        reaction_center_model.load_state_dict(
            torch.load(args['center_model_path'])['model_state_dict'])
        reaction_center_model = reaction_center_model.to(args['device'])
    reaction_center_model.eval()

    path_to_candidate_bonds = dict()
    for subset in ['train', 'val', 'test']:
        if '{}_path'.format(subset) not in args:
            continue

        path_to_candidate_bonds[subset] = args['result_path'] + \
                                          '/{}_candidate_bonds.txt'.format(subset)
        if os.path.isfile(path_to_candidate_bonds[subset]):
            continue

        print('Processing subset {}...'.format(subset))
        print('Stage 1/3: Loading dataset...')
        if args['{}_path'.format(subset)] is None:
            dataset = USPTOCenter(subset, num_processes=args['num_processes'])
        else:
            dataset = WLNCenterDataset(
                raw_file_path=args['{}_path'.format(subset)],
                mol_graph_path='{}.bin'.format(subset),
                num_processes=args['num_processes'])

        dataloader = DataLoader(dataset,
                                batch_size=args['reaction_center_batch_size'],
                                collate_fn=collate_center,
                                shuffle=False)

        print('Stage 2/3: Performing model prediction...')
        output_strings = []
        for batch_id, batch_data in enumerate(dataloader):
            print('Computing candidate bonds for batch {:d}/{:d}'.format(
                batch_id + 1, len(dataloader)))
            batch_reactions, batch_graph_edits, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            with torch.no_grad():
                pred, biased_pred = reaction_center_prediction(
                    args['device'], reaction_center_model, batch_mol_graphs,
                    batch_complete_graphs)
            batch_size = len(batch_reactions)
            start = 0
            for i in range(batch_size):
                end = start + batch_complete_graphs.batch_num_edges[i]
                output_strings.append(
                    output_candidate_bonds_for_a_reaction(
                        (batch_reactions[i],
                         biased_pred[start:end, :].flatten(),
                         batch_complete_graphs.batch_num_nodes[i]),
                        reaction_center_config['max_k']))
                start = end

        print('Stage 3/3: Output candidate bonds...')
        with open(path_to_candidate_bonds[subset], 'w') as f:
            for candidate_string in output_strings:
                f.write(candidate_string)

        del dataset
        del dataloader

    del reaction_center_model

    return path_to_candidate_bonds