def nn_index(vocab, idx_to_vec, nAssign, checks=None): """ >>> idx_to_vec = depc.d.get_feat_vecs(aid_list)[0] >>> vocab = vocab >>> nAssign = 1 """ # Assign each vector to the nearest visual words assert nAssign > 0, 'cannot assign to 0 neighbors' if checks is None: checks = vocab.flann_params['checks'] try: idx_to_vec = idx_to_vec.astype( vocab.wordflann._FLANN__curindex_data.dtype) _idx_to_wx, _idx_to_wdist = vocab.wordflann.nn_index(idx_to_vec, nAssign, checks=checks) except pyflann.FLANNException as ex: ut.printex( ex, 'probably misread the cached flann_fpath=%r' % (getattr(vocab.wordflann, 'flann_fpath', None), ), ) raise else: _idx_to_wx = vt.atleast_nd(_idx_to_wx, 2) _idx_to_wdist = vt.atleast_nd(_idx_to_wdist, 2) return _idx_to_wx, _idx_to_wdist
def neighbors(query, temp_K): _idxs, _dists = query.get_neighbors(query.vecs, temp_K) idxs = vt.atleast_nd(_idxs, 2) dists = vt.atleast_nd(_dists, 2) # Flag any neighbors that are invalid validflags = ~in1d_shape(query.get_axs(idxs), query.invalid_axs) # Store results in an object cand = TempResults(query.index, idxs, dists, validflags) return cand
def nn_index(self, idx2_vec, nAssign): """ >>> idx2_vec = depc.d.get_feat_vecs(aid_list)[0] >>> self = vocab >>> nAssign = 1 """ # Assign each vector to the nearest visual words assert nAssign > 0, 'cannot assign to 0 neighbors' try: _idx2_wx, _idx2_wdist = self.wordflann.nn_index(idx2_vec, nAssign) except pyflann.FLANNException as ex: ut.printex(ex, 'probably misread the cached flann_fpath=%r' % (self.wordflann.flann_fpath,)) raise else: _idx2_wx = vt.atleast_nd(_idx2_wx, 2) _idx2_wdist = vt.atleast_nd(_idx2_wdist, 2) return _idx2_wx, _idx2_wdist
def nn_index(self, idx2_vec, nAssign): """ >>> idx2_vec = depc.d.get_feat_vecs(aid_list)[0] >>> self = vocab >>> nAssign = 1 """ # Assign each vector to the nearest visual words assert nAssign > 0, 'cannot assign to 0 neighbors' try: _idx2_wx, _idx2_wdist = self.wordflann.nn_index(idx2_vec, nAssign) except pyflann.FLANNException as ex: ut.printex( ex, 'probably misread the cached flann_fpath=%r' % (self.wordflann.flann_fpath, )) raise else: _idx2_wx = vt.atleast_nd(_idx2_wx, 2) _idx2_wdist = vt.atleast_nd(_idx2_wdist, 2) return _idx2_wx, _idx2_wdist
def augment_graph_mst(ibs, graph): import wbia.plottool as pt # spantree_aids1_ = [] # spantree_aids2_ = [] # Add edges between all names aid_list = list(graph.nodes()) aug_digraph = graph.copy() # Change all weights in initial graph to be small (likely to be part of mst) nx.set_edge_attributes(aug_digraph, name='weight', values=0.0001) aids1, aids2 = get_name_rowid_edges_from_aids(ibs, aid_list) if False: # Weight edges in the MST based on tenative distances # Get tentative node positions initial_pos = pt.get_nx_layout(graph.to_undirected(), 'graphviz')['node_pos'] # initial_pos = pt.get_nx_layout(graph.to_undirected(), 'agraph')['node_pos'] edge_pts1 = ut.dict_take(initial_pos, aids1) edge_pts2 = ut.dict_take(initial_pos, aids2) edge_pts1 = vt.atleast_nd(np.array(edge_pts1, dtype=np.int32), 2) edge_pts2 = vt.atleast_nd(np.array(edge_pts2, dtype=np.int32), 2) edge_weights = vt.L2(edge_pts1, edge_pts2) else: edge_weights = [1.0] * len(aids1) # Create implicit fully connected (by name) graph aug_edges = [(a1, a2, { 'weight': w }) for a1, a2, w in zip(aids1, aids2, edge_weights)] aug_digraph.add_edges_from(aug_edges) # Determine which edges need to be added to # make original graph connected by name aug_graph = aug_digraph.to_undirected() for cc_sub_graph in connected_component_subgraphs(aug_graph): mst_sub_graph = nx.minimum_spanning_tree(cc_sub_graph) mst_edges = mst_sub_graph.edges() for edge in mst_edges: redge = edge[::-1] # attr_dict = {'color': pt.DARK_ORANGE[0:3]} attr_dict = {'color': pt.BLACK[0:3]} if not (graph.has_edge(*edge) or graph.has_edge(*redge)): graph.add_edge(*redge, attr_dict=attr_dict)
def normalized_nearest_neighbors(flann, vecs2, K, checks=800): """ uses flann index to return nearest neighbors with distances normalized between 0 and 1 using sifts uint8 trick """ import vtool as vt if K == 0: (fx2_to_fx1, _fx2_to_dist_sqrd) = empty_neighbors(len(vecs2), 0) elif len(vecs2) == 0: (fx2_to_fx1, _fx2_to_dist_sqrd) = empty_neighbors(0, K) elif K > flann.get_indexed_shape()[0]: # Corner case, may be better to throw an assertion error raise MatchingError('not enough database features') #(fx2_to_fx1, _fx2_to_dist_sqrd) = empty_neighbors(len(vecs2), 0) else: fx2_to_fx1, _fx2_to_dist_sqrd = flann.nn_index(vecs2, num_neighbors=K, checks=checks) _fx2_to_dist = np.sqrt(_fx2_to_dist_sqrd.astype(np.float64)) # normalized dist fx2_to_dist = np.divide(_fx2_to_dist, PSEUDO_MAX_DIST) fx2_to_fx1 = vt.atleast_nd(fx2_to_fx1, 2) fx2_to_dist = vt.atleast_nd(fx2_to_dist, 2) return fx2_to_fx1, fx2_to_dist
def augment_graph_mst(ibs, graph): import plottool as pt #spantree_aids1_ = [] #spantree_aids2_ = [] # Add edges between all names aid_list = list(graph.nodes()) aug_digraph = graph.copy() # Change all weights in initial graph to be small (likely to be part of mst) nx.set_edge_attributes(aug_digraph, 'weight', .0001) aids1, aids2 = get_name_rowid_edges_from_aids(ibs, aid_list) if False: # Weight edges in the MST based on tenative distances # Get tentative node positions initial_pos = pt.get_nx_layout(graph.to_undirected(), 'graphviz')['node_pos'] #initial_pos = pt.get_nx_layout(graph.to_undirected(), 'agraph')['node_pos'] edge_pts1 = ut.dict_take(initial_pos, aids1) edge_pts2 = ut.dict_take(initial_pos, aids2) edge_pts1 = vt.atleast_nd(np.array(edge_pts1, dtype=np.int32), 2) edge_pts2 = vt.atleast_nd(np.array(edge_pts2, dtype=np.int32), 2) edge_weights = vt.L2(edge_pts1, edge_pts2) else: edge_weights = [1.0] * len(aids1) # Create implicit fully connected (by name) graph aug_edges = [(a1, a2, {'weight': w}) for a1, a2, w in zip(aids1, aids2, edge_weights)] aug_digraph.add_edges_from(aug_edges) # Determine which edges need to be added to # make original graph connected by name aug_graph = aug_digraph.to_undirected() for cc_sub_graph in nx.connected_component_subgraphs(aug_graph): mst_sub_graph = nx.minimum_spanning_tree(cc_sub_graph) mst_edges = mst_sub_graph.edges() for edge in mst_edges: redge = edge[::-1] #attr_dict = {'color': pt.DARK_ORANGE[0:3]} attr_dict = {'color': pt.BLACK[0:3]} if not (graph.has_edge(*edge) or graph.has_edge(*redge)): graph.add_edge(*redge, attr_dict=attr_dict)
def _update_weights(model, thresh=None): int_factor = 100 edge_weights = np.array(model.edge_weights) if thresh is None: thresh = model.estimate_threshold() else: if isinstance(thresh, six.string_types): thresh = model.estimate_threshold(method=thresh) #np.mean(edge_weights) weights = (edge_weights - thresh) * int_factor weights = weights.astype(np.int32) edges_ = np.array(model.edges).astype(np.int32) edges_ = vt.atleast_nd(edges_, 2) edges_.shape = (edges_.shape[0], 2) weighted_edges = np.vstack((edges_.T, weights)).T weighted_edges = np.ascontiguousarray(weighted_edges) # Update internals model.thresh = thresh model.weighted_edges = weighted_edges model.weights = weights
def _update_weights(model, thresh=None): int_factor = 1E2 edge_weights = np.array(model.edge_weights) if thresh is None: thresh = model._estimate_threshold() else: if isinstance(thresh, six.string_types): thresh = model._estimate_threshold(method=thresh) #np.mean(edge_weights) if True: # Center and scale weights between -1 and 1 centered = (edge_weights - thresh) centered[centered < 0] = (centered[centered < 0] / thresh) centered[centered > 0] = (centered[centered > 0] / (1 - thresh)) newprob = (centered + 1) / 2 newprob[np.isnan(newprob)] = .5 # Apply logit rule # prevent infinity #pad = 1 / (int_factor * 2) pad = 1E6 perbprob = (newprob * (1.0 - pad * 2)) + pad weights = vt.logit(perbprob) else: weights = (edge_weights - thresh) # Conv weights[np.isnan(edge_weights)] = 0 weights = (weights * int_factor).astype(np.int32) edges_ = np.round(model.edges).astype(np.int32) edges_ = vt.atleast_nd(edges_, 2) edges_.shape = (edges_.shape[0], 2) weighted_edges = np.vstack((edges_.T, weights)).T weighted_edges = np.ascontiguousarray(weighted_edges) weighted_edges = np.nan_to_num(weighted_edges) # Remove edges with 0 weight as they have no influence weighted_edges = weighted_edges.compress(weighted_edges.T[2] != 0, axis=0) # Update internals model.thresh = thresh model.weighted_edges = weighted_edges model.weights = weights
def run_asmk_script(): with ut.embed_on_exception_context: # NOQA """ >>> from wbia.algo.smk.script_smk import * """ # NOQA # ============================================== # PREPROCESSING CONFIGURATION # ============================================== config = { # 'data_year': 2013, 'data_year': None, 'dtype': 'float32', # 'root_sift': True, 'root_sift': False, # 'centering': True, 'centering': False, 'num_words': 2**16, # 'num_words': 1E6 # 'num_words': 8000, 'kmeans_impl': 'sklearn.mini', 'extern_words': False, 'extern_assign': False, 'assign_algo': 'kdtree', 'checks': 1024, 'int_rvec': True, 'only_xy': False, } # Define which params are relevant for which operations relevance = {} relevance['feats'] = ['dtype', 'root_sift', 'centering', 'data_year'] relevance['words'] = relevance['feats'] + [ 'num_words', 'extern_words', 'kmeans_impl', ] relevance['assign'] = relevance['words'] + [ 'checks', 'extern_assign', 'assign_algo', ] # relevance['ydata'] = relevance['assign'] + ['int_rvec'] # relevance['xdata'] = relevance['assign'] + ['only_xy', 'int_rvec'] nAssign = 1 class SMKCacher(ut.Cacher): def __init__(self, fname, ext='.cPkl'): relevant_params = relevance[fname] relevant_cfg = ut.dict_subset(config, relevant_params) cfgstr = ut.get_cfg_lbl(relevant_cfg) dbdir = ut.truepath('/raid/work/Oxford/') super(SMKCacher, self).__init__(fname, cfgstr, cache_dir=dbdir, ext=ext) # ============================================== # LOAD DATASET, EXTRACT AND POSTPROCESS FEATURES # ============================================== if config['data_year'] == 2007: data = load_oxford_2007() elif config['data_year'] == 2013: data = load_oxford_2013() elif config['data_year'] is None: data = load_oxford_wbia() offset_list = data['offset_list'] all_kpts = data['all_kpts'] raw_vecs = data['all_vecs'] query_uri_order = data['query_uri_order'] data_uri_order = data['data_uri_order'] # del data # ================ # PRE-PROCESS # ================ import vtool as vt # Alias names to avoid errors in interactive sessions proc_vecs = raw_vecs del raw_vecs feats_cacher = SMKCacher('feats', ext='.npy') all_vecs = feats_cacher.tryload() if all_vecs is None: if config['dtype'] == 'float32': logger.info('Converting vecs to float32') proc_vecs = proc_vecs.astype(np.float32) else: proc_vecs = proc_vecs raise NotImplementedError('other dtype') if config['root_sift']: with ut.Timer('Apply root sift'): np.sqrt(proc_vecs, out=proc_vecs) vt.normalize(proc_vecs, ord=2, axis=1, out=proc_vecs) if config['centering']: with ut.Timer('Apply centering'): mean_vec = np.mean(proc_vecs, axis=0) # Center and then re-normalize np.subtract(proc_vecs, mean_vec[None, :], out=proc_vecs) vt.normalize(proc_vecs, ord=2, axis=1, out=proc_vecs) if config['dtype'] == 'int8': smk_funcs all_vecs = proc_vecs feats_cacher.save(all_vecs) del proc_vecs # ===================================== # BUILD VISUAL VOCABULARY # ===================================== if config['extern_words']: words = data['words'] assert config['num_words'] is None or len( words) == config['num_words'] else: word_cacher = SMKCacher('words') words = word_cacher.tryload() if words is None: with ut.embed_on_exception_context: if config['kmeans_impl'] == 'sklearn.mini': import sklearn.cluster rng = np.random.RandomState(13421421) # init_size = int(config['num_words'] * 8) init_size = int(config['num_words'] * 4) # converged after 26043 iterations clusterer = sklearn.cluster.MiniBatchKMeans( config['num_words'], init_size=init_size, batch_size=1000, compute_labels=False, max_iter=20, random_state=rng, n_init=1, verbose=1, ) clusterer.fit(all_vecs) words = clusterer.cluster_centers_ elif config['kmeans_impl'] == 'yael': from yael import ynumpy centroids, qerr, dis, assign, nassign = ynumpy.kmeans( all_vecs, config['num_words'], init='kmeans++', verbose=True, output='all', ) words = centroids word_cacher.save(words) # ===================================== # ASSIGN EACH VECTOR TO ITS NEAREST WORD # ===================================== if config['extern_assign']: assert config[ 'extern_words'], 'need extern cluster to extern assign' idx_to_wxs = vt.atleast_nd(data['idx_to_wx'], 2) idx_to_maws = np.ones(idx_to_wxs.shape, dtype=np.float32) idx_to_wxs = np.ma.array(idx_to_wxs) idx_to_maws = np.ma.array(idx_to_maws) else: from wbia.algo.smk import vocab_indexer vocab = vocab_indexer.VisualVocab(words) dassign_cacher = SMKCacher('assign') assign_tup = dassign_cacher.tryload() if assign_tup is None: vocab.flann_params['algorithm'] = config['assign_algo'] vocab.build() # Takes 12 minutes to assign jegous vecs to 2**16 vocab with ut.Timer('assign vocab neighbors'): _idx_to_wx, _idx_to_wdist = vocab.nn_index( all_vecs, nAssign, checks=config['checks']) if nAssign > 1: idx_to_wxs, idx_to_maws = smk_funcs.weight_multi_assigns( _idx_to_wx, _idx_to_wdist, massign_alpha=1.2, massign_sigma=80.0, massign_equal_weights=True, ) else: idx_to_wxs = np.ma.masked_array(_idx_to_wx, fill_value=-1) idx_to_maws = np.ma.ones(idx_to_wxs.shape, fill_value=-1, dtype=np.float32) idx_to_maws.mask = idx_to_wxs.mask assign_tup = (idx_to_wxs, idx_to_maws) dassign_cacher.save(assign_tup) idx_to_wxs, idx_to_maws = assign_tup # Breakup vectors, keypoints, and word assignments by annotation wx_lists = [ idx_to_wxs[left:right] for left, right in ut.itertwo(offset_list) ] maw_lists = [ idx_to_maws[left:right] for left, right in ut.itertwo(offset_list) ] vecs_list = [ all_vecs[left:right] for left, right in ut.itertwo(offset_list) ] kpts_list = [ all_kpts[left:right] for left, right in ut.itertwo(offset_list) ] # ======================= # FIND QUERY SUBREGIONS # ======================= ibs, query_annots, data_annots, qx_to_dx = load_ordered_annots( data_uri_order, query_uri_order) daids = data_annots.aids qaids = query_annots.aids query_super_kpts = ut.take(kpts_list, qx_to_dx) query_super_vecs = ut.take(vecs_list, qx_to_dx) query_super_wxs = ut.take(wx_lists, qx_to_dx) query_super_maws = ut.take(maw_lists, qx_to_dx) # Mark which keypoints are within the bbox of the query query_flags_list = [] only_xy = config['only_xy'] for kpts_, bbox in zip(query_super_kpts, query_annots.bboxes): flags = kpts_inside_bbox(kpts_, bbox, only_xy=only_xy) query_flags_list.append(flags) logger.info('Queries are crops of existing database images.') logger.info('Looking at average percents') percent_list = [ flags_.sum() / flags_.shape[0] for flags_ in query_flags_list ] percent_stats = ut.get_stats(percent_list) logger.info('percent_stats = %s' % (ut.repr4(percent_stats), )) import vtool as vt query_kpts = vt.zipcompress(query_super_kpts, query_flags_list, axis=0) query_vecs = vt.zipcompress(query_super_vecs, query_flags_list, axis=0) query_wxs = vt.zipcompress(query_super_wxs, query_flags_list, axis=0) query_maws = vt.zipcompress(query_super_maws, query_flags_list, axis=0) # ======================= # CONSTRUCT QUERY / DATABASE REPR # ======================= # int_rvec = not config['dtype'].startswith('float') int_rvec = config['int_rvec'] X_list = [] _prog = ut.ProgPartial(length=len(qaids), label='new X', bs=True, adjust=True) for aid, fx_to_wxs, fx_to_maws in _prog( zip(qaids, query_wxs, query_maws)): X = new_external_annot(aid, fx_to_wxs, fx_to_maws, int_rvec) X_list.append(X) # ydata_cacher = SMKCacher('ydata') # Y_list = ydata_cacher.tryload() # if Y_list is None: Y_list = [] _prog = ut.ProgPartial(length=len(daids), label='new Y', bs=True, adjust=True) for aid, fx_to_wxs, fx_to_maws in _prog(zip(daids, wx_lists, maw_lists)): Y = new_external_annot(aid, fx_to_wxs, fx_to_maws, int_rvec) Y_list.append(Y) # ydata_cacher.save(Y_list) # ====================== # Add in some groundtruth logger.info('Add in some groundtruth') for Y, nid in zip(Y_list, ibs.get_annot_nids(daids)): Y.nid = nid for X, nid in zip(X_list, ibs.get_annot_nids(qaids)): X.nid = nid for Y, qual in zip(Y_list, ibs.get_annot_quality_texts(daids)): Y.qual = qual # ====================== # Add in other properties for Y, vecs, kpts in zip(Y_list, vecs_list, kpts_list): Y.vecs = vecs Y.kpts = kpts imgdir = ut.truepath('/raid/work/Oxford/oxbuild_images') for Y, imgid in zip(Y_list, data_uri_order): gpath = ut.unixjoin(imgdir, imgid + '.jpg') Y.gpath = gpath for X, vecs, kpts in zip(X_list, query_vecs, query_kpts): X.kpts = kpts X.vecs = vecs # ====================== logger.info('Building inverted list') daids = [Y.aid for Y in Y_list] # wx_list = sorted(ut.list_union(*[Y.wx_list for Y in Y_list])) wx_list = sorted(set.union(*[Y.wx_set for Y in Y_list])) assert daids == data_annots.aids assert len(wx_list) <= config['num_words'] wx_to_aids = smk_funcs.invert_lists(daids, [Y.wx_list for Y in Y_list], all_wxs=wx_list) # Compute IDF weights logger.info('Compute IDF weights') ndocs_total = len(daids) # Use only the unique number of words ndocs_per_word = np.array([len(set(wx_to_aids[wx])) for wx in wx_list]) logger.info('ndocs_perword stats: ' + ut.repr4(ut.get_stats(ndocs_per_word))) idf_per_word = smk_funcs.inv_doc_freq(ndocs_total, ndocs_per_word) wx_to_weight = dict(zip(wx_list, idf_per_word)) logger.info('idf stats: ' + ut.repr4(ut.get_stats(wx_to_weight.values()))) # Filter junk Y_list_ = [Y for Y in Y_list if Y.qual != 'junk'] # ======================= # CHOOSE QUERY KERNEL # ======================= params = { 'asmk': dict(alpha=3.0, thresh=0.0), 'bow': dict(), 'bow2': dict(), } # method = 'bow' method = 'bow2' method = 'asmk' smk = SMK(wx_to_weight, method=method, **params[method]) # Specific info for the type of query if method == 'asmk': # Make residual vectors if True: # The stacked way is 50x faster # TODO: extend for multi-assignment and record fxs flat_query_vecs = np.vstack(query_vecs) flat_query_wxs = np.vstack(query_wxs) flat_query_offsets = np.array( [0] + ut.cumsum(ut.lmap(len, query_wxs))) flat_wxs_assign = flat_query_wxs flat_offsets = flat_query_offsets flat_vecs = flat_query_vecs tup = smk_funcs.compute_stacked_agg_rvecs( words, flat_wxs_assign, flat_vecs, flat_offsets) all_agg_vecs, all_error_flags, agg_offset_list = tup if int_rvec: all_agg_vecs = smk_funcs.cast_residual_integer( all_agg_vecs) agg_rvecs_list = [ all_agg_vecs[left:right] for left, right in ut.itertwo(agg_offset_list) ] agg_flags_list = [ all_error_flags[left:right] for left, right in ut.itertwo(agg_offset_list) ] for X, agg_rvecs, agg_flags in zip(X_list, agg_rvecs_list, agg_flags_list): X.agg_rvecs = agg_rvecs X.agg_flags = agg_flags[:, None] flat_wxs_assign = idx_to_wxs flat_offsets = offset_list flat_vecs = all_vecs tup = smk_funcs.compute_stacked_agg_rvecs( words, flat_wxs_assign, flat_vecs, flat_offsets) all_agg_vecs, all_error_flags, agg_offset_list = tup if int_rvec: all_agg_vecs = smk_funcs.cast_residual_integer( all_agg_vecs) agg_rvecs_list = [ all_agg_vecs[left:right] for left, right in ut.itertwo(agg_offset_list) ] agg_flags_list = [ all_error_flags[left:right] for left, right in ut.itertwo(agg_offset_list) ] for Y, agg_rvecs, agg_flags in zip(Y_list, agg_rvecs_list, agg_flags_list): Y.agg_rvecs = agg_rvecs Y.agg_flags = agg_flags[:, None] else: # This non-stacked way is about 500x slower _prog = ut.ProgPartial(label='agg Y rvecs', bs=True, adjust=True) for Y in _prog(Y_list_): make_agg_vecs(Y, words, Y.vecs) _prog = ut.ProgPartial(label='agg X rvecs', bs=True, adjust=True) for X in _prog(X_list): make_agg_vecs(X, words, X.vecs) elif method == 'bow2': # Hack for orig tf-idf bow vector nwords = len(words) for X in ut.ProgIter(X_list, label='make bow vector'): ensure_tf(X) bow_vector(X, wx_to_weight, nwords) for Y in ut.ProgIter(Y_list_, label='make bow vector'): ensure_tf(Y) bow_vector(Y, wx_to_weight, nwords) if method != 'bow2': for X in ut.ProgIter(X_list, 'compute X gamma'): X.gamma = smk.gamma(X) for Y in ut.ProgIter(Y_list_, 'compute Y gamma'): Y.gamma = smk.gamma(Y) # Execute matches (could go faster by enumerating candidates) scores_list = [] for X in ut.ProgIter(X_list, label='query %s' % (smk, )): scores = [smk.kernel(X, Y) for Y in Y_list_] scores = np.array(scores) scores = np.nan_to_num(scores) scores_list.append(scores) import sklearn.metrics avep_list = [] _iter = list(zip(scores_list, X_list)) _iter = ut.ProgIter(_iter, label='evaluate %s' % (smk, )) for scores, X in _iter: truth = [X.nid == Y.nid for Y in Y_list_] avep = sklearn.metrics.average_precision_score(truth, scores) avep_list.append(avep) avep_list = np.array(avep_list) mAP = np.mean(avep_list) logger.info('mAP = %r' % (mAP, ))
def conditional_knn(nnindexer, qfx2_vec, num_neighbors, invalid_axs): """ >>> from ibeis.algo.hots.neighbor_index import * # NOQA >>> qreq_ = ibeis.testdata_qreq_(defaultdb='seaturtles') >>> qreq_.load_indexer() >>> qfx2_vec = qreq_.ibs.get_annot_vecs(qreq_.qaids[0]) >>> num_neighbors = 2 >>> nnindexer = qreq_.indexer >>> ibs = qreq_.ibs >>> qaid = 1 >>> qencid = ibs.get_annot_encounter_text([qaid])[0] >>> ax2_encid = np.array(ibs.get_annot_encounter_text(nnindexer.ax2_aid)) >>> invalid_axs = np.where(ax2_encid == qencid)[0] """ #import ibeis import itertools def in1d_shape(arr1, arr2): return np.in1d(arr1, arr2).reshape(arr1.shape) get_neighbors = ut.partial(nnindexer.flann.nn_index, checks=nnindexer.checks, cores=nnindexer.cores) # Alloc space for final results K = num_neighbors shape = (len(qfx2_vec), K) qfx2_idx = np.full(shape, -1, dtype=np.int32) qfx2_rawdist = np.full(shape, np.nan, dtype=np.float64) qfx2_truek = np.full(shape, -1, dtype=np.int32) # Make a set of temporary indexes and loop variables limit = None limit = 4 K_ = K tx2_qfx = np.arange(len(qfx2_vec)) tx2_vec = qfx2_vec iter_count = 0 for iter_count in itertools.count(): if limit is not None and iter_count >= limit: break # Find a set of neighbors (tx2_idx, tx2_rawdist) = get_neighbors(tx2_vec, K_) tx2_idx = vt.atleast_nd(tx2_idx, 2) tx2_rawdist = vt.atleast_nd(tx2_rawdist, 2) tx2_ax = nnindexer.get_nn_axs(tx2_idx) # Check to see if they meet the criteria tx2_invalid = in1d_shape(tx2_ax, invalid_axs) tx2_valid = np.logical_not(tx2_invalid) tx2_num_valid = tx2_valid.sum(axis=1) tx2_notdone = tx2_num_valid < K tx2_done = np.logical_not(tx2_notdone) # Move completely valid queries into the results if np.any(tx2_done): done_qfx = tx2_qfx.compress(tx2_done, axis=0) # Need to parse which columns are the completed ones done_valid_ = tx2_valid.compress(tx2_done, axis=0) done_rawdist_ = tx2_rawdist.compress(tx2_done, axis=0) done_idx_ = tx2_idx.compress(tx2_done, axis=0) # Get the complete valid indicies rowxs, colxs = np.where(done_valid_) unique_rows, groupxs = vt.group_indices(rowxs) first_k_groupxs = [groupx[0:K] for groupx in groupxs] chosen_xs = np.hstack(first_k_groupxs) multi_index = (rowxs.take(chosen_xs), colxs.take(chosen_xs)) flat_xs = np.ravel_multi_index(multi_index, done_valid_.shape) done_rawdist = done_rawdist_.take(flat_xs).reshape((-1, K)) done_idx = done_idx_.take(flat_xs).reshape((-1, K)) # Write done results in output qfx2_idx[done_qfx, :] = done_idx qfx2_rawdist[done_qfx, :] = done_rawdist qfx2_truek[done_qfx, :] = vt.apply_grouping( colxs, first_k_groupxs) if np.all(tx2_done): break K_increase = (K - tx2_num_valid.min()) K_ += K_increase tx2_qfx = tx2_qfx.compress(tx2_notdone, axis=0) tx2_vec = tx2_vec.compress(tx2_notdone, axis=0) if nnindexer.max_distance_sqrd is not None: qfx2_dist = np.divide(qfx2_rawdist, nnindexer.max_distance_sqrd) else: qfx2_dist = qfx2_rawdist return (qfx2_idx, qfx2_dist, iter_count)