def test_remove_duplicates__pick_random_removes_at_random():
    random.seed(1)
    assert remove_duplicates([(3, 2, 1), (3, 6, 5), (3, 5, 6)],
                             pick_random=True) == [(3, 2, 1), (3, 6, 5)]
    random.seed(4)
    assert remove_duplicates([(3, 2, 1), (3, 6, 5), (3, 5, 6)],
                             pick_random=True) == [(3, 2, 1), (3, 5, 6)]
def test_remove_duplicates__leaves_order_untouched():
    assert remove_duplicates([(3, 2, 1)]) == [(3, 2, 1)]
    assert remove_duplicates([(3, 2, 1), (1, 2, 3)]) == [(3, 2, 1)]
def test_remove_duplicates__with_constant_key_is_one_item():
    assert remove_duplicates([(3, 2, 1), (3, 6, 5), (3, 6, 5)],
                             key=lambda m: "asdf") == [(3, 2, 1)]
def test_remove_duplicates__with_no_key_uses_sorted_tuples():
    assert remove_duplicates([(3, 2, 1), (3, 6, 5), (3, 5, 6)]) == [(3, 2, 1),
                                                                    (3, 6, 5)]
Exemple #5
0
def find_pattern_in_structure(structure,
                              pattern,
                              return_positions=False,
                              abstol=5e-2,
                              verbose=False):
    """Looks for instances of `pattern` in `structure`, where a match in the structure has the same number
    of atoms, the same elements and the same relative coordinates as in `pattern`.

    Returns a list of tuples, one tuple per match found in `structure` where each tuple has the size
    `len(pattern)` and contains the indices in the structure that matched the pattern. If
    `return_postions=True` then  an additional list is returned containing positions for each
    matched index for each match.

    Args:
        structure (Atoms): an Atoms object to search in.
        pattern (Atoms): an Atoms object to search for.
        return_positions (bool): additionally returns the positions for each index
        abstol (float): the absolute tolerance (how close an atom must be in the structure to the position in pattern to be consdired a match).
        verbose (bool): print debugging info.
    Returns:
        List [tuple(len(pattern))]: returns a tuple of size `len(pattern)` containing the indices in structure that matched the pattern, one tuple per each match.
    """

    if verbose:
        print("calculating point distances...")
    p_ss = distance.cdist(pattern.positions, pattern.positions, "sqeuclidean")
    pattern_length = p_ss.max()**0.5 + 2 * abstol
    s_types_view, index_mapper, s_pos_view, s_positions = get_types_ss_map_limited_near_uc(
        structure, pattern_length)
    atoms_by_type = atoms_by_type_dict(s_types_view)

    # created sorted coords array for creating search subsets
    p = np.array(sorted([(*r, i) for i, r in enumerate(s_pos_view)]))

    # Search instances of first atom in a search pattern
    # 0,0,0 uc atoms are always indexed first from 0 to # atoms in structure.
    starting_atoms = [
        idx for idx in atoms_of_type(s_types_view[0:len(structure)],
                                     pattern.elements[0])
    ]
    if verbose:
        print(
            "round %d (%d) [%s]: " %
            (0, len(starting_atoms), pattern.elements[0]), starting_atoms)

    def get_nearby_atoms(p, s_pos_view, pattern_length, a):
        p1 = p[(p[:, 0] <= s_pos_view[a][0] + pattern_length)
               & (p[:, 0] >= s_pos_view[a][0] - pattern_length)]
        p2 = p1[(p1[:, 1] <= s_pos_view[a][1] + pattern_length)
                & (p1[:, 1] >= s_pos_view[a][1] - pattern_length)]
        return (p2[(p2[:, 2] <= s_pos_view[a][2] + pattern_length)
                   & (p2[:, 2] >= s_pos_view[a][2] - pattern_length)])

    pattern_elements = pattern.elements
    all_match_index_tuples = []
    for a_idx, a in enumerate(starting_atoms):
        match_index_tuples = [[a]]

        nearby = get_nearby_atoms(p, s_pos_view, pattern_length, a)
        nearby_atom_indices = nearby[:, 3].astype(np.int32)
        nearby_positions = nearby[:, 0:3]

        idx2ssidx = {
            atom_idx: i
            for i, atom_idx in enumerate(nearby_atom_indices)
        }
        s_ss = distance.cdist(nearby_positions, nearby_positions,
                              "sqeuclidean")

        # start for loop at one since we've searched for starting atoms (index == 0) above
        for i in range(1, len(pattern)):
            if len(match_index_tuples) == 0:
                break
            last_match_index_tuples = match_index_tuples
            match_index_tuples = []
            for match in last_match_index_tuples:
                for ss_idx, atom_idx in enumerate(nearby_atom_indices):
                    if s_types_view[atom_idx] == pattern_elements[i]:
                        found_match = True
                        # check all distances to this new proposed atom
                        for j in range(0, i):
                            if not math.isclose(p_ss[i, j]**0.5,
                                                s_ss[idx2ssidx[match[j]],
                                                     ss_idx]**0.5,
                                                abs_tol=abstol):
                                found_match = False
                                break

                        # anything that matches the distance to all prior pattern atoms is a good match so far
                        if found_match:
                            match_index_tuples.append(match + [atom_idx])
            if verbose:
                print(
                    "round %d (%d) [%s]: " %
                    (i, len(match_index_tuples), pattern.elements[i]),
                    match_index_tuples)
        if verbose:
            print("starting atom %d: found %d matches: %s" %
                  (a_idx, len(match_index_tuples), match_index_tuples))
        all_match_index_tuples += match_index_tuples

    all_match_index_tuples = remove_duplicates(
        all_match_index_tuples,
        key=lambda m: tuple(
            sorted([index_mapper[i] % len(structure) for i in m])))

    match_index_tuples_in_uc = [
        tuple([index_mapper[m] % len(structure) for m in match])
        for match in all_match_index_tuples
    ]
    if return_positions:
        match_index_tuple_positions = np.array(
            [[s_positions[index_mapper[m]] for m in match]
             for match in all_match_index_tuples])
        return match_index_tuples_in_uc, match_index_tuple_positions
    else:
        return match_index_tuples_in_uc