def create_knn_graph(data, k, opt=None): if opt != None and hasattr(opt, 'ranks_path'): ranks = np.load(opt.ranks_path) ranks = torch.from_numpy(ranks) pdb.set_trace() elif opt != None and opt.normalize_data: ''' data /= data.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-3) dist = torch.matmul(data, data.t())# - torch.eye(len(data)) val, ranks = torch.topk(dist, k=k+1, dim=1) ranks = ranks[:, 1:] ''' ranks = utils.dist_rank(data, k=k, opt=opt) else: #compute l2 dist <--be memory efficient by blocking ''' dist = utils.l2_dist(data) dist += 2*torch.max(dist).item()*torch.eye(len(data)) val, ranks = torch.topk(dist, k=k, dim=1, largest=False) ''' ranks = utils.dist_rank(data, k=k, opt=opt) if DEBUG: print(dist) #add 1 since the indices for kahip must be 1-based. ranks += 1 return ranks
def create_knn_sub_graph(all_ranks, idx2weights, ds_idx, data, opt): if False: if opt != None and opt.normalize_data and not opt.glove: data /= data.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-3) dist = torch.matmul(data, data.t()) - torch.eye(len(data)) val, local_ranks = torch.topk(dist, k=1, dim=1) else: ''' dist = utils.l2_dist(data) dist = dist + 2*torch.max(dist).item()*torch.eye(len(data)) val, local_ranks = torch.topk(dist, k=1, dim=1, largest=False) ''' #just compute this in the loop since not too frequent local_ranks = utils.dist_rank(data, k=1, opt=opt) local_ranks += 1 ranks = [] #is_tensor = isinstance(all_ranks, torch.Tensor) ds_idx2idx = {t.item() + 1: idx for idx, t in enumerate(ds_idx, 1)} #dict of idx of a point to its nearest neighbor idx2nn = {} #set of added tuples added_tups = set() if isinstance(all_ranks, torch.Tensor): all_ranks = all_ranks.cpu().numpy() if isinstance(ds_idx, torch.Tensor): ds_idx = ds_idx.cpu().numpy() for idx, i in enumerate(ds_idx): cur_ranks = [] for j in all_ranks[i]: tup = (i + 1, j) if i + 1 < j else (j, i + 1) if idx2weights[tup] == 1 and tup in added_tups: continue added_tups.add(tup) if j in ds_idx2idx: cur_ranks.append(ds_idx2idx[j]) if cur_ranks == []: #nearest_idx = local_ranks[idx][0].item() local_ranks = utils.dist_rank(data[idx].unsqueeze(0), k=1, data_y=data, opt=opt) nearest_idx = local_ranks[0][0].item() + 1 cur_ranks.append(nearest_idx) idx2nn[idx] = nearest_idx ranks.append(cur_ranks) for idx, nn in idx2nn.items(): ranks[nn - 1].append(idx + 1) return ranks
def knn_dist_lof(X, k=10): X_len = len(X) #dist_ = dist(X, X) #min_dist, min_idx = torch.topk(dist_, dim=-1, k=k, largest=False) min_dist, min_idx = utils.dist_rank(X, k=k, largest=False) kth_dist = min_dist[:, -1] # sum max(kth dist, dist(o, p)) over neighbors o of p kth_dist_exp = kth_dist.expand(X.size(0), -1) #n x n kth_dist = torch.gather(input=kth_dist_exp, dim=1, index=min_idx) min_dist[kth_dist > min_dist] = kth_dist[kth_dist > min_dist] #inverse of lrd scores dist_avg = min_dist.mean(-1).clamp(min=0.0001) compare_density = False if compare_density: #compare with density. Get kth neighbor index. dist_avg_exp = dist_avg.unsqueeze(-1) / dist_avg.unsqueeze(0).expand( X_len, -1) #lof = torch.zeros(X_len, 1).to(utils.device) lof = torch.gather(input=dist_avg_exp, dim=-1, index=min_idx).sum(-1) torch.scatter_add_(lof, dim=-1, index=min_idx, src=dist_avg_exp) return -lof.squeeze(0) return dist_avg
def create_dataset(): data_dir = 'data/prefix' data_dir = '/large/prefix' data_dir = '../deep1b.dat' data_dir = '/large/deep1b.dat' #with open(osp.join(data_dir, 'base_00'), 'rb') as file: with open(osp.join(data_dir), 'rb') as file: data = file.read() #(8*96+8) i = 0 skip = 4 * 96 #4*97 data_len = len(data) #last vector can be cut off in the middle of vec ##assert data_len % 4 == 0 and (data_len-4) % 96 == 0 n_queries = 100000 n_data = 10000000 #n_queries = 1 #n_data = 9 n_total = n_data + 2 * n_queries #2x to account for dups data_l = [] stop_len = data_len - skip counter = 0 byte_set = set() dup_count = 0 while i < stop_len: #cur_bytes = data[i+4:i+skip] <--if download directly cur_bytes = data[i:i + skip] # if cur_bytes in byte_set: dup_count += 1 i += skip continue byte_set.add(cur_bytes) data_l.append(list(struct.unpack('96f', cur_bytes))) #except struct.error: # print('struct error') # pdb.set_trace() counter += 1 if counter == n_total: #pdb.set_trace() break i += skip print('dup count {}'.format(dup_count)) print('number of vectors {}'.format(len(data_l))) queries = torch.FloatTensor(data_l[:n_queries]) dataset = torch.FloatTensor(data_l[n_queries:n_queries + n_data]) torch.save(queries, '/large/prefix10m_queries.pt') torch.save(dataset, '/large/prefix10m_dataset.pt') #pdb.set_trace() #need bah size 200 for 10mil to not be out of memory answers = utils.dist_rank(queries, k=10, data_y=dataset, include_self=True)
def knn_dist(X, k=10, sum_dist=False): min_dist, idx = utils.dist_rank(X, k=k, largest=False) if sum_dist: dist_score = min_dist.sum(-1) else: dist_score = min_dist.mean(-1) return dist_score
def predict(self, query, k): #query = query.to(utils.device) if isinstance(query, np.ndarray): query = torch.from_numpy(query).to(utils.device) #self centers have dimension 1, torch.Size([100, 1024]) if hasattr(self, 'opt') and (self.opt.glove or self.opt.sift) and self.centers.size(1) > 512: centers = self.centers.t() idx = utils.dist_rank(query, k, data_y=centers, largest=False) else: q_norm = torch.sum(query ** 2, dim=1).view(-1, 1) dist = q_norm + self.centers_norm - 2*torch.mm(query, self.centers) if k > dist.size(1): k = dist.size(1) _, idx = torch.topk(dist, k=k, dim=1, largest=False) #move predict to numpy idx = idx.cpu().numpy() return idx
prev = d.eq(data).sum(-1) #if i == 10: # pdb.set_trace() if (prev == 96).sum().item() > 1: #1 for itself selected[i] = 0 print('{} dup '.format(i)) data = data[selected] print('{} duplicates'.format(len(data) - selected.sum())) np.save('data/prefix1m_dataset2.npy', data.cpu().numpy()) return data if __name__ == '__main__': remove_dup_bool = False if remove_dup_bool: dataset = np.load('data/prefix1m_dataset.npy') dataset = torch.from_numpy(dataset).cuda() dataset = remove_dup(dataset) queries = torch.from_numpy(np.load('data/prefix1m_queries.npy')).cuda() answers = utils.dist_rank(queries, k=10, data_y=dataset, include_self=True) np.save('data/prefix1m_answers2.npy', answers.cpu().numpy()) else: create_dataset()
def k_means(dataset, dataset_idx, ht2cutsz, height, n_clusters, opt): #ranks num_points = dataset.shape[0] dimension = dataset.shape[1] use_kahip_solver = False if opt.kmeans_use_kahip_height == height: use_kahip_solver = True if use_kahip_solver: solver = kahip_solver.KahipSolver() elif opt.fast_kmeans: solver = kmeans.FastKMeans(dataset, n_clusters, opt) elif opt.itq: solver = itq.ITQSolver(dataset, n_clusters) elif opt.cplsh: solver = cplsh.CPLshSolver(dataset, n_clusters, opt) elif opt.pca: assert n_clusters == 2 solver = pca.PCASolver(dataset, opt) elif opt.st: assert n_clusters == 2 solver = pca.STSolver(dataset, opt.glob_st_ranks, dataset_idx, opt) elif opt.rp: if n_clusters != 2: raise Exception('n_cluster {} must be 2!'.format(n_clusters)) solver = pca.RPSolver(dataset, opt) elif km_method == 'km': solver = KMeans(n_clusters=n_clusters, max_iter=max_loyd) solver.fit(dataset) elif km_method == 'mbkm': solver = MiniBatchKMeans(n_clusters=n_clusters, max_iter=max_loyd) solver.fit(dataset) else: raise Exception('method {} not supported'.format(km_method)) #print("Ranking clusters for data and query points...") #dataset_dist = solver.transform(dataset) #could be useful, commented out for speed #queries_dist = solver.transform(queries) #the distances to cluster centers, ranked smallest first. #dataset_dist_idx = np.argsort(dataset_dist, axis=1) #queries_dist_idx = np.argsort(queries_dist, axis=1) if use_kahip_solver: #output is numpy array d_cls_idx = solver.predict(dataset_idx) elif isinstance(solver, kmeans.FastKMeans): d_cls_idx = solver.predict(dataset, k=1) d_cls_idx = d_cls_idx.reshape(-1) elif isinstance(solver, cplsh.CPLshSolver): d_cls_idx = solver.predict(dataset, k=1) elif isinstance(solver, itq.ITQSolver): d_cls_idx = solver.predict(dataset, k=1) d_cls_idx = d_cls_idx.reshape(-1) else: d_cls_idx = solver.predict(dataset) #lists of indices (not dataset points) for each class. Note each list element is a tuple. #list of np arrays d_idx_l = [np.where(d_cls_idx == i)[0] for i in range(n_clusters)] #q_idx_l = [np.where(q_cls_idx==i) for i in range(n_clusters)] #could be useful, commented out for speed compute_cut_sz_b = False if compute_cut_sz_b: ranks = utils.dist_rank(dataset, k=opt.k, opt=opt) #ranks are assumed to be 1-based ranks += 1 cut_sz = compute_cut_size(d_cls_idx.tolist(), ranks) ht2cutsz[height].append(cut_sz) return d_idx_l, solver
res_l = ['Catalyzed data '] ds, qu, neigh = load_data(utils.data_dir, opt) if opt.cplsh and opt.sift: ds = ds / np.sqrt((ds**2).sum(-1, keepdims=True)) qu = qu / np.sqrt((qu**2).sum(-1, keepdims=True)) qu = qu[:500] neigh = neigh[:500] n_repeat = 1 #search tree #global glob_st_ranks #if glob_st_ranks is None: if opt.st: opt.glob_st_ranks = utils.dist_rank(ds, opt.k, include_self=True, opt=opt) torch.save(opt.glob_st_ranks, 'st_ranks_glove') for i in range(n_repeat): for height in height_l: acc, probe, probe95 = run_main(height, ds, qu, neigh, opt) res_l.append( str(height) + ' ' + ' '.join( [str(acc[0, 0]), str(probe[0, 0]), str(probe95[0, 0])])) res_str = '\n'.join(res_l) if opt.rp: with open(osp.join(utils.data_dir, 'rp_data_mnist.md'), 'a') as file: file.write(res_str + '\n')
def run_kmkahip(height_preset, opt, dataset, queryset, neighbors): k = opt.k print('Configs: {} \n Starting data processing and training ...'.format(opt)) #this root node is a dummy node, since it doesn't have a trained model or idx2bin train_node = train.TrainNode(-1, opt, -1) swap_query_to_data = False if swap_query_to_data: print('** NOTE: Modifying queryset as part of dataset **') queryset = dataset[:11000] #queryset = dataset neighbors = utils.dist_rank(queryset, k=opt.k, data_y=dataset, largest=False) #dist += 2*torch.max(dist).item()*torch.eye(len(dist)) #torch.diag(torch.max(dist)) #val, neighbors = torch.topk(dist, k=opt.k, dim=1, largest=False) #dsnode_path = opt.dsnode_path + str(opt.n_clusters) #dsnode = utils.pickle_load(dsnode_path) #check if need to normalize data. Remove second conditions eventually. if opt.normalize_data and dataset[0].norm(p=2).item() != 1 and not opt.glove: print('Normalizing data ...') dataset = utils.normalize(dataset) queryset = utils.normalize(queryset) #create data tree used for training n_clusters = opt.n_clusters height = height_preset n_bins = 1 ds_idx = torch.LongTensor(list(range(len(dataset)))) print('{} height: {} level2action {}'.format(ds_idx.size(), height, opt.level2action)) idx2bin = {} ht2cutsz = defaultdict(list) #used for memoizing partition results branching_l = ['0'] all_ranks = None root_dsnode = create_data_tree_root(dataset, all_ranks, ds_idx, train_node, idx2bin, height, branching_l,ht2cutsz, opt) print('Done creating training tree. Starting evaluation ...') #top node only first child node is train node. eval_root = train.EvalNode(train_node.children[0]) ''' Evaluate ''' with torch.no_grad(): print('About to evaluate model! {} height: {} level2action {}'.format(ds_idx.size(), height, opt.level2action)) acc, probe_count, probe_count95 = train.eval_model(eval_root, queryset, neighbors, n_bins, opt) print('cut_sizes {}'.format(ht2cutsz)) print('Configs: {}'.format(opt)) print('acc {} probe count {} 95th {}'.format(acc, probe_count, probe_count95)) ''' Serialize ''' serialize_bool = False if 'kahip' in set(opt.level2action.values()) else True serialize_bool = True if serialize_bool: print('Serializing eval root...') if opt.sift: data_name = 'sift' elif opt.glove: data_name = 'glove' elif opt.prefix10m: data_name = 'prefix10m' else: data_name = 'mnist' idx2bin = eval_root.idx2bin if 'logreg' in opt.level2action.values(): serial_path = 'evalroot_{}_ht{}_{}_{}{}nn{}logreg' else: serial_path = 'evalroot_{}_ht{}_{}_{}{}nn{}' eval_root_path = osp.join(opt.data_dir, serial_path.format(data_name, height, n_clusters, opt.k_graph, opt.k, opt.nn_mult)) eval_root_dict = {'eval_root':eval_root, 'opt':opt} utils.pickle_dump(eval_root_dict, eval_root_path) print('Done serializing {}'.format(eval_root_path)) #dsnode_path = opt.dsnode_path + str(opt.n_clusters) #utils.pickle_dump(root_dsnode, dsnode_path) with open(osp.join(opt.data_dir, 'cutsz_k{}_ht{}_{}'.format(k, height, n_clusters)), 'w') as file: file.write(str(ht2cutsz)) file.write('\n\n') file.write(str(opt))
def create_data_tree_root(dataset, all_ranks, ds_idx, train_node, idx2bin, height, branching_l, ht2cutsz, opt): datalen = len(ds_idx) if datalen <= opt.k: return None graph_path = os.path.join(opt.data_dir, opt.graph_file) #'../data/knn.graph' #ranks are 1-based if opt.glove or opt.sift or opt.prefix10m: #and len(branching_l) == 1: if opt.glove: #custom paths #if opt.glove and opt.k_graph==50: #april, 50NN graph file #graph_path = os.path.join(opt.data_dir, 'glove50_'+opt.graph_file) #'../data/knn.graph' graph_path = os.path.join(opt.data_dir, opt.graph_file) #'../data/knn.graph' #graph_path = os.path.join(opt.data_dir, 'glove10_sub10knn.graph') print('graph file {}'.format(graph_path)) parts_path = run_kahip(graph_path, datalen, branching_l, height, opt) print('Done partitioning top level!') lines = utils.load_lines(parts_path) classes = [int(line) for line in lines] #read in all_ranks, for partitioning on further levels. all_ranks, idx2weights = read_all_ranks(opt) if opt.dataset_name != 'prefix10m': k1 = max(1, int(opt.nn_mult*opt.k)) ranks = utils.dist_rank(dataset, k=k1) else: #subtract 1 as graph was created with 1-indexing for kahip. ranks = torch.load('/large/prefix10m10knn.graph.pt') - 1 #create root DataNode dataset, ds_idx, parent_train_node, idx2bin, height, opt dsnode = add_datanode_children(dataset, (all_ranks, idx2weights), ds_idx, train_node, idx2bin, height-1, branching_l, classes, ht2cutsz, 0, opt, ranks, toplevel=True, root=True) return dsnode #create graph from data. data = dataset[ds_idx] if len(branching_l) == 1: #this is always the case now #use tree created at top level throughout the hierarchy ranks = create_graph.create_knn_graph(data, k=opt.k, opt=opt) #should supply opt all_ranks = ranks else: assert all_ranks is not None #else compute part of previous graph ranks = create_graph.create_knn_sub_graph(all_ranks, ds_idx, data, opt) n_edges = create_graph.write_knn_graph(ranks, graph_path) _, idx2weights = read_all_ranks(opt, path=graph_path) #create partition from graph #this overrides file each iteration parts_path = run_kahip(graph_path, datalen, branching_l, height, opt) lines = utils.load_lines(parts_path) classes = [int(line) for line in lines] compute_cut_size_b = False and not opt.glove if compute_cut_size_b: cut_sz = compute_cut_size(classes, ranks) ht2cutsz[height].append((cut_sz, n_edges)) #create root DataNode dataset, ds_idx, parent_train_node, idx2bin, height, opt dsnode = add_datanode_children(dataset, (all_ranks, idx2weights), ds_idx, train_node, idx2bin, height-1, branching_l, classes, ht2cutsz, 0, opt, all_ranks-1, toplevel=True, root=True) #Note the above all_ranks is not 5*opt.k number of nearest neighbors. return dsnode
def add_datanode_children(dataset, all_ranks_data, ds_idx, parent_train_node, idx2bin, height, branching_l, classes, ht2cutsz, cur_tn_idx, opt, ds_idx_ranks, toplevel=None, root=False): all_ranks, idx2weights = all_ranks_data n_class = opt.n_class if opt.n_class <= len(ds_idx) else len(ds_idx) ''' For 2nd level, SIFT, say 64 parts, beyond 25 epochs train does not improve much. ''' if opt.glove or opt.sift: n_epochs = 18 if len(branching_l)==1 else 15 #44 opt.n_epochs #opt.n_epochs ################stopping mechanism 65. 18 if len(branching_l)==1 else 15 <-for MCCE loss #glove+sift: 18 then 15 else: n_epochs = 18 if len(branching_l)==1 else 10 #opt.n_epochs #opt.n_epochs ################stopping mechanism 65. #85 good top level epoch number for MNIST. #glove+sift: 18 then 10 toplevel = toplevel if toplevel is not None else (True if height > 0 else False) #need to train and get children idx (classes) from net. train_node = train.TrainNode(n_epochs, opt, height, toplevel=toplevel) #append node to parent parent_train_node.add_child(train_node, cur_tn_idx) dataset_data = dataset[ds_idx] if False and opt.sift: #'n' stands for neural and normalized dataset_n = dataset / dataset.norm(dim=1, p=2, keepdim=True).clamp(1) dataset_data_n = dataset_n[ds_idx] else: dataset_n = dataset dataset_data_n = dataset_data #height is 0 for leaf level nodes if False and height < 1:#not opt.compute_gt_nn: #height < 1: #not opt.compute_gt_nn: True or train_node.train(dataset, dsnode, idx2bin, height) model = train_node.model model.eval() classes_l = [] chunk_sz = 90000 dataset_len = len(dataset_data) for i in range(0, dataset_len, chunk_sz): end = min(i+chunk_sz, dataset_len) cur_data = dataset_data[i:end, :] classes_l.append(torch.argmax(model(cur_data), dim=1)) classes = torch.cat(classes_l) action = opt.level2action[height] if action == 'km': #bottom level, use kmeans train_node.model = None train_node.trained = True train_node.idx2bin = idx2bin solver = kmeans.FastKMeans(dataset_data, n_class, opt) d_cls_idx = solver.predict(dataset_data, k=1) d_cls_idx = d_cls_idx.reshape(-1) classes = torch.LongTensor(d_cls_idx) train_node.kmsolver = solver d_idx_l = [np.where(d_cls_idx==i)[0] for i in range(n_class)] train_node.probe_count_l = [len(l) for l in d_idx_l] #[(classes == i).sum().item() for i in range(n_class) ] else: classes = torch.LongTensor(classes) if action == 'train': device = dataset.device ''' #compute the ranks of top classes. Using centers of all points in a class sums = torch.zeros(n_class, dataset_data.size(-1), device=device) classes_exp = classes.unsqueeze(1).expand_as(dataset_data).to(device) sums.scatter_add_(0, classes_exp, dataset_data) lens = torch.zeros(n_class)#, dtype=torch.int64) lens_ones = torch.ones(dataset_data.size(0))# , dtype=torch.int64) lens.scatter_add_(0, classes, lens_ones) lens = lens.to(device) centers = sums / lens.unsqueeze(-1) ranks = utils.dist_rank(dataset_data, k=n_class, data_y=centers, include_self=True) ''' dsnode = DataNode(ds_idx, classes, n_class, ranks=ds_idx_ranks) #if opt.sift: #center as well? train_node.train(dataset_n, dsnode, idx2bin, height) #else: # train_node.train(dataset, dsnode, idx2bin, height) model = train_node.model model.eval() classes_l = [] chunk_sz = 80000 dataset_len = len(dataset_data_n) for i in range(0, dataset_len, chunk_sz): end = min(i+chunk_sz, dataset_len) cur_data = dataset_data_n[i:end, :] classes_l.append(torch.argmax(model(cur_data), dim=1)) classes = torch.cat(classes_l) elif action == 'logreg': train_node.model = None train_node.trained = True train_node.idx2bin = idx2bin cur_path = None if opt.glove: cur_path = osp.join(utils.data_dir, 'lg_glove') elif opt.sift: cur_path = osp.join(utils.data_dir, 'lg_sift') if root and cur_path is not None: if osp.exists(cur_path): #deserialize with open(cur_path, 'rb') as file: solver = pickle.load(file) else: #serialize solver = logreg.LogReg(dataset_data, classes, opt) with open(cur_path, 'wb') as file: pickle.dump(solver, file) else: solver = logreg.LogReg(dataset_data, classes, opt) d_cls_idx = solver.predict(dataset_data, k=1) d_cls_idx = d_cls_idx.reshape(-1) classes = torch.LongTensor(d_cls_idx) train_node.kmsolver = solver d_idx_l = [np.where(d_cls_idx==i)[0] for i in range(n_class)] train_node.probe_count_l = [len(l) for l in d_idx_l] elif action == 'kahip': #kahip only train_node.model = None train_node.trained = True train_node.idx2bin = idx2bin train_node.idx2kahip = {} for i, cur_idx in enumerate(ds_idx): train_node.idx2kahip[cur_idx.item()] = classes[i] train_node.probe_count_l = [(classes == i).sum().item() for i in range(n_class) ] else: raise Exception('Action must be either kahip km or train') dsnode = DataNode(ds_idx, classes, n_class) #ds_idx needs to be indices wrt entire dataset. #y are labels of clusters, indices 0 to num_cluster. if height > 0: #recurse based on children procs = [] next_act = opt.level2action[height-1] parallelize = next_act in ['train', 'kahip', 'logreg'] if parallelize: p_man = mp.Manager() idx2classes = p_man.dict() branching_l_l = [] child_ds_idx_l = [] #index of child TrainNode tnode_idx_l = [] ranks_l = [] for cur_class in range(n_class): #pick the samples having this class child_ds_idx = ds_idx[classes==cur_class] child_branching_l = list(branching_l) child_branching_l.append(str(cur_class)) if len(child_ds_idx) < opt.k: #create train_node without model, but with base_idx, leaf_idx etc. Need to have placeholder for correct indexing. child_tn = train.TrainNode(opt.n_epochs, opt, height-1) child_tn.base_idx = len(set(idx2bin.values())) child_tn.leaf_idx = [child_tn.base_idx] for j in child_ds_idx: idx2bin[j.item()] = child_tn.base_idx child_tn.probe_count_l = [len(child_ds_idx)] child_tn.idx2bin = idx2bin train_node.add_child(child_tn, cur_class) else: ranks, all_ranks_data, graph_path = create_data_tree(dataset, all_ranks_data, child_ds_idx, train_node, idx2bin, height, child_branching_l, ht2cutsz, opt) branching_l_l.append(child_branching_l) #those knn graphs for kahip are one-based, and are lists and not tensors due to weights. if next_act == 'train': k1 = max(1, int(opt.nn_mult*opt.k)) ranks_l.append(utils.dist_rank(dataset[child_ds_idx], k=k1)) else: ranks_l.append([]) if parallelize: datalen = len(child_ds_idx) p = mp.Process(target=process_child, args=(ranks, graph_path, datalen, child_branching_l, height, idx2classes, len(procs), ht2cutsz, opt)) #print('processed child process!! len {}'.format(len(cur_classes))) procs.append(p) p.start() tnode_idx_l.append(cur_class) child_ds_idx_l.append(child_ds_idx) for p in procs: p.join() print('~~~~~~~~~~finished p.join. check classes_l') for i in range(len(branching_l_l )): if parallelize: classes = idx2classes[i] else: classes = None child_branching_l = branching_l_l[i] child_ds_idx = child_ds_idx_l[i] child_ranks = ranks_l[i] #create root DataNode dataset, ds_idx, parent_train_node, idx2bin, height, opt child_dsnode = add_datanode_children(dataset, all_ranks_data, child_ds_idx, train_node, idx2bin, height-1, child_branching_l, classes, ht2cutsz, tnode_idx_l[i], opt, child_ranks) dsnode.add_child(child_dsnode) else: train_node.base_idx = len(set(idx2bin.values())) train_node.leaf_idx = range(train_node.base_idx, train_node.base_idx+n_class) if train_node.kmsolver is not None: predicted = train_node.kmsolver.predict(dataset_data, k=1) for i, pred in enumerate(predicted): idx2bin[ds_idx[i].item()] = train_node.base_idx + int(pred) else: #predict entire dataset at once! if opt.compute_gt_nn or action == 'kahip': for i, data in enumerate(dataset_data): predicted = train_node.idx2kahip[ds_idx[i].item()].item() idx2bin[ds_idx[i].item()] = train_node.base_idx + predicted elif train_node.model is not None: dataset_data_len = len(dataset_data_n) chunk_sz = 80000 if dataset_data_len > chunk_sz: pred_l = [] for p in range(0, dataset_data_len, chunk_sz): cur_data = dataset_data_n[p : min(p+chunk_sz, dataset_data_len)] pred_l.append( torch.argmax(model(cur_data), dim=1) ) predicted = torch.cat(pred_l) else: predicted = torch.argmax(model(dataset_data_n), dim=1) for i, pred in enumerate(predicted): #idx2bin[ds_idx[i].item()] = train_node.base_idx + train_node.leaf_idx[predicted] idx2bin[ds_idx[i].item()] = train_node.base_idx + int(pred) else: raise Exception('Training inconsistency') return dsnode