def main(args): set_seed() if torch.cuda.is_available(): args['device'] = torch.device('cuda:0') else: args['device'] = torch.device('cpu') # Set current device torch.cuda.set_device(args['device']) if args['test_path'] is None: test_set = USPTOCenter('test', num_processes=args['num_processes'], load=args['load']) else: test_set = WLNCenterDataset(raw_file_path=args['test_path'], mol_graph_path=args['test_path'] + '.bin', num_processes=args['num_processes'], load=args['load'], reaction_validity_result_prefix='test') test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate_center, shuffle=False) if args['model_path'] is None: model = load_pretrained('wln_center_uspto') else: model = WLNReactionCenter( node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_pair_in_feats=args['node_pair_in_feats'], node_out_feats=args['node_out_feats'], n_layers=args['n_layers'], n_tasks=args['n_tasks']) model.load_state_dict( torch.load(args['model_path'], map_location='cpu')['model_state_dict']) model = model.to(args['device']) print('Evaluation on the test set.') test_result = reaction_center_final_eval(args, args['top_ks_test'], model, test_loader, args['easy']) print(test_result) with open(args['result_path'] + '/test_eval.txt', 'w') as f: f.write(test_result)
def main(args): setup(args) if args['train_path'] is None: train_set = USPTO('train') else: train_set = WLNReactionDataset(raw_file_path=args['train_path'], mol_graph_path='train.bin') if args['val_path'] is None: val_set = USPTO('val') else: val_set = WLNReactionDataset(raw_file_path=args['val_path'], mol_graph_path='val.bin') if args['test_path'] is None: test_set = USPTO('test') else: test_set = WLNReactionDataset(raw_file_path=args['test_path'], mol_graph_path='test.bin') train_loader = DataLoader(train_set, batch_size=args['batch_size'], collate_fn=collate, shuffle=True) val_loader = DataLoader(val_set, batch_size=args['batch_size'], collate_fn=collate, shuffle=False) test_loader = DataLoader(test_set, batch_size=args['batch_size'], collate_fn=collate, shuffle=False) if args['pre_trained']: model = load_pretrained('wln_center_uspto').to(args['device']) args['num_epochs'] = 0 else: model = WLNReactionCenter( node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_pair_in_feats=args['node_pair_in_feats'], node_out_feats=args['node_out_feats'], n_layers=args['n_layers'], n_tasks=args['n_tasks']).to(args['device']) criterion = BCEWithLogitsLoss(reduction='sum') optimizer = Adam(model.parameters(), lr=args['lr']) scheduler = StepLR(optimizer, step_size=args['decay_every'], gamma=args['lr_decay_factor']) total_iter = 0 grad_norm_sum = 0 loss_sum = 0 dur = [] for epoch in range(args['num_epochs']): t0 = time.time() for batch_id, batch_data in enumerate(train_loader): total_iter += 1 batch_reactions, batch_graph_edits, batch_mols, batch_mol_graphs, \ batch_complete_graphs, batch_atom_pair_labels = batch_data labels = batch_atom_pair_labels.to(args['device']) pred, biased_pred = reaction_center_prediction( args['device'], model, batch_mol_graphs, batch_complete_graphs) loss = criterion(pred, labels) / len(batch_reactions) loss_sum += loss.cpu().detach().data.item() optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), args['max_norm']) grad_norm_sum += grad_norm optimizer.step() scheduler.step() if total_iter % args['print_every'] == 0: progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | time/minibatch {:.4f} | ' \ 'loss {:.4f} | grad norm {:.4f}'.format( epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader), (np.sum(dur) + time.time() - t0) / total_iter, loss_sum / args['print_every'], grad_norm_sum / args['print_every']) grad_norm_sum = 0 loss_sum = 0 print(progress) if total_iter % args['decay_every'] == 0: torch.save(model.state_dict(), args['result_path'] + '/model.pkl') dur.append(time.time() - t0) print('Epoch {:d}/{:d}, validation '.format(epoch + 1, args['num_epochs']) + \ rough_eval_on_a_loader(args, model, val_loader)) del train_loader del val_loader del train_set del val_set print('Evaluation on the test set.') test_result = reaction_center_final_eval(args, model, test_loader, args['easy']) print(test_result) with open(args['result_path'] + '/results.txt', 'w') as f: f.write(test_result)
def main(rank, dev_id, args): set_seed() # Remove the line below will result in problems for multiprocess if args['num_devices'] > 1: torch.set_num_threads(1) if dev_id == -1: args['device'] = torch.device('cpu') else: args['device'] = torch.device('cuda:{}'.format(dev_id)) # Set current device torch.cuda.set_device(args['device']) train_set, val_set = load_dataset(args) get_center_subset(train_set, rank, args['num_devices']) train_loader = DataLoader(train_set, batch_size=args['batch_size'], collate_fn=collate_center, shuffle=True) val_loader = DataLoader(val_set, batch_size=args['batch_size'], collate_fn=collate_center, shuffle=False) model = WLNReactionCenter(node_in_feats=args['node_in_feats'], edge_in_feats=args['edge_in_feats'], node_pair_in_feats=args['node_pair_in_feats'], node_out_feats=args['node_out_feats'], n_layers=args['n_layers'], n_tasks=args['n_tasks']).to(args['device']) model.train() if rank == 0: print('# trainable parameters in the model: ', count_parameters(model)) criterion = BCEWithLogitsLoss(reduction='sum') optimizer = Adam(model.parameters(), lr=args['lr']) if args['num_devices'] <= 1: from utils import Optimizer optimizer = Optimizer(model, args['lr'], optimizer, max_grad_norm=args['max_norm']) else: from utils import MultiProcessOptimizer optimizer = MultiProcessOptimizer(args['num_devices'], model, args['lr'], optimizer, max_grad_norm=args['max_norm']) total_iter = 0 rank_iter = 0 grad_norm_sum = 0 loss_sum = 0 dur = [] for epoch in range(args['num_epochs']): t0 = time.time() for batch_id, batch_data in enumerate(train_loader): total_iter += args['num_devices'] rank_iter += 1 batch_reactions, batch_graph_edits, batch_mol_graphs, \ batch_complete_graphs, batch_atom_pair_labels = batch_data labels = batch_atom_pair_labels.to(args['device']) pred, biased_pred = reaction_center_prediction( args['device'], model, batch_mol_graphs, batch_complete_graphs) loss = criterion(pred, labels) / len(batch_reactions) loss_sum += loss.cpu().detach().data.item() grad_norm_sum += optimizer.backward_and_step(loss) if rank_iter % args['print_every'] == 0 and rank == 0: progress = 'Epoch {:d}/{:d}, iter {:d}/{:d} | ' \ 'loss {:.4f} | grad norm {:.4f}'.format( epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader), loss_sum / args['print_every'], grad_norm_sum / args['print_every']) print(progress) grad_norm_sum = 0 loss_sum = 0 if total_iter % args['decay_every'] == 0: optimizer.decay_lr(args['lr_decay_factor']) if total_iter % args['decay_every'] == 0 and rank == 0: if epoch >= 1: dur.append(time.time() - t0) print('Training time per {:d} iterations: {:.4f}'.format( rank_iter, np.mean(dur))) total_samples = total_iter * args['batch_size'] prediction_summary = 'total samples {:d}, (epoch {:d}/{:d}, iter {:d}/{:d}) '.format( total_samples, epoch + 1, args['num_epochs'], batch_id + 1, len(train_loader)) + \ reaction_center_final_eval(args, args['top_ks_val'], model, val_loader, easy=True) print(prediction_summary) with open(args['result_path'] + '/val_eval.txt', 'a') as f: f.write(prediction_summary) torch.save({'model_state_dict': model.state_dict()}, args['result_path'] + '/model_{:d}.pkl'.format(total_samples)) t0 = time.time() model.train() synchronize(args['num_devices'])
def prepare_reaction_center(args, reaction_center_config): """Use a trained model for reaction center prediction to prepare candidate bonds. Parameters ---------- args : dict Configuration for the experiment. reaction_center_config : dict Configuration for the experiment on reaction center prediction. Returns ------- path_to_candidate_bonds : dict Mapping 'train', 'val', 'test' to the corresponding files for candidate bonds. """ if args['center_model_path'] is None: reaction_center_model = load_pretrained('wln_center_uspto').to( args['device']) else: reaction_center_model = WLNReactionCenter( node_in_feats=reaction_center_config['node_in_feats'], edge_in_feats=reaction_center_config['edge_in_feats'], node_pair_in_feats=reaction_center_config['node_pair_in_feats'], node_out_feats=reaction_center_config['node_out_feats'], n_layers=reaction_center_config['n_layers'], n_tasks=reaction_center_config['n_tasks']) reaction_center_model.load_state_dict( torch.load(args['center_model_path'])['model_state_dict']) reaction_center_model = reaction_center_model.to(args['device']) reaction_center_model.eval() path_to_candidate_bonds = dict() for subset in ['train', 'val', 'test']: if '{}_path'.format(subset) not in args: continue path_to_candidate_bonds[subset] = args['result_path'] + \ '/{}_candidate_bonds.txt'.format(subset) if os.path.isfile(path_to_candidate_bonds[subset]): continue print('Processing subset {}...'.format(subset)) print('Stage 1/3: Loading dataset...') if args['{}_path'.format(subset)] is None: dataset = USPTOCenter(subset, num_processes=args['num_processes']) else: dataset = WLNCenterDataset( raw_file_path=args['{}_path'.format(subset)], mol_graph_path='{}.bin'.format(subset), num_processes=args['num_processes']) dataloader = DataLoader(dataset, batch_size=args['reaction_center_batch_size'], collate_fn=collate_center, shuffle=False) print('Stage 2/3: Performing model prediction...') output_strings = [] for batch_id, batch_data in enumerate(dataloader): print('Computing candidate bonds for batch {:d}/{:d}'.format( batch_id + 1, len(dataloader))) batch_reactions, batch_graph_edits, batch_mol_graphs, \ batch_complete_graphs, batch_atom_pair_labels = batch_data with torch.no_grad(): pred, biased_pred = reaction_center_prediction( args['device'], reaction_center_model, batch_mol_graphs, batch_complete_graphs) batch_size = len(batch_reactions) start = 0 for i in range(batch_size): end = start + batch_complete_graphs.batch_num_edges[i] output_strings.append( output_candidate_bonds_for_a_reaction( (batch_reactions[i], biased_pred[start:end, :].flatten(), batch_complete_graphs.batch_num_nodes[i]), reaction_center_config['max_k'])) start = end print('Stage 3/3: Output candidate bonds...') with open(path_to_candidate_bonds[subset], 'w') as f: for candidate_string in output_strings: f.write(candidate_string) del dataset del dataloader del reaction_center_model return path_to_candidate_bonds