def get_atom_hydrogen_bonding_one_hot(
    atom: RDKitAtom, hydrogen_bonding: List[Tuple[int, str]]) -> List[float]:
  """Get an one-hot feat about whether an atom accepts electrons or donates electrons.

  Parameters
  ---------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  hydrogen_bonding: List[Tuple[int, str]]
    The return value of `construct_hydrogen_bonding_info`.
    The value is a list of tuple `(atom_index, hydrogen_bonding)` like (1, "Acceptor").

  Returns
  -------
  List[float]
    A one-hot vector of the ring size type. The first element
    indicates "Donor", and the second element indicates "Acceptor".
  """
  one_hot = [0.0, 0.0]
  atom_idx = atom.GetIdx()
  for hydrogen_bonding_tuple in hydrogen_bonding:
    if hydrogen_bonding_tuple[0] == atom_idx:
      if hydrogen_bonding_tuple[1] == "Donor":
        one_hot[0] = 1.0
      elif hydrogen_bonding_tuple[1] == "Acceptor":
        one_hot[1] = 1.0
  return one_hot
Beispiel #2
0
def get_atom_ring_size_one_hot(
        atom: RDKitAtom,
        sssr: Sequence,
        allowable_set: List[int] = DEFAULT_RING_SIZE_SET,
        include_unknown_set: bool = False) -> List[float]:
    """Get an one-hot feature about the ring size if an atom is in a ring.

  Parameters
  ---------
  atom: rdkit.Chem.rdchem.Atom
    RDKit atom object
  sssr: Sequence
    The return value of `Chem.GetSymmSSSR(mol)`.
    The value is a sequence of rings.
  allowable_set: List[int]
    The ring size types to consider. The default set is `[3, 4, ..., 8]`.
  include_unknown_set: bool, default False
    If true, the index of all types not in `allowable_set` is `len(allowable_set)`.

  Returns
  -------
  List[float]
    A one-hot vector of the ring size type.
    If `include_unknown_set` is False, the length is `len(allowable_set)`.
    If `include_unknown_set` is True, the length is `len(allowable_set) + 1`.
  """
    one_hot = [0.0 for _ in range(len(allowable_set))]
    atom_index = atom.GetIdx()
    if atom.IsInRing():
        for ring in sssr:
            ring = list(ring)
            if atom_index in ring:
                ring_size = len(ring)
                try:
                    one_hot[DEFAULT_RING_SIZE_SET.index(ring_size)] = 1.0
                except:
                    pass
    return one_hot