示例#1
0
def test_FGW(args):
    """
    Fused Gromov-Wasserstein distance
    """
    import lib.graph as gwGraph
    from lib.ot_distances import Fused_Gromov_Wasserstein_distance
    args.m = 8
    args.n = 4
    if args.fix_seed:
        torch.manual_seed(0)
    #args.Lx = torch.randn(args.m*(args.m-1)//2)  #torch.FloatTensor([[1, -1], [-1, 2]])
    #args.Lx = realize_upper(args.Lx, args.m)
    #pdb.set_trace()    
    g = nx.stochastic_block_model([4,4],[[0.9,0.1],[0.1,0.9]], seed = 8576)
    #components = nx.connected_components(g)
    g.remove_nodes_from(list(nx.isolates(g)))
    args.m = len(g)
    Lx = nx.laplacian_matrix(g, range(args.m)).todense()
    args.Lx = torch.from_numpy(Lx).to(dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m
    args.n_epochs = 150
    '''
    g2 = nx.stochastic_block_model([4,4],[[0.9,0.1],[0.1,0.9]])    
    g2.remove_nodes_from(list(nx.isolates(g2)))
    args.n = len(g2)
    '''
    loss, P, L = graph.graph_dist(args, plot=False)
    if isinstance(L, torch.Tensor):
        L = L.numpy()
    np.fill_diagonal(L, 0)
    A = -L
    g2 = nx.from_numpy_array(A)
    
    gwdist = Fused_Gromov_Wasserstein_distance(alpha=0.8,features_metric='sqeuclidean')
    g = gwGraph.Graph(g)
    g2 = gwGraph.Graph(g2)    
    dist = gwdist.graph_d(g,g2)    
    print('GW dist ', dist)   

    ###
    g3 = nx.stochastic_block_model([4,4],[[0.9,0.1],[0.1,0.9]],seed=452)    
    g3.remove_nodes_from(list(nx.isolates(g3)))
    args.m = len(g3)
    Lx = nx.laplacian_matrix(g3, range(args.m)).todense()
    args.Lx = torch.from_numpy(Lx).to(dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m    
    loss2, P2, L2 = graph.graph_dist(args, plot=False)
    L=L2
    if isinstance(L, torch.Tensor):
        L = L.numpy()
    np.fill_diagonal(L, 0)
    A = -L
    g4 = nx.from_numpy_array(A)
    
    #gwdist = Fused_Gromov_Wasserstein_distance(alpha=0.8,features_metric='sqeuclidean')
    g3 = gwGraph.Graph(g3)
    g4 = gwGraph.Graph(g4)    
    dist = gwdist.graph_d(g3,g4)    
    print('GW dist ', dist)   
    
    pdb.set_trace()
示例#2
0
def post_process_graph(graph_dict):
    save_dir = os.path.join(
        cfg.DIR.SAVE_GRAPH_DIR, '{}_{}'.format(cfg.TEST.CKPT,
                                               cfg.TEST.NUM_TARGETS), 'post')
    os.makedirs(save_dir, exist_ok=True)
    for region_name, g in graph_dict.items():
        bad_edges = set()
        road_segments, _ = graph_helper.get_graph_road_segments(g)
        for rs in road_segments:
            if rs.marked_length < 2 * cfg.TEST.STEP_LENGTH and \
                    (len(rs.src(g).in_edges_id) <= 1 or len(rs.dst(g).in_edges_id) <= 1):
                for edge in rs.edges(g):
                    bad_edges.add(edge)
        ng = graph_helper.Graph()
        seen_pnts = dict()
        for edge in g.edges.values():
            if edge in bad_edges:
                continue
            if edge.src(g).point == edge.dst(g).point:
                continue
            src_dst = []
            for pnt in [edge.src(g).point, edge.dst(g).point]:
                if pnt not in seen_pnts:
                    v = ng.add_vertex(pnt)
                    seen_pnts[pnt] = v.id
                src_dst.append(seen_pnts[pnt])
            ng.add_edge(src_dst[0], src_dst[1])
        ng.save(os.path.join(save_dir, '{}.graph'.format(region_name)),
                clear_self=False)
示例#3
0
def buildgraph(matrix):
    g = graph.Graph(edge_values=defaultdict(int))
    for i in range(len(matrix)):
        for j in range(len(matrix[0])):
            g.add_node((i, j), matrix[i][j])
            if i + 1 < len(matrix):
                g.add_edge((i, j), (i + 1, j))
            if j + 1 < len(matrix[0]):
                g.add_edge((i, j), (i, j + 1))
    return g
    def search(self, request):
        # process lexems
        start = time.process_time()
        raw = re.split("\W+", request.rstrip())
        combs = []
        for i in range(len(raw)):
            combs.append(Coordinated_Dictionary.process_word(raw[i]))

        # check for 0- and 1-words-long request
        if not len(combs):
            return []
        if len(combs) == 1:
            if combs[0] in self.__reversed_index.keys():
                return [
                    self.__files[x] for x in self.__reversed_index[combs[0]]
                ]
            return []
        if combs[0] not in self.__reversed_index.keys():
            return []
        res = self.__reversed_index[combs[0]].keys()
        for i in range(1, len(combs)):
            if combs[i] not in self.__reversed_index.keys():
                return []
            res = res & self.__reversed_index[combs[i]].keys()
            if not len(res):
                return []
            elif len(res) == 1:
                return [self.__files[x] for x in res]

        # Build a weighted unorinted graph, in which the indexes and keywords are nodes
        sorted = {}
        for k in res:
            g = graph.Graph()
            for w in combs:
                g.add_node(w)
                for i in self.__reversed_index[w][k]:
                    g.add_node(i)
                    g.add_edge(w, i, 0)
                for i in g.nodes:
                    if isinstance(i, (int)):
                        for j in g.nodes:
                            if isinstance(j, int):
                                g.add_edge(i, j, abs(i - j))

            #calculate shortest path, between every word in every graph
            sum = 0
            for i in range(len(combs) - 1):
                sum += graph.shortest_path(g, combs[i], combs[i + 1:])
            sorted[k] = sum

        # sort by distance between keywords
        temp = list(res)
        is_sorted = False
        while not is_sorted:
            is_sorted = True
            for i in range(len(temp) - 1):
                if sorted[temp[i]] > sorted[temp[i + 1]]:
                    is_sorted = False
                    temp[i], temp[i + 1] = temp[i + 1], temp[i]

        # print(sorted)
        end = time.process_time()
        print("Search took", end - start, "s")
        return [self.__files[x] for x in temp]