def count_province_union_find(is_connected: List[List[int]]) -> int:
    """
    :param is_connected: connection matrix of 0 and 1 such that is_connected[i][j] = 1 if i and j are directly connected
    :return: number of connected components
    """
    union_find_object = UnionFindArray(len(is_connected))
    for i, connected_i in enumerate(is_connected):
        for j, c_ij in enumerate(connected_i[i + 1:], start=i + 1):
            if c_ij:
                union_find_object.unify(i, j)
    return union_find_object.components_count()
def min_swap_couples_union_find(row: List[int]) -> int:
    """
    * Unify persons sitting in adjacent seats, row[2 * i] and row[2 * i + 1]
    * Unify persons that are couples, 2 * i and 2 * i + 1
    If swaps are needed, there must be some closed cycles where swaps occurred inside the closed cycles

    Now the goal is to compute the minimum number of swaps for each closed cycle then sum them up to get the final
     answer

    Claim 1: For any closed cycle with m persons, we need at most m // 2 - 1 swaps to make all couples in the cycle
        matched
    Proof: For any person, we know whom this person is paired with, so all we need to do it ot just find their partner
        and swap. The worst case is that we need to do thus for m // 2 -1 times, because once we have the first
        m // 2 -1 pairs matched, the last pair is also matched automatically

    Claim 2: For any closed cycle with m persons, we need no less than m // 2 - 1 swaps to make all couple in the cycle
        matched.
    Proof by contradiction: If this was true, then there must be one time when we get two pairs matched with one swap,
        that means there exist two couples which formed a closed cycle among themselves and can be removed from the
        whole closed cycle we start with. This is a contradiction because we start with a cycle that cannot be separate
        further

    Conclusion: with claim 1 and 2: for each cycle, we need exactly m // 2 - 1 swaps to make all couples paired.

    :param row: 4 <= len(row) <= 60, and row is a permutation of range(len(row))
    :return: minimum number of swaps to make couples (2 * i, 2 * i + 1) sit together
    """
    union_find_object = UnionFindArray(len(row))
    for i in range(0, len(row), 2):
        union_find_object.unify(i, i + 1)
        union_find_object.unify(row[i], row[i + 1])

    return sum(size_i // 2 - 1
               for size_i in [union_find_object.component_size(i)
                              for i in range(len(row)) if union_find_object.find(i) == i])
def find_redundant_connections_union_find(
        edges: List[Tuple[int, int]]) -> Tuple[int, int]:
    """
    Tree with 3 <= n <= 1000 nodes, with an additional graph that caused a cycle (i.e. the graph is connected)

    :param edges: list of (ai, bi) representing an edge between ai and bi,
        1 <= ai < bi < len(edges),
        ai != bi,
        no repeated edges
    :return: an edge of the cycle that showed last in the edges list
    """
    nodes_list = UnionFindArray(len(edges) + 1, use_recursion=True)
    for edge in edges:
        if nodes_list.is_connected(*edge):
            return edge
        nodes_list.unify(*edge)
Beispiel #4
0
def accounts_merge_union_find(accounts: List[List[str]]) -> List[List[str]]:
    """
    :param accounts: List of ["name", "email1", "email2", ...].
        - 1 <= len(accounts) <= 1000
        - 2 <= len(accounts[i]) <= 10
        - 1 <= len(accounts[i][j]) <= 30
    :return: post merge list of ["name", "email1", "email2", ...] where emails are sorted. Accounts can be returned in
        any order
    """
    union_find_object = UnionFindArray(10001, True)
    email_to_name_lookup = {}
    email_to_id_lookup = {}
    counter = 0

    for account_i in accounts:
        person_name = account_i[0]
        main_email = account_i[1]

        if main_email not in email_to_id_lookup:
            email_to_name_lookup[main_email] = person_name
            email_to_id_lookup[main_email] = main_email_id = counter
            counter += 1
        else:
            main_email_id = email_to_id_lookup[main_email]

        for email in account_i[2:]:
            if email not in email_to_id_lookup:
                email_to_name_lookup[email] = person_name
                email_to_id_lookup[email] = mail_id = counter
                counter += 1
            else:
                mail_id = email_to_id_lookup[email]
            union_find_object.unify(main_email_id, mail_id)

    union_id_emails: DefaultDict[int, List[str]] = defaultdict(list)
    for email in email_to_id_lookup:
        union_id_emails[union_find_object.find(
            email_to_id_lookup[email])].append(email)

    return [[email_to_name_lookup[email_list[0]]] + sorted(email_list)
            for email_list in union_id_emails.values()]
Beispiel #5
0
def largest_component_size(num_list: List[int]) -> int:
    """
    :param num_list: a non-empty array of unique positive integers A
    :return: largest component size in the graph
    """
    union_find = UnionFindArray(len(num_list), use_recursion=True)
    primes_list = defaultdict(list)

    for i, num in enumerate(num_list):
        n_prime_set = get_prime_set(num)
        for p_n in n_prime_set:
            primes_list[p_n].append(i)

    for _, indexes in primes_list.items():
        for i in range(len(indexes) - 1):
            union_find.unify(indexes[i], indexes[i + 1])

    return max(union_find.component_size(i) for i in range(union_find.size()))
    assert test_union_find.component_size(e2) == expected_component_size
    if not previously_connected:
        expected_component_count -= 1
    assert test_union_find.components_count() == expected_component_count
assert test_union_find.components_count() == 1

assert test_union_find.find('Z') == test_union_find.ELEMENT_NOT_FOUND
try:
    test_union_find.unify('A', 'Z')
except ValueError as e:
    print("Expected Value Error Message:", e)

print("Testing Union Find Array")
mapping = {e: i for i, e in enumerate(elements)}
try:
    failed_creation_empty_array = UnionFindArray(0)
except ValueError as e:
    print("Expected Value Error Message:", e)

test_union_find_array = UnionFindArray(len(elements))
assert not test_union_find_array.use_recursion
try:
    test_union_find_array.is_connected(-1, mapping['A'])
except ValueError as e:
    print("Expected Value Error Message:", e)

assert test_union_find_array.size() == len(elements)
assert test_union_find_array.components_count() == len(elements)
expected_component_count = len(elements)
for e1, e2, previously_connected, expected_component_size in join_list:
    assert test_union_find_array.is_connected(
def find_critical_and_pseudo_critical_edges(
        n: int, edges: List[List[int]]) -> List[List[int]]:
    """
    Progressively add edges, ranked in reverse order of edge weight
    Collapse original union_graph into union_graph of Unions after adding each batch of edges
    Run DFS algorithm to find critical connections in the Union Graph

    :param n: number of nodes in the original union_graph
    :param edges: list of edges in the original Weighted Undirected union_graph
    :return: [edges_critical, edges_pseudos]
    """

    _dfs_not_visited = -1

    def find_critical_connection(current_node: int,
                                 level: int = 0,
                                 previous_node: int = _dfs_not_visited) -> int:
        """
        Find critical connections in an union graph, among edges out of current_node by searching for cycles.
        Those not part of a cycle will be marked as critical

        :param current_node: finding critical edge out of current_node
        :param level: levels through DFS, used to find cycle
        :param previous_node: used to remember incoming path. Do NOT traverse the edge <previous_node, current_node>
        :return: level of current_node
        """
        levels[current_node] = level
        for child, edge_i in union_graph[current_node]:
            if child == previous_node:
                # do not go back from the incoming path
                continue
            elif levels[child] == _dfs_not_visited:
                levels[current_node] = min(
                    levels[current_node],
                    find_critical_connection(child, level + 1, current_node))
            else:
                levels[current_node] = min(levels[current_node], levels[child])
            if levels[child] >= level + 1 and edge_i not in edge_pseudos:
                # critical edge in current snapshot, as edge_i connecting current_node to child is not part of a cycle
                # no smaller weight edges has previously connection union_u and union_v
                # critical connection in the current snapshot will also be a critical edge in MST
                edge_critical.add(edge_i)
        return levels[current_node]

    # Initialize critical and pseudo-critical edge set
    edge_critical, edge_pseudos = set(), set()

    # use weight_distribution to break edges into weight classes
    weight_distribution = defaultdict(list)
    for i, (u, v, w) in enumerate(edges):
        weight_distribution[w].append((u, v, i))

    # define union find set
    union_set = UnionFindArray(n, use_recursion=True)

    # iterate through all weights in ascending order
    for weight_class in sorted(weight_distribution):
        # connections_between[(union_u, union_v)] contains all edges connecting union union_u and union_v,
        # where union_u and union_v are the previous_node nodes of their corresponding groups
        connections_between = defaultdict(set)
        # populate connections_between
        for u, v, i in weight_distribution[weight_class]:
            union_u, union_v = union_set.find(u), union_set.find(v)

            if union_u != union_v:
                # Skip the edge that creates cycle and links two already connected graphs
                # otherwise edge edge_i connects union_u and union_v
                connections_between[min(union_u, union_v),
                                    max(union_u, union_v)].add(i)

        # w_edges contains all edges of weight_class that we may add to MST
        # i.e. edges in w_edges either belong to edge_critical or edge_pseudos
        w_edges = []
        # construct a snapshot of current graph of Unions to run DFS on
        union_graph = defaultdict(list)
        for union_u, union_v in connections_between:
            # if exists more than 1 edge can connect union_u and union_v, then these edges are pseudo-critical
            if len(connections_between[union_u, union_v]) > 1:
                edge_pseudos |= connections_between[union_u, union_v]

            # Connect union_u and union_v in the Union Graph
            # using one edge only, to avoid cycles
            edge_idx = connections_between[union_u, union_v].pop()
            union_graph[union_u].append((union_v, edge_idx))
            union_graph[union_v].append((union_u, edge_idx))
            w_edges.append((union_u, union_v, edge_idx))
            union_set.unify(union_u, union_v)

        # run find_critical_connection to mark all critical w_edges
        levels = [_dfs_not_visited] * n
        for u, v, i in w_edges:
            if levels[u] == _dfs_not_visited:
                find_critical_connection(u)

        # the edges in w_edges cycles are pseudo-critical
        for u, v, i in w_edges:
            if i not in edge_critical:
                edge_pseudos.add(i)

    return [sorted(list(edge_critical)), sorted(list(edge_pseudos))]
Beispiel #8
0
def regions_by_slashes(grid: List[str]) -> int:
    """
    For each grid cell split into 4 sub cells: [0 - up, 1 - left, 2 - right, 3 - down]

    :param grid: 1 <= len(grid) == len(grid[0]) <= 30, and grid[i][j] in ('/', '\', ' ')
    :return: number of regions in the grid
    """
    n = len(grid)
    union_find_object = UnionFindArray(4 * n * n)
    for r, row_r in enumerate(grid):
        for c, cell_r_c in enumerate(row_r):
            cell_up_position = 4 * (r * n + c)
            if cell_r_c == '/':
                union_find_object.unify(cell_up_position + 0,
                                        cell_up_position + 1)
                union_find_object.unify(cell_up_position + 2,
                                        cell_up_position + 3)
            elif cell_r_c == '\\':
                union_find_object.unify(cell_up_position + 0,
                                        cell_up_position + 2)
                union_find_object.unify(cell_up_position + 1,
                                        cell_up_position + 3)
            else:
                union_find_object.unify(cell_up_position + 0,
                                        cell_up_position + 1)
                union_find_object.unify(cell_up_position + 0,
                                        cell_up_position + 2)
                union_find_object.unify(cell_up_position + 0,
                                        cell_up_position + 3)

            # Connect up
            if r > 0:
                union_find_object.unify(cell_up_position + 0,
                                        cell_up_position - 4 * n + 3)
            # Connect down
            if r + 1 < n:
                union_find_object.unify(cell_up_position + 3,
                                        cell_up_position + 4 * n + 0)
            # Connect left
            if c > 0:
                union_find_object.unify(cell_up_position + 1,
                                        cell_up_position - 4 + 2)
            # Connect right
            if c + 1 < n:
                union_find_object.unify(cell_up_position + 2,
                                        cell_up_position + 4 + 1)

    return union_find_object.components_count()