def read_cif(cif):

    atoms = read(cif, format='cif')
    atoms.set_pbc(True)
    cutoffs = neighborlist.natural_cutoffs(atoms)
    unit_cell = atoms.get_cell()

    neighborlist.primitive_neighbor_list
    NL = neighborlist.NewPrimitiveNeighborList(cutoffs,
                                               use_scaled_positions=True,
                                               self_interaction=False,
                                               skin=0.1)
    NL.build([True, True, True], unit_cell, atoms.get_scaled_positions())

    G = nx.Graph()

    for a, fcoord in zip(atoms, atoms.get_scaled_positions()):
        G.add_node(a.symbol + str(a.index),
                   element_symbol=a.symbol,
                   fcoord=fcoord)

    for a in G.nodes():

        nbors = [
            atoms[i].symbol + str(atoms[i].index)
            for i in NL.get_neighbors(int(nl(a)))[0]
        ]

        for nbor in nbors:
            G.add_edge(a, nbor)

    return G, np.asarray(unit_cell).T
def atoms2graph(atoms, kwargs={}):

    unit_cell = atoms.get_cell()
    cutoffs = neighborlist.natural_cutoffs(atoms)
    NL = neighborlist.NewPrimitiveNeighborList(
        cutoffs, self_interaction=False,
        skin=0.1)  # default atom cutoffs work well
    NL.build([False, False, False], unit_cell, atoms.positions)

    G = nx.Graph()

    for a, ccoord in zip(atoms, atoms.positions):
        ind = a.index + 1
        G.add_node(a.symbol + str(ind),
                   element_symbol=a.symbol,
                   ccoord=ccoord,
                   fcoord=np.array([0.0, 0.0, 0.0]),
                   freeze_code='0',
                   **kwargs)

    for a in G.nodes():

        nbors = [
            atoms[i].symbol + str(atoms[i].index + 1)
            for i in NL.get_neighbors(int(nl(a)) - 1)[0]
        ]

        for nbor in nbors:
            G.add_edge(a, nbor, bond_type='', bond_sym='.', bond_length=0)

    return G
def add_small_molecules(FF, ff_string):

    if ff_string == 'TraPPE':
        SM_constants = small_molecule_constants.TraPPE
    elif ff_string == 'TIP4P_2005_long':
        SM_constants = small_molecule_constants.TIP4P_2005_long
        FF.pair_data['special_bonds'] = 'lj 0.0 0.0 1.0 coul 0.0 0.0 0.0'
    elif ff_string == 'TIP4P_cutoff':
        SM_constants = small_molecule_constants.TIP4P_cutoff
        FF.pair_data['special_bonds'] = 'lj/coul 0.0 0.0 1.0'
    elif ff_string == 'TIP4P_2005_cutoff':
        SM_constants = small_molecule_constants.TIP4P_cutoff
        FF.pair_data['special_bonds'] = 'lj/coul 0.0 0.0 1.0'
    elif ff_string == 'Ions':
        SM_constants = small_molecule_constants.Ions
        FF.pair_data['special_bonds'] = 'lj/coul 0.0 0.0 1.0'
    # insert more force fields here if needed
    else:
        raise ValueError('the small molecule force field', ff_string,
                         'is not defined')

    SG = FF.system['graph']
    SMG = FF.system['SM_graph']

    if len(SMG.nodes()) > 0 and len(SMG.edges()) == 0:

        print(
            'there are no small molecule bonds in the CIF, calculating based on covalent radii...'
        )
        atoms = Atoms()

        offset = min(SMG.nodes())

        for node, data in SMG.nodes(data=True):
            #print(node, data)
            atoms.append(
                Atom(data['element_symbol'], data['cartesian_position']))

        atoms.set_cell(FF.system['box'])
        unit_cell = atoms.get_cell()
        cutoffs = neighborlist.natural_cutoffs(atoms)
        NL = neighborlist.NewPrimitiveNeighborList(
            cutoffs,
            use_scaled_positions=False,
            self_interaction=False,
            skin=0.10)  # shorten the cutoff a bit
        NL.build([True, True, True], unit_cell, atoms.get_positions())

        for i in atoms:

            nbors = NL.get_neighbors(i.index)[0]

            for j in nbors:

                bond_length = get_distances(i.position,
                                            p2=atoms[j].position,
                                            cell=unit_cell,
                                            pbc=[True, True, True])
                bond_length = np.round(bond_length[1][0][0], 3)
                SMG.add_edge(i.index + offset,
                             j + offset,
                             bond_length=bond_length,
                             bond_order='1.0',
                             bond_type='S')

        NMOL = len(list(nx.connected_components(SMG)))
        print(NMOL, 'small molecules were recovered after bond calculation')

    mol_flag = 1
    max_ind = FF.system['max_ind']
    index = max_ind

    box = FF.system['box']
    a, b, c, alpha, beta, gamma = box
    pi = np.pi
    ax = a
    ay = 0.0
    az = 0.0
    bx = b * np.cos(gamma * pi / 180.0)
    by = b * np.sin(gamma * pi / 180.0)
    bz = 0.0
    cx = c * np.cos(beta * pi / 180.0)
    cy = (c * b * np.cos(alpha * pi / 180.0) - bx * cx) / by
    cz = (c**2.0 - cx**2.0 - cy**2.0)**0.5
    unit_cell = np.asarray([[ax, ay, az], [bx, by, bz], [cx, cy, cz]]).T
    inv_unit_cell = np.linalg.inv(unit_cell)

    add_nodes = []
    add_edges = []
    comps = []

    for comp in nx.connected_components(SMG):

        mol_flag += 1
        comp = sorted(list(comp))
        ID_string = sorted([SMG.nodes[n]['element_symbol'] for n in comp])
        ID_string = [(key, len(list(group)))
                     for key, group in groupby(ID_string)]
        ID_string = ''.join([str(e) for c in ID_string for e in c])
        comps.append(ID_string)

        for n in comp:

            data = SMG.nodes[n]

            SMG.nodes[n]['mol_flag'] = str(mol_flag)

            if ID_string == 'H2O1':
                SMG.nodes[n][
                    'force_field_type'] = SMG.nodes[n]['element_symbol'] + '_w'
            else:
                SMG.nodes[n]['force_field_type'] = SMG.nodes[n][
                    'element_symbol'] + '_' + ID_string

        # add COM sites where relevant, extend this list as new types are added
        if ID_string in ('O2', 'N2'):

            coords = []
            anchor = SMG.nodes[comp[0]]['fractional_position']

            for n in comp:

                data = SMG.nodes[n]
                data['mol_flag'] = str(mol_flag)
                fcoord = data['fractional_position']
                mic = PBC3DF_sym(fcoord, anchor)
                fcoord += mic[1]
                ccoord = np.dot(unit_cell, fcoord)
                coords.append(ccoord)

            ccom = np.average(coords, axis=0)
            fcom = np.dot(inv_unit_cell, ccom)
            index += 1

            if ID_string == 'O2':
                fft = 'O_com'
            elif ID_string == 'N2':
                fft = 'N_com'

            ndata = {
                'element_symbol': 'NA',
                'mol_flag': mol_flag,
                'index': index,
                'force_field_type': fft,
                'cartesian_position': ccom,
                'fractional_position': fcom,
                'charge': 0.0,
                'replication': np.array([0.0, 0.0, 0.0]),
                'duplicated_version_of': None
            }
            edata = {'sym_code': None, 'length': None, 'bond_type': None}

            add_nodes.append([index, ndata])
            add_edges.extend([(index, comp[0], edata),
                              (index, comp[1], edata)])

    for n, data in add_nodes:
        SMG.add_node(n, **data)
    for e0, e1, data in add_edges:
        SMG.add_edge(e0, e1, **data)

    ntypes = max([FF.atom_types[ty] for ty in FF.atom_types])
    maxatomtype_wsm = max([FF.atom_types[ty] for ty in FF.atom_types])

    maxbondtype_wsm = max([bty for bty in FF.bond_data['params']])
    maxangletype_wsm = max([aty for aty in FF.angle_data['params']])

    nbonds = max([i for i in FF.bond_data['params']])
    nangles = max([i for i in FF.angle_data['params']])

    try:
        ndihedrals = max([i for i in FF.dihedral_data['params']])
    except ValueError:
        ndihedrals = 0
    try:
        nimpropers = max([i for i in FF.improper_data['params']])
    except ValueError:
        nimpropers = 0

    new_bond_types = {}
    new_angle_types = {}
    new_dihedral_types = {}
    new_improper_types = {}

    for subG, ID_string in zip(
        [SMG.subgraph(c).copy() for c in nx.connected_components(SMG)], comps):

        constants = SM_constants[ID_string]

        # add new atom types
        for name, data in sorted(subG.nodes(data=True), key=lambda x: x[0]):

            fft = data['force_field_type']
            chg = constants['pair']['charges'][fft]
            data['charge'] = chg
            SG.add_node(name, **data)

            try:

                FF.atom_types[fft] += 0

            except KeyError:

                ntypes += 1
                FF.atom_types[fft] = ntypes
                style = constants['pair']['style']
                vdW = constants['pair']['vdW'][fft]
                FF.pair_data['params'][FF.atom_types[fft]] = (style, vdW[0],
                                                              vdW[1])
                FF.pair_data['comments'][FF.atom_types[fft]] = [fft, fft]
                FF.atom_masses[fft] = mass_key[data['element_symbol']]

                if 'hybrid' not in FF.pair_data[
                        'style'] and style != FF.pair_data['style']:
                    FF.pair_data['style'] = ' '.join(
                        ['hybrid', FF.pair_data['style'], style])
                elif 'hybrid' in FF.pair_data[
                        'style'] and style in FF.pair_data['style']:
                    pass
                elif 'hybrid' in FF.pair_data[
                        'style'] and style not in FF.pair_data['style']:
                    FF.pair_data['style'] += ' ' + style

        # add new bonds
        used_bonds = []
        ty = nbonds
        for e0, e1, data in subG.edges(data=True):

            bonds = constants['bonds']
            fft_i = SG.nodes[e0]['force_field_type']
            fft_j = SG.nodes[e1]['force_field_type']
            # make sure the order corresponds to that in the molecule dictionary
            bond = tuple(sorted([fft_i, fft_j]))

            try:

                style = bonds[bond][0]

                if bond not in used_bonds:

                    ty = ty + 1
                    new_bond_types[bond] = ty
                    FF.bond_data['params'][ty] = list(bonds[bond])
                    FF.bond_data['comments'][ty] = list(bond)

                    used_bonds.append(bond)

                if 'hybrid' not in FF.bond_data[
                        'style'] and style != FF.bond_data['style']:
                    FF.bond_data['style'] = ' '.join(
                        ['hybrid', FF.bond_data['style'], style])
                elif 'hybrid' in FF.bond_data[
                        'style'] and style in FF.bond_data['style']:
                    pass
                elif 'hybrid' in FF.bond_data[
                        'style'] and style not in FF.bond_data['style']:
                    FF.bond_data['style'] += ' ' + style

                if ty in FF.bond_data['all_bonds']:
                    FF.bond_data['count'] = (FF.bond_data['count'][0] + 1,
                                             FF.bond_data['count'][1] + 1)
                    FF.bond_data['all_bonds'][ty].append((e0, e1))
                else:
                    FF.bond_data['count'] = (FF.bond_data['count'][0] + 1,
                                             FF.bond_data['count'][1] + 1)
                    FF.bond_data['all_bonds'][ty] = [(e0, e1)]

            except KeyError:
                pass

        # add new angles
        used_angles = []
        ty = nangles
        for name, data in subG.nodes(data=True):

            angles = constants['angles']
            nbors = list(subG.neighbors(name))

            for comb in combinations(nbors, 2):

                j = name
                i, k = comb
                fft_i = subG.nodes[i]['force_field_type']
                fft_j = subG.nodes[j]['force_field_type']
                fft_k = subG.nodes[k]['force_field_type']

                angle = sorted((fft_i, fft_k))
                angle = (angle[0], fft_j, angle[1])

                try:

                    style = angles[angle][0]
                    FF.angle_data['count'] = (FF.angle_data['count'][0] + 1,
                                              FF.angle_data['count'][1])

                    if angle not in used_angles:

                        ty = ty + 1
                        new_angle_types[angle] = ty
                        FF.angle_data['count'] = (FF.angle_data['count'][0],
                                                  FF.angle_data['count'][1] +
                                                  1)
                        FF.angle_data['params'][ty] = list(angles[angle])
                        FF.angle_data['comments'][ty] = list(angle)

                        used_angles.append(angle)

                    if 'hybrid' not in FF.angle_data[
                            'style'] and style != FF.angle_data['style']:
                        FF.angle_data['style'] = ' '.join(
                            ['hybrid', FF.angle_data['style'], style])
                    elif 'hybrid' in FF.angle_data[
                            'style'] and style in FF.angle_data['style']:
                        pass
                    elif 'hybrid' in FF.angle_data[
                            'style'] and style not in FF.angle_data['style']:
                        FF.angle_data['style'] += ' ' + style

                    if ty in FF.angle_data['all_angles']:
                        FF.angle_data['all_angles'][ty].append((i, j, k))
                    else:
                        FF.angle_data['all_angles'][ty] = [(i, j, k)]

                except KeyError:
                    pass

        # add new dihedrals

    FF.bond_data['count'] = (FF.bond_data['count'][0],
                             len(FF.bond_data['params']))
    FF.angle_data['count'] = (FF.angle_data['count'][0],
                              len(FF.angle_data['params']))

    if 'tip4p' in FF.pair_data['style']:

        for ty, pair in FF.pair_data['comments'].items():
            fft = pair[0]
            if fft == 'O_w':
                FF.pair_data['O_type'] = ty
            if fft == 'H_w':
                FF.pair_data['H_type'] = ty

        for ty, bond in FF.bond_data['comments'].items():
            if sorted(bond) == ['H_w', 'O_w']:
                FF.pair_data['H2O_bond_type'] = ty

        for ty, angle in FF.angle_data['comments'].items():
            if angle == ['H_w', 'O_w', 'H_w']:
                FF.pair_data['H2O_angle_type'] = ty

        if 'long' in FF.pair_data['style']:
            FF.pair_data[
                'M_site_dist'] = 0.1546  # only TIP4P/2005 is implemented
        elif 'cut' in FF.pair_data[
                'style'] and ff_string == 'TIP4P_2005_cutoff':
            FF.pair_data['M_site_dist'] = 0.1546
        elif 'cut' in FF.pair_data['style'] and ff_string == 'TIP4P_cutoff':
            FF.pair_data['M_site_dist'] = 0.1500
def cif_read_pymatgen(filename, charges=False, coplanarity_tolerance=0.1):

    valencies = {
        'C': 4.0,
        'Si': 4.0,
        'Ge': 4.0,
        'N': 3.0,
        'P': 3.0,
        'As': 3.0,
        'Sb': 3.0,
        'O': 2.0,
        'S': 2.0,
        'Se': 2.0,
        'Te': 2.0,
        'F': 1.0,
        'Cl': 1.0,
        'Br': 1.0,
        'I': 1.0,
        'H': 1.0,
        'X': 1.0
    }

    bond_types = {0.5: 'S', 1.0: 'S', 1.5: 'A', 2.0: 'D', 3.0: 'T'}

    with open(filename, 'r') as f:
        f = f.read()
        f = filter(None, f.split('\n'))

    charge_list = []
    charge_switch = False

    for line in f:

        s = line.split()
        if '_atom_site_charge' in s:
            charge_switch = True
        if '_loop' in s:
            charge_switch = False

        if len(s) > 5:
            if charges:
                if charge_switch:
                    charge_list.append(float(s[-1]))

    cif = CifParser(filename)
    struct = cif.get_structures(primitive=False)[0]
    atoms = AseAtomsAdaptor.get_atoms(struct)
    unit_cell = atoms.get_cell()
    inv_uc = inv(unit_cell.T)
    elements = atoms.get_chemical_symbols()
    small_skin_metals = ('Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy',
                         'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Ba', 'La')

    if any(i in elements for i in ('Zn')):
        skin = 0.30
    if any(i in elements for i in small_skin_metals):
        skin = 0.05
    else:
        skin = 0.20

    print('skin for bond calculation is', skin)

    if not charges:
        charge_list = [0.0 for a in atoms]

    cutoffs = neighborlist.natural_cutoffs(atoms)
    NL = neighborlist.NewPrimitiveNeighborList(
        cutoffs, use_scaled_positions=False, self_interaction=False,
        skin=skin)  # default atom cutoffs work well
    NL.build([True, True, True], unit_cell, atoms.get_positions())

    G = nx.Graph()
    for a in atoms:
        G.add_node(a.index, element_symbol=a.symbol, position=a.position)

    for i in atoms:

        nbors = NL.get_neighbors(i.index)[0]
        isym = i.symbol

        for j in nbors:

            jsym = atoms[j].symbol

            if (isym not in metals) and (jsym not in metals) and not any(
                    e == 'X' for e in [isym, jsym]):
                try:
                    bond = bonds.CovalentBond(struct[i.index], struct[j])
                    bond_order = bond.get_bond_order()
                except ValueError:
                    bond_order = 1.0
            elif (isym == 'X' or jsym
                  == 'X') and (isym not in metals) and (jsym not in metals):
                bond_order = 1.0
            else:
                bond_order = 0.5

            bond_length = get_distances(i.position,
                                        p2=atoms[j].position,
                                        cell=unit_cell,
                                        pbc=[True, True, True])
            bond_length = np.round(bond_length[1][0][0], 3)

            G.add_edge(i.index,
                       j,
                       bond_length=bond_length,
                       bond_order=bond_order,
                       bond_type='',
                       pymatgen_bond_order=bond_order)

    NMG = G.copy()
    edge_list = list(NMG.edges())

    for e0, e1 in edge_list:

        sym0 = NMG.nodes[e0]['element_symbol']
        sym1 = NMG.nodes[e1]['element_symbol']

        if sym0 in metals or sym1 in metals:
            NMG.remove_edge(e0, e1)

    for i, data in G.nodes(data=True):

        isym = data['element_symbol']
        nbors = list(G.neighbors(i))
        nbor_symbols = [G.nodes[n]['element_symbol'] for n in nbors]
        nonmetal_nbor_symbols = [n for n in nbor_symbols if n not in metals]

        # remove C-M bonds if C is also bonded to carboxylate atoms, these are almost always wrong
        for n, nsym in zip(nbors, nbor_symbols):
            if isym == 'C' and sorted(nonmetal_nbor_symbols) == [
                    'C', 'O', 'O'
            ] and nsym in metals:
                G.remove_edge(i, n)

    ### intial bond typing, guessed from rounding pymatgen bond orders
    linkers = nx.connected_components(NMG)
    aromatic_atoms = []
    for linker in linkers:

        SG = NMG.subgraph(linker)

        for i, data in SG.nodes(data=True):

            isym = data['element_symbol']
            nbors = list(G.neighbors(i))
            nbor_symbols = [G.nodes[n]['element_symbol'] for n in nbors]
            nonmetal_nbor_symbols = [
                n for n in nbor_symbols if n not in metals
            ]
            CB = nx.cycle_basis(SG)

            check_cycles = True
            if len(CB) < 3:
                check_cycles = False

            cyloc = None
            if check_cycles:
                for cy in range(len(CB)):
                    if i in CB[cy]:
                        cyloc = cy

            bond_orders = []
            for n, nsym in zip(nbors, nbor_symbols):

                edge_data = G[n][i]
                bond_order = edge_data['bond_order']

                if bond_order < 1.0 and bond_order != 0.5:
                    bond_order = 1.0
                # shortest observed single bond had order 1.321
                elif 1.00 <= bond_order < 1.33:
                    bond_order = 1.0
                elif 1.33 <= bond_order < 1.75:
                    bond_order = 1.5
                # bond orders tend to be on the high end for aromatic compounds
                elif 1.75 <= bond_order < 2.00:
                    bond_order = 1.5
                elif 2.00 <= bond_order < 3.00:
                    bond_order = round(bond_order)

                # bonds between two disparate cycles or cycles and non-cycles should have order 1.0
                if check_cycles and cyloc != None:
                    if n not in CB[cyloc]:
                        bond_order = 1.0

                if any(i in metals
                       for i in nbor_symbols) and isym == 'O' and nsym == 'C':
                    bond_order = 1.5

                if nsym in metals:
                    bond_order = 0.5

                if isym == 'C' and len(nbor_symbols) == 4:
                    bond_order = 1.0

                if isym == 'C' and sorted(nbor_symbols) == ['C', 'O', 'O']:
                    if nsym == 'C':
                        bond_order = 1.0
                    elif nsym == 'O':
                        bond_order = 1.5
                    else:
                        pass

                edge_data['bond_order'] = bond_order
                bond_orders.append(bond_order)
                edge_data['bond_type'] = bond_types[bond_order]

        all_cycles = nx.simple_cycles(nx.to_directed(SG))
        all_cycles = set(
            [tuple(sorted(cy)) for cy in all_cycles if len(cy) > 4])

        ### assign aromatic bond orders as 1.5 (in most cases they will be already)
        for cycle in all_cycles:

            # rotate the ring normal vec onto the z-axis to determine coplanarity
            coords = np.array([G.nodes[c]['position'] for c in cycle])
            fcoords = np.dot(inv_uc, coords.T).T
            anchor = fcoords[0]
            fcoords = np.array(
                [vec - PBC3DF_sym(anchor, vec)[1] for vec in fcoords])
            coords = np.dot(unit_cell.T, fcoords.T).T

            coords -= np.average(coords, axis=0)

            vec0 = coords[0]
            vec1 = coords[1]

            normal = np.cross(vec0, vec1)
            RZ = M(normal, np.array([0.0, 0.0, 1.0]))
            coords = np.dot(RZ, coords.T).T
            maxZ = max([abs(z) for z in coords[:, -1]])

            # if coplanar make all bond orders 1.5
            if maxZ < coplanarity_tolerance:

                aromatic_atoms.extend(list(cycle))
                cycle_subgraph = SG.subgraph(cycle)

                for e0, e1 in cycle_subgraph.edges():
                    G[e0][e1]['bond_order'] = 1.5

    for i, data in G.nodes(data=True):

        isym = data['element_symbol']
        bond_orders = [G[i][n]['bond_order'] for n in G.neighbors(i)]
        total_bond_order = np.sum(bond_orders)
        bond_orders = [str(o) for o in bond_orders]
        nbor_symbols = ' '.join(
            [G.nodes[n]['element_symbol'] for n in G.neighbors(i)])

        if isym not in metals and total_bond_order != valencies[isym]:
            message = ' '.join([
                str(isym), 'has total bond order',
                str(total_bond_order), 'with neighbors', nbor_symbols,
                'and bond orders'
            ] + bond_orders)
            warnings.warn(message)

    elems = atoms.get_chemical_symbols()
    names = [a.symbol + str(a.index) for a in atoms]
    ccoords = atoms.get_positions()
    fcoords = atoms.get_scaled_positions()

    bond_list = []
    for e0, e1, data in G.edges(data=True):

        sym0 = G.nodes[e0]['element_symbol']
        sym1 = G.nodes[e1]['element_symbol']

        name0 = sym0 + str(e0)
        name1 = sym1 + str(e1)

        bond_list.append(
            [name0, name1, '.', data['bond_type'], data['bond_length']])

    A, B, C = unit_cell.lengths()
    alpha, beta, gamma = unit_cell.angles()

    return elems, names, ccoords, fcoords, charge_list, bond_list, (
        A, B, C, alpha, beta, gamma), np.asarray(unit_cell).T