Beispiel #1
0
    def test_gasp_abs_max(self):
        clusterPolicy = nagglo.get_GASP_policy(
            graph=self.g,
            signed_edge_weights=self.edgeIndicators,
            linkage_criteria='abs_max')

        agglomerativeClustering = nagglo.agglomerativeClustering(clusterPolicy)
        agglomerativeClustering.run()
        seg = agglomerativeClustering.result()
        self.assertTrue(seg[0] != seg[1] and seg[0] == seg[2] and seg[0] == seg[3])
Beispiel #2
0
    def test_gasp_sum(self):
        clusterPolicy = nagglo.get_GASP_policy(
            graph=self.g,
            signed_edge_weights=self.edgeIndicators,
            linkage_criteria='sum',
            add_cannot_link_constraints=False)

        agglomerativeClustering = nagglo.agglomerativeClustering(clusterPolicy)
        agglomerativeClustering.run()
        seg = agglomerativeClustering.result().tolist()
        self.assertTrue(seg[0] != seg[1] and seg[0] == seg[2] and seg[0] == seg[3])
Beispiel #3
0
def run_GASP(
        graph,
        signed_edge_weights,
        linkage_criteria='mean',
        add_cannot_link_constraints= False,
        edge_sizes=None,
        is_mergeable_edge=None,
        use_efficient_implementations=True,
        verbose=False,
        linkage_criteria_kwargs=None,
        print_every=100000):
    """
    Run the Generalized Algorithm for Agglomerative Clustering on Signed Graphs (GASP).
    The C++ implementation is currently part of the nifty library (https://github.com/abailoni/nifty).

    Parameters
    ----------
    graph : nifty.graph
        Instance of a graph, e.g. nifty.graph.UndirectedGraph, nifty.graph.undirectedLongRangeGridGraph or
        nifty.graph.rag.gridRag

    signed_edge_weights : numpy.array(float) with shape (nb_graph_edges, )
        Attractive weights are positive; repulsive weights are negative.

    linkage_criteria : str (default 'mean')
        Specifies the linkage criteria / update rule used during agglomeration.
        List of available criteria:
            - 'mean', 'average', 'avg'
            - 'max', 'single_linkage'
            - 'min', 'complete_linkage'
            - 'mutex_watershed', 'abs_max'
            - 'sum'
            - 'quantile', 'rank' keeps statistics in a histogram, with parameters:
                    * q : float (default 0.5 equivalent to the median)
                    * numberOfBins: int (default: 40)
            - 'generalized_mean', 'gmean' with parameters:
                    * p : float (default: 1.0)
                    * https://en.wikipedia.org/wiki/Generalized_mean
            - 'smooth_max', 'smax' with parameters:
                    * p : float (default: 0.0)
                    * https://en.wikipedia.org/wiki/Smooth_maximum

    add_cannot_link_constraints : bool

    edge_sizes : numpy.array(float) with shape (nb_graph_edges, )
        Depending on the linkage criteria, they can be used during the agglomeration to weight differently
        the edges  (e.g. with sum or avg linkage criteria). Commonly used with regionAdjGraphs when edges
        represent boundaries of different length between segments / super-pixels. By default, all edges have
        the same weighting.

    is_mergeable_edge : numpy.array(bool) with shape (nb_graph_edges, )
        Specifies if an edge can be merged or not. Sometimes some edges represent direct-neighbor relations
        and others describe long-range connections. If a long-range connection /edge is assigned to
        `is_mergeable_edge = False`, then the two associated nodes are not merged until they become
        direct neighbors and they get connected in the image-plane.
        By default all edges are mergeable.

    use_efficient_implementations : bool (default: True)
        In the following special cases, alternative efficient implementations are used:
            - 'abs_max' criteria: Mutex Watershed (https://github.com/hci-unihd/mutex-watershed.git)
            - 'max' criteria without cannot-link constraints: maximum spanning tree

    verbose : bool (default: False)

    linkage_criteria_kwargs : dict
        Additional optional parameters passed to the chosen linkage criteria (see previous list)

    print_every : int (default: 100000)
        After how many agglomeration iteration to print in verbose mode

    Returns
    -------
    node_labels : numpy.array(uint) with shape (nb_graph_nodes, )
        Node labels representing the final clustering

    runtime : float
    """

    if use_efficient_implementations and (linkage_criteria in ['mutex_watershed', 'abs_max'] or
                                          (linkage_criteria == 'max' and not add_cannot_link_constraints)):
        if is_mergeable_edge is not None:
            if not is_mergeable_edge.all():
                print("WARNING: Efficient implementations only works when all edges are mergeable")
            # assert is_mergeable_edge.all(), "Efficient implementations only works when all edges are mergeable"
        # assert is_mergeable_edge is None, "Efficient implementations only works when all edges are mergeable"
        nb_nodes = graph.numberOfNodes
        uv_ids = graph.uvIds()
        mutex_edges = signed_edge_weights < 0.

        tick = time.time()
        # These implementations use the convention where all edge weights are positive
        assert aff_segm is not None, "For the efficient implementation of GASP, affogato module is needed"
        if linkage_criteria in ['mutex_watershed', 'abs_max']:
            node_labels = aff_segm.compute_mws_clustering(nb_nodes,
                                             uv_ids[np.logical_not(mutex_edges)],
                                             uv_ids[mutex_edges],
                                             signed_edge_weights[np.logical_not(mutex_edges)],
                                             -signed_edge_weights[mutex_edges])
        else:
            node_labels = aff_segm.compute_single_linkage_clustering(nb_nodes,
                                                        uv_ids[np.logical_not(mutex_edges)],
                                                        uv_ids[mutex_edges],
                                                        signed_edge_weights[np.logical_not(mutex_edges)],
                                                        -signed_edge_weights[mutex_edges])
        runtime = time.time() - tick
    else:
        cluster_policy = nifty_agglo.get_GASP_policy(graph, signed_edge_weights,
                                                     edge_sizes=edge_sizes,
                                                     linkage_criteria=linkage_criteria,
                                                     linkage_criteria_kwargs=linkage_criteria_kwargs,
                                                     add_cannot_link_constraints=add_cannot_link_constraints,
                                                     is_mergeable_edge=is_mergeable_edge)
        agglomerativeClustering = nifty_agglo.agglomerativeClustering(cluster_policy)

        # Run clustering:
        tick = time.time()
        agglomerativeClustering.run(verbose=verbose,
                                    printNth=print_every)
        runtime = time.time() - tick

        # Collect results:
        node_labels = agglomerativeClustering.result()
    return node_labels, runtime
def runGreedyGraphEdgeContraction(graph,
                                  signed_edge_weights,
                                  linkage_criteria='mean',
                                  add_cannot_link_constraints=False,
                                  edge_sizes=None,
                                  node_sizes=None,
                                  is_merge_edge=None,
                                  size_regularizer=0.0,
                                  return_UCM=False,
                                  return_agglomeration_data=False,
                                  ignored_edge_weights=None,
                                  **run_kwargs):
    """
    :param ignored_edge_weights: boolean array, if an edge label is True, than the passed signed weight is ignored
            (neither attractive nor repulsive)

    Returns node_labels and runtime. If return_UCM == True, then also returns the UCM and the merging iteration for
    every edge.
    """
    raise DeprecationWarning("use version in GASP repo instead")
    # Legacy:
    if "update_rule" in run_kwargs:
        update_rule = run_kwargs.pop("update_rule")
    else:
        update_rule = linkage_criteria

    if update_rule == 'mutex_watershed' or (update_rule == 'max' and
                                            not add_cannot_link_constraints):
        # if False:
        assert not return_UCM
        if is_merge_edge is not None:
            if not is_merge_edge.all():
                print(
                    "WARNING: Efficient implementations only works when all edges are mergeable"
                )
        # In this case we use the efficient MWS clustering implementation in affogato:
        nb_nodes = graph.numberOfNodes
        uv_ids = graph.uvIds()
        mutex_edges = signed_edge_weights < 0.

        # if is_merge_edge is not None:
        #     # If we have edges labelled as lifted, they should all be repulsive in this implementation!
        #     if not is_merge_edge.min():
        #         assert all(is_merge_edge == np.logical_not(mutex_edges)), "Affogato MWS cannot enforce local merges!"

        tick = time.time()
        # This function will sort the edges in ascending order, so we transform all the edges to negative values
        if update_rule == 'mutex_watershed':
            nodeSeg = compute_mws_clustering(
                nb_nodes, uv_ids[np.logical_not(mutex_edges)],
                uv_ids[mutex_edges],
                signed_edge_weights[np.logical_not(mutex_edges)],
                -signed_edge_weights[mutex_edges])
        else:
            nodeSeg = compute_single_linkage_clustering(
                nb_nodes, uv_ids[np.logical_not(mutex_edges)],
                uv_ids[mutex_edges],
                signed_edge_weights[np.logical_not(mutex_edges)],
                -signed_edge_weights[mutex_edges])
        runtime = time.time() - tick
        out_dict = {'runtime': runtime}

        return nodeSeg, out_dict
    else:
        # FIXME: temporary fix for the sum rule
        # if update_rule == 'sum':
        #     signed_edge_weights *= edge_sizes

        cluster_policy = nagglo.get_GASP_policy(
            graph,
            signed_edge_weights,
            edge_sizes=edge_sizes,
            linkage_criteria=update_rule,
            linkage_criteria_kwargs=None,
            add_cannot_link_constraints=add_cannot_link_constraints,
            node_sizes=node_sizes,
            is_mergeable_edge=is_merge_edge,
            size_regularizer=size_regularizer,
        )
        agglomerativeClustering = nagglo.agglomerativeClustering(
            cluster_policy)

        out_dict = {}

        tick = time.time()
        if not return_UCM:
            agglomerativeClustering.run(**run_kwargs)
        else:
            # TODO: add run_kwargs with UCM
            outputs = agglomerativeClustering.runAndGetMergeTimesAndDendrogramHeight(
                verbose=False)
            mergeTimes, UCM = outputs
            out_dict['UCM'] = UCM
            out_dict['mergeTimes'] = mergeTimes

        runtime = time.time() - tick

        nodeSeg = agglomerativeClustering.result()
        out_dict['runtime'] = runtime
        if return_agglomeration_data:
            out_dict['agglomeration_data'], out_dict[
                'edge_data_contracted_graph'] = cluster_policy.exportAgglomerationData(
                )
        return nodeSeg, out_dict