Example #1
0
def main(args, path_to_candidate_bonds):
    if args['test_path'] is None:
        test_set = USPTORank(
            subset='test',
            candidate_bond_path=path_to_candidate_bonds['test'],
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_eval'],
            num_processes=args['num_processes'])
    else:
        test_set = WLNRankDataset(
            path_to_reaction_file=args['test_path'],
            candidate_bond_path=path_to_candidate_bonds['test'],
            mode='test',
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_eval'],
            num_processes=args['num_processes'])

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             collate_fn=collate_rank_eval,
                             shuffle=False,
                             num_workers=args['num_workers'])

    if args['num_workers'] > 1:
        torch.multiprocessing.set_sharing_strategy('file_system')

    if args['model_path'] is None:
        model = load_pretrained('wln_rank_uspto')
    else:
        model = WLNReactionRanking(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_hidden_feats=args['hidden_size'],
            num_encode_gnn_layers=args['num_encode_gnn_layers'])
        model.load_state_dict(
            torch.load(args['model_path'],
                       map_location='cpu')['model_state_dict'])
    model = model.to(args['device'])

    prediction_summary = candidate_ranking_eval(args, model, test_loader)
    with open(args['result_path'] + '/test_eval.txt', 'w') as f:
        f.write(prediction_summary)
def main(args, path_to_candidate_bonds):
    if args['train_path'] is None:
        train_set = USPTORank(
            subset='train',
            candidate_bond_path=path_to_candidate_bonds['train'],
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_train'],
            num_processes=args['num_processes'])
    else:
        train_set = WLNRankDataset(
            path_to_reaction_file=args['train_path'],
            candidate_bond_path=path_to_candidate_bonds['train'],
            mode='train',
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_train'],
            num_processes=args['num_processes'])
    train_set.ignore_large()
    if args['val_path'] is None:
        val_set = USPTORank(subset='val',
                            candidate_bond_path=path_to_candidate_bonds['val'],
                            max_num_change_combos_per_reaction=args[
                                'max_num_change_combos_per_reaction_eval'],
                            num_processes=args['num_processes'])
    else:
        val_set = WLNRankDataset(
            path_to_reaction_file=args['val_path'],
            candidate_bond_path=path_to_candidate_bonds['val'],
            mode='val',
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_eval'],
            num_processes=args['num_processes'])

    if args['num_workers'] > 1:
        torch.multiprocessing.set_sharing_strategy('file_system')

    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate_rank_train,
                              shuffle=True,
                              num_workers=args['num_workers'])
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_rank_eval,
                            shuffle=False,
                            num_workers=args['num_workers'])

    model = WLNReactionRanking(
        node_in_feats=args['node_in_feats'],
        edge_in_feats=args['edge_in_feats'],
        node_hidden_feats=args['hidden_size'],
        num_encode_gnn_layers=args['num_encode_gnn_layers']).to(args['device'])
    criterion = CrossEntropyLoss(reduction='sum')
    optimizer = Adam(model.parameters(), lr=args['lr'])
    from utils import Optimizer
    optimizer = Optimizer(model,
                          args['lr'],
                          optimizer,
                          max_grad_norm=args['max_norm'])

    acc_sum = 0
    grad_norm_sum = 0
    dur = []
    total_samples = 0
    for epoch in range(args['num_epochs']):
        t0 = time.time()
        model.train()
        for batch_id, batch_data in enumerate(train_loader):
            batch_reactant_graphs, batch_product_graphs, \
            batch_combo_scores, batch_labels, batch_num_candidate_products = batch_data

            batch_reactant_graphs = batch_reactant_graphs.to(args['device'])
            batch_product_graphs = batch_product_graphs.to(args['device'])
            batch_combo_scores = batch_combo_scores.to(args['device'])
            batch_labels = batch_labels.to(args['device'])
            reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(
                args['device'])
            reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(
                args['device'])
            product_node_feats = batch_product_graphs.ndata.pop('hv').to(
                args['device'])
            product_edge_feats = batch_product_graphs.edata.pop('he').to(
                args['device'])

            pred = model(
                reactant_graph=batch_reactant_graphs,
                reactant_node_feats=reactant_node_feats,
                reactant_edge_feats=reactant_edge_feats,
                product_graphs=batch_product_graphs,
                product_node_feats=product_node_feats,
                product_edge_feats=product_edge_feats,
                candidate_scores=batch_combo_scores,
                batch_num_candidate_products=batch_num_candidate_products)

            # Check if the ground truth candidate has the highest score
            batch_loss = 0
            product_graph_start = 0
            for i in range(len(batch_num_candidate_products)):
                product_graph_end = product_graph_start + batch_num_candidate_products[
                    i]
                reaction_pred = pred[product_graph_start:product_graph_end, :]
                acc_sum += float(
                    reaction_pred.max(
                        dim=0)[1].detach().cpu().data.item() == 0)
                batch_loss += criterion(reaction_pred.reshape(1, -1),
                                        batch_labels[i, :])
                product_graph_start = product_graph_end

            grad_norm_sum += optimizer.backward_and_step(batch_loss)
            total_samples += args['batch_size']
            if total_samples % args['print_every'] == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time {:.4f} | ' \
                           'accuracy {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'],
                    (batch_id + 1) * args['batch_size'] // args['print_every'],
                    len(train_set) // args['print_every'],
                    (sum(dur) + time.time() - t0) / total_samples * args['print_every'],
                    acc_sum / args['print_every'],
                    grad_norm_sum / args['print_every'])
                print(progress)
                acc_sum = 0
                grad_norm_sum = 0

            if total_samples % args['decay_every'] == 0:
                dur.append(time.time() - t0)
                old_lr = optimizer.lr
                optimizer.decay_lr(args['lr_decay_factor'])
                new_lr = optimizer.lr
                print('Learning rate decayed from {:.4f} to {:.4f}'.format(
                    old_lr, new_lr))
                torch.save({'model_state_dict': model.state_dict()},
                           args['result_path'] +
                           '/model_{:d}.pkl'.format(total_samples))
                prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d})\n'.format(
                    total_samples, epoch + 1, args['num_epochs'],
                    (batch_id + 1) * args['batch_size'] // args['print_every'],
                    len(train_set) //
                    args['print_every']) + candidate_ranking_eval(
                        args, model, val_loader)
                print(prediction_summary)
                with open(args['result_path'] + '/val_eval.txt', 'a') as f:
                    f.write(prediction_summary)
                t0 = time.time()
                model.train()