def main(args, path_to_candidate_bonds):
    if args['train_path'] is None:
        train_set = USPTORank(
            subset='train',
            candidate_bond_path=path_to_candidate_bonds['train'],
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_train'],
            num_processes=args['num_processes'])
    else:
        train_set = WLNRankDataset(
            path_to_reaction_file=args['train_path'],
            candidate_bond_path=path_to_candidate_bonds['train'],
            mode='train',
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_train'],
            num_processes=args['num_processes'])
    train_set.ignore_large()
    if args['val_path'] is None:
        val_set = USPTORank(subset='val',
                            candidate_bond_path=path_to_candidate_bonds['val'],
                            max_num_change_combos_per_reaction=args[
                                'max_num_change_combos_per_reaction_eval'],
                            num_processes=args['num_processes'])
    else:
        val_set = WLNRankDataset(
            path_to_reaction_file=args['val_path'],
            candidate_bond_path=path_to_candidate_bonds['val'],
            mode='val',
            max_num_change_combos_per_reaction=args[
                'max_num_change_combos_per_reaction_eval'],
            num_processes=args['num_processes'])

    if args['num_workers'] > 1:
        torch.multiprocessing.set_sharing_strategy('file_system')

    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate_rank_train,
                              shuffle=True,
                              num_workers=args['num_workers'])
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate_rank_eval,
                            shuffle=False,
                            num_workers=args['num_workers'])

    model = WLNReactionRanking(
        node_in_feats=args['node_in_feats'],
        edge_in_feats=args['edge_in_feats'],
        node_hidden_feats=args['hidden_size'],
        num_encode_gnn_layers=args['num_encode_gnn_layers']).to(args['device'])
    criterion = CrossEntropyLoss(reduction='sum')
    optimizer = Adam(model.parameters(), lr=args['lr'])
    from utils import Optimizer
    optimizer = Optimizer(model,
                          args['lr'],
                          optimizer,
                          max_grad_norm=args['max_norm'])

    acc_sum = 0
    grad_norm_sum = 0
    dur = []
    total_samples = 0
    for epoch in range(args['num_epochs']):
        t0 = time.time()
        model.train()
        for batch_id, batch_data in enumerate(train_loader):
            batch_reactant_graphs, batch_product_graphs, \
            batch_combo_scores, batch_labels, batch_num_candidate_products = batch_data

            batch_reactant_graphs = batch_reactant_graphs.to(args['device'])
            batch_product_graphs = batch_product_graphs.to(args['device'])
            batch_combo_scores = batch_combo_scores.to(args['device'])
            batch_labels = batch_labels.to(args['device'])
            reactant_node_feats = batch_reactant_graphs.ndata.pop('hv').to(
                args['device'])
            reactant_edge_feats = batch_reactant_graphs.edata.pop('he').to(
                args['device'])
            product_node_feats = batch_product_graphs.ndata.pop('hv').to(
                args['device'])
            product_edge_feats = batch_product_graphs.edata.pop('he').to(
                args['device'])

            pred = model(
                reactant_graph=batch_reactant_graphs,
                reactant_node_feats=reactant_node_feats,
                reactant_edge_feats=reactant_edge_feats,
                product_graphs=batch_product_graphs,
                product_node_feats=product_node_feats,
                product_edge_feats=product_edge_feats,
                candidate_scores=batch_combo_scores,
                batch_num_candidate_products=batch_num_candidate_products)

            # Check if the ground truth candidate has the highest score
            batch_loss = 0
            product_graph_start = 0
            for i in range(len(batch_num_candidate_products)):
                product_graph_end = product_graph_start + batch_num_candidate_products[
                    i]
                reaction_pred = pred[product_graph_start:product_graph_end, :]
                acc_sum += float(
                    reaction_pred.max(
                        dim=0)[1].detach().cpu().data.item() == 0)
                batch_loss += criterion(reaction_pred.reshape(1, -1),
                                        batch_labels[i, :])
                product_graph_start = product_graph_end

            grad_norm_sum += optimizer.backward_and_step(batch_loss)
            total_samples += args['batch_size']
            if total_samples % args['print_every'] == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time {:.4f} | ' \
                           'accuracy {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'],
                    (batch_id + 1) * args['batch_size'] // args['print_every'],
                    len(train_set) // args['print_every'],
                    (sum(dur) + time.time() - t0) / total_samples * args['print_every'],
                    acc_sum / args['print_every'],
                    grad_norm_sum / args['print_every'])
                print(progress)
                acc_sum = 0
                grad_norm_sum = 0

            if total_samples % args['decay_every'] == 0:
                dur.append(time.time() - t0)
                old_lr = optimizer.lr
                optimizer.decay_lr(args['lr_decay_factor'])
                new_lr = optimizer.lr
                print('Learning rate decayed from {:.4f} to {:.4f}'.format(
                    old_lr, new_lr))
                torch.save({'model_state_dict': model.state_dict()},
                           args['result_path'] +
                           '/model_{:d}.pkl'.format(total_samples))
                prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d})\n'.format(
                    total_samples, epoch + 1, args['num_epochs'],
                    (batch_id + 1) * args['batch_size'] // args['print_every'],
                    len(train_set) //
                    args['print_every']) + candidate_ranking_eval(
                        args, model, val_loader)
                print(prediction_summary)
                with open(args['result_path'] + '/val_eval.txt', 'a') as f:
                    f.write(prediction_summary)
                t0 = time.time()
                model.train()
Exemple #2
0
def main(rank, dev_id, args):
    set_seed()
    # Remove the line below will result in problems for multiprocess
    if args['num_devices'] > 1:
        torch.set_num_threads(1)
    if dev_id == -1:
        args['device'] = torch.device('cpu')
    else:
        args['device'] = torch.device('cuda:{}'.format(dev_id))
        # Set current device
        torch.cuda.set_device(args['device'])

    train_set, val_set = load_dataset(args)
    get_center_subset(train_set, rank, args['num_devices'])
    train_loader = DataLoader(train_set, batch_size=args['batch_size'],
                              collate_fn=collate_center, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=args['batch_size'],
                            collate_fn=collate_center, shuffle=False)

    model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
                              edge_in_feats=args['edge_in_feats'],
                              node_pair_in_feats=args['node_pair_in_feats'],
                              node_out_feats=args['node_out_feats'],
                              n_layers=args['n_layers'],
                              n_tasks=args['n_tasks']).to(args['device'])
    model.train()
    if rank == 0:
        print('# trainable parameters in the model: ', count_parameters(model))

    criterion = BCEWithLogitsLoss(reduction='sum')
    optimizer = Adam(model.parameters(), lr=args['lr'])
    if args['num_devices'] <= 1:
        from utils import Optimizer
        optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
    else:
        from utils import MultiProcessOptimizer
        optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'],
                                          optimizer, max_grad_norm=args['max_norm'])

    total_iter = 0
    rank_iter = 0
    grad_norm_sum = 0
    loss_sum = 0
    dur = []

    for epoch in range(args['num_epochs']):
        t0 = time.time()
        for batch_id, batch_data in enumerate(train_loader):
            total_iter += args['num_devices']
            rank_iter += 1

            batch_reactions, batch_graph_edits, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            labels = batch_atom_pair_labels.to(args['device'])
            pred, biased_pred = reaction_center_prediction(
                args['device'], model, batch_mol_graphs, batch_complete_graphs)
            loss = criterion(pred, labels) / len(batch_reactions)
            loss_sum += loss.cpu().detach().data.item()
            grad_norm_sum += optimizer.backward_and_step(loss)

            if rank_iter % args['print_every'] == 0 and rank == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \
                           'loss {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
                    loss_sum / args['print_every'], grad_norm_sum / args['print_every'])
                print(progress)
                grad_norm_sum = 0
                loss_sum = 0

            if total_iter % args['decay_every'] == 0:
                optimizer.decay_lr(args['lr_decay_factor'])
            if total_iter % args['decay_every'] == 0 and rank == 0:
                if epoch >= 1:
                    dur.append(time.time() - t0)
                    print('Training time per {:d} iterations: {:.4f}'.format(
                        rank_iter, np.mean(dur)))
                total_samples = total_iter * args['batch_size']
                prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
                    total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
                      reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
                print(prediction_summary)
                with open(args['result_path'] + '/val_eval.txt', 'a') as f:
                    f.write(prediction_summary)
                torch.save({'model_state_dict': model.state_dict()},
                           args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
                t0 = time.time()
                model.train()
        synchronize(args['num_devices'])