def test_branch_and_ring_at_beginning_of_branch(): """Test SELFIES that have a branch and ring immediately at the start of a branch. """ reset_alphabet() # CC1CCCS((Br)1Cl)F assert is_eq( sf.decoder("[C][C][C][C][C][S][Branch1_2][Branch1_3]" "[Branch1_1][C][Br]" "[Ring1][Branch1_1][Cl][F]"), "CC1CCCS1(Br)(Cl)F") # CC1CCCS(1(Br)Cl)F assert is_eq( sf.decoder("[C][C][C][C][C][S][Branch1_2][Branch1_3]" "[Ring1][Branch1_1]" "[Branch1_1][C][Br][Cl][F]"), "CC1CCCS1(Br)(Cl)F") # CC1CCCS(((1Br)Cl)I)F assert is_eq( sf.decoder("[C][C][C][C][C][S][Branch1_3][Branch2_3]" "[Branch1_2][Branch1_3]" "[Branch1_1][Ring2][Ring1][Branch1_1][Br]" "[Cl][I][F]"), "CC1CCCS1(Br)(Cl)(I)F")
def test_explicit_hydrogen_symbols(): """Tests that SELFIES symbols with explicit hydrogen specifications are constrained properly. """ assert sf.decoder("[CHexpl][Branch1_1][C][Cl][#C]") == "[CH](Cl)=C" assert sf.decoder("[CH3expl][=C]") == "[CH3]C"
def time_roundtrip(file_path: str, sample_size: int = -1): """Tests the amount of time it takes to encode and then decode an entire .txt file of SMILES strings <n> times. If <sample_size> is positive, then a random sample is taken from the file instead. """ curr_dir = os.path.dirname(__file__) file_path = os.path.join(curr_dir, file_path) # load data with open(file_path, 'r') as file: smiles = [line.rstrip() for line in file.readlines()] smiles.pop(0) if sample_size > 0: smiles = random.sample(smiles, sample_size) selfies = list(map(sf.encoder, smiles)) print(f"Timing {len(smiles)} SMILES from {file_path}") # time sf.encoder start = time.time() for s in smiles: sf.encoder(s) enc_time = time.time() - start print(f"--> selfies.encoder: {enc_time:0.7f}s") # time sf.decoder start = time.time() for s in selfies: sf.decoder(s) dec_time = time.time() - start print(f"--> selfies.decoder: {dec_time:0.7f}s")
def log_reconstruction(true_idces, probas, idx_to_char, string_type='smiles'): """ Input : true_idces : shape (N, seq_len) probas : shape (N, voc_size, seq_len) idx_to_char : dict with idx to char mapping string-type : smiles or selfies Argmax on probas array (dim 1) to find most likely character indices """ probas = probas.to('cpu').numpy() true_idces = true_idces.cpu().numpy() N, voc_size, seq_len = probas.shape out_idces = np.argmax(probas, axis=1) # get char_indices in_smiles, out_smiles = [], [] for i in range(N): out_smiles.append(''.join([idx_to_char[idx] for idx in out_idces[i]])) in_smiles.append(''.join([idx_to_char[idx] for idx in true_idces[i]])) d = {'input smiles': in_smiles, 'output smiles': out_smiles} if string_type == 'smiles': df = pd.DataFrame.from_dict(d) valid = [Chem.MolFromSmiles(o.rstrip('\n')) for o in out_smiles] valid = [int(m != None) for m in valid] frac_valid = np.mean(valid) return df, frac_valid else: smiles = [decoder(out) for out in out_smiles] valid = [Chem.MolFromSmiles(s) for s in smiles] valid = [int(m != None) for m in valid] frac_valid = np.mean(valid) for i in range(3): # printing only 3 samples print(decoder(in_smiles[i]), ' => ', decoder(out_smiles[i])) return 0, frac_valid
def test_oversized_branch(): """Test SELFIES that have a branch, with Q larger than the length of the SELFIES """ assert is_eq(sf.decoder("[C][Branch2_1][O][O][C][C][S][F][C]"), "C(CCSF)") assert is_eq(sf.decoder("[C][Branch2_3][O][O][#C][C][S][F]"), "C(#CCSF)")
def test_ring_at_end_of_selfies(): """Test SELFIES that have a ring symbol as its very last symbol. """ reset_alphabet() assert is_eq(sf.decoder("[C][C][C][C][C][Ring1]"), "CCCC=C") assert is_eq(sf.decoder("[C][C][C][C][C][Ring3]"), "CCCC=C")
def test_branch_at_end_of_selfies(): """Test SELFIES that have a branch symbol as its very last symbol. """ reset_alphabet() assert is_eq(sf.decoder("[C][C][C][C][Branch1_1]"), "CCCC") assert is_eq(sf.decoder("[C][C][C][C][Branch3_3]"), "CCCC")
def test_branch_at_beginning_of_branch(): """Test SELFIES that have a branch immediately at the start of a branch. """ reset_alphabet() # [C@]((Br)Cl)F assert is_eq( sf.decoder("[C@expl][Branch1_2][Branch1_1]" "[Branch1_1][C][Br]" "[Cl][F]"), "[C@](Br)(Cl)F") # [C@](((Br)Cl)I)F assert is_eq( sf.decoder("[C@expl][Branch1_3][Branch2_1]" "[Branch1_2][Branch1_1]" "[Branch1_1][C][Br]" "[Cl][I][F]"), "[C@](Br)(Cl)(I)F") # [C@]((Br)(Cl)I)F assert is_eq( sf.decoder("[C@expl][Branch1_3][Branch2_1]" "[Branch1_1][C][Br]" "[Branch1_1][C][Cl]" "[I][F]"), "[C@](Br)(Cl)(I)F")
def test_branch_and_ring_at_state_X0(): """Tests SELFIES with branches and rings at state X0 (i.e. at the very beginning of a SELFIES). These symbols should be skipped. """ assert is_eq(sf.decoder("[Branch3_1][C][S][C][O]"), "CSCO") assert is_eq(sf.decoder("[Ring3][C][S][C][O]"), "CSCO") assert is_eq(sf.decoder("[Branch1_1][Ring1][Ring3][C][S][C][O]"), "CSCO")
def test_branch_at_state_X1(): """Test SELFIES with branches at state X1 (i.e. at an atom that can only make one bond. In this case, the branch symbol should be skipped. """ reset_alphabet() assert is_eq(sf.decoder("[C][C][O][Branch1_1][C][I]"), "CCOCI") assert is_eq(sf.decoder("[C][C][C][O][Branch3_3][C][I]"), "CCCOCI")
def test_isotope_symbols(): """Tests that SELFIES symbols with isotope specifications are constrained properly. """ assert sf.decoder("[13Cexpl][Branch1_1][C][Cl][Branch1_1][C][F]" "[Branch1_1][C][Br][Branch1_1][C][I]") \ == "[13C](Cl)(F)(Br)CI" assert sf.decoder("[C][36Clexpl][C]") == "C[36Cl]"
def test_ring_on_top_of_existing_bond(): """Tests SELFIES with rings between two atoms that are already bonded in the main scaffold. """ # C1C1, C1C=1, C1C#1, ... assert is_eq(sf.decoder("[C][C][Ring1][C]"), "C=C") assert is_eq(sf.decoder("[C][/C][Ring1][C]"), "C=C") assert is_eq(sf.decoder("[C][C][Expl=Ring1][C]"), "C#C") assert is_eq(sf.decoder("[C][C][Expl#Ring1][C]"), "C#C")
def test_consecutive_rings(): """Test SELFIES which have multiple consecutive rings. """ assert is_eq(sf.decoder("[C][C][C][C][Ring1][Ring2][Ring1][Ring2]"), "C=1CCC=1") # 1 + 1 assert is_eq( sf.decoder("[C][C][C][C][Ring1][Ring2][Ring1][Ring2]" "[Ring1][Ring2]"), "C#1CCC#1") # 1 + 1 + 1 assert is_eq(sf.decoder("[C][C][C][C][Expl=Ring1][Ring2][Ring1][Ring2]"), "C#1CCC#1") # 2 + 1 assert is_eq(sf.decoder("[C][C][C][C][Ring1][Ring2][Expl=Ring1][Ring2]"), "C#1CCC#1") # 1 + 2 # consecutive rings that exceed bond constraints assert is_eq( sf.decoder("[C][C][C][C][Expl#Ring1][Ring2]" "[Expl=Ring1][Ring2]"), "C#1CCC#1") # 3 + 2 assert is_eq( sf.decoder("[C][C][C][C][Expl=Ring1][Ring2]" "[Expl#Ring1][Ring2]"), "C#1CCC#1") # 2 + 3 assert is_eq( sf.decoder("[C][C][C][C][Expl=Ring1][Ring2]" "[Expl=Ring1][Ring2]"), "C#1CCC#1") # 2 + 2 # consecutive rings with stereochemical single bond assert sf.decoder("[C][C][C][C][Expl/Ring1][Ring2]") == "C/1CCC/1" assert sf.decoder("[C][C][C][C][Expl/Ring1][Ring2][Ring1][Ring2]") \ == "C=1CCC=1"
def test_chiral_symbols(): """Tests that SELFIES symbols with chirality specifications are constrained properly. """ assert sf.decoder("[C@@expl][Branch1_1][C][Cl][Branch1_1][C][F]" "[Branch1_1][C][Br][Branch1_1][C][I]") \ == "[C@@](Cl)(F)(Br)CI" assert sf.decoder("[C@Hexpl][Branch1_1][C][Cl][Branch1_1][C][F]" "[Branch1_1][C][Br]") \ == "[C@H](Cl)(F)CBr"
def test_ring_at_beginning_of_branch(): """Test SELFIES that have a ring immediately at the start of a branch. """ # CC1CCC(1CCl)F assert is_eq(sf.decoder("[C][C][C][C][C][Branch1_1][Branch1_1]" "[Ring1][Ring2][C][Cl][F]"), "CC1CCC1(CCl)F") # CC1CCS(Br)(1CCl)F assert is_eq(sf.decoder("[C][C][C][C][S][Branch1_1][C][Br]" "[Branch1_1][Branch1_1][Ring1][Ring2][C][Cl][F]"), "CC1CCS1(Br)(CCl)F")
def test_ring_after_branch(): """Tests SELFIES that have a ring following a branch, but not immediately after a branch. """ # CCCCCCC1(OCO)1 assert is_eq(sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]" "[C][Ring1][Branch1_1]"), "CCCCCCC(OCO)=C") assert is_eq(sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]" "[Branch1_1][C][F][C][C][Ring1][Branch2_2]"), "CCCCC1CC(OCO)(F)CC1")
def test_ring_immediately_following_branch(): """Test SELFIES that have a ring immediately following after a branch. """ # CCC1CCCC(OCO)1 assert is_eq( sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]" "[Ring1][Branch1_1]"), "CCC1CCCC1(OCO)") # CCC1CCCC(OCO)(F)1 assert is_eq( sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]" "[Branch1_1][C][F][Ring1][Branch1_1]"), "CCC1CCCC1(OCO)(F)")
def test_change_constraints_cache_clear(): alphabet = sf.get_semantic_robust_alphabet() assert alphabet == sf.get_semantic_robust_alphabet() assert sf.decoder("[C][#C]") == "C#C" new_constraints = sf.get_semantic_constraints() new_constraints["C"] = 1 sf.set_semantic_constraints(new_constraints) new_alphabet = sf.get_semantic_robust_alphabet() assert new_alphabet != alphabet assert sf.decoder("[C][#C]") == "CC" sf.set_semantic_constraints() # re-set alphabet
def test_oversized_ring(): """Test SELFIES that have a ring, with Q so large that the (Q + 1)-th previously derived atom does not exist. """ assert is_eq(sf.decoder("[C][C][C][C][Ring1][O]"), "C1CCC1") assert is_eq(sf.decoder("[C][C][C][C][Ring2][O][C]"), "C1CCC1") # special case: Ring2 takes Q_1 = [O] and Q_2 = '', leading to # Q = 9 * 16 + 0 (i.e. an oversized ring) assert is_eq(sf.decoder("[C][C][C][C][Ring2][O]"), "C1CCC1") # special case: ring between 1st atom and 1st atom should not be formed assert is_eq(sf.decoder("[C][Ring1][O]"), "C")
def test_explicit_hydrogen_symbols(): """Tests that SELFIES symbols with explicit hydrogen specifications are constrained properly. """ assert decode_eq("[CH1][Branch1][C][Cl][#C]", "[CH1](Cl)=C") assert decode_eq("[CH3][=C]", "[CH3]C") assert decode_eq("[CH4][C][C]", "[CH4]") assert decode_eq("[C][C][C][CH4]", "CCC") assert decode_eq("[C][Branch1][Ring2][C][=CH4][C][=C]", "C(C)=C") with pytest.raises(sf.DecoderError): sf.decoder("[C][C][CH5]") with pytest.raises(sf.DecoderError): sf.decoder("[C][C][C][OH9]")
def hasher(q, hasher, valid, total, i): from rdkit import rdBase rdBase.DisableLog('rdApp.error') print("Hasher Thread on", i) torch.manual_seed(i) torch.cuda.manual_seed(i) while True: if not q.empty(): smis, count = q.get(block=True) total.value += count for smi in smis: try: smi = selfies.decoder(smi) m = Chem.MolFromSmiles(smi) s = Chem.MolToSmiles(m) if s is not None: valid.value += 1 if s in hasher: hasher[s] += 1 else: hasher[s] = 1 except KeyboardInterrupt: print("Bye") exit() except: None
def predict_SMILES(image_path): predicted_SELFIES = evaluate(image_path) predicted_SMILES = decoder(''.join(predicted_SELFIES).replace("<start>", "").replace("<end>", ""), constraints='hypervalent') return predicted_SMILES
def linear_interpolation(x_from, x_to, steps): n = steps + 1 hot = multiple_selfies_to_hot([x_from], largest_molecule_len, encoding_alphabet) x_from = torch.tensor(hot, dtype=torch.float).to(device) input_shape1 = x_from.shape[1] input_shape2 = x_from.shape[2] x_from = x_from.reshape(x_from.shape[0], x_from.shape[1] * x_from.shape[2]) _, hot = selfies_to_hot(x_to, largest_molecule_len, encoding_alphabet) x_to = torch.tensor([hot], dtype=torch.float).to(device) x_to = x_to.reshape(x_to.shape[0], x_to.shape[1] * x_to.shape[2]) t_from = model_encode(x_from)[0] t_from = t_from.reshape(1, 1, t_from.shape[1]) t_to = model_encode(x_to)[0] t_to = t_to.reshape(1, 1, t_to.shape[1]) diff = t_to[0][0] - t_from[0][0] inter = torch.zeros((1, n, t_to.shape[2])) for i in range(n): inter[0][i] = t_from[0][0] + i / steps * diff hidden = model_decode.init_hidden(batch_size=n) decoded_one_hot = torch.zeros(n, input_shape1, input_shape2).to(device) for seq_index in range(input_shape1): decoded_one_hot_line, hidden = model_decode(inter, hidden) decoded_one_hot[:, seq_index, :] = decoded_one_hot_line[0] decoded_one_hot = decoded_one_hot.reshape(n, input_shape1, input_shape2) output_mol = [] _, decoded_mol = decoded_one_hot.max(2) for mol in decoded_mol: output_mol.append(decoder(hot_to_selfies(mol, encoding_alphabet))) return output_mol
def sample_latent_space(latent_dimension, total_samples): model_encode.eval() model_decode.eval() fancy_latent_point = torch.normal(torch.zeros(latent_dimension), torch.ones(latent_dimension)) print(fancy_latent_point) hidden = model_decode.init_hidden() gathered_atoms = [] for ii in range( len_max_molec ): # runs over letters from molecules (len=size of largest molecule) fancy_latent_point = fancy_latent_point.reshape(1, 1, latent_dimension) fancy_latent_point = fancy_latent_point.to(device) decoded_one_hot, hidden = model_decode(fancy_latent_point, hidden) decoded_one_hot = decoded_one_hot.flatten() decoded_one_hot = decoded_one_hot.detach() soft = nn.Softmax(0) decoded_one_hot = soft(decoded_one_hot) _, max_index = decoded_one_hot.max(0) gathered_atoms.append(max_index.data.cpu().numpy().tolist()) model_encode.train() model_decode.train() #test molecules visually if total_samples <= 5: print('Sample #', total_samples, decoder(hot_to_selfies(gathered_atoms, encoding_alphabet))) return gathered_atoms
def sf2sm(line): words = line.strip().split(".") words = ["[" + word.replace(" ", "][") + "]" for word in words] new_line = [] for word in words: new_line.append(sf.decoder(word)) new_line = ".".join(new_line) return new_line
def test_branch_with_no_atoms(): """Test SELFIES that have a branch, but the branch has no atoms in it. Such branches should not be made in the outputted SMILES. """ assert is_eq( sf.decoder("[C][Branch1_1][Ring2][Branch1_1]" "[Branch1_1][Branch1_1][F]"), "CF") assert is_eq( sf.decoder("[C][Branch1_1][Ring2][Ring1]" "[Ring1][Branch1_1][F]"), "CF") assert is_eq(sf.decoder("[C][Branch1_2][Ring2][Branch1_1]" "[C][Cl][F]"), "C(Cl)F") # special case: Branch3_3 takes Q_1, Q_2 = [O] and Q_3 = ''. However, # there are no more symbols in the branch. assert is_eq(sf.decoder("[C][C][C][C][Branch3_3][O][O]"), "CCCC")
def sf2sm(line): line = line.replace(" ", "") molecules = line.split(".") new_line = [] for molecule in molecules: new_molecule = sf.decoder(molecule) new_line.append(new_molecule) new_line = ".".join(new_line) return new_line
def test_nop_symbol_decoder(max_selfies_len, large_alphabet): """Tests that the '[nop]' symbol is always skipped over. """ alphabet = list(large_alphabet) alphabet.remove("[nop]") for _ in range(100): # create random SELFIES with and without [nop] rand_len = random.randint(1, max_selfies_len) rand_mol = random_choices(alphabet, k=rand_len) rand_mol.extend(["[nop]"] * (max_selfies_len - rand_len)) random.shuffle(rand_mol) with_nops = "".join(rand_mol) without_nops = with_nops.replace("[nop]", "") assert sf.decoder(with_nops) == sf.decoder(without_nops)
def test_decoder_attribution(): sm, am = sf.decoder("[C][N][C][Branch1][C][P][C][C][Ring1][=Branch1]", attribute=True) # check that P lined up for ta in am: if ta[0] == 'P': for i, v in ta[1]: if v == '[P]': return raise ValueError('Failed to find P in attribution map')
def test_old_symbols(): """Tests backward compatibility of SELFIES with old (<v2) symbols. """ s = "[C@@Hexpl][Branch1_2][Branch1_1][Branch1_1][C][C][Cl][F]" assert sf.decoder(s, compatible=True) == "[C@@H1](C)(Cl)F" s = "[C][C][C][C][Expl=Ring1][Ring2][Expl#Ring1][Ring2]" assert sf.decoder(s, compatible=True) == "C#1CCC#1" long_s = "[C@@Hexpl][=C][C@@Hexpl][N+expl][=C][C+expl][N+expl][O+expl]" \ "[Fe++expl][C@@Hexpl][C][N+expl][Branch1_2][Fe++expl][S+expl]" \ "[=C][Expl=Ring1][Fe++expl][S+expl][Expl=Ring1][O+expl]" \ "[C@@Hexpl][Expl=Ring1][C@@Hexpl][C@@Hexpl][N+expl][Expl=Ring1]" \ "[Expl=Ring1][S+expl][=C]" try: sf.decoder(long_s, compatible=True) except Exception: assert False