def view2d(dataset_lo, dataset_cls_lo, args): fig = plt.figure() #ax = fig.add_subplot(111, projection='3d') ax = plt.gca() #ax.scatter(xs, ys, zs, c=c, marker=m) n_data = len(dataset_lo) Ly_mx = [] ones = torch.ones(3, 3).triu(diagonal=1) for i, data in enumerate(dataset_lo): #n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) #Ly_mx.append(L[torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) > 0]) L = L[ones > 0][:2] #.view(-1) L = torch.sort(L, -1)[0] #torch.topk(L,k=2,dim=-1)[0] Ly_mx.append(L) cls2c = {0: 'r', 1: 'b', 2: 'g', 3: 'c', 4: 'm', 5: 'k', 6: 'y'} cls2label = { 0: 'block2', 1: 'rand_reg', 2: 'barabasi', 3: 'block3', 4: 'block4', 5: 'k', 6: 'y' } c_l = [] for i in range(n_data): c_l.append(cls2c[dataset_cls_lo[i]]) ar = torch.stack(Ly_mx) #pdb.set_trace() #ar = (ar/torch.norm(ar, 2, dim=-1, keepdim=True) ).t().numpy() #.transpose() #.t().numpy() ar = (ar).t().numpy() #.transpose() #.t().numpy() ax.set_ylim(-7, ar[0].max()) # #ax.set_zlim3d(-15, ar[1].max()) #ax.set_xlim3d(-20, ar[2].max()) ax.set_xlim(-25, ar[1].max()) range_ar = np.array(list(range(n_data))) #ax.scatter(ar[1], ar[0], c=c_l) for i in range(5): #if i == 4: # continue idx = range_ar[dataset_cls_lo == i] #pdb.set_trace() ax.scatter(ar[0][idx], ar[1][idx], c=cls2c[i], label=cls2label[i]) #, c=c, marker=m) #''' ax.legend() path = 'data/projection_2d{}.jpg'.format(args.data_dim) fig.savefig(path) print('plot saved to {}'.format(path))
def sketch_graph(args): data_dim = 20 lo_dim = 5 g1 = utils.create_graph(data_dim, 'random_regular') #args.n_epochs = 300 <--parameters like this can be set here or in command line args.Lx = utils.graph_to_lap(g1) args.m = len(args.Lx) args.n = lo_dim # sketch graphs of lo_dim. # Returns optimization loss, transport plan P, and Laplacian of sketched graph loss, P, Ly = graph.graph_dist(args, plot=False) print('sketched graph Laplacian {}'.format(Ly)) #can convert Ly to a networkx graph with utils.lap_to_graph(Ly) return loss, P, Ly
def sketch_graph(graphs, lo_dim, args): ''' Run graph sketching. Input: graphs: graphs to be dimension-reduced for.. ''' args.n = lo_dim lo_graphs = [] args.n_epochs = 230 for g in tqdm(graphs, desc='sketching'): args.Lx = utils.graph_to_lap(g) #graph.graph_dist(args, plot=False) args.m = len(args.Lx) #sys.stdout.write(' ' +str(len(g.nodes()))) #sys.stdout.write(str(args.m) +' ') loss, P, Ly = graph.graph_dist(args, plot=False) lo_graphs.append(utils.lap_to_graph(Ly)) return lo_graphs
def sketch_graph(graphs, dataset_cls, lo_dim, args): ''' Run graph sketching. Input: graphs: graphs to be dimension-reduced for. ''' args.n = lo_dim lo_graphs = [] lo_cls = [] args.n_epochs = 230 #250 for i, g in enumerate(tqdm(graphs, desc='sketching')): args.Lx = utils.graph_to_lap(g) #graph.graph_dist(args, plot=False) args.m = len(args.Lx) try: #rarely, 0.2% of time pytorch's eigenvalue finding doesn't converge loss, P, Ly = graph.graph_dist(args, plot=False) except RuntimeError as e: # pdb.set_trace() print(e) continue lo_graphs.append(utils.lap_to_graph(Ly)) lo_cls.append(dataset_cls[i]) return lo_graphs, lo_cls
def classify_st(dataset, queries, dataset_cls, target, args, dataset0=None, queries0=None): """ classify graphs. Can be used to compare COPT, GOT. Input: dataset, queries: could be sketched or non-sketched. dataset0, queries0: original, non-sketched graphs. """ if dataset0 is None: dataset0 = dataset queries0 = queries n_data = len(dataset) n_queries = len(queries) ot_cost = np.zeros((len(queries), len(dataset))) st_cost = np.zeros((len(queries), len(dataset))) Ly_mx = [] Lx_mx = [] data_graphs = [] for i, data in enumerate(dataset): n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) #Ly_mx.append(L[torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) > 0]) Ly_mx.append(L) #pdb.set_trace() for i, q in enumerate(tqdm(queries, desc='queries')): Lx = utils.graph_to_lap(q) args.Lx = Lx args.m = len(q.nodes()) Lx_mx.append(args.Lx) n_repeat = 1 #1 works fine for j, data in enumerate(dataset): Ly = Ly_mx[j].clone() args.n = len(Ly) min_loss = 10000 for _ in range(n_repeat): loss, P, Ly_ = graph.graph_dist(args, plot=False, Ly=Ly, take_ly_exp=False) if loss < min_loss: min_loss = loss ot_cost[i][j] = min_loss try: x_reg, y_reg, (P_st, loss_st) = st.find_permutation( Lx.cpu().numpy(), Ly.cpu().numpy(), args.st_it, args.st_tau, args.st_n_samples, args.st_epochs, args.st_lr, loss_type='w', alpha=0, ones=True, graphs=True) #l2 except Exception: print('Exception encountered during GOT') #pdb.set_trace() st_cost[i][j] = loss_st ##can also try median, or dataset_cls[np.argsort(ot_cost[-8],-1)[:10]], or dataset_cls[np.argpartition(ot_cost[6],10)[:10]] ot_cost_ = torch.from_numpy(ot_cost) #for combined, can add dist here ot_cost_ranks = torch.argsort(ot_cost_, -1)[:, :args.n_per_cls] ones = torch.ones(100) #args.n_per_cls*2 (n_cls*2) ot_cls = np.ones(n_queries) dataset_cls_t = torch.from_numpy(dataset_cls) for i in range(n_queries): #for each cls cur_ranks = dataset_cls_t[ot_cost_ranks[i]] ranked = torch.zeros(100) #n_cls*2 ranked.scatter_add_(src=ones, index=cur_ranks, dim=-1) ot_cls[i] = torch.argmax(ranked).item() ot_cost_means = np.mean(ot_cost.reshape(n_queries, n_data // args.n_per_cls, args.n_per_cls), axis=-1) ot_idx = np.argmin(ot_cost_means, axis=-1) * args.n_per_cls st_cost_means = np.mean(st_cost.reshape(n_queries, n_data // args.n_per_cls, args.n_per_cls), axis=-1) st_idx = np.argmin(st_cost_means, axis=-1) * args.n_per_cls ot_cls1 = dataset_cls[ot_idx] st_cls = dataset_cls[st_idx] ot_acc, ot_acc1 = np.equal(ot_cls, target).sum() / len(target), np.equal( ot_cls1, target).sum() / len(target) st_acc = np.equal(st_cls, target).sum() / len(target) print('ot acc1 {} ot acc {} st acc {}'.format(ot_acc1, ot_acc, st_acc)) return
def perm_mi(args): ''' Remove edges, permute, align, then measure MI. ''' args.n_epochs = 1000 params = {'n_blocks': 4} use_given_graph = False if use_given_graph: #True:#False: #True: g = torch.load('mi_g_.pt') else: seed = 0 if args.fix_seed else None g = utils.create_graph(40, gtype='block', params=params, seed=seed) #torch.save(g, 'mi_g.pt') orig_cls = [] for i in range(4): orig_cls.extend([i for _ in range(10)]) orig_cls = np.array(orig_cls) Lg = utils.graph_to_lap(g) args.Lx = Lg.clone() args.m = len(Lg) #remove edges and permute n_remove = args.n_remove #150 rand_seed = 0 if args.fix_seed else None Lg_removed = utils.remove_edges(Lg, n_remove=n_remove, seed=rand_seed) Lg_perm, perm = utils.permute_nodes(Lg_removed.numpy(), seed=rand_seed) inv_perm = np.empty(args.m, perm.dtype) inv_perm[perm] = np.arange(args.m) ##Ly = torch.from_numpy(Lg_perm) Ly = torch.from_numpy(Lg_perm) #Lg_removed.clone() #args.Lx.clone() args.n = len(Ly) #8 st_n_samples worked best, 5 sinkhorn iter, 1 as tau #align time0 = time.time() loss, P, Ly_ = graph.graph_dist(args, plot=False, Ly=Ly, take_ly_exp=False) dur_ot = time.time() - time0 orig_idx = P.argmax(-1).cpu().numpy() perm_mx = False if perm_mx: P_max = P.max(-1, keepdim=True)[0] P[P < P_max - .1] = 0 P[P > 0] = 1 new_cls = orig_cls[perm][orig_idx].reshape(-1) mi = utils.normalizedMI(orig_cls, new_cls) #return mi Lx = args.Lx time0 = time.time() x_reg, y_reg, (P_st, loss_st) = st.find_permutation(Ly.cpu().numpy(), Lx.cpu().numpy(), args.st_it, args.st_tau, args.st_n_samples, args.st_epochs, args.st_lr, loss_type='w', alpha=0, ones=True, graphs=True) dur_st = time.time() - time0 orig_idx = P_st.argmax(-1) new_cls_st = orig_cls[perm][orig_idx].reshape(-1) mi_st = utils.normalizedMI(orig_cls, new_cls_st) #print('{} COPT {} GOT {} dur ot {} dur st {}'.format(n_remove, mi, mi_st, dur_ot, dur_st)) print('{} {} {} {} {}'.format(n_remove, mi, mi_st, dur_ot, dur_st)) return mi
def classify(dataset, queries, dataset_cls, target, args, dataset0=None, queries0=None): """ classification tasks using various methods. dataset0, queries0 are original, non-sketched graphs. dataset, queries contain sketched graphs. """ if dataset0 is None: dataset0 = dataset queries0 = queries #with open(args.graph_fname, 'rb') as f: # graphs = pickle.read(f) n_data = len(dataset) n_queries = len(queries) ot_cost = np.zeros((len(queries), len(dataset))) netlsd_cost = np.zeros((len(queries), len(dataset))) Ly_mx = [] Lx_mx = [] data_graphs = [] heat_l = [] #avg_deg = 0 for i, data in enumerate(dataset): #pdb.set_trace() if isinstance(data, torch.Tensor): L = data else: n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) avg_deg = (L.diag().mean()) L /= avg_deg #Ly_mx.append(L[torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) > 0]) Ly_mx.append(L) #pdb.set_trace() heat_l.append(netlsd.heat(L.numpy())) #avg_deg /= len(dataset) for i, q in enumerate(tqdm(queries, desc='queries')): '''### if isinstance(data, torch.Tensor): L = data else: n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) ''' Lx = utils.graph_to_lap(q) avg_deg = (Lx.diag().mean()) Lx /= avg_deg args.Lx = Lx args.m = len(q.nodes()) q_heat = netlsd.heat(Lx.numpy()) Lx_mx.append(args.Lx) for j, data in enumerate(dataset): Ly = Ly_mx[j].clone() args.n = len(Ly) min_loss = 10000 for _ in range(1): loss, P, Ly_ = graph.graph_dist(args, plot=False, Ly=Ly, take_ly_exp=False) #pdb.set_trace() if loss < min_loss: min_loss = loss ot_cost[i][j] = min_loss netlsd_cost[i][j] = netlsd.compare(q_heat, heat_l[j]) if args.dataset_type == 'real': ot_cost1 = (ot_cost - ot_cost.mean()) / np.std(ot_cost) ot_pred = ot_cost.argmin(1) ot_acc00 = np.equal(dataset_cls[ot_pred], target).sum() / len(target) print('OT ACC |{} '.format(ot_acc00)) ot_sorted = np.argsort(ot_cost, axis=-1) #pdb.set_trace() ot_cls = dataset_cls[ot_sorted[:, :3]].tolist() combine_pred = np.zeros(len(target)) for i, ot_c in enumerate(ot_cls): counter = collections.Counter() counter.update(ot_c) #pdb.set_trace() common = counter.most_common(1)[0][0] combine_pred[i] = common combine_acc = np.equal(combine_pred, target).sum() / len(target) #pdb.set_trace() ### ot_pred = ot_cost.argmin(1) ot_acc = np.equal(dataset_cls[ot_pred], target).sum() / len(target) netlsd_pred = netlsd_cost.argmin(1) netlsd_acc = np.equal(dataset_cls[netlsd_pred], target).sum() / len(target) print('OT ACC |{} '.format(ot_acc)) return ot_acc00, netlsd_acc ot_cost_ = torch.from_numpy(ot_cost) #for combined, can add dist here ot_cost_ranks = torch.argsort(ot_cost_, -1)[:, :args.n_per_cls] ones = torch.ones(args.n_per_cls * 3) #args.n_per_cls*2 (n_cls*2) 100 ot_cls = np.ones(n_queries) combine_cls = np.ones(n_queries) dataset_cls_t = torch.from_numpy(dataset_cls) #pdb.set_trace() for i in range(n_queries): #for each cls cur_ranks_ot = dataset_cls_t[ot_cost_ranks[i]] ranked = torch.zeros(100) #n_cls*2 ranked.scatter_add_(src=ones, index=cur_ranks_ot, dim=-1) ot_cls[i] = torch.argmax(ranked).item() ot_cost_means = np.mean(ot_cost.reshape(n_queries, n_data // args.n_per_cls, args.n_per_cls), axis=-1) ot_idx = np.argmin(ot_cost_means, axis=-1) * args.n_per_cls print('ot_cost mx ', ot_cost) ot_cls1 = dataset_cls[ot_idx] ot_acc, ot_acc1 = np.equal(ot_cls, target).sum() / len(target), np.equal( ot_cls1, target).sum() / len(target) print('ot acc1 {} ot acc {} '.format(ot_acc1, ot_acc))
def view(dataset_lo, dataset_cls_lo, args): fig = plt.figure() ax = fig.add_subplot(111, projection='3d') #ax.scatter(xs, ys, zs, c=c, marker=m) n_data = len(dataset_lo) Ly_mx = [] ones = torch.ones(3, 3).triu(diagonal=1) for i, data in enumerate(dataset_lo): #n_nodes = len(data.nodes()) L = utils.graph_to_lap(data) #Ly_mx.append(L[torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) > 0]) L = torch.sort(L[ones > 0].view(-1))[0] Ly_mx.append(L) #cls2label = {0:0, 1:1, 4:2, 5:3, 7:4, 8:5, 2:6, 3:7} #cls2label = {0:0, 1:1, 4:2, 5:3, 7:4, 8:5, 2:6, 3:7} #cls2c = {0:'r',1:'b',2:'g',3:'c',4:'m',5:'k',6:'y', 7:'.77'} cls2c = {0: 'r', 1: 'b', 2: 'g', 3: 'c', 4: 'm', 5: 'k', 6: 'y', 7: '.77'} #cls2label = {0:'block2',1:'rand_reg',2:'barabasi',3:'block3',4:'block4',5:'k',6:'y',7:'4'} cls2label = { 0: 'block-2', 1: 'random regular', 2: 'barabasi', 3: 'block-3', 4: 'powerlaw tree', 5: 'caveman', 6: 'watts-strogatz', 7: 'binomial' } c_l = [] labels = [] for i in range(n_data): c_l.append(cls2c[dataset_cls_lo[i]]) labels.append(cls2c[dataset_cls_lo[i]]) ar = torch.stack(Ly_mx) #pdb.set_trace() #ar = (ar/torch.norm(ar, 2, dim=-1, keepdim=True)).t().numpy() #.transpose() #.t().numpy() ar = (ar).t().numpy() #.transpose() #.t().numpy() #ax.scatter(ar[0], ar[1], ar[2], c=c_l, label=labels)#, c=c, marker=m) #### #''' #zoom out ax.set_ylim3d(-17, ar[1].max()) ax.set_zlim3d(-10, ar[2].max()) ax.set_xlim3d(-20, ar[0].max()) #''' #''' #zoom in ax.set_ylim3d(-2, ar[1].max()) ax.set_zlim3d(-1, ar[2].max()) ax.set_xlim3d(-5, ar[0].max()) range_ar = np.array(list(range(n_data))) markers = ['^', 'o', 'x', '.', '1', '3', '+', '4', '5'] marker_cnt = 0 for i in range(8): if i == 6 or i == 7 or i == 2 or i == 3: continue idx = range_ar[dataset_cls_lo == i] #pdb.set_trace() #ax.scatter(ar[0][idx], ar[1][idx], ar[2][idx], c=cls2c[i], label=cls2label[i])#, c=c, marker=m) ax.scatter(ar[0][idx], ar[1][idx], ar[2][idx], c=cls2c[i], marker=markers[marker_cnt], label=cls2label[i]) #, c=c, marker=m) marker_cnt += 1 ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') ax.legend() #plt.title('3D sketches of 10-node graphs (zoomed)', fontsize=18) plt.title('3D sketches (zoomed)', fontsize=18) #plt.title('3D sketches of 10-node graphs', fontsize=18) #plt.title('Three dimensional COPT projections of 20-node graphs') #''' path = 'data/projection_{}.jpg'.format(args.data_dim) fig.savefig(path) print('plot saved to {}'.format(path))