def Generator(restore_path, voc_path, csv_num, gen_num, mol_num): restore_from = restore_path # Read vocabulary from a file voc = Vocabulary(init_from_file=voc_path) for n in range(0, csv_num): Prior = RNN(voc) if restore_from: Prior.rnn.load_state_dict(torch.load(restore_from)) seqs, likelihood, _ = Prior.sample(gen_num) valid = 0 smiles = [] val_smi = [] for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) smiles.append(smile) if Chem.MolFromSmiles(smile): valid += 1 val_smi.append(smile) if i < mol_num: print(smile) Val_s = pd.DataFrame(data=val_smi, columns=['smiles']) Val_s.to_csv('./model/data_gen_' + str(n) + '.csv', index=False) print(valid) gc.collect()
def pretrain(restore_from=None): """Trains the Prior RNN""" # Read vocabulary from a file voc = Vocabulary(init_from_file="data/Voc") # Create a Dataset from a SMILES file moldata = MolData("data/mols_filtered.smi", voc) data = DataLoader(moldata, batch_size=128, shuffle=True, drop_last=True, collate_fn=MolData.collate_fn) Prior = RNN(voc) # Can restore from a saved RNN if restore_from: Prior.rnn.load_state_dict(torch.load(restore_from)) optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr=0.001) for epoch in range(1, 6): # When training on a few million compounds, this model converges # in a few of epochs or even faster. If model sized is increased # its probably a good idea to check loss against an external set of # validation SMILES to make sure we dont overfit too much. for step, batch in tqdm(enumerate(data), total=len(data)): # Sample from DataLoader seqs = batch.long() # Calculate loss log_p, _ = Prior.likelihood(seqs) loss = -log_p.mean() # Calculate gradients and take a step optimizer.zero_grad() loss.backward() optimizer.step() # Every 500 steps we decrease learning rate and print some information if step % 500 == 0 and step != 0 and False: decrease_learning_rate(optimizer, decrease_by=0.03) tqdm.write("*" * 50) tqdm.write( "Epoch {:3d} step {:3d} loss: {:5.2f}\n".format( epoch, step, loss.data[0])) seqs, likelihood, _ = Prior.sample(128) valid = 0 for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) if Chem.MolFromSmiles(smile): valid += 1 if i < 5: tqdm.write(smile) tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs))) tqdm.write("*" * 50 + "\n")
def train_model(): """Do transfer learning for generating SMILES""" voc = Vocabulary(init_from_file='data/Voc') cano_smi_file('refined_smii.csv', 'refined_smii_cano.csv') moldata = MolData('refined_smii_cano.csv', voc) # Monomers 67 and 180 were removed because of the unseen [C-] in voc # DAs containing [se] [SiH2] [n] removed: 38 molecules data = DataLoader(moldata, batch_size=64, shuffle=True, drop_last=False, collate_fn=MolData.collate_fn) transfer_model = RNN(voc) if torch.cuda.is_available(): transfer_model.rnn.load_state_dict(torch.load('data/Prior.ckpt')) else: transfer_model.rnn.load_state_dict( torch.load('data/Prior.ckpt', map_location=lambda storage, loc: storage)) # for param in transfer_model.rnn.parameters(): # param.requires_grad = False optimizer = torch.optim.Adam(transfer_model.rnn.parameters(), lr=0.001) for epoch in range(1, 10): for step, batch in tqdm(enumerate(data), total=len(data)): seqs = batch.long() log_p, _ = transfer_model.likelihood(seqs) loss = -log_p.mean() optimizer.zero_grad() loss.backward() optimizer.step() if step % 5 == 0 and step != 0: decrease_learning_rate(optimizer, decrease_by=0.03) tqdm.write('*' * 50) tqdm.write( "Epoch {:3d} step {:3d} loss: {:5.2f}\n".format( epoch, step, loss.data[0])) seqs, likelihood, _ = transfer_model.sample(128) valid = 0 for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) if Chem.MolFromSmiles(smile): valid += 1 if i < 5: tqdm.write(smile) tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs))) tqdm.write("*" * 50 + '\n') torch.save(transfer_model.rnn.state_dict(), "data/transfer_model2.ckpt") torch.save(transfer_model.rnn.state_dict(), "data/transfer_modelw.ckpt")
def train_model(voc_dir, smi_dir, prior_dir, tf_dir, tf_process_dir, freeze=False): """ Transfer learning on target molecules using the SMILES structures Args: voc_dir: location of the vocabulary smi_dir: location of the SMILES file used for transfer learning prior_dir: location of prior trained model to initialize transfer learning tf_dir: location to save the transfer learning model tf_process_dir: location to save the SMILES sampled while doing transfer learning freeze: Bool. If true, all parameters in the RNN will be frozen except for the last linear layer during transfer learning. Returns: None """ voc = Vocabulary(init_from_file=voc_dir) #cano_smi_file('all_smi_refined.csv', 'all_smi_refined_cano.csv') moldata = MolData(smi_dir, voc) # Monomers 67 and 180 were removed because of the unseen [C-] in voc # DAs containing [C] removed: 43 molecules in 5356; Ge removed: 154 in 5356; [c] removed 4 in 5356 # [S] 1 molecule in 5356 data = DataLoader(moldata, batch_size=64, shuffle=True, drop_last=False, collate_fn=MolData.collate_fn) transfer_model = RNN(voc) # if freeze=True, freeze all parameters except those in the linear layer if freeze: for param in transfer_model.rnn.parameters(): param.requires_grad = False transfer_model.rnn.linear = nn.Linear(512, voc.vocab_size) if torch.cuda.is_available(): transfer_model.rnn.load_state_dict(torch.load(prior_dir)) else: transfer_model.rnn.load_state_dict( torch.load(prior_dir, map_location=lambda storage, loc: storage)) optimizer = torch.optim.Adam(transfer_model.rnn.parameters(), lr=0.0005) smi_lst = [] epoch_lst = [] for epoch in range(1, 11): for step, batch in tqdm(enumerate(data), total=len(data)): seqs = batch.long() log_p, _ = transfer_model.likelihood(seqs) loss = -log_p.mean() optimizer.zero_grad() loss.backward() optimizer.step() if step % 80 == 0 and step != 0: decrease_learning_rate(optimizer, decrease_by=0.03) tqdm.write('*' * 50) tqdm.write( "Epoch {:3d} step {:3d} loss: {:5.2f}\n".format( epoch, step, loss.data[0])) seqs, likelihood, _ = transfer_model.sample(128) valid = 0 for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) if Chem.MolFromSmiles(smile): valid += 1 if i < 5: tqdm.write(smile) tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs))) tqdm.write("*" * 50 + '\n') torch.save(transfer_model.rnn.state_dict(), tf_dir) seqs, likelihood, _ = transfer_model.sample(1024) valid = 0 #valid_smis = [] for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) if Chem.MolFromSmiles(smile): try: AllChem.GetMorganFingerprintAsBitVect( Chem.MolFromSmiles(smile), 2, 1024) valid += 1 smi_lst.append(smile) epoch_lst.append(epoch) except: continue torch.save(transfer_model.rnn.state_dict(), tf_dir) transfer_process_df = pd.DataFrame(columns=['SMILES', 'Epoch']) transfer_process_df['SMILES'] = pd.Series(data=smi_lst) transfer_process_df['Epoch'] = pd.Series(data=epoch_lst) transfer_process_df.to_csv(tf_process_dir)
def sample_smiles(voc_dir, nums, outfn, tf_dir, until=False): """Sample smiles using the transferred model""" voc = Vocabulary(init_from_file=voc_dir) transfer_model = RNN(voc) output = open(outfn, 'w') if torch.cuda.is_available(): transfer_model.rnn.load_state_dict(torch.load(tf_dir)) else: transfer_model.rnn.load_state_dict( torch.load(tf_dir, map_location=lambda storage, loc: storage)) for param in transfer_model.rnn.parameters(): param.requires_grad = False if not until: seqs, likelihood, _ = transfer_model.sample(nums) valid = 0 double_br = 0 unique_idx = unique(seqs) seqs = seqs[unique_idx] for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) if Chem.MolFromSmiles(smile): try: AllChem.GetMorganFingerprintAsBitVect( Chem.MolFromSmiles(smile), 2, 1024) valid += 1 output.write(smile + '\n') except: continue #if smile.count('Br') == 2: # double_br += 1 #output.write(smile+'\n') tqdm.write( '\n{} molecules sampled, {} valid SMILES, {} with double Br'. format(nums, valid, double_br)) output.close() else: valid = 0 n_sample = 0 while valid < nums: seq, likelihood, _ = transfer_model.sample(1) n_sample += 1 seq = seq.cpu().numpy() seq = seq[0] # print(seq) smile = voc.decode(seq) if Chem.MolFromSmiles(smile): try: AllChem.GetMorganFingerprintAsBitVect( Chem.MolFromSmiles(smile), 2, 1024) valid += 1 output.write(smile + '\n') #if valid % 100 == 0 and valid != 0: # tqdm.write('\n{} valid molecules sampled, with {} of total samples'.format(valid, n_sample)) except: continue tqdm.write( '\n{} valid molecules sampled, with {} of total samples'.format( nums, n_sample))
def pretrain(restore_from=None, save_to="data/Prior.ckpt", data="data/mols_filtered.smi", voc_file="data/Voc", batch_size=128, learning_rate=0.001, n_epochs=5, store_loss_dir=None, embedding_size=32): """Trains the Prior RNN""" # Read vocabulary from a file voc = Vocabulary(init_from_file=voc_file) # Create a Dataset from a SMILES file moldata = MolData(data, voc) data = DataLoader(moldata, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=MolData.collate_fn) Prior = RNN(voc, embedding_size) # Adding a file to log loss info if store_loss_dir is None: out_f = open("loss.csv", "w") else: out_f = open("{}/loss.csv".format(store_loss_dir.rstrip("/")), "w") out_f.write("Step,Loss\n") # Can restore from a saved RNN if restore_from: Prior.rnn.load_state_dict(torch.load(restore_from)) # For later plotting the loss training_step_counter = 0 n_logging = 100 optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr=learning_rate) for epoch in range(1, n_epochs + 1): # When training on a few million compounds, this model converges # in a few of epochs or even faster. If model sized is increased # its probably a good idea to check loss against an external set of # validation SMILES to make sure we dont overfit too much. for step, batch in tqdm(enumerate(data), total=len(data)): # Sample from DataLoader seqs = batch.long() # Calculate loss log_p, _ = Prior.likelihood(seqs) loss = -log_p.mean() # Calculate gradients and take a step optimizer.zero_grad() loss.backward() optimizer.step() # Logging the loss to a file if training_step_counter % n_logging == 0: out_f.write("{},{}\n".format(step, loss)) training_step_counter += 1 # Every 500 steps we decrease learning rate and print some information if step % 500 == 0 and step != 0: decrease_learning_rate(optimizer, decrease_by=0.03) tqdm.write("*" * 50) tqdm.write( "Epoch {:3d} step {:3d} loss: {:5.2f}\n".format( epoch, step, loss.data)) seqs, likelihood, _ = Prior.sample(128) valid = 0 for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) if Chem.MolFromSmiles(smile): valid += 1 if i < 5: tqdm.write(smile) tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs))) tqdm.write("*" * 50 + "\n") torch.save(Prior.rnn.state_dict(), save_to) # Save the Prior torch.save(Prior.rnn.state_dict(), save_to) f_out.close()
def fit(voc_path, mol_path, restore_path, max_save_path, last_save_path, epoch_num, step_num, decay_step_num, smile_num, lr, weigth_decay): restore_from = restore_path # if not restore model print None # Read vocabulary from a file voc = Vocabulary(init_from_file=voc_path) # Create a Dataset from a SMILES file moldata = MolData(mol_path, voc) data = DataLoader(moldata, batch_size=128, shuffle=True, drop_last=True, collate_fn=MolData.collate_fn) Prior = RNN(voc) # Can restore from a saved RNN if restore_from: Prior.rnn.load_state_dict(torch.load(restore_from)) total_loss = [] total_valid = [] max_valid_pro = 0 optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr=lr) for epoch in range(1, epoch_num): for step, batch in tqdm(enumerate(data), total=len(data)): # Sample from DataLoader seqs = batch.long() # Calculate loss log_p, _ = Prior.likelihood(seqs) loss = -log_p.mean() # Calculate gradients and take a step optimizer.zero_grad() loss.backward() optimizer.step() # Every 300 steps we decrease learning rate and print some information if step != 0 and step % decay_step_num == 0: decrease_learning_rate(optimizer, decrease_by=weigth_decay) if step % step_num == 0: tqdm.write("*" * 50) tqdm.write( "Epoch {:3d} step {:3d} loss: {:5.2f}\n".format( epoch, step, loss)) # print("Epoch {:3d} step {:3d} loss: {:5.2f}\n".format(epoch, step, loss)) total_loss.append(float(loss)) seqs, likelihood, _ = Prior.sample(128) valid = 0 # smiles=[] # vali_smi=[] for i, seq in enumerate(seqs.cpu().numpy()): smile = voc.decode(seq) # smiles.append(smile) if Chem.MolFromSmiles(smile): valid += 1 # vali_smi.append(smile) if i < smile_num: print(smile) vali_pro = valid / len(seqs) total_valid.append(float(vali_pro)) tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs))) tqdm.write("*" * 50 + "\n") if vali_pro > max_valid_pro: max_valid_pro = vali_pro torch.save(Prior.rnn.state_dict(), max_save_path) # Save the Prior torch.save(Prior.rnn.state_dict(), last_save_path) print("total loss:", total_loss) print("total valid:", total_valid) return total_loss, total_valid