Exemple #1
0
def main(args):
    set_seed()
    if torch.cuda.is_available():
        args['device'] = torch.device('cuda:0')
    else:
        args['device'] = torch.device('cpu')
    # Set current device
    torch.cuda.set_device(args['device'])

    if args['test_path'] is None:
        test_set = USPTOCenter('test',
                               num_processes=args['num_processes'],
                               load=args['load'])
    else:
        test_set = WLNCenterDataset(raw_file_path=args['test_path'],
                                    mol_graph_path=args['test_path'] + '.bin',
                                    num_processes=args['num_processes'],
                                    load=args['load'],
                                    reaction_validity_result_prefix='test')
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate_center,
                             shuffle=False)

    if args['model_path'] is None:
        model = load_pretrained('wln_center_uspto')
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks'])
        model.load_state_dict(
            torch.load(args['model_path'],
                       map_location='cpu')['model_state_dict'])
    model = model.to(args['device'])

    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, args['top_ks_test'], model,
                                             test_loader, args['easy'])
    print(test_result)
    with open(args['result_path'] + '/test_eval.txt', 'w') as f:
        f.write(test_result)
Exemple #2
0
def main(rank, dev_id, args):
    set_seed()
    # Remove the line below will result in problems for multiprocess
    if args['num_devices'] > 1:
        torch.set_num_threads(1)
    if dev_id == -1:
        args['device'] = torch.device('cpu')
    else:
        args['device'] = torch.device('cuda:{}'.format(dev_id))
        # Set current device
        torch.cuda.set_device(args['device'])

    train_set, val_set = load_dataset(args)
    get_center_subset(train_set, rank, args['num_devices'])
    train_loader = DataLoader(train_set, batch_size=args['batch_size'],
                              collate_fn=collate_center, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=args['batch_size'],
                            collate_fn=collate_center, shuffle=False)

    model = WLNReactionCenter(node_in_feats=args['node_in_feats'],
                              edge_in_feats=args['edge_in_feats'],
                              node_pair_in_feats=args['node_pair_in_feats'],
                              node_out_feats=args['node_out_feats'],
                              n_layers=args['n_layers'],
                              n_tasks=args['n_tasks']).to(args['device'])
    model.train()
    if rank == 0:
        print('# trainable parameters in the model: ', count_parameters(model))

    criterion = BCEWithLogitsLoss(reduction='sum')
    optimizer = Adam(model.parameters(), lr=args['lr'])
    if args['num_devices'] <= 1:
        from utils import Optimizer
        optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm'])
    else:
        from utils import MultiProcessOptimizer
        optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'],
                                          optimizer, max_grad_norm=args['max_norm'])

    total_iter = 0
    rank_iter = 0
    grad_norm_sum = 0
    loss_sum = 0
    dur = []

    for epoch in range(args['num_epochs']):
        t0 = time.time()
        for batch_id, batch_data in enumerate(train_loader):
            total_iter += args['num_devices']
            rank_iter += 1

            batch_reactions, batch_graph_edits, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            labels = batch_atom_pair_labels.to(args['device'])
            pred, biased_pred = reaction_center_prediction(
                args['device'], model, batch_mol_graphs, batch_complete_graphs)
            loss = criterion(pred, labels) / len(batch_reactions)
            loss_sum += loss.cpu().detach().data.item()
            grad_norm_sum += optimizer.backward_and_step(loss)

            if rank_iter % args['print_every'] == 0 and rank == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \
                           'loss {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
                    loss_sum / args['print_every'], grad_norm_sum / args['print_every'])
                print(progress)
                grad_norm_sum = 0
                loss_sum = 0

            if total_iter % args['decay_every'] == 0:
                optimizer.decay_lr(args['lr_decay_factor'])
            if total_iter % args['decay_every'] == 0 and rank == 0:
                if epoch >= 1:
                    dur.append(time.time() - t0)
                    print('Training time per {:d} iterations: {:.4f}'.format(
                        rank_iter, np.mean(dur)))
                total_samples = total_iter * args['batch_size']
                prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format(
                    total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \
                      reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True)
                print(prediction_summary)
                with open(args['result_path'] + '/val_eval.txt', 'a') as f:
                    f.write(prediction_summary)
                torch.save({'model_state_dict': model.state_dict()},
                           args['result_path'] + '/model_{:d}.pkl'.format(total_samples))
                t0 = time.time()
                model.train()
        synchronize(args['num_devices'])
Exemple #3
0
def main(args):
    setup(args)
    if args['train_path'] is None:
        train_set = USPTO('train')
    else:
        train_set = WLNReactionDataset(raw_file_path=args['train_path'],
                                       mol_graph_path='train.bin')
    if args['val_path'] is None:
        val_set = USPTO('val')
    else:
        val_set = WLNReactionDataset(raw_file_path=args['val_path'],
                                     mol_graph_path='val.bin')
    if args['test_path'] is None:
        test_set = USPTO('test')
    else:
        test_set = WLNReactionDataset(raw_file_path=args['test_path'],
                                      mol_graph_path='test.bin')
    train_loader = DataLoader(train_set,
                              batch_size=args['batch_size'],
                              collate_fn=collate,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=args['batch_size'],
                            collate_fn=collate,
                            shuffle=False)
    test_loader = DataLoader(test_set,
                             batch_size=args['batch_size'],
                             collate_fn=collate,
                             shuffle=False)

    if args['pre_trained']:
        model = load_pretrained('wln_center_uspto').to(args['device'])
        args['num_epochs'] = 0
    else:
        model = WLNReactionCenter(
            node_in_feats=args['node_in_feats'],
            edge_in_feats=args['edge_in_feats'],
            node_pair_in_feats=args['node_pair_in_feats'],
            node_out_feats=args['node_out_feats'],
            n_layers=args['n_layers'],
            n_tasks=args['n_tasks']).to(args['device'])

        criterion = BCEWithLogitsLoss(reduction='sum')
        optimizer = Adam(model.parameters(), lr=args['lr'])
        scheduler = StepLR(optimizer,
                           step_size=args['decay_every'],
                           gamma=args['lr_decay_factor'])

        total_iter = 0
        grad_norm_sum = 0
        loss_sum = 0
        dur = []

    for epoch in range(args['num_epochs']):
        t0 = time.time()
        for batch_id, batch_data in enumerate(train_loader):
            total_iter += 1

            batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \
            batch_complete_graphs, batch_atom_pair_labels = batch_data
            labels = batch_atom_pair_labels.to(args['device'])
            pred, biased_pred = reaction_center_prediction(
                args['device'], model, batch_mol_graphs, batch_complete_graphs)
            loss = criterion(pred, labels) / len(batch_reactions)
            loss_sum += loss.cpu().detach().data.item()
            optimizer.zero_grad()
            loss.backward()
            grad_norm = clip_grad_norm_(model.parameters(), args['max_norm'])
            grad_norm_sum += grad_norm
            optimizer.step()
            scheduler.step()

            if total_iter % args['print_every'] == 0:
                progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \
                           'loss {:.4f} | grad norm {:.4f}'.format(
                    epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader),
                    (np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'],
                    grad_norm_sum / args['print_every'])
                grad_norm_sum = 0
                loss_sum = 0
                print(progress)

            if total_iter % args['decay_every'] == 0:
                torch.save(model.state_dict(),
                           args['result_path'] + '/model.pkl')

        dur.append(time.time() - t0)
        print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \
              rough_eval_on_a_loader(args, model, val_loader))

    del train_loader
    del val_loader
    del train_set
    del val_set
    print('Evaluation on the test set.')
    test_result = reaction_center_final_eval(args, model, test_loader,
                                             args['easy'])
    print(test_result)
    with open(args['result_path'] + '/results.txt', 'w') as f:
        f.write(test_result)