def create_knn_graph(data, k, opt=None):

    if opt != None and hasattr(opt, 'ranks_path'):
        ranks = np.load(opt.ranks_path)
        ranks = torch.from_numpy(ranks)
        pdb.set_trace()
    elif opt != None and opt.normalize_data:
        '''
        data /= data.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-3)
        dist = torch.matmul(data, data.t())# - torch.eye(len(data))
        val, ranks = torch.topk(dist, k=k+1, dim=1)
        ranks = ranks[:, 1:]
        '''
        ranks = utils.dist_rank(data, k=k, opt=opt)
    else:
        #compute l2 dist <--be memory efficient by blocking
        '''
        dist = utils.l2_dist(data)        
        dist += 2*torch.max(dist).item()*torch.eye(len(data))
        val, ranks = torch.topk(dist, k=k, dim=1, largest=False)
        '''
        ranks = utils.dist_rank(data, k=k, opt=opt)

        if DEBUG:
            print(dist)

    #add 1 since the indices for kahip must be 1-based.
    ranks += 1

    return ranks
def create_knn_sub_graph(all_ranks, idx2weights, ds_idx, data, opt):

    if False:
        if opt != None and opt.normalize_data and not opt.glove:
            data /= data.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-3)
            dist = torch.matmul(data, data.t()) - torch.eye(len(data))
            val, local_ranks = torch.topk(dist, k=1, dim=1)
        else:
            '''
            dist = utils.l2_dist(data)        
            dist = dist + 2*torch.max(dist).item()*torch.eye(len(data))
            val, local_ranks = torch.topk(dist, k=1, dim=1, largest=False)
            '''
            #just compute this in the loop since not too frequent
            local_ranks = utils.dist_rank(data, k=1, opt=opt)

        local_ranks += 1

    ranks = []
    #is_tensor = isinstance(all_ranks, torch.Tensor)
    ds_idx2idx = {t.item() + 1: idx for idx, t in enumerate(ds_idx, 1)}
    #dict of idx of a point to its nearest neighbor
    idx2nn = {}
    #set of added tuples
    added_tups = set()

    if isinstance(all_ranks, torch.Tensor):
        all_ranks = all_ranks.cpu().numpy()
    if isinstance(ds_idx, torch.Tensor):
        ds_idx = ds_idx.cpu().numpy()
    for idx, i in enumerate(ds_idx):
        cur_ranks = []
        for j in all_ranks[i]:

            tup = (i + 1, j) if i + 1 < j else (j, i + 1)
            if idx2weights[tup] == 1 and tup in added_tups:
                continue

            added_tups.add(tup)

            if j in ds_idx2idx:
                cur_ranks.append(ds_idx2idx[j])

        if cur_ranks == []:
            #nearest_idx = local_ranks[idx][0].item()

            local_ranks = utils.dist_rank(data[idx].unsqueeze(0),
                                          k=1,
                                          data_y=data,
                                          opt=opt)
            nearest_idx = local_ranks[0][0].item() + 1
            cur_ranks.append(nearest_idx)
            idx2nn[idx] = nearest_idx

        ranks.append(cur_ranks)
    for idx, nn in idx2nn.items():
        ranks[nn - 1].append(idx + 1)

    return ranks
def knn_dist_lof(X, k=10):
    X_len = len(X)

    #dist_ = dist(X, X)
    #min_dist, min_idx = torch.topk(dist_, dim=-1, k=k, largest=False)

    min_dist, min_idx = utils.dist_rank(X, k=k, largest=False)
    kth_dist = min_dist[:, -1]
    # sum max(kth dist, dist(o, p)) over neighbors o of p
    kth_dist_exp = kth_dist.expand(X.size(0), -1)  #n x n
    kth_dist = torch.gather(input=kth_dist_exp, dim=1, index=min_idx)

    min_dist[kth_dist > min_dist] = kth_dist[kth_dist > min_dist]
    #inverse of lrd scores
    dist_avg = min_dist.mean(-1).clamp(min=0.0001)

    compare_density = False
    if compare_density:
        #compare with density. Get kth neighbor index.
        dist_avg_exp = dist_avg.unsqueeze(-1) / dist_avg.unsqueeze(0).expand(
            X_len, -1)
        #lof = torch.zeros(X_len, 1).to(utils.device)
        lof = torch.gather(input=dist_avg_exp, dim=-1, index=min_idx).sum(-1)
        torch.scatter_add_(lof, dim=-1, index=min_idx, src=dist_avg_exp)
        return -lof.squeeze(0)

    return dist_avg
def create_dataset():
    data_dir = 'data/prefix'
    data_dir = '/large/prefix'
    data_dir = '../deep1b.dat'
    data_dir = '/large/deep1b.dat'
    #with open(osp.join(data_dir, 'base_00'), 'rb') as file:
    with open(osp.join(data_dir), 'rb') as file:
        data = file.read()  #(8*96+8)
    i = 0
    skip = 4 * 96  #4*97

    data_len = len(data)
    #last vector can be cut off in the middle of vec
    ##assert data_len % 4 == 0 and (data_len-4) % 96 == 0
    n_queries = 100000
    n_data = 10000000
    #n_queries = 1
    #n_data = 9
    n_total = n_data + 2 * n_queries  #2x to account for dups

    data_l = []
    stop_len = data_len - skip
    counter = 0
    byte_set = set()
    dup_count = 0

    while i < stop_len:

        #cur_bytes = data[i+4:i+skip] <--if download directly
        cur_bytes = data[i:i + skip]  #
        if cur_bytes in byte_set:
            dup_count += 1
            i += skip
            continue

        byte_set.add(cur_bytes)

        data_l.append(list(struct.unpack('96f', cur_bytes)))
        #except struct.error:
        #    print('struct error')
        #    pdb.set_trace()
        counter += 1
        if counter == n_total:
            #pdb.set_trace()
            break

        i += skip

    print('dup count {}'.format(dup_count))
    print('number of vectors {}'.format(len(data_l)))
    queries = torch.FloatTensor(data_l[:n_queries])
    dataset = torch.FloatTensor(data_l[n_queries:n_queries + n_data])
    torch.save(queries, '/large/prefix10m_queries.pt')
    torch.save(dataset, '/large/prefix10m_dataset.pt')
    #pdb.set_trace()
    #need bah size 200 for 10mil to not be out of memory
    answers = utils.dist_rank(queries, k=10, data_y=dataset, include_self=True)
def knn_dist(X, k=10, sum_dist=False):

    min_dist, idx = utils.dist_rank(X, k=k, largest=False)

    if sum_dist:
        dist_score = min_dist.sum(-1)
    else:
        dist_score = min_dist.mean(-1)

    return dist_score
Example #6
0
    def predict(self, query, k):

        #query = query.to(utils.device)
        if isinstance(query, np.ndarray):
            query = torch.from_numpy(query).to(utils.device)
        
        #self centers have dimension 1, torch.Size([100, 1024])
        if hasattr(self, 'opt') and (self.opt.glove or self.opt.sift) and self.centers.size(1) > 512:
            
            centers = self.centers.t()
            idx = utils.dist_rank(query, k, data_y=centers, largest=False)
        else:
            
            q_norm = torch.sum(query ** 2, dim=1).view(-1, 1)
            dist = q_norm + self.centers_norm - 2*torch.mm(query, self.centers)
        
            if k > dist.size(1):
                k = dist.size(1)
            _, idx = torch.topk(dist, k=k, dim=1, largest=False)
            #move predict to numpy
        
        idx = idx.cpu().numpy()
        return idx    
        prev = d.eq(data).sum(-1)
        #if i == 10:
        #    pdb.set_trace()
        if (prev == 96).sum().item() > 1:  #1 for itself
            selected[i] = 0
            print('{} dup '.format(i))

    data = data[selected]
    print('{} duplicates'.format(len(data) - selected.sum()))
    np.save('data/prefix1m_dataset2.npy', data.cpu().numpy())

    return data


if __name__ == '__main__':
    remove_dup_bool = False
    if remove_dup_bool:
        dataset = np.load('data/prefix1m_dataset.npy')
        dataset = torch.from_numpy(dataset).cuda()
        dataset = remove_dup(dataset)

        queries = torch.from_numpy(np.load('data/prefix1m_queries.npy')).cuda()
        answers = utils.dist_rank(queries,
                                  k=10,
                                  data_y=dataset,
                                  include_self=True)

        np.save('data/prefix1m_answers2.npy', answers.cpu().numpy())
    else:
        create_dataset()
def k_means(dataset, dataset_idx, ht2cutsz, height, n_clusters, opt):  #ranks

    num_points = dataset.shape[0]

    dimension = dataset.shape[1]

    use_kahip_solver = False
    if opt.kmeans_use_kahip_height == height:
        use_kahip_solver = True
    if use_kahip_solver:
        solver = kahip_solver.KahipSolver()
    elif opt.fast_kmeans:
        solver = kmeans.FastKMeans(dataset, n_clusters, opt)
    elif opt.itq:
        solver = itq.ITQSolver(dataset, n_clusters)
    elif opt.cplsh:
        solver = cplsh.CPLshSolver(dataset, n_clusters, opt)
    elif opt.pca:
        assert n_clusters == 2
        solver = pca.PCASolver(dataset, opt)
    elif opt.st:
        assert n_clusters == 2
        solver = pca.STSolver(dataset, opt.glob_st_ranks, dataset_idx, opt)
    elif opt.rp:
        if n_clusters != 2:
            raise Exception('n_cluster {} must be 2!'.format(n_clusters))
        solver = pca.RPSolver(dataset, opt)
    elif km_method == 'km':
        solver = KMeans(n_clusters=n_clusters, max_iter=max_loyd)
        solver.fit(dataset)
    elif km_method == 'mbkm':
        solver = MiniBatchKMeans(n_clusters=n_clusters, max_iter=max_loyd)
        solver.fit(dataset)
    else:
        raise Exception('method {} not supported'.format(km_method))

    #print("Ranking clusters for data and query points...")
    #dataset_dist = solver.transform(dataset) #could be useful, commented out for speed
    #queries_dist = solver.transform(queries)

    #the distances to cluster centers, ranked smallest first.
    #dataset_dist_idx = np.argsort(dataset_dist, axis=1)
    #queries_dist_idx = np.argsort(queries_dist, axis=1)

    if use_kahip_solver:
        #output is numpy array
        d_cls_idx = solver.predict(dataset_idx)
    elif isinstance(solver, kmeans.FastKMeans):
        d_cls_idx = solver.predict(dataset, k=1)
        d_cls_idx = d_cls_idx.reshape(-1)
    elif isinstance(solver, cplsh.CPLshSolver):

        d_cls_idx = solver.predict(dataset, k=1)
    elif isinstance(solver, itq.ITQSolver):
        d_cls_idx = solver.predict(dataset, k=1)
        d_cls_idx = d_cls_idx.reshape(-1)
    else:
        d_cls_idx = solver.predict(dataset)

    #lists of indices (not dataset points) for each class. Note each list element is a tuple.
    #list of np arrays
    d_idx_l = [np.where(d_cls_idx == i)[0] for i in range(n_clusters)]

    #q_idx_l = [np.where(q_cls_idx==i) for i in range(n_clusters)] #could be useful, commented out for speed

    compute_cut_sz_b = False
    if compute_cut_sz_b:
        ranks = utils.dist_rank(dataset, k=opt.k, opt=opt)
        #ranks are assumed to be 1-based
        ranks += 1
        cut_sz = compute_cut_size(d_cls_idx.tolist(), ranks)
        ht2cutsz[height].append(cut_sz)

    return d_idx_l, solver
        res_l = ['Catalyzed data ']

    ds, qu, neigh = load_data(utils.data_dir, opt)
    if opt.cplsh and opt.sift:
        ds = ds / np.sqrt((ds**2).sum(-1, keepdims=True))
        qu = qu / np.sqrt((qu**2).sum(-1, keepdims=True))
        qu = qu[:500]
        neigh = neigh[:500]

    n_repeat = 1
    #search tree
    #global glob_st_ranks
    #if glob_st_ranks is None:
    if opt.st:
        opt.glob_st_ranks = utils.dist_rank(ds,
                                            opt.k,
                                            include_self=True,
                                            opt=opt)
        torch.save(opt.glob_st_ranks, 'st_ranks_glove')

    for i in range(n_repeat):
        for height in height_l:
            acc, probe, probe95 = run_main(height, ds, qu, neigh, opt)
            res_l.append(
                str(height) + ' ' + ' '.join(
                    [str(acc[0, 0]),
                     str(probe[0, 0]),
                     str(probe95[0, 0])]))
    res_str = '\n'.join(res_l)
    if opt.rp:
        with open(osp.join(utils.data_dir, 'rp_data_mnist.md'), 'a') as file:
            file.write(res_str + '\n')
Example #10
0
def run_kmkahip(height_preset, opt, dataset, queryset, neighbors):
    
    k = opt.k
    print('Configs: {} \n Starting data processing and training ...'.format(opt))
    #this root node is a dummy node, since it doesn't have a trained model or idx2bin
    train_node = train.TrainNode(-1, opt, -1)
    
    swap_query_to_data = False
    if swap_query_to_data:
        print('** NOTE: Modifying queryset as part of dataset **')        
        queryset = dataset[:11000]
        #queryset = dataset
        neighbors = utils.dist_rank(queryset, k=opt.k, data_y=dataset, largest=False)        
        #dist += 2*torch.max(dist).item()*torch.eye(len(dist)) #torch.diag(torch.max(dist))
        #val, neighbors = torch.topk(dist, k=opt.k, dim=1, largest=False)        
    
    #dsnode_path = opt.dsnode_path + str(opt.n_clusters)
    #dsnode = utils.pickle_load(dsnode_path)
    
    #check if need to normalize data. Remove second conditions eventually.
    if opt.normalize_data and dataset[0].norm(p=2).item() != 1 and not opt.glove:
        print('Normalizing data ...')
        dataset = utils.normalize(dataset)
        queryset = utils.normalize(queryset)
        
    #create data tree used for training
    n_clusters = opt.n_clusters
    
    height = height_preset
    n_bins = 1
    
    ds_idx = torch.LongTensor(list(range(len(dataset))))
    print('{} height: {} level2action {}'.format(ds_idx.size(), height, opt.level2action))
    
    idx2bin = {}
    ht2cutsz = defaultdict(list) 
    #used for memoizing partition results
    branching_l = ['0']
    all_ranks = None
    
    root_dsnode = create_data_tree_root(dataset, all_ranks, ds_idx, train_node, idx2bin, height, branching_l,ht2cutsz, opt)
    print('Done creating training tree. Starting evaluation ...')

    #top node only first child node is train node.
    eval_root = train.EvalNode(train_node.children[0])

    ''' Evaluate '''
    
    with torch.no_grad():
        print('About to evaluate model! {} height: {} level2action {}'.format(ds_idx.size(), height, opt.level2action))                    
        acc, probe_count, probe_count95 = train.eval_model(eval_root, queryset, neighbors, n_bins, opt)
    
    print('cut_sizes {}'.format(ht2cutsz))
    print('Configs: {}'.format(opt))
    print('acc {} probe count {} 95th {}'.format(acc, probe_count, probe_count95))

    ''' Serialize '''
    serialize_bool = False if 'kahip' in set(opt.level2action.values()) else True
    serialize_bool = True
    if serialize_bool:
        print('Serializing eval root...')
        if opt.sift:
            data_name = 'sift'
        elif opt.glove:
            data_name = 'glove'
        elif opt.prefix10m:
            data_name = 'prefix10m'
        else:
            data_name = 'mnist'
        idx2bin = eval_root.idx2bin
        if 'logreg' in opt.level2action.values():
            serial_path = 'evalroot_{}_ht{}_{}_{}{}nn{}logreg'
        else:
            serial_path = 'evalroot_{}_ht{}_{}_{}{}nn{}'
        eval_root_path = osp.join(opt.data_dir, serial_path.format(data_name, height, n_clusters, opt.k_graph, opt.k, opt.nn_mult)) 
        eval_root_dict = {'eval_root':eval_root, 'opt':opt}
        utils.pickle_dump(eval_root_dict, eval_root_path)
        print('Done serializing {}'.format(eval_root_path))
        #dsnode_path = opt.dsnode_path + str(opt.n_clusters)
        #utils.pickle_dump(root_dsnode, dsnode_path)
    
    with open(osp.join(opt.data_dir, 'cutsz_k{}_ht{}_{}'.format(k, height, n_clusters)), 'w') as file:
        file.write(str(ht2cutsz))
        file.write('\n\n')
        file.write(str(opt))
Example #11
0
def create_data_tree_root(dataset, all_ranks, ds_idx, train_node, idx2bin, height, branching_l, ht2cutsz, opt):

    datalen = len(ds_idx)
    if datalen <= opt.k:
        return None
    graph_path = os.path.join(opt.data_dir, opt.graph_file) #'../data/knn.graph'
    
    #ranks are 1-based
    if opt.glove or opt.sift or opt.prefix10m: #and len(branching_l) == 1:

        if opt.glove:
            #custom paths
        #if opt.glove and opt.k_graph==50: #april, 50NN graph file
            #graph_path = os.path.join(opt.data_dir, 'glove50_'+opt.graph_file) #'../data/knn.graph'
            graph_path = os.path.join(opt.data_dir, opt.graph_file) #'../data/knn.graph'
            #graph_path = os.path.join(opt.data_dir, 'glove10_sub10knn.graph')
            print('graph file {}'.format(graph_path))
        parts_path = run_kahip(graph_path, datalen, branching_l, height, opt)
        print('Done partitioning top level!')
        lines = utils.load_lines(parts_path)
        classes = [int(line) for line in lines]
        
        #read in all_ranks, for partitioning on further levels.
        all_ranks, idx2weights = read_all_ranks(opt)
        if opt.dataset_name != 'prefix10m':
            k1 = max(1, int(opt.nn_mult*opt.k))
            ranks = utils.dist_rank(dataset, k=k1)
        else:
            #subtract 1 as graph was created with 1-indexing for kahip.
            ranks = torch.load('/large/prefix10m10knn.graph.pt') - 1
        #create root DataNode dataset, ds_idx, parent_train_node, idx2bin, height, opt
        dsnode = add_datanode_children(dataset, (all_ranks, idx2weights), ds_idx, train_node, idx2bin, height-1, branching_l, classes, ht2cutsz, 0, opt, ranks, toplevel=True, root=True)    
        return dsnode

    #create graph from data.
    data = dataset[ds_idx]
    if len(branching_l) == 1: #this is always the case now
        #use tree created at top level throughout the hierarchy
        ranks = create_graph.create_knn_graph(data, k=opt.k, opt=opt) #should supply opt
        all_ranks = ranks
    else:
        assert all_ranks is not None        
        #else compute part of previous graph
        ranks = create_graph.create_knn_sub_graph(all_ranks, ds_idx, data, opt)
    
    n_edges = create_graph.write_knn_graph(ranks, graph_path)
    _, idx2weights = read_all_ranks(opt, path=graph_path)
        
    #create partition from graph
    #this overrides file each iteration        
    parts_path = run_kahip(graph_path, datalen, branching_l, height, opt)

    lines = utils.load_lines(parts_path)
    classes = [int(line) for line in lines]
    
    compute_cut_size_b = False and not opt.glove
    if compute_cut_size_b:
        cut_sz = compute_cut_size(classes, ranks)
        ht2cutsz[height].append((cut_sz, n_edges))                
    
    #create root DataNode dataset, ds_idx, parent_train_node, idx2bin, height, opt
    dsnode = add_datanode_children(dataset, (all_ranks, idx2weights), ds_idx, train_node, idx2bin, height-1, branching_l, classes, ht2cutsz, 0, opt, all_ranks-1, toplevel=True, root=True)
    #Note the above all_ranks is not 5*opt.k number of nearest neighbors.
    
    return dsnode
Example #12
0
def add_datanode_children(dataset, all_ranks_data, ds_idx, parent_train_node, idx2bin, height, branching_l, classes, ht2cutsz, cur_tn_idx, opt, ds_idx_ranks, toplevel=None, root=False):
    
    all_ranks, idx2weights = all_ranks_data
    n_class = opt.n_class if opt.n_class <= len(ds_idx) else len(ds_idx)
    '''
    For 2nd level, SIFT, say 64 parts, beyond 25 epochs train does not improve much.
    '''
    if opt.glove or opt.sift:
        n_epochs = 18 if len(branching_l)==1 else 15 #44 opt.n_epochs #opt.n_epochs ################stopping mechanism 65. 18 if len(branching_l)==1 else 15 <-for MCCE loss  #glove+sift: 18 then 15
    else:
        n_epochs = 18 if len(branching_l)==1 else 10 #opt.n_epochs #opt.n_epochs ################stopping mechanism 65.
    #85 good top level epoch number for MNIST. #glove+sift: 18 then 10
    
    toplevel = toplevel if toplevel is not None else (True if height > 0 else False)
    #need to train and get children idx (classes) from net.
    train_node = train.TrainNode(n_epochs, opt, height, toplevel=toplevel)
    #append node to parent
    parent_train_node.add_child(train_node, cur_tn_idx)
    
    dataset_data = dataset[ds_idx]
    if False and opt.sift:
        #'n' stands for neural and normalized
        dataset_n = dataset / dataset.norm(dim=1, p=2, keepdim=True).clamp(1)
        dataset_data_n = dataset_n[ds_idx]
    else:
        dataset_n = dataset
        dataset_data_n = dataset_data
        
    #height is 0 for leaf level nodes
    if False and height < 1:#not opt.compute_gt_nn: #height < 1: #not opt.compute_gt_nn:     True or
                
        train_node.train(dataset, dsnode, idx2bin, height)
        model = train_node.model        
        model.eval()
        classes_l = []
        chunk_sz = 90000
        
        dataset_len = len(dataset_data)
        for i in range(0, dataset_len, chunk_sz):
            
            end = min(i+chunk_sz, dataset_len)            
            cur_data = dataset_data[i:end, :]            
            classes_l.append(torch.argmax(model(cur_data), dim=1))

        classes = torch.cat(classes_l)

    action = opt.level2action[height]
    if action == 'km':
        #bottom level, use kmeans
        train_node.model = None
        train_node.trained = True
        train_node.idx2bin = idx2bin         
        
        solver = kmeans.FastKMeans(dataset_data, n_class, opt)
        d_cls_idx = solver.predict(dataset_data, k=1)

        d_cls_idx = d_cls_idx.reshape(-1)
        
        classes = torch.LongTensor(d_cls_idx)
        
        train_node.kmsolver = solver
        
        d_idx_l = [np.where(d_cls_idx==i)[0] for i in range(n_class)]
        
        train_node.probe_count_l = [len(l) for l in d_idx_l] #[(classes == i).sum().item() for i in range(n_class) ]        
    else:
        classes = torch.LongTensor(classes)
        
        if action == 'train':
            device = dataset.device
            '''
            #compute the ranks of top  classes. Using centers of all points in a class
            
            sums = torch.zeros(n_class, dataset_data.size(-1), device=device)
            classes_exp = classes.unsqueeze(1).expand_as(dataset_data).to(device)
            sums.scatter_add_(0, classes_exp, dataset_data)
            
            lens = torch.zeros(n_class)#, dtype=torch.int64)
            lens_ones = torch.ones(dataset_data.size(0))# , dtype=torch.int64)
            lens.scatter_add_(0, classes, lens_ones)
            lens = lens.to(device)
            centers = sums / lens.unsqueeze(-1)
            
            ranks = utils.dist_rank(dataset_data, k=n_class, data_y=centers, include_self=True)
            '''
            
            dsnode = DataNode(ds_idx, classes, n_class, ranks=ds_idx_ranks)
            
            #if opt.sift:          
                #center as well?
            train_node.train(dataset_n, dsnode, idx2bin, height)
            #else:
            #    train_node.train(dataset, dsnode, idx2bin, height)
            model = train_node.model        
            model.eval()
            classes_l = []
            chunk_sz = 80000
            dataset_len = len(dataset_data_n)
            for i in range(0, dataset_len, chunk_sz):
                end = min(i+chunk_sz, dataset_len)            
                cur_data = dataset_data_n[i:end, :]            
                classes_l.append(torch.argmax(model(cur_data), dim=1))

            classes = torch.cat(classes_l)
        elif action == 'logreg':
            
            train_node.model = None
            train_node.trained = True
            train_node.idx2bin = idx2bin

            cur_path = None
            if opt.glove:
                cur_path = osp.join(utils.data_dir, 'lg_glove')
            elif opt.sift:
                cur_path = osp.join(utils.data_dir, 'lg_sift')
            
            if root and cur_path is not None:                
                if osp.exists(cur_path):
                    #deserialize
                    with open(cur_path, 'rb') as file:
                        solver = pickle.load(file)
                else:
                    #serialize
                    solver = logreg.LogReg(dataset_data, classes, opt)
                    with open(cur_path, 'wb') as file:
                        pickle.dump(solver, file)
                    
            else:
                solver = logreg.LogReg(dataset_data, classes, opt)
            
            d_cls_idx = solver.predict(dataset_data, k=1)
            d_cls_idx = d_cls_idx.reshape(-1)

            classes = torch.LongTensor(d_cls_idx)
            train_node.kmsolver = solver
            
            
            d_idx_l = [np.where(d_cls_idx==i)[0] for i in range(n_class)]

            train_node.probe_count_l = [len(l) for l in d_idx_l] 
        elif action == 'kahip':
            #kahip only
            train_node.model = None
            train_node.trained = True
            train_node.idx2bin = idx2bin         
            train_node.idx2kahip = {}

            for i, cur_idx in enumerate(ds_idx):
                train_node.idx2kahip[cur_idx.item()] = classes[i]
            
            train_node.probe_count_l = [(classes == i).sum().item() for i in range(n_class) ]
        else:
            raise Exception('Action must be either kahip km or train')
    dsnode = DataNode(ds_idx, classes, n_class)    
    #ds_idx needs to be indices wrt entire dataset.    
    #y are labels of clusters, indices 0 to num_cluster. 
    
    if height > 0: 
        #recurse based on children
        procs = []
        next_act = opt.level2action[height-1]
        parallelize = next_act in ['train', 'kahip', 'logreg'] 
        if parallelize:
            p_man = mp.Manager()
            idx2classes = p_man.dict()
        
        branching_l_l = []
        child_ds_idx_l = []
        #index of child TrainNode
        tnode_idx_l = []
        ranks_l = []
        
        for cur_class in range(n_class):
            #pick the samples having this class        
            child_ds_idx = ds_idx[classes==cur_class]
            child_branching_l = list(branching_l)
            child_branching_l.append(str(cur_class))

            if len(child_ds_idx) < opt.k:
                #create train_node without model, but with base_idx, leaf_idx etc. Need to have placeholder for correct indexing.
                child_tn = train.TrainNode(opt.n_epochs, opt, height-1)
                child_tn.base_idx = len(set(idx2bin.values()))
                child_tn.leaf_idx = [child_tn.base_idx]
                for j in child_ds_idx:
                    idx2bin[j.item()] = child_tn.base_idx
                
                child_tn.probe_count_l = [len(child_ds_idx)]
                child_tn.idx2bin = idx2bin
                train_node.add_child(child_tn, cur_class)
            else:                
                ranks, all_ranks_data, graph_path = create_data_tree(dataset, all_ranks_data, child_ds_idx, train_node, idx2bin, height, child_branching_l, ht2cutsz, opt)
                branching_l_l.append(child_branching_l)
                
                #those knn graphs for kahip are one-based, and are lists and not tensors due to weights.
                if next_act == 'train':
                    k1 = max(1, int(opt.nn_mult*opt.k))
                    ranks_l.append(utils.dist_rank(dataset[child_ds_idx], k=k1))
                else:
                    ranks_l.append([])
                if parallelize:
                    datalen = len(child_ds_idx)                
                    p = mp.Process(target=process_child, args=(ranks, graph_path, datalen, child_branching_l, height, idx2classes, len(procs), ht2cutsz, opt))
                    
                    #print('processed child process!! len {}'.format(len(cur_classes)))                
                    procs.append(p)
                    p.start()
                tnode_idx_l.append(cur_class)
                child_ds_idx_l.append(child_ds_idx)
                
        for p in procs:
            p.join()
        print('~~~~~~~~~~finished p.join. check classes_l')
        
        for i in range(len(branching_l_l )):
            if parallelize:
                classes = idx2classes[i]
            else:
                classes = None
            child_branching_l = branching_l_l[i]
            child_ds_idx = child_ds_idx_l[i]
            child_ranks = ranks_l[i]
            #create root DataNode dataset, ds_idx, parent_train_node, idx2bin, height, opt
            child_dsnode = add_datanode_children(dataset, all_ranks_data, child_ds_idx, train_node, idx2bin, height-1, child_branching_l, classes, ht2cutsz, tnode_idx_l[i], opt, child_ranks)
            dsnode.add_child(child_dsnode)
    else:
        
        train_node.base_idx = len(set(idx2bin.values()))
        train_node.leaf_idx = range(train_node.base_idx, train_node.base_idx+n_class)

        if train_node.kmsolver is not None:
            predicted = train_node.kmsolver.predict(dataset_data, k=1)
            for i, pred in enumerate(predicted):
                idx2bin[ds_idx[i].item()] = train_node.base_idx + int(pred)
        else:
            #predict entire dataset at once!
            if opt.compute_gt_nn or action == 'kahip':
                for i, data in enumerate(dataset_data):                
                    predicted = train_node.idx2kahip[ds_idx[i].item()].item()
                    idx2bin[ds_idx[i].item()] = train_node.base_idx + predicted
            elif train_node.model is not None:
                dataset_data_len = len(dataset_data_n)

                chunk_sz = 80000
                if dataset_data_len > chunk_sz:
                    pred_l = []
                    
                    for p in range(0, dataset_data_len, chunk_sz):
                        cur_data = dataset_data_n[p : min(p+chunk_sz, dataset_data_len)]    
                        pred_l.append( torch.argmax(model(cur_data), dim=1) )

                    predicted = torch.cat(pred_l)
                else:
                    predicted = torch.argmax(model(dataset_data_n), dim=1)
                for i, pred in enumerate(predicted):                    
                    #idx2bin[ds_idx[i].item()] = train_node.base_idx + train_node.leaf_idx[predicted]
                    idx2bin[ds_idx[i].item()] = train_node.base_idx + int(pred)
            else:
                raise Exception('Training inconsistency')
    return dsnode