Example #1
0
def test_branch_and_ring_at_beginning_of_branch():
    """Test SELFIES that have a branch and ring immediately at the start
    of a branch.
    """

    reset_alphabet()

    # CC1CCCS((Br)1Cl)F
    assert is_eq(
        sf.decoder("[C][C][C][C][C][S][Branch1_2][Branch1_3]"
                   "[Branch1_1][C][Br]"
                   "[Ring1][Branch1_1][Cl][F]"), "CC1CCCS1(Br)(Cl)F")

    # CC1CCCS(1(Br)Cl)F
    assert is_eq(
        sf.decoder("[C][C][C][C][C][S][Branch1_2][Branch1_3]"
                   "[Ring1][Branch1_1]"
                   "[Branch1_1][C][Br][Cl][F]"), "CC1CCCS1(Br)(Cl)F")

    # CC1CCCS(((1Br)Cl)I)F
    assert is_eq(
        sf.decoder("[C][C][C][C][C][S][Branch1_3][Branch2_3]"
                   "[Branch1_2][Branch1_3]"
                   "[Branch1_1][Ring2][Ring1][Branch1_1][Br]"
                   "[Cl][I][F]"), "CC1CCCS1(Br)(Cl)(I)F")
Example #2
0
def test_explicit_hydrogen_symbols():
    """Tests that SELFIES symbols with explicit hydrogen specifications
     are constrained properly.
     """

    assert sf.decoder("[CHexpl][Branch1_1][C][Cl][#C]") == "[CH](Cl)=C"
    assert sf.decoder("[CH3expl][=C]") == "[CH3]C"
def time_roundtrip(file_path: str, sample_size: int = -1):
    """Tests the amount of time it takes to encode and then decode an
    entire .txt file of SMILES strings <n> times. If <sample_size> is positive,
    then a random sample is taken from the file instead.
    """

    curr_dir = os.path.dirname(__file__)
    file_path = os.path.join(curr_dir, file_path)

    # load data
    with open(file_path, 'r') as file:
        smiles = [line.rstrip() for line in file.readlines()]
        smiles.pop(0)

        if sample_size > 0:
            smiles = random.sample(smiles, sample_size)
        selfies = list(map(sf.encoder, smiles))

    print(f"Timing {len(smiles)} SMILES from {file_path}")

    # time sf.encoder
    start = time.time()
    for s in smiles:
        sf.encoder(s)
    enc_time = time.time() - start
    print(f"--> selfies.encoder: {enc_time:0.7f}s")

    # time sf.decoder
    start = time.time()
    for s in selfies:
        sf.decoder(s)
    dec_time = time.time() - start
    print(f"--> selfies.decoder: {dec_time:0.7f}s")
Example #4
0
def log_reconstruction(true_idces, probas, idx_to_char, string_type='smiles'):
    """
    Input : 
        true_idces : shape (N, seq_len)
        probas : shape (N, voc_size, seq_len)
        idx_to_char : dict with idx to char mapping 
        string-type : smiles or selfies 
        
        Argmax on probas array (dim 1) to find most likely character indices 
    """
    probas = probas.to('cpu').numpy()
    true_idces = true_idces.cpu().numpy()
    N, voc_size, seq_len = probas.shape
    out_idces = np.argmax(probas, axis=1)  # get char_indices
    in_smiles, out_smiles = [], []
    for i in range(N):
        out_smiles.append(''.join([idx_to_char[idx] for idx in out_idces[i]]))
        in_smiles.append(''.join([idx_to_char[idx] for idx in true_idces[i]]))
    d = {'input smiles': in_smiles, 'output smiles': out_smiles}

    if string_type == 'smiles':
        df = pd.DataFrame.from_dict(d)
        valid = [Chem.MolFromSmiles(o.rstrip('\n')) for o in out_smiles]
        valid = [int(m != None) for m in valid]
        frac_valid = np.mean(valid)
        return df, frac_valid
    else:
        smiles = [decoder(out) for out in out_smiles]
        valid = [Chem.MolFromSmiles(s) for s in smiles]
        valid = [int(m != None) for m in valid]
        frac_valid = np.mean(valid)
        for i in range(3):  # printing only 3 samples
            print(decoder(in_smiles[i]), ' => ', decoder(out_smiles[i]))
        return 0, frac_valid
Example #5
0
def test_oversized_branch():
    """Test SELFIES that have a branch, with Q larger than the length
    of the SELFIES
    """

    assert is_eq(sf.decoder("[C][Branch2_1][O][O][C][C][S][F][C]"), "C(CCSF)")
    assert is_eq(sf.decoder("[C][Branch2_3][O][O][#C][C][S][F]"), "C(#CCSF)")
Example #6
0
def test_ring_at_end_of_selfies():
    """Test SELFIES that have a ring symbol as its very last symbol.
    """

    reset_alphabet()
    assert is_eq(sf.decoder("[C][C][C][C][C][Ring1]"), "CCCC=C")
    assert is_eq(sf.decoder("[C][C][C][C][C][Ring3]"), "CCCC=C")
Example #7
0
def test_branch_at_end_of_selfies():
    """Test SELFIES that have a branch symbol as its very last symbol.
    """

    reset_alphabet()
    assert is_eq(sf.decoder("[C][C][C][C][Branch1_1]"), "CCCC")
    assert is_eq(sf.decoder("[C][C][C][C][Branch3_3]"), "CCCC")
Example #8
0
def test_branch_at_beginning_of_branch():
    """Test SELFIES that have a branch immediately at the start of a branch.
    """

    reset_alphabet()

    # [C@]((Br)Cl)F
    assert is_eq(
        sf.decoder("[C@expl][Branch1_2][Branch1_1]"
                   "[Branch1_1][C][Br]"
                   "[Cl][F]"), "[C@](Br)(Cl)F")

    # [C@](((Br)Cl)I)F
    assert is_eq(
        sf.decoder("[C@expl][Branch1_3][Branch2_1]"
                   "[Branch1_2][Branch1_1]"
                   "[Branch1_1][C][Br]"
                   "[Cl][I][F]"), "[C@](Br)(Cl)(I)F")

    # [C@]((Br)(Cl)I)F
    assert is_eq(
        sf.decoder("[C@expl][Branch1_3][Branch2_1]"
                   "[Branch1_1][C][Br]"
                   "[Branch1_1][C][Cl]"
                   "[I][F]"), "[C@](Br)(Cl)(I)F")
Example #9
0
def test_branch_and_ring_at_state_X0():
    """Tests SELFIES with branches and rings at state X0 (i.e. at the
    very beginning of a SELFIES). These symbols should be skipped.
    """

    assert is_eq(sf.decoder("[Branch3_1][C][S][C][O]"), "CSCO")
    assert is_eq(sf.decoder("[Ring3][C][S][C][O]"), "CSCO")
    assert is_eq(sf.decoder("[Branch1_1][Ring1][Ring3][C][S][C][O]"), "CSCO")
Example #10
0
def test_branch_at_state_X1():
    """Test SELFIES with branches at state X1 (i.e. at an atom that
    can only make one bond. In this case, the branch symbol should be skipped.
    """

    reset_alphabet()
    assert is_eq(sf.decoder("[C][C][O][Branch1_1][C][I]"), "CCOCI")
    assert is_eq(sf.decoder("[C][C][C][O][Branch3_3][C][I]"), "CCCOCI")
Example #11
0
def test_isotope_symbols():
    """Tests that SELFIES symbols with isotope specifications are
     constrained properly.
    """

    assert sf.decoder("[13Cexpl][Branch1_1][C][Cl][Branch1_1][C][F]"
                      "[Branch1_1][C][Br][Branch1_1][C][I]") \
           == "[13C](Cl)(F)(Br)CI"
    assert sf.decoder("[C][36Clexpl][C]") == "C[36Cl]"
Example #12
0
def test_ring_on_top_of_existing_bond():
    """Tests SELFIES with rings between two atoms that are already bonded
    in the main scaffold.
    """

    # C1C1, C1C=1, C1C#1, ...
    assert is_eq(sf.decoder("[C][C][Ring1][C]"), "C=C")
    assert is_eq(sf.decoder("[C][/C][Ring1][C]"), "C=C")
    assert is_eq(sf.decoder("[C][C][Expl=Ring1][C]"), "C#C")
    assert is_eq(sf.decoder("[C][C][Expl#Ring1][C]"), "C#C")
Example #13
0
def test_consecutive_rings():
    """Test SELFIES which have multiple consecutive rings.
    """

    assert is_eq(sf.decoder("[C][C][C][C][Ring1][Ring2][Ring1][Ring2]"),
                 "C=1CCC=1")  # 1 + 1
    assert is_eq(
        sf.decoder("[C][C][C][C][Ring1][Ring2][Ring1][Ring2]"
                   "[Ring1][Ring2]"), "C#1CCC#1")  # 1 + 1 + 1
    assert is_eq(sf.decoder("[C][C][C][C][Expl=Ring1][Ring2][Ring1][Ring2]"),
                 "C#1CCC#1")  # 2 + 1
    assert is_eq(sf.decoder("[C][C][C][C][Ring1][Ring2][Expl=Ring1][Ring2]"),
                 "C#1CCC#1")  # 1 + 2

    # consecutive rings that exceed bond constraints
    assert is_eq(
        sf.decoder("[C][C][C][C][Expl#Ring1][Ring2]"
                   "[Expl=Ring1][Ring2]"), "C#1CCC#1")  # 3 + 2
    assert is_eq(
        sf.decoder("[C][C][C][C][Expl=Ring1][Ring2]"
                   "[Expl#Ring1][Ring2]"), "C#1CCC#1")  # 2 + 3
    assert is_eq(
        sf.decoder("[C][C][C][C][Expl=Ring1][Ring2]"
                   "[Expl=Ring1][Ring2]"), "C#1CCC#1")  # 2 + 2

    # consecutive rings with stereochemical single bond
    assert sf.decoder("[C][C][C][C][Expl/Ring1][Ring2]") == "C/1CCC/1"
    assert sf.decoder("[C][C][C][C][Expl/Ring1][Ring2][Ring1][Ring2]") \
           == "C=1CCC=1"
Example #14
0
def test_chiral_symbols():
    """Tests that SELFIES symbols with chirality specifications are
    constrained properly.
    """

    assert sf.decoder("[C@@expl][Branch1_1][C][Cl][Branch1_1][C][F]"
                      "[Branch1_1][C][Br][Branch1_1][C][I]") \
           == "[C@@](Cl)(F)(Br)CI"
    assert sf.decoder("[C@Hexpl][Branch1_1][C][Cl][Branch1_1][C][F]"
                      "[Branch1_1][C][Br]") \
           == "[C@H](Cl)(F)CBr"
def test_ring_at_beginning_of_branch():
    """Test SELFIES that have a ring immediately at the start of a branch.
    """

    # CC1CCC(1CCl)F
    assert is_eq(sf.decoder("[C][C][C][C][C][Branch1_1][Branch1_1]"
                            "[Ring1][Ring2][C][Cl][F]"),
                 "CC1CCC1(CCl)F")

    # CC1CCS(Br)(1CCl)F
    assert is_eq(sf.decoder("[C][C][C][C][S][Branch1_1][C][Br]"
                            "[Branch1_1][Branch1_1][Ring1][Ring2][C][Cl][F]"),
                 "CC1CCS1(Br)(CCl)F")
def test_ring_after_branch():
    """Tests SELFIES that have a ring following a branch, but not
    immediately after a branch.
    """

    # CCCCCCC1(OCO)1
    assert is_eq(sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]"
                            "[C][Ring1][Branch1_1]"),
                 "CCCCCCC(OCO)=C")

    assert is_eq(sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]"
                            "[Branch1_1][C][F][C][C][Ring1][Branch2_2]"),
                 "CCCCC1CC(OCO)(F)CC1")
Example #17
0
def test_ring_immediately_following_branch():
    """Test SELFIES that have a ring immediately following after a branch.
    """

    # CCC1CCCC(OCO)1
    assert is_eq(
        sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]"
                   "[Ring1][Branch1_1]"), "CCC1CCCC1(OCO)")

    # CCC1CCCC(OCO)(F)1
    assert is_eq(
        sf.decoder("[C][C][C][C][C][C][C][Branch1_1][Ring2][O][C][O]"
                   "[Branch1_1][C][F][Ring1][Branch1_1]"), "CCC1CCCC1(OCO)(F)")
Example #18
0
def test_change_constraints_cache_clear():
    alphabet = sf.get_semantic_robust_alphabet()
    assert alphabet == sf.get_semantic_robust_alphabet()
    assert sf.decoder("[C][#C]") == "C#C"

    new_constraints = sf.get_semantic_constraints()
    new_constraints["C"] = 1
    sf.set_semantic_constraints(new_constraints)

    new_alphabet = sf.get_semantic_robust_alphabet()
    assert new_alphabet != alphabet
    assert sf.decoder("[C][#C]") == "CC"

    sf.set_semantic_constraints()  # re-set alphabet
Example #19
0
def test_oversized_ring():
    """Test SELFIES that have a ring, with Q so large that the (Q + 1)-th
    previously derived atom does not exist.
    """

    assert is_eq(sf.decoder("[C][C][C][C][Ring1][O]"), "C1CCC1")
    assert is_eq(sf.decoder("[C][C][C][C][Ring2][O][C]"), "C1CCC1")

    # special case: Ring2 takes Q_1 = [O] and Q_2 = '', leading to
    # Q = 9 * 16 + 0 (i.e. an oversized ring)
    assert is_eq(sf.decoder("[C][C][C][C][Ring2][O]"), "C1CCC1")

    # special case: ring between 1st atom and 1st atom should not be formed
    assert is_eq(sf.decoder("[C][Ring1][O]"), "C")
Example #20
0
def test_explicit_hydrogen_symbols():
    """Tests that SELFIES symbols with explicit hydrogen specifications
     are constrained properly.
     """

    assert decode_eq("[CH1][Branch1][C][Cl][#C]", "[CH1](Cl)=C")
    assert decode_eq("[CH3][=C]", "[CH3]C")

    assert decode_eq("[CH4][C][C]", "[CH4]")
    assert decode_eq("[C][C][C][CH4]", "CCC")
    assert decode_eq("[C][Branch1][Ring2][C][=CH4][C][=C]", "C(C)=C")

    with pytest.raises(sf.DecoderError):
        sf.decoder("[C][C][CH5]")
    with pytest.raises(sf.DecoderError):
        sf.decoder("[C][C][C][OH9]")
Example #21
0
def hasher(q, hasher, valid, total, i):
    from rdkit import rdBase
    rdBase.DisableLog('rdApp.error')
    print("Hasher Thread on", i)
    torch.manual_seed(i)
    torch.cuda.manual_seed(i)
    while True:
        if not q.empty():
            smis, count = q.get(block=True)
            total.value += count
            for smi in smis:
                try:
                    smi = selfies.decoder(smi)
                    m = Chem.MolFromSmiles(smi)
                    s = Chem.MolToSmiles(m)
                    if s is not None:
                        valid.value += 1
                        if s in hasher:
                            hasher[s] += 1
                        else:
                            hasher[s] = 1
                except KeyboardInterrupt:
                    print("Bye")
                    exit()
                except:
                    None
Example #22
0
def predict_SMILES(image_path):
    predicted_SELFIES = evaluate(image_path)

    predicted_SMILES = decoder(''.join(predicted_SELFIES).replace("<start>", "").replace("<end>", ""),
                               constraints='hypervalent')

    return predicted_SMILES
def linear_interpolation(x_from, x_to, steps):
    n = steps + 1
    hot = multiple_selfies_to_hot([x_from], largest_molecule_len,
                                  encoding_alphabet)
    x_from = torch.tensor(hot, dtype=torch.float).to(device)
    input_shape1 = x_from.shape[1]
    input_shape2 = x_from.shape[2]
    x_from = x_from.reshape(x_from.shape[0], x_from.shape[1] * x_from.shape[2])
    _, hot = selfies_to_hot(x_to, largest_molecule_len, encoding_alphabet)
    x_to = torch.tensor([hot], dtype=torch.float).to(device)
    x_to = x_to.reshape(x_to.shape[0], x_to.shape[1] * x_to.shape[2])
    t_from = model_encode(x_from)[0]
    t_from = t_from.reshape(1, 1, t_from.shape[1])
    t_to = model_encode(x_to)[0]
    t_to = t_to.reshape(1, 1, t_to.shape[1])
    diff = t_to[0][0] - t_from[0][0]
    inter = torch.zeros((1, n, t_to.shape[2]))
    for i in range(n):
        inter[0][i] = t_from[0][0] + i / steps * diff

    hidden = model_decode.init_hidden(batch_size=n)

    decoded_one_hot = torch.zeros(n, input_shape1, input_shape2).to(device)
    for seq_index in range(input_shape1):
        decoded_one_hot_line, hidden = model_decode(inter, hidden)
        decoded_one_hot[:, seq_index, :] = decoded_one_hot_line[0]

    decoded_one_hot = decoded_one_hot.reshape(n, input_shape1, input_shape2)
    output_mol = []
    _, decoded_mol = decoded_one_hot.max(2)
    for mol in decoded_mol:
        output_mol.append(decoder(hot_to_selfies(mol, encoding_alphabet)))

    return output_mol
def sample_latent_space(latent_dimension, total_samples):
    model_encode.eval()
    model_decode.eval()

    fancy_latent_point = torch.normal(torch.zeros(latent_dimension),
                                      torch.ones(latent_dimension))
    print(fancy_latent_point)
    hidden = model_decode.init_hidden()
    gathered_atoms = []
    for ii in range(
            len_max_molec
    ):  # runs over letters from molecules (len=size of largest molecule)
        fancy_latent_point = fancy_latent_point.reshape(1, 1, latent_dimension)
        fancy_latent_point = fancy_latent_point.to(device)
        decoded_one_hot, hidden = model_decode(fancy_latent_point, hidden)
        decoded_one_hot = decoded_one_hot.flatten()
        decoded_one_hot = decoded_one_hot.detach()
        soft = nn.Softmax(0)
        decoded_one_hot = soft(decoded_one_hot)
        _, max_index = decoded_one_hot.max(0)
        gathered_atoms.append(max_index.data.cpu().numpy().tolist())

    model_encode.train()
    model_decode.train()

    #test molecules visually
    if total_samples <= 5:
        print('Sample #', total_samples,
              decoder(hot_to_selfies(gathered_atoms, encoding_alphabet)))

    return gathered_atoms
Example #25
0
def sf2sm(line):
    words = line.strip().split(".")
    words = ["[" + word.replace(" ", "][") + "]" for word in words]
    new_line = []
    for word in words:
        new_line.append(sf.decoder(word))
    new_line = ".".join(new_line)
    return new_line
Example #26
0
def test_branch_with_no_atoms():
    """Test SELFIES that have a branch, but the branch has no atoms in it.
    Such branches should not be made in the outputted SMILES.
    """

    assert is_eq(
        sf.decoder("[C][Branch1_1][Ring2][Branch1_1]"
                   "[Branch1_1][Branch1_1][F]"), "CF")
    assert is_eq(
        sf.decoder("[C][Branch1_1][Ring2][Ring1]"
                   "[Ring1][Branch1_1][F]"), "CF")
    assert is_eq(sf.decoder("[C][Branch1_2][Ring2][Branch1_1]"
                            "[C][Cl][F]"), "C(Cl)F")

    # special case: Branch3_3 takes Q_1, Q_2 = [O] and Q_3 = ''. However,
    # there are no more symbols in the branch.
    assert is_eq(sf.decoder("[C][C][C][C][Branch3_3][O][O]"), "CCCC")
Example #27
0
def sf2sm(line):
    line = line.replace(" ", "")
    molecules = line.split(".")
    new_line = []
    for molecule in molecules:
        new_molecule = sf.decoder(molecule)
        new_line.append(new_molecule)
    new_line = ".".join(new_line)
    return new_line
Example #28
0
def test_nop_symbol_decoder(max_selfies_len, large_alphabet):
    """Tests that the '[nop]' symbol is always skipped over.
    """

    alphabet = list(large_alphabet)
    alphabet.remove("[nop]")

    for _ in range(100):

        # create random SELFIES with and without [nop]
        rand_len = random.randint(1, max_selfies_len)
        rand_mol = random_choices(alphabet, k=rand_len)
        rand_mol.extend(["[nop]"] * (max_selfies_len - rand_len))
        random.shuffle(rand_mol)

        with_nops = "".join(rand_mol)
        without_nops = with_nops.replace("[nop]", "")

        assert sf.decoder(with_nops) == sf.decoder(without_nops)
Example #29
0
def test_decoder_attribution():
    sm, am = sf.decoder("[C][N][C][Branch1][C][P][C][C][Ring1][=Branch1]",
                        attribute=True)
    # check that P lined up
    for ta in am:
        if ta[0] == 'P':
            for i, v in ta[1]:
                if v == '[P]':
                    return
    raise ValueError('Failed to find P in attribution map')
Example #30
0
def test_old_symbols():
    """Tests backward compatibility of SELFIES with old (<v2) symbols.
    """

    s = "[C@@Hexpl][Branch1_2][Branch1_1][Branch1_1][C][C][Cl][F]"
    assert sf.decoder(s, compatible=True) == "[C@@H1](C)(Cl)F"

    s = "[C][C][C][C][Expl=Ring1][Ring2][Expl#Ring1][Ring2]"
    assert sf.decoder(s, compatible=True) == "C#1CCC#1"

    long_s = "[C@@Hexpl][=C][C@@Hexpl][N+expl][=C][C+expl][N+expl][O+expl]" \
             "[Fe++expl][C@@Hexpl][C][N+expl][Branch1_2][Fe++expl][S+expl]" \
             "[=C][Expl=Ring1][Fe++expl][S+expl][Expl=Ring1][O+expl]" \
             "[C@@Hexpl][Expl=Ring1][C@@Hexpl][C@@Hexpl][N+expl][Expl=Ring1]" \
             "[Expl=Ring1][S+expl][=C]"
    try:
        sf.decoder(long_s, compatible=True)
    except Exception:
        assert False