def run_cyclic_graph(args): """ Test for sketching Cyclic graph, and some other sanity checks. """ #e.g. test if Lx and Ly are the same, then dist is very small args.m = 8 args.n = 4 ''' if args.fix_seed: torch.manual_seed(0) #args.Lx = torch.randn(args.m*(args.m-1)//2) #torch.FloatTensor([[1, -1], [-1, 2]]) #args.Lx = realize_upper(args.Lx, args.m) #pdb.set_trace() g = nx.stochastic_block_model([6,6],[[0.9,0.1],[0.1,0.9]], seed = 8576) #components = nx.connected_components(g) g.remove_nodes_from(list(nx.isolates(g))) args.m = len(g) Lx = nx.laplacian_matrix(g, range(args.m)).todense() args.Lx = torch.from_numpy(Lx).to(dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m ''' ones = torch.ones((args.m, args.m), dtype=torch.uint8) Lx = torch.zeros((args.m, args.m)) Lx[torch.triu(ones, diagonal=1) & torch.tril(ones, diagonal=1)] = -1 Lx[0, -1] = -1 Lx += Lx.t() Lx[torch.eye(args.m) > 0] = -Lx.sum(0) args.Lx = Lx args.n_epochs = 370 #370 graph.graph_dist(args)
def test_FGW(args): """ Fused Gromov-Wasserstein distance """ args.m = 8 args.n = 4 if args.fix_seed: torch.manual_seed(0) g = nx.stochastic_block_model([4, 4], [[0.9, 0.1], [0.1, 0.9]], seed=8576) #components = nx.connected_components(g) g.remove_nodes_from(list(nx.isolates(g))) args.m = len(g) g2 = nx.stochastic_block_model([4, 4], [[0.9, 0.1], [0.1, 0.9]]) #components = nx.connected_components(g) g2.remove_nodes_from(list(nx.isolates(g2))) args.n = len(g2) gwdist = Fused_Gromov_Wasserstein_distance(alpha=0.8, features_metric='sqeuclidean') graph.graph_dist(args) ''' g = gwGraph.Graph(g) g2 = gwGraph.Graph(g2) dist = gwdist.graph_d(g,g2) print('GW dist ', dist) ''' pdb.set_trace()
def test_FGW(args): """ Fused Gromov-Wasserstein distance """ import lib.graph as gwGraph from lib.ot_distances import Fused_Gromov_Wasserstein_distance args.m = 8 args.n = 4 if args.fix_seed: torch.manual_seed(0) #args.Lx = torch.randn(args.m*(args.m-1)//2) #torch.FloatTensor([[1, -1], [-1, 2]]) #args.Lx = realize_upper(args.Lx, args.m) #pdb.set_trace() g = nx.stochastic_block_model([4,4],[[0.9,0.1],[0.1,0.9]], seed = 8576) #components = nx.connected_components(g) g.remove_nodes_from(list(nx.isolates(g))) args.m = len(g) Lx = nx.laplacian_matrix(g, range(args.m)).todense() args.Lx = torch.from_numpy(Lx).to(dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m args.n_epochs = 150 ''' g2 = nx.stochastic_block_model([4,4],[[0.9,0.1],[0.1,0.9]]) g2.remove_nodes_from(list(nx.isolates(g2))) args.n = len(g2) ''' loss, P, L = graph.graph_dist(args, plot=False) if isinstance(L, torch.Tensor): L = L.numpy() np.fill_diagonal(L, 0) A = -L g2 = nx.from_numpy_array(A) gwdist = Fused_Gromov_Wasserstein_distance(alpha=0.8,features_metric='sqeuclidean') g = gwGraph.Graph(g) g2 = gwGraph.Graph(g2) dist = gwdist.graph_d(g,g2) print('GW dist ', dist) ### g3 = nx.stochastic_block_model([4,4],[[0.9,0.1],[0.1,0.9]],seed=452) g3.remove_nodes_from(list(nx.isolates(g3))) args.m = len(g3) Lx = nx.laplacian_matrix(g3, range(args.m)).todense() args.Lx = torch.from_numpy(Lx).to(dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m loss2, P2, L2 = graph.graph_dist(args, plot=False) L=L2 if isinstance(L, torch.Tensor): L = L.numpy() np.fill_diagonal(L, 0) A = -L g4 = nx.from_numpy_array(A) #gwdist = Fused_Gromov_Wasserstein_distance(alpha=0.8,features_metric='sqeuclidean') g3 = gwGraph.Graph(g3) g4 = gwGraph.Graph(g4) dist = gwdist.graph_d(g3,g4) print('GW dist ', dist) pdb.set_trace()
def run_same_dim(args): """ When m = n, Ly converges to Lx and P converges to identity mx. """ args.Lx = torch.eye(args.m)*torch.abs(torch.randn((args.m, args.m)))*2 #utils.symmetrize(torch.randn((args.m, args.m))) args.m = 5 args.n = 5 args.Lx = torch.randn(args.m*(args.m-1)//2) #torch.FloatTensor([[1, -1], [-1, 2]]) args.Lx = graph.realize_upper(args.Lx, args.m) #args.Lx = torch.exp(torch.FloatTensor([[2, -2], [-2, 1]])) #good initializations?! checks & stability args.n_epochs = 280 graph.graph_dist(args) return
def run_community_graph(args): #test if Lx and Ly are the same, then dist should be small! #laplacian make integral at end. Inverse is often quite small for images! ?? leading to tiny evals. even neg #args.Lx = torch.eye(args.m)*torch.abs(torch.randn((args.m, args.m)))*2 #utils.symmetrize(torch.randn((args.m, args.m))) args.m = 12 args.n = 4 if args.fix_seed: torch.manual_seed(0) g = nx.stochastic_block_model([6, 6], [[0.9, 0.1], [0.1, 0.9]], seed=8576) #components = nx.connected_components(g) g.remove_nodes_from(list(nx.isolates(g))) args.m = len(g) Lx = nx.laplacian_matrix(g, range(args.m)).todense() args.Lx = torch.from_numpy(Lx).to( dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m args.n_epochs = 370 #100 graph.graph_dist(args)
def run_community_graph(args): args.m = 12 args.n = 4 if args.fix_seed: torch.manual_seed(0) #args.Lx = torch.randn(args.m*(args.m-1)//2) #torch.FloatTensor([[1, -1], [-1, 2]]) #args.Lx = realize_upper(args.Lx, args.m) #pdb.set_trace() g = nx.stochastic_block_model([6,6],[[0.9,0.1],[0.1,0.9]], seed = 8576) #components = nx.connected_components(g) g.remove_nodes_from(list(nx.isolates(g))) args.m = len(g) Lx = nx.laplacian_matrix(g, range(args.m)).todense() args.Lx = torch.from_numpy(Lx).to(dtype=torch.float32) #+ torch.ones(args.m, args.m)/args.m args.n_epochs = 370 #100 graph.graph_dist(args)
def sketch_graph(args): data_dim = 20 lo_dim = 5 g1 = utils.create_graph(data_dim, 'random_regular') #args.n_epochs = 300 <--parameters like this can be set here or in command line args.Lx = utils.graph_to_lap(g1) args.m = len(args.Lx) args.n = lo_dim # sketch graphs of lo_dim. # Returns optimization loss, transport plan P, and Laplacian of sketched graph loss, P, Ly = graph.graph_dist(args, plot=False) print('sketched graph Laplacian {}'.format(Ly)) #can convert Ly to a networkx graph with utils.lap_to_graph(Ly) return loss, P, Ly
def sketch_graph(graphs, lo_dim, args): ''' Run graph sketching. Input: graphs: graphs to be dimension-reduced for.. ''' args.n = lo_dim lo_graphs = [] args.n_epochs = 230 for g in tqdm(graphs, desc='sketching'): args.Lx = utils.graph_to_lap(g) #graph.graph_dist(args, plot=False) args.m = len(args.Lx) #sys.stdout.write(' ' +str(len(g.nodes()))) #sys.stdout.write(str(args.m) +' ') loss, P, Ly = graph.graph_dist(args, plot=False) lo_graphs.append(utils.lap_to_graph(Ly)) return lo_graphs
def sketch_graph(graphs, dataset_cls, lo_dim, args): ''' Run graph sketching. Input: graphs: graphs to be dimension-reduced for. ''' args.n = lo_dim lo_graphs = [] lo_cls = [] args.n_epochs = 230 #250 for i, g in enumerate(tqdm(graphs, desc='sketching')): args.Lx = utils.graph_to_lap(g) #graph.graph_dist(args, plot=False) args.m = len(args.Lx) try: #rarely, 0.2% of time pytorch's eigenvalue finding doesn't converge loss, P, Ly = graph.graph_dist(args, plot=False) except RuntimeError as e: # pdb.set_trace() print(e) continue lo_graphs.append(utils.lap_to_graph(Ly)) lo_cls.append(dataset_cls[i]) return lo_graphs, lo_cls
def classify_st(dataset, queries, dataset_cls, target, args, dataset0=None, queries0=None): """ classify graphs. Can be used to compare COPT, GOT. Input: dataset, queries: could be sketched or non-sketched. dataset0, queries0: original, non-sketched graphs. """ if dataset0 is None: dataset0 = dataset queries0 = queries n_data = len(dataset) n_queries = len(queries) ot_cost = np.zeros((len(queries), len(dataset))) st_cost = np.zeros((len(queries), len(dataset))) Ly_mx = [] Lx_mx = [] data_graphs = [] for i, data in enumerate(dataset): n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) #Ly_mx.append(L[torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) > 0]) Ly_mx.append(L) #pdb.set_trace() for i, q in enumerate(tqdm(queries, desc='queries')): Lx = utils.graph_to_lap(q) args.Lx = Lx args.m = len(q.nodes()) Lx_mx.append(args.Lx) n_repeat = 1 #1 works fine for j, data in enumerate(dataset): Ly = Ly_mx[j].clone() args.n = len(Ly) min_loss = 10000 for _ in range(n_repeat): loss, P, Ly_ = graph.graph_dist(args, plot=False, Ly=Ly, take_ly_exp=False) if loss < min_loss: min_loss = loss ot_cost[i][j] = min_loss try: x_reg, y_reg, (P_st, loss_st) = st.find_permutation( Lx.cpu().numpy(), Ly.cpu().numpy(), args.st_it, args.st_tau, args.st_n_samples, args.st_epochs, args.st_lr, loss_type='w', alpha=0, ones=True, graphs=True) #l2 except Exception: print('Exception encountered during GOT') #pdb.set_trace() st_cost[i][j] = loss_st ##can also try median, or dataset_cls[np.argsort(ot_cost[-8],-1)[:10]], or dataset_cls[np.argpartition(ot_cost[6],10)[:10]] ot_cost_ = torch.from_numpy(ot_cost) #for combined, can add dist here ot_cost_ranks = torch.argsort(ot_cost_, -1)[:, :args.n_per_cls] ones = torch.ones(100) #args.n_per_cls*2 (n_cls*2) ot_cls = np.ones(n_queries) dataset_cls_t = torch.from_numpy(dataset_cls) for i in range(n_queries): #for each cls cur_ranks = dataset_cls_t[ot_cost_ranks[i]] ranked = torch.zeros(100) #n_cls*2 ranked.scatter_add_(src=ones, index=cur_ranks, dim=-1) ot_cls[i] = torch.argmax(ranked).item() ot_cost_means = np.mean(ot_cost.reshape(n_queries, n_data // args.n_per_cls, args.n_per_cls), axis=-1) ot_idx = np.argmin(ot_cost_means, axis=-1) * args.n_per_cls st_cost_means = np.mean(st_cost.reshape(n_queries, n_data // args.n_per_cls, args.n_per_cls), axis=-1) st_idx = np.argmin(st_cost_means, axis=-1) * args.n_per_cls ot_cls1 = dataset_cls[ot_idx] st_cls = dataset_cls[st_idx] ot_acc, ot_acc1 = np.equal(ot_cls, target).sum() / len(target), np.equal( ot_cls1, target).sum() / len(target) st_acc = np.equal(st_cls, target).sum() / len(target) print('ot acc1 {} ot acc {} st acc {}'.format(ot_acc1, ot_acc, st_acc)) return
def perm_mi(args): ''' Remove edges, permute, align, then measure MI. ''' args.n_epochs = 1000 params = {'n_blocks': 4} use_given_graph = False if use_given_graph: #True:#False: #True: g = torch.load('mi_g_.pt') else: seed = 0 if args.fix_seed else None g = utils.create_graph(40, gtype='block', params=params, seed=seed) #torch.save(g, 'mi_g.pt') orig_cls = [] for i in range(4): orig_cls.extend([i for _ in range(10)]) orig_cls = np.array(orig_cls) Lg = utils.graph_to_lap(g) args.Lx = Lg.clone() args.m = len(Lg) #remove edges and permute n_remove = args.n_remove #150 rand_seed = 0 if args.fix_seed else None Lg_removed = utils.remove_edges(Lg, n_remove=n_remove, seed=rand_seed) Lg_perm, perm = utils.permute_nodes(Lg_removed.numpy(), seed=rand_seed) inv_perm = np.empty(args.m, perm.dtype) inv_perm[perm] = np.arange(args.m) ##Ly = torch.from_numpy(Lg_perm) Ly = torch.from_numpy(Lg_perm) #Lg_removed.clone() #args.Lx.clone() args.n = len(Ly) #8 st_n_samples worked best, 5 sinkhorn iter, 1 as tau #align time0 = time.time() loss, P, Ly_ = graph.graph_dist(args, plot=False, Ly=Ly, take_ly_exp=False) dur_ot = time.time() - time0 orig_idx = P.argmax(-1).cpu().numpy() perm_mx = False if perm_mx: P_max = P.max(-1, keepdim=True)[0] P[P < P_max - .1] = 0 P[P > 0] = 1 new_cls = orig_cls[perm][orig_idx].reshape(-1) mi = utils.normalizedMI(orig_cls, new_cls) #return mi Lx = args.Lx time0 = time.time() x_reg, y_reg, (P_st, loss_st) = st.find_permutation(Ly.cpu().numpy(), Lx.cpu().numpy(), args.st_it, args.st_tau, args.st_n_samples, args.st_epochs, args.st_lr, loss_type='w', alpha=0, ones=True, graphs=True) dur_st = time.time() - time0 orig_idx = P_st.argmax(-1) new_cls_st = orig_cls[perm][orig_idx].reshape(-1) mi_st = utils.normalizedMI(orig_cls, new_cls_st) #print('{} COPT {} GOT {} dur ot {} dur st {}'.format(n_remove, mi, mi_st, dur_ot, dur_st)) print('{} {} {} {} {}'.format(n_remove, mi, mi_st, dur_ot, dur_st)) return mi
def classify(dataset, queries, dataset_cls, target, args, dataset0=None, queries0=None): """ classification tasks using various methods. dataset0, queries0 are original, non-sketched graphs. dataset, queries contain sketched graphs. """ if dataset0 is None: dataset0 = dataset queries0 = queries #with open(args.graph_fname, 'rb') as f: # graphs = pickle.read(f) n_data = len(dataset) n_queries = len(queries) ot_cost = np.zeros((len(queries), len(dataset))) netlsd_cost = np.zeros((len(queries), len(dataset))) Ly_mx = [] Lx_mx = [] data_graphs = [] heat_l = [] #avg_deg = 0 for i, data in enumerate(dataset): #pdb.set_trace() if isinstance(data, torch.Tensor): L = data else: n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) avg_deg = (L.diag().mean()) L /= avg_deg #Ly_mx.append(L[torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) > 0]) Ly_mx.append(L) #pdb.set_trace() heat_l.append(netlsd.heat(L.numpy())) #avg_deg /= len(dataset) for i, q in enumerate(tqdm(queries, desc='queries')): '''### if isinstance(data, torch.Tensor): L = data else: n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) ''' Lx = utils.graph_to_lap(q) avg_deg = (Lx.diag().mean()) Lx /= avg_deg args.Lx = Lx args.m = len(q.nodes()) q_heat = netlsd.heat(Lx.numpy()) Lx_mx.append(args.Lx) for j, data in enumerate(dataset): Ly = Ly_mx[j].clone() args.n = len(Ly) min_loss = 10000 for _ in range(1): loss, P, Ly_ = graph.graph_dist(args, plot=False, Ly=Ly, take_ly_exp=False) #pdb.set_trace() if loss < min_loss: min_loss = loss ot_cost[i][j] = min_loss netlsd_cost[i][j] = netlsd.compare(q_heat, heat_l[j]) if args.dataset_type == 'real': ot_cost1 = (ot_cost - ot_cost.mean()) / np.std(ot_cost) ot_pred = ot_cost.argmin(1) ot_acc00 = np.equal(dataset_cls[ot_pred], target).sum() / len(target) print('OT ACC |{} '.format(ot_acc00)) ot_sorted = np.argsort(ot_cost, axis=-1) #pdb.set_trace() ot_cls = dataset_cls[ot_sorted[:, :3]].tolist() combine_pred = np.zeros(len(target)) for i, ot_c in enumerate(ot_cls): counter = collections.Counter() counter.update(ot_c) #pdb.set_trace() common = counter.most_common(1)[0][0] combine_pred[i] = common combine_acc = np.equal(combine_pred, target).sum() / len(target) #pdb.set_trace() ### ot_pred = ot_cost.argmin(1) ot_acc = np.equal(dataset_cls[ot_pred], target).sum() / len(target) netlsd_pred = netlsd_cost.argmin(1) netlsd_acc = np.equal(dataset_cls[netlsd_pred], target).sum() / len(target) print('OT ACC |{} '.format(ot_acc)) return ot_acc00, netlsd_acc ot_cost_ = torch.from_numpy(ot_cost) #for combined, can add dist here ot_cost_ranks = torch.argsort(ot_cost_, -1)[:, :args.n_per_cls] ones = torch.ones(args.n_per_cls * 3) #args.n_per_cls*2 (n_cls*2) 100 ot_cls = np.ones(n_queries) combine_cls = np.ones(n_queries) dataset_cls_t = torch.from_numpy(dataset_cls) #pdb.set_trace() for i in range(n_queries): #for each cls cur_ranks_ot = dataset_cls_t[ot_cost_ranks[i]] ranked = torch.zeros(100) #n_cls*2 ranked.scatter_add_(src=ones, index=cur_ranks_ot, dim=-1) ot_cls[i] = torch.argmax(ranked).item() ot_cost_means = np.mean(ot_cost.reshape(n_queries, n_data // args.n_per_cls, args.n_per_cls), axis=-1) ot_idx = np.argmin(ot_cost_means, axis=-1) * args.n_per_cls print('ot_cost mx ', ot_cost) ot_cls1 = dataset_cls[ot_idx] ot_acc, ot_acc1 = np.equal(ot_cls, target).sum() / len(target), np.equal( ot_cls1, target).sum() / len(target) print('ot acc1 {} ot acc {} '.format(ot_acc1, ot_acc))