def packed_sequential_data_preparation(input_batch, input_keep=1, start_index=2, end_index=3, dropout_index=1, device=get_device(), enforce_sorted=False): """ Sequential Training Data Builder. Args: input_batch (torch.Tensor): Batch of padded sequences, output of nn.utils.rnn.pad_sequence(batch) of size `[sequence length, batch_size, 1]`. input_keep (float): The probability not to drop input sequence tokens according to a Bernoulli distribution with p = input_keep. Defaults to 1. start_index (int): The index of the sequence start token. end_index (int): The index of the sequence end token. dropout_index (int): The index of the dropout token. Defaults to 1. Returns: (torch.Tensor, torch.Tensor, torch.Tensor): encoder_seq, decoder_seq, target_seq encoder_seq is a batch of padded input sequences starting with the start_index, of size `[sequence length +1, batch_size, 1]`. decoder_seq is like encoder_seq but word dropout is applied (so if input_keep==1, then decoder_seq = encoder_seq). target_seq (torch.Tensor): Batch of padded target sequences ending in the end_index, of size `[sequence length +1, batch_size, 1]`. """ def _process_sample(sample): if len(sample.shape) != 1: raise ValueError input = sample.long().to(device) decoder = input.clone() # apply token dropout if keep != 1 if input_keep != 1: # mask for token dropout mask = Bernoulli(input_keep).sample(input.shape) mask = torch.LongTensor(mask.numpy()) dropout_loc = np.where(mask == 0)[0] decoder[dropout_loc] = dropout_index # just .clone() propagates to graph target = torch.cat( [input[1:].detach().clone(), torch.Tensor([0]).long().to(device)]) return input, decoder, target.to(device) batch = [_process_sample(sample) for sample in input_batch] encoder_decoder_target = zip(*batch) encoder_decoder_target = [ torch.nn.utils.rnn.pack_sequence(entry, enforce_sorted=enforce_sorted) for entry in encoder_decoder_target ] return encoder_decoder_target
def load_pretrained_paccmann(self, params_file: str, lang_file: str, weights_file: str, batch_size: int, batch_mode: str): params = dict() with open(params_file, 'r') as f: params.update(json.load(f)) params['batch_mode'] = batch_mode params['batch_size'] = batch_size self.selfies = params.get('selfies', False) self.device = get_device() self.smiles_language = SMILESLanguage.load(lang_file) self.gru_encoder = StackGRUEncoder(params).to(self.device) self.gru_decoder = StackGRUDecoder(params).to(self.device) self.gru_vae = TeacherVAE(self.gru_encoder, self.gru_decoder).to(self.device) self.gru_vae.load_state_dict( torch.load(weights_file, map_location=self.device)) self.gru_vae.eval() transforms = [] if self.selfies: transforms += [Selfies()] transforms += [ SMILESToTokenIndexes(smiles_language=self.smiles_language) ] transforms += [ToTensor(device=self.device)] self.transform = Compose(transforms)
def main(*, parser_namespace): disable_rdkit_logging() # read the params json params = dict() with open(parser_namespace.params_path) as f: params.update(json.load(f)) # get params mol_model_path = params.get('mol_model_path', parser_namespace.mol_model_path) omics_model_path = params.get('omics_model_path', parser_namespace.omics_model_path) ic50_model_path = params.get('ic50_model_path', parser_namespace.ic50_model_path) omics_data_path = params.get('omics_data_path', parser_namespace.omics_data_path) model_name = params.get( 'model_name', parser_namespace.model_name ) # yapf: disable site = params.get( 'site', parser_namespace.site ) # yapf: disable params['site'] = site logger.info(f'Model with name {model_name} starts.') # Load omics profiles for conditional generation, # complement with avg per site omics_df = pd.read_pickle(omics_data_path) omics_df = add_avg_profile(omics_df) # Restore SMILES Model with open(os.path.join(mol_model_path, 'model_params.json')) as f: mol_params = json.load(f) gru_encoder = StackGRUEncoder(mol_params) gru_decoder = StackGRUDecoder(mol_params) generator = TeacherVAE(gru_encoder, gru_decoder) generator.load(os.path.join( mol_model_path, f"weights/best_{params.get('smiles_metric', 'rec')}.pt"), map_location=get_device()) # Load languages generator_smiles_language = SMILESLanguage.load( os.path.join(mol_model_path, 'selfies_language.pkl')) generator._associate_language(generator_smiles_language) # Restore omics model with open(os.path.join(omics_model_path, 'model_params.json')) as f: cell_params = json.load(f) # Define network cell_encoder = ENCODER_FACTORY['dense'](cell_params) cell_encoder.load(os.path.join( omics_model_path, f"weights/best_{params.get('omics_metric','both')}_encoder.pt"), map_location=get_device()) cell_encoder.eval() # Restore PaccMann with open(os.path.join(ic50_model_path, 'model_params.json')) as f: paccmann_params = json.load(f) paccmann_predictor = MODEL_FACTORY['mca'](paccmann_params) paccmann_predictor.load(os.path.join( ic50_model_path, f"weights/best_{params.get('ic50_metric', 'rmse')}_mca.pt"), map_location=get_device()) paccmann_predictor.eval() paccmann_smiles_language = SMILESLanguage.load( os.path.join(ic50_model_path, 'smiles_language.pkl')) paccmann_predictor._associate_language(paccmann_smiles_language) # Specifies the baseline model used for comparison baseline = ReinforceOmic(generator, cell_encoder, paccmann_predictor, omics_df, params, 'baseline', logger) # Create a fresh model that will be optimized gru_encoder_rl = StackGRUEncoder(mol_params) gru_decoder_rl = StackGRUDecoder(mol_params) generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl) generator_rl.load(os.path.join( mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"), map_location=get_device()) generator_rl.eval() generator_rl._associate_language(generator_smiles_language) cell_encoder_rl = ENCODER_FACTORY['dense'](cell_params) cell_encoder_rl.load(os.path.join( omics_model_path, f"weights/best_{params.get('metric', 'both')}_encoder.pt"), map_location=get_device()) cell_encoder_rl.eval() model_folder_name = site + '_' + model_name learner = ReinforceOmic(generator_rl, cell_encoder_rl, paccmann_predictor, omics_df, params, model_folder_name, logger) # Split the samples for conditional generation and initialize training train_omics, test_omics = omics_data_splitter( omics_df, site, params.get('test_fraction', 0.2)) rewards, rl_losses = [], [] gen_mols, gen_cell, gen_ic50, modes = [], [], [], [] logger.info('Models restored, start training.') for epoch in range(1, params['epochs'] + 1): for step in range(1, params['steps']): # Randomly sample a cell line: cell_line = np.random.choice(train_omics) rew, loss = learner.policy_gradient(cell_line, epoch, params['batch_size']) print(f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/" f"{params['steps']:d}\t loss={loss:.2f}, rew={rew:.2f}") rewards.append(rew.item()) rl_losses.append(loss) # Save model learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt') # Compare baseline and trained model on cell line base_smiles, base_preds = baseline.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], cell_line) smiles, preds = learner.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], cell_line) gs = [ s for i, s in enumerate(smiles) if preds[i] < learner.ic50_threshold ] gp = preds[preds < learner.ic50_threshold] for p, s in zip(gp, gs): gen_mols.append(s) gen_cell.append(cell_line) gen_ic50.append(p) modes.append('train') plot_and_compare(base_preds, preds, site, cell_line, epoch, learner.model_path, 'train', params['eval_batch_size']) # Evaluate on a validation cell line. eval_cell_line = np.random.choice(test_omics) base_smiles, base_preds = baseline.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], eval_cell_line) smiles, preds = learner.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], eval_cell_line) plot_and_compare(base_preds, preds, site, eval_cell_line, epoch, learner.model_path, 'test', params['eval_batch_size']) gs = [ s for i, s in enumerate(smiles) if preds[i] < learner.ic50_threshold ] gp = preds[preds < learner.ic50_threshold] for p, s in zip(gp, gs): gen_mols.append(s) gen_cell.append(eval_cell_line) gen_ic50.append(p) modes.append('test') inds = np.argsort(preds) for i in inds[:5]: logger.info(f'Epoch {epoch:d}, generated {smiles[i]} against ' f'{eval_cell_line}.\n Predicted IC50 = {preds[i]}. ') # Save results (good molecules!) in DF df = pd.DataFrame({ 'cell_line': gen_cell, 'SMILES': gen_mols, 'IC50': gen_ic50, 'mode': modes, 'tox21': [learner.tox21(s) for s in gen_mols] }) df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv')) # Plot loss development loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards}) loss_df.to_csv(learner.model_path + '/results/loss_reward_evolution.csv') plot_loss(rl_losses, rewards, params['epochs'], cell_line, learner.model_path, rolling=5, site=site)
"""SMILES decoding from latent space module.""" import sys from itertools import count from typing import Any, List import torch from rdkit import Chem from paccmann_chemistry.models.vae import TeacherVAE from paccmann_chemistry.utils import get_device from paccmann_chemistry.utils.search import SamplingSearch device = get_device() def get_stack_size(size_hint: int = 2) -> int: """Stack size from caller's frame from: https://stackoverflow.com/a/47956089. Args: size_hint: hint for the stack size. Defaults to 2. Returns: size of the stack. """ get_frame = sys._getframe frame = None try: while True: frame = get_frame(size_hint) size_hint *= 2 except ValueError: if frame:
def train( epoch, model, train_loader, optimizer, scheduler, graph_gamma, growth_rate=0.0015, writer=None, verbose=False ): start_time = time() device = get_device() # selfies = train_loader.dataset._dataset.selfies data_preparation = packed_sequential_data_preparation model.gru_vae.to(device) model.gru_vae.train() input_keep = 1. start_index = 2 end_index = 3 train_loss = 0 for _iter, data in tqdm.tqdm( enumerate(train_loader), total=len(train_loader), disable=(not verbose) ): seqs = data.x batch_size = len(seqs) # FIXME? variable batch size in data model.gru_decoder._update_batch_size(batch_size) model.gru_encoder._update_batch_size(batch_size) encoder_seq, decoder_seq, target_seq = data_preparation( seqs, input_keep=input_keep, start_index=start_index, end_index=end_index, device=device ) optimizer.zero_grad() decoder_loss, mu, logvar, z = model.train_step( encoder_seq, decoder_seq, target_seq ) # Compute distances graph = nx.from_edgelist(data.edge_index.T.tolist()) graph.add_nodes_from(list(range(len(seqs)))) dists = nx.floyd_warshall_numpy(graph) dists[np.isinf(dists)] = 0 dists = torch.tensor(dists).to(device) kl_scale = 1 / ( 1 + np.exp((6 - growth_rate * epoch + (_iter / len(train_loader)))) ) z = z.squeeze() gr_loss = graph_loss(z, dists) loss = decoder_loss + graph_gamma * gr_loss loss.backward() optimizer.step() writer.add_scalar( 'recon_loss', decoder_loss, _iter + epoch * len(train_loader) ) writer.add_scalar( 'graph_loss', gr_loss, _iter + epoch * len(train_loader) ) writer.add_scalar( 'loss', loss.item(), _iter + epoch * len(train_loader) ) train_loss += loss.item() if _iter % (len(train_loader) // 10) == 0: tqdm.tqdm.write(f'{decoder_loss}\t{gr_loss}') scheduler.step() logger.info( f"Learning rate {optimizer.param_groups[0]['lr']}" f"\tkl_scale {kl_scale}" ) logger.info(f'{epoch}\t{train_loss/_iter}\t{time()-start_time}')
def main(parser_namespace): # model loading disable_rdkit_logging() affinity_path = parser_namespace.affinity_path svae_path = parser_namespace.svae_path svae_weights_path = os.path.join(svae_path, "weights", "best_rec.pt") results_file_name = parser_namespace.optimisation_name logger.add(results_file_name + ".log", rotation="10 MB") svae_params = dict() with open(os.path.join(svae_path, "model_params.json"), "r") as f: svae_params.update(json.load(f)) smiles_language = SMILESLanguage.load( os.path.join(svae_path, "selfies_language.pkl")) # initialize encoder, decoder, testVAE, and GP_generator_MW gru_encoder = StackGRUEncoder(svae_params) gru_decoder = StackGRUDecoder(svae_params) gru_vae = TeacherVAE(gru_encoder, gru_decoder) gru_vae.load_state_dict( torch.load(svae_weights_path, map_location=get_device())) gru_vae._associate_language(smiles_language) gru_vae.eval() smiles_generator = SmilesGenerator(gru_vae) with open(os.path.join(affinity_path, "model_params.json")) as f: predictor_params = json.load(f) affinity_predictor = MODEL_FACTORY["bimodal_mca"](predictor_params) affinity_predictor.load( os.path.join( affinity_path, f"weights/best_{predictor_params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt", ), map_location=get_device(), ) affinity_protein_language = ProteinLanguage.load( os.path.join(affinity_path, "protein_language.pkl")) affinity_smiles_language = SMILESLanguage.load( os.path.join(affinity_path, "smiles_language.pkl")) affinity_predictor._associate_language(affinity_smiles_language) affinity_predictor._associate_language(affinity_protein_language) affinity_predictor.eval() erg_protein = "MASTIKEALSVVSEDQSLFECAYGTPHLAKTEMTASSSSDYGQTSKMSPRVPQQDWLSQPPARVTIKMECNPSQVNGSRNSPDECSVAKGGKMVGSPDTVGMNYGSYMEEKHMPPPNMTTNERRVIVPADPTLWSTDHVRQWLEWAVKEYGLPDVNILLFQNIDGKELCKMTKDDFQRLTPSYNADILLSHLHYLRETPLPHLTSDDVDKALQNSPRLMHARNTGGAAFIFPNTSVYPEATQRITTRPDLPYEPPRRSAWTGHGHPTPQSKAAQPSPSTVPKTEDQRPQLDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSSNSSCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPPESSLYKYPSDLPYMGSYHAHPQKMNFVAPHPPALPVTSSSFFAAPNPYWNSPTGGIYPNTRLPTSHMPSHLGTYY" target_minimization_function = AffinityMinimization( smiles_generator, 30, affinity_predictor, erg_protein) qed_function = QEDMinimization(smiles_generator, 30) sa_function = SAMinimization(smiles_generator, 30) combined_minimization = CombinedMinimization( [target_minimization_function, qed_function, sa_function], 1, [0.75, 1, 0.5]) target_optimizer = GPOptimizer(combined_minimization.evaluate) params = dict( dimensions=[(-5.0, 5.0)] * 256, acq_func="EI", n_calls=20, n_initial_points=19, initial_point_generator="random", random_state=1234, ) logger.info("Optimisation parameters: {params}", params=params) # optimisation for j in range(5): res = target_optimizer.optimize(params) latent_point = torch.tensor([[res.x]]) with open(results_file_name + "_LP" + str(j + 1) + ".pkl", "wb") as f: pickle.dump(latent_point, f, protocol=2) smile_set = set() while len(smile_set) < 20: smiles = smiles_generator.generate_smiles( latent_point.repeat(1, 30, 1)) smile_set.update(set(smiles)) smile_set = list(smile_set) pad_smiles_predictor = LeftPadding( affinity_predictor.smiles_padding_length, affinity_predictor.smiles_language.padding_index, ) to_tensor = ToTensor(get_device()) smiles_num = [ torch.unsqueeze( to_tensor( pad_smiles_predictor( affinity_predictor.smiles_language. smiles_to_token_indexes(smile))), 0, ) for smile in smile_set ] smiles_tensor = torch.cat(smiles_num, dim=0) pad_protein_predictor = LeftPadding( affinity_predictor.protein_padding_length, affinity_predictor.protein_language.padding_index, ) protein_num = torch.unsqueeze( to_tensor( pad_protein_predictor( affinity_predictor.protein_language. sequence_to_token_indexes([erg_protein]))), 0, ) protein_num = protein_num.repeat(len(smile_set), 1) with torch.no_grad(): pred, _ = affinity_predictor(smiles_tensor, protein_num) affinities = torch.squeeze(pred, 1).numpy() sas = SAS() sa_scores = [sas(smile) for smile in smile_set] qed_scores = [qed(Chem.MolFromSmiles(smile)) for smile in smile_set] # save to file file = results_file_name + str(j + 1) + ".txt" logger.info("creating {file}", file=file) with open(file, "w") as f: f.write( f'{"point":<10}{"Affinity":<10}{"QED":<10}{"SA":<10}{"smiles":<15}\n' ) for i in range(20): dat = [ i + 1, affinities[i], qed_scores[i], sa_scores[i], smile_set[i] ] f.write( f'{dat[0]:<10}{"%.3f"%dat[1]:<10}{"%.3f"%dat[2]:<10}{"%.3f"%dat[3]:<10}{dat[4]:<15}\n' )
def main(parser_namespace): disable_rdkit_logging() model_path = parser_namespace.model_path data_path = parser_namespace.data_path weights_path = os.path.join(model_path, 'weights', 'best_loss.pt') device = get_device() # read the params json params = dict() with open(os.path.join(model_path, 'model_params.json'), 'r') as f: params.update(json.load(f)) params['batch_size'] = 1 # Load SMILES language smiles_language = SMILESLanguage.load( os.path.join(model_path, 'selfies_language.pkl')) data_preparation = get_data_preparation(params.get('batch_mode')) device = get_device() dataset = SMILESDataset( data_path, smiles_language=smiles_language, padding=False, selfies=params.get('selfies', False), add_start_and_stop=params.get('add_start_stop_token', True), augment=False, #params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), backend='lazy', device=device) dataloader = torch.utils.data.DataLoader( dataset, batch_size=params.get('batch_size', 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get('pin_memory', True), num_workers=params.get('num_workers', 8)) # initialize encoder and decoder gru_encoder = StackGRUEncoder(params).to(device) gru_decoder = StackGRUDecoder(params).to(device) gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device) logger.info('\n****MODEL SUMMARY***\n') for name, parameter in gru_vae.named_parameters(): logger.info(f'Param {name}, shape:\t{parameter.shape}') total_params = sum(p.numel() for p in gru_vae.parameters()) logger.info(f'Total # params: {total_params}') gru_vae.load_state_dict(torch.load(weights_path, map_location=device)) # Updating the vocab size will break the model params.update({ # 'vocab_size': smiles_language.number_of_tokens, 'pad_index': smiles_language.padding_index }) # yapf:disable # if params.get('embedding', 'learned') == 'one_hot': # params.update({'embedding_size': params['vocab_size']}) # train for n_epoch epochs logger.info( 'Model creation, loading and data processing done. Evaluation starts.') gru_vae.eval() gru_vae.to(device) counter = 0 with torch.no_grad(): latent_code = [] from tqdm import tqdm for batch in tqdm(dataloader, total=len(dataloader)): (encoder_seq, _, _) = data_preparation(batch, input_keep=0., start_index=2, end_index=3, device=device) try: mu, logvar = gru_vae.encode(encoder_seq) except RuntimeError: # Substitute any new tokens by "<UNK>" tokens new_seq = [] padd_encoder_seq, lenghts = ( torch.nn.utils.rnn.pad_packed_sequence(encoder_seq, batch_first=True)) for seq, _len in zip(padd_encoder_seq, lenghts): seq = seq[:_len] if any([x >= params['vocab_size'] for x in seq]): seq = torch.tensor([ x if x < params['vocab_size'] else smiles_language.unknown_index for x in seq.tolist() ]).short() failed_smiles = smiles_language.selfies_to_smiles( smiles_language.token_indexes_to_smiles( seq.tolist())) logger.warning( f'Out of bounds sample: ~{counter}\t{failed_smiles}' ) new_seq.append(seq) if new_seq: for _ in range(params['batch_size'] - len(new_seq)): new_seq.append(torch.ones_like(new_seq[-1])) (encoder_seq, _, _) = data_preparation(new_seq, input_keep=0., start_index=2, end_index=3, device=device) mu, logvar = gru_vae.encode(encoder_seq) for _mu in mu.tolist(): latent_code.append([counter, _mu]) counter += 1 LATENT_CODE_PATH = os.path.join(os.path.dirname(data_path), 'samples_latent_code.tsv') with open(LATENT_CODE_PATH, 'w') as f: for i, mu in latent_code: f.write(f'{i}\t{",".join([str(x) for x in mu[0]])}\n')
def main(parser_namespace): try: device = get_device() disable_rdkit_logging() # read the params json params = dict() with open(parser_namespace.params_filepath) as f: params.update(json.load(f)) # get params train_smiles_filepath = parser_namespace.train_smiles_filepath test_smiles_filepath = parser_namespace.test_smiles_filepath smiles_language_filepath = ( parser_namespace.smiles_language_filepath if parser_namespace.smiles_language_filepath.lower() != 'none' else None) model_path = parser_namespace.model_path training_name = parser_namespace.training_name writer = SummaryWriter(f'logs/{training_name}') logger.info(f'Model with name {training_name} starts.') model_dir = os.path.join(model_path, training_name) log_path = os.path.join(model_dir, 'logs') val_dir = os.path.join(log_path, 'val_logs') os.makedirs(os.path.join(model_dir, 'weights'), exist_ok=True) os.makedirs(os.path.join(model_dir, 'results'), exist_ok=True) os.makedirs(log_path, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # Load SMILES language smiles_language = None if smiles_language_filepath is not None: smiles_language = SMILESLanguage.load(smiles_language_filepath) logger.info(f'Smiles filepath: {train_smiles_filepath}') # create SMILES eager dataset smiles_train_data = SMILESDataset( train_smiles_filepath, smiles_language=smiles_language, padding=False, selfies=params.get('selfies', False), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), backend='lazy', device=device, ) smiles_test_data = SMILESDataset( test_smiles_filepath, smiles_language=smiles_language, padding=False, selfies=params.get('selfies', False), add_start_and_stop=params.get('add_start_stop_token', True), augment=params.get('augment_smiles', False), canonical=params.get('canonical', False), kekulize=params.get('kekulize', False), all_bonds_explicit=params.get('all_bonds_explicit', False), all_hs_explicit=params.get('all_hs_explicit', False), remove_bonddir=params.get('remove_bonddir', False), remove_chirality=params.get('remove_chirality', False), backend='lazy', device=device, ) if smiles_language_filepath is None: smiles_language = smiles_train_data.smiles_language smiles_language.save( os.path.join(model_path, f'{training_name}.lang')) else: smiles_language_filename = os.path.basename( smiles_language_filepath) smiles_language.save( os.path.join(model_dir, smiles_language_filename)) params.update({ 'vocab_size': smiles_language.number_of_tokens, 'pad_index': smiles_language.padding_index }) vocab_dict = smiles_language.index_to_token params.update({ 'start_index': list(vocab_dict.keys())[list( vocab_dict.values()).index('<START>')], 'end_index': list(vocab_dict.keys())[list(vocab_dict.values()).index('<STOP>')] }) if params.get('embedding', 'learned') == 'one_hot': params.update({'embedding_size': params['vocab_size']}) with open(os.path.join(model_dir, 'model_params.json'), 'w') as fp: json.dump(params, fp) # create DataLoaders train_data_loader = torch.utils.data.DataLoader( smiles_train_data, batch_size=params.get('batch_size', 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get('pin_memory', True), num_workers=params.get('num_workers', 8)) test_data_loader = torch.utils.data.DataLoader( smiles_test_data, batch_size=params.get('batch_size', 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get('pin_memory', True), num_workers=params.get('num_workers', 8)) # initialize encoder and decoder gru_encoder = StackGRUEncoder(params).to(device) gru_decoder = StackGRUDecoder(params).to(device) gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device) # TODO I haven't managed to get this to work. I will leave it here # if somewant (or future me) wants to give it a look and get the # tensorboard graph to work # if writer and False: # gru_vae.set_batch_mode('padded') # dummy_input = torch.ones(smiles_train_data[0].shape) # dummy_input = dummy_input.unsqueeze(0).to(device) # writer.add_graph(gru_vae, (dummy_input, dummy_input, dummy_input)) # gru_vae.set_batch_mode(params.get('batch_mode')) logger.info('\n****MODEL SUMMARY***\n') for name, parameter in gru_vae.named_parameters(): logger.info(f'Param {name}, shape:\t{parameter.shape}') total_params = sum(p.numel() for p in gru_vae.parameters()) logger.info(f'Total # params: {total_params}') loss_tracker = { 'test_loss_a': 10e4, 'test_rec_a': 10e4, 'test_kld_a': 10e4, 'ep_loss': 0, 'ep_rec': 0, 'ep_kld': 0 } # train for n_epoch epochs logger.info( 'Model creation and data processing done, Training starts.') decoder_search = SEARCH_FACTORY[ params.get('decoder_search', 'sampling') ]( temperature=params.get('temperature', 1.), beam_width=params.get('beam_width', 3), top_tokens=params.get('top_tokens', 5) ) # yapf: disable if writer: pparams = params.copy() pparams['training_file'] = train_smiles_filepath pparams['test_file'] = test_smiles_filepath pparams['language_file'] = smiles_language_filepath pparams['model_path'] = model_path pparams = { k: v if v is not None else 'N.A.' for k, v in params.items() } pparams['training_name'] = training_name from pprint import pprint pprint(pparams) writer.add_hparams(hparam_dict=pparams, metric_dict={}) for epoch in range(params['epochs'] + 1): t = time() loss_tracker = train_vae( epoch, gru_vae, train_data_loader, test_data_loader, smiles_language, model_dir, search=decoder_search, optimizer=params.get('optimizer', 'adadelta'), lr=params['learning_rate'], kl_growth=params['kl_growth'], input_keep=params['input_keep'], test_input_keep=params['test_input_keep'], generate_len=params['generate_len'], log_interval=params['log_interval'], save_interval=params['save_interval'], eval_interval=params['eval_interval'], loss_tracker=loss_tracker, logger=logger, # writer=writer, batch_mode=params.get('batch_mode')) logger.info(f'Epoch {epoch}, took {time() - t:.1f}.') logger.info('OVERALL: \t Best loss = {0:.4f} in Ep {1}, ' 'best Rec = {2:.4f} in Ep {3}, ' 'best KLD = {4:.4f} in Ep {5}'.format( loss_tracker['test_loss_a'], loss_tracker['ep_loss'], loss_tracker['test_rec_a'], loss_tracker['ep_rec'], loss_tracker['test_kld_a'], loss_tracker['ep_kld'])) logger.info('Training done, shutting down.') except Exception: logger.exception('Exception occurred while running train_vae.py.')
def train( epoch, model, train_loader, optimizer, scheduler, writer=None, verbose=False ): start_time = time() device = get_device() # selfies = train_loader.dataset._dataset.selfies data_preparation = packed_sequential_data_preparation model.to(device) model.train() input_keep = 1. start_index = 2 end_index = 3 train_loss = 0 for _iter, data in tqdm.tqdm( enumerate(train_loader), total=len(train_loader), disable=(not verbose) ): # Seqs are list of strings, so they must be first preprocessed # and the data has then to be moved .to(device) seqs = data.x batch_size = len(seqs) # FIXME? variable batch size in data model.encoder.update_batch_size(batch_size) encoder_seq, _, _ = data_preparation( seqs, input_keep=input_keep, start_index=start_index, end_index=end_index, device=device ) optimizer.zero_grad() pos_z, neg_z, summary = model(encoder_seq, data.edge_index.to(device)) loss = model.loss(pos_z, neg_z, summary) loss.backward() optimizer.step() writer.add_scalar( 'loss', loss.item(), _iter + epoch * len(train_loader) ) train_loss += loss.item() if _iter % (len(train_loader) // 10) == 0: tqdm.tqdm.write(f'{loss}') if scheduler is not None: scheduler.step() logger.info(f"Learning rate {optimizer.param_groups[0]['lr']}") logger.info(f'Epoch: {epoch}\t{train_loss/_iter}\t{time()-start_time}')
def main(*, parser_namespace): disable_rdkit_logging() # read the params json params = dict() with open(parser_namespace.params_path) as f: params.update(json.load(f)) # get params, json args take precedence mol_model_path = params.get('mol_model_path', parser_namespace.mol_model_path) protein_model_path = params.get('protein_model_path', parser_namespace.protein_model_path) affinity_model_path = params.get('affinity_model_path', parser_namespace.affinity_model_path) protein_data_path = params.get('protein_data_path', parser_namespace.protein_data_path) model_name = params.get( 'model_name', parser_namespace.model_name ) # yapf: disable test_id = int(params.get( 'test_protein_id', parser_namespace.test_protein_id )) # yapf: disable unbiased_preds_path = params.get( 'unbiased_predictions_path', parser_namespace.unbiased_predictions_path ) # yapf: disable model_name += '_' + str(test_id) logger.info(f'Model with name {model_name} starts.') # passing optional paths to params to possibly update_reward_fn optional_reward_args = [ 'tox21_path', 'organdb_path', 'site', 'clintox_path', 'sider_path' ] for arg in optional_reward_args: if parser_namespace.__dict__[arg]: # json still has presedence params[arg] = params.get(arg, parser_namespace.__dict__[arg]) # Load protein sequence data if protein_data_path.endswith('.smi'): protein_df = read_smi(protein_data_path, names=['Sequence']) elif protein_data_path.endswith('.csv'): protein_df = pd.read_csv(protein_data_path, index_col='entry_name') else: raise TypeError( f"{protein_data_path.split('.')[-1]} files are not supported.") protein_test_name = protein_df.iloc[test_id].name logger.info(f'Test protein is {protein_test_name}') # Restore SMILES Model with open(os.path.join(mol_model_path, 'model_params.json')) as f: mol_params = json.load(f) gru_encoder = StackGRUEncoder(mol_params) gru_decoder = StackGRUDecoder(mol_params) generator = TeacherVAE(gru_encoder, gru_decoder) generator.load(os.path.join( mol_model_path, f"weights/best_{params.get('smiles_metric', 'rec')}.pt"), map_location=get_device()) # Load languages generator_smiles_language = SMILESLanguage.load( os.path.join(mol_model_path, 'selfies_language.pkl')) generator._associate_language(generator_smiles_language) # Restore protein model with open(os.path.join(protein_model_path, 'model_params.json')) as f: protein_params = json.load(f) # Define network protein_encoder = ENCODER_FACTORY['dense'](protein_params) protein_encoder.load(os.path.join( protein_model_path, f"weights/best_{params.get('omics_metric','both')}_encoder.pt"), map_location=get_device()) protein_encoder.eval() # Restore affinity predictor with open(os.path.join(affinity_model_path, 'model_params.json')) as f: predictor_params = json.load(f) predictor = MODEL_FACTORY['bimodal_mca'](predictor_params) predictor.load(os.path.join( affinity_model_path, f"weights/best_{params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt"), map_location=get_device()) predictor.eval() # Load languages affinity_smiles_language = SMILESLanguage.load( os.path.join(affinity_model_path, 'smiles_language.pkl')) affinity_protein_language = ProteinLanguage.load( os.path.join(affinity_model_path, 'protein_language.pkl')) predictor._associate_language(affinity_smiles_language) predictor._associate_language(affinity_protein_language) # Specifies the baseline model used for comparison unbiased_preds = np.array( pd.read_csv( os.path.join(unbiased_preds_path, protein_test_name + '.csv') )['affinity'].values ) # yapf: disable # Create a fresh model that will be optimized gru_encoder_rl = StackGRUEncoder(mol_params) gru_decoder_rl = StackGRUDecoder(mol_params) generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl) generator_rl.load(os.path.join( mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"), map_location=get_device()) generator_rl.eval() # Load languages generator_rl._associate_language(generator_smiles_language) protein_encoder_rl = ENCODER_FACTORY['dense'](protein_params) protein_encoder_rl.load(os.path.join( protein_model_path, f"weights/best_{params.get('metric', 'both')}_encoder.pt"), map_location=get_device()) protein_encoder_rl.eval() model_folder_name = model_name learner = ReinforceProtein(generator_rl, protein_encoder_rl, predictor, protein_df, params, model_folder_name, logger) biased_ratios, tox_ratios = [], [] rewards, rl_losses = [], [] gen_mols, gen_prot, gen_affinity, mode = [], [], [], [] logger.info(f'Model stored at {learner.model_path}') for epoch in range(1, params['epochs'] + 1): for step in range(1, params['steps']): # Randomly sample a protein protein_name = np.random.choice(protein_df.index) while protein_name == protein_test_name: protein_name = np.random.choice(protein_df.index) logger.info(f'Current train protein: {protein_name}') rew, loss = learner.policy_gradient(protein_name, epoch, params['batch_size']) logger.info( f"Epoch {epoch:d}/{params['epochs']:d}, step {step:d}/" f"{params['steps']:d}\t loss={loss:.2f}, mean rew={rew:.2f}") rewards.append(rew.item()) rl_losses.append(loss) # Save model if epoch % 10 == 0: learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt') logger.info(f'EVAL protein: {protein_test_name}') smiles, preds = (learner.generate_compounds_and_evaluate( epoch, params['eval_batch_size'], protein_test_name)) gs = [s for i, s in enumerate(smiles) if preds[i] > 0.5] gp = preds[preds > 0.5] for p, s in zip(gp, gs): gen_mols.append(s) gen_prot.append(protein_test_name) gen_affinity.append(p) mode.append('eval') inds = np.argsort(gp)[::-1] for i in inds[:5]: logger.info(f'Epoch {epoch:d}, generated {gs[i]} against ' f'{protein_test_name}.\n Predicted IC50 = {gp[i]}. ') plot_and_compare_proteins(unbiased_preds, preds, protein_test_name, epoch, learner.model_path, 'train', params['eval_batch_size']) biased_ratios.append( np.round(100 * (np.sum(preds > 0.5) / len(preds)), 1)) all_toxes = np.array([learner.tox21(s) for s in smiles]) tox_ratios.append( np.round(100 * (np.sum(all_toxes == 1.) / len(all_toxes)), 1)) logger.info(f'Percentage of non-toxic compounds {tox_ratios[-1]}') toxes = [learner.tox21(s) for s in gen_mols] # Save results (good molecules!) in DF df = pd.DataFrame({ 'protein': gen_prot, 'SMILES': gen_mols, 'Binding probability': gen_affinity, 'mode': mode, 'Tox21': toxes }) df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv')) # Plot loss development loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards}) loss_df.to_csv(learner.model_path + '/results/loss_reward_evolution.csv') plot_loss(rl_losses, rewards, params['epochs'], protein_name, learner.model_path, rolling=5) pd.DataFrame({ 'efficacy_ratio': biased_ratios, 'tox_ratio': tox_ratios }).to_csv(learner.model_path + '/results/ratios.csv')