def main(cls, cmdline=True, **kw): """ Example: >>> kw = {'src': 'special:shapes8', >>> 'dst1': 'train.json', 'dst2': 'test.json'} >>> cmdline = False >>> cls = CocoSplitCLI >>> cls.main(cmdline, **kw) """ import kwcoco import kwarray from kwcoco.util import util_sklearn config = cls.CLIConfig(kw, cmdline=cmdline) print('config = {}'.format(ub.repr2(dict(config), nl=1))) if config['src'] is None: raise Exception('must specify source: {}'.format(config['src'])) print('reading fpath = {!r}'.format(config['src'])) dset = kwcoco.CocoDataset.coerce(config['src']) annots = dset.annots() gids = annots.gids cids = annots.cids # Balanced category split rng = kwarray.ensure_rng(config['rng']) shuffle = rng is not None self = util_sklearn.StratifiedGroupKFold(n_splits=config['factor'], random_state=rng, shuffle=shuffle) split_idxs = list(self.split(X=gids, y=cids, groups=gids)) idxs1, idxs2 = split_idxs[0] gids1 = sorted(ub.unique(ub.take(gids, idxs1))) gids2 = sorted(ub.unique(ub.take(gids, idxs2))) dset1 = dset.subset(gids1) dset2 = dset.subset(gids2) dset1.fpath = config['dst1'] print('Writing dset1 = {!r}'.format(dset1.fpath)) dset1.dump(dset1.fpath, newlines=True) dset2.fpath = config['dst2'] print('Writing dset2 = {!r}'.format(dset2.fpath)) dset2.dump(dset2.fpath, newlines=True)
def _cm_breaking(infr, cm_list=None, review_cfg={}): """ >>> review_cfg = {} """ if cm_list is None: cm_list = infr.cm_list ranks_top = review_cfg.get('ranks_top', None) ranks_bot = review_cfg.get('ranks_bot', None) # Construct K-broken graph edges = [] if ranks_bot is None: ranks_bot = 0 for count, cm in enumerate(cm_list): score_list = cm.annot_score_list rank_list = ub.argsort(score_list)[::-1] sortx = ub.argsort(rank_list) top_sortx = sortx[:ranks_top] bot_sortx = sortx[len(sortx) - ranks_bot:] short_sortx = list(ub.unique(top_sortx + bot_sortx)) daid_list = list(ub.take(cm.daid_list, short_sortx)) for daid in daid_list: u, v = (cm.qaid, daid) if v < u: u, v = v, u edges.append((u, v)) return edges
def paths(self, cwd=None, recursive=False): groups = (p.paths(cwd=cwd, recursive=recursive) for p in self.patterns) if self.predicate in {any}: # all}: yield from ub.unique(ub.flatten(groups)) elif self.predicate in {all}: # all}: yield from set.intersection(*map(set, groups)) else: raise NotImplementedError
def make_data(num_items, num_other, remove_fraction, keytype): if keytype == 'str': keytype = str if keytype == 'int': keytype = int first_keys = [random.randint(0, 1000) for _ in range(num_items)] k = int(remove_fraction * len(first_keys)) remove_sets = [list(ub.unique(random.choices(first_keys, k=k) + [random.randint(0, 1000) for _ in range(num_items)])) for _ in range(num_other)] first_dict = {keytype(k): k for k in first_keys} args = [first_dict] + [{keytype(k): k for k in ks} for ks in remove_sets] return args
def _balance_report(self, limit=None): # Print the epoch / item label frequency per epoch label_sequence = [] index_sequence = [] if limit is None: limit = self.num_batches for item_indices, _ in zip(self, range(limit)): item_indices = np.array(item_indices) item_labels = list(ub.take(self.index_to_label, item_indices)) index_sequence.extend(item_indices) label_sequence.extend(ub.unique(item_labels)) label_hist = ub.dict_hist(label_sequence) index_hist = ub.dict_hist(index_sequence) label_hist = ub.sorted_vals(label_hist, reverse=True) index_hist = ub.sorted_vals(index_hist, reverse=True) index_hist = ub.dict_subset(index_hist, list(index_hist.keys())[0:5]) print('label_hist = {}'.format(ub.repr2(label_hist, nl=1))) print('index_hist = {}'.format(ub.repr2(index_hist, nl=1)))
def variant(): import random import ubelt as ub num_items = 100 num_other = 1 first_keys = [random.randint(0, 1000) for _ in range(num_items)] remove_sets = [list(ub.unique(random.choices(first_keys, k=10) + [random.randint(0, 1000) for _ in range(num_items)])) for _ in range(num_other)] first_dict = {k: k for k in first_keys} args = [first_dict] + [{k: k for k in ks} for ks in remove_sets] dictclass = dict import timerit ti = timerit.Timerit(100, bestof=10, verbose=2) for timer in ti.reset('orig'): with timer: keys = set(first_dict) keys.difference_update(*map(set, args[1:])) new0 = dictclass((k, first_dict[k]) for k in keys) for timer in ti.reset('alt1'): with timer: remove_keys = {k for ks in args[1:] for k in ks} new1 = dictclass((k, v) for k, v in first_dict.items() if k not in remove_keys) for timer in ti.reset('alt2'): with timer: remove_keys = set.union(*map(set, args[1:])) new2 = dictclass((k, v) for k, v in first_dict.items() if k not in remove_keys) for timer in ti.reset('alt3'): with timer: remove_keys = set.union(*map(set, args[1:])) new3 = dictclass((k, first_dict[k]) for k in first_dict.keys() if k not in remove_keys) # Cannot use until 3.6 is dropped (it is faster) for timer in ti.reset('alt4'): with timer: remove_keys = set.union(*map(set, args[1:])) new4 = {k: v for k, v in first_dict.items() if k not in remove_keys} assert new1 == new0 assert new2 == new0 assert new3 == new0 assert new4 == new0
def fix_path(r): """ Removes duplicates from the path Variable """ PATH_SEP = os.path.pathsep pathstr = robos.get_env_var('PATH') import ubelt as ub pathlist = list(ub.unique(pathstr.split(PATH_SEP))) new_path = '' failed_bit = False for p in pathlist: if os.path.exists(p): new_path = new_path + p + PATH_SEP elif p == '': pass elif p.find('%') > -1 or p.find('$') > -1: print('PATH=%s has a envvar. Not checking existance' % p) new_path = new_path + p + PATH_SEP else: print('PATH=%s does not exist!!' % p) failed_bit = True #remove trailing semicolons if failed_bit: ans = input('Should I overwrite the path? yes/no?') if ans == 'yes': failed_bit = False if len(new_path) > 0 and new_path[-1] == PATH_SEP: new_path = new_path[0:-1] if failed_bit is True: print("Path FIXING Failed. A Good path should be: \n%s" % new_path) print("\n\n====\n\n The old path was:\n%s" % pathstr) elif pathstr == new_path: print("The path was already clean") else: robos.set_env_var('PATH', new_path)
def _squash_between(repo, start, stop, dry=False, verbose=True): """ inplace squash between, use external function that sets up temp branches to use this directly from the commandline. """ assert len(start.parents) == 1 assert start.authored_datetime < stop.authored_datetime assert repo.is_ancestor(ancestor_rev=start, rev=stop) # Do RFC2822 # ISO_8601 = '%Y-%m-%d %H:%M:%S %z' # NOQA # ts_start = start.authored_datetime.strftime(ISO_8601) # ts_stop = stop.authored_datetime.strftime(ISO_8601) ts_start = email.utils.format_datetime(start.authored_datetime) ts_stop = email.utils.format_datetime(stop.authored_datetime) # if ts_start.split()[0:4] == ts_stop.split()[0:4]: if start.authored_datetime.date() == stop.authored_datetime.date(): ts_stop_short = ' '.join(ts_stop.split()[4:]) else: ts_stop_short = ts_stop # Construct a new message commits = commits_between(repo, start, stop) messages = [commit.message for commit in commits] # messages = [commit.message for commit in streak._streak] unique_messages = ub.unique(messages) summary = '\n'.join(unique_messages) if summary == 'wip\n': summary = summary.strip('\n') new_msg = '{} - Squashed {} commits from <{}> to <{}>\n'.format( summary, len(commits), ts_start, ts_stop_short) if verbose: print(' * Creating new commit with message:') print(new_msg) old_head = repo.commit('HEAD') assert (stop == old_head or repo.is_ancestor(ancestor_rev=stop, rev=old_head)) if not dry: # ------------------ # MODIFICATION LOGIC # ------------------ # Go back in time to the sequence stopping point repo.git.reset(stop.hexsha, hard=True) # Undo commits from start to stop by softly reseting to just before the start before_start = start.parents[0] if verbose: print(' * reseting to before <start>') repo.git.reset(before_start.hexsha, soft=True) # Commit the changes in a new squashed commit and presever authored date if verbose: print(' * creating one commit with all modifications up to <stop>') repo.index.commit(new_msg, author_date=ts_stop) # If <stop> was not the most recent commit, we need to take those back on if stop != old_head: # Copy commits following the end of the streak in front of our new commit if verbose: print(' * fixing up the head') try: # above_commits = commits_between(repo, stop, old_head) # print('above_commits = {}'.format(ub.repr2(above_commits, si=True))) above = stop.hexsha + '..' + old_head.hexsha # above = streak.child.hexsha + '..' + old_head repo.git.cherry_pick(above, allow_empty=True) except git.GitCommandError: print('ERROR: need to roll back') raise else: if verbose: print(' * already at the head, no need to fix')
def _split_train_vali_test(coco_dset, factor=3): """ Args: factor (int): number of pieces to divide images into CommandLine: xdoctest -m /home/joncrall/code/ndsampler/ndsampler/coerce_data.py _split_train_vali_test Example: >>> from ndsampler.coerce_data import _split_train_vali_test >>> import kwcoco >>> coco_dset = kwcoco.CocoDataset.demo('shapes8') >>> split_gids = _split_train_vali_test(coco_dset) >>> print('split_gids = {}'.format(ub.repr2(split_gids, nl=1))) """ import kwarray images = coco_dset.images() def _stratified_split(gids, cids, n_splits=2, rng=None): """ helper to split while trying to maintain class balance within images """ rng = kwarray.ensure_rng(rng) from ndsampler.utils import util_sklearn selector = util_sklearn.StratifiedGroupKFold(n_splits=n_splits, random_state=rng, shuffle=True) # from sklearn import model_selection # selector = model_selection.StratifiedKFold( # n_splits=n_splits, random_state=rng, shuffle=True) skf_list = list(selector.split(X=gids, y=cids, groups=gids)) trainx, testx = skf_list[0] if 0: _train_gids = set(ub.take(gids, trainx)) _test_gids = set(ub.take(gids, testx)) print('_train_gids = {!r}'.format(_train_gids)) print('_test_gids = {!r}'.format(_test_gids)) return trainx, testx # Create flat table of image-ids and category-ids gids, cids = [], [] for gid_, cids_ in zip(images, images.annots.cids): cids.extend(cids_) gids.extend([gid_] * len(cids_)) # Split into learn/test then split learn into train/vali learnx, testx = _stratified_split(gids, cids, rng=2997217409, n_splits=factor) learn_gids = list(ub.take(gids, learnx)) learn_cids = list(ub.take(cids, learnx)) _trainx, _valix = _stratified_split(learn_gids, learn_cids, rng=140860164, n_splits=factor) trainx = learnx[_trainx] valix = learnx[_valix] split_gids = { 'train': sorted(ub.unique(ub.take(gids, trainx))), 'vali': sorted(ub.unique(ub.take(gids, valix))), 'test': sorted(ub.unique(ub.take(gids, testx))), } if True: # Hack to favor training a good model over testing it properly The only # real fix to this is to add more data, otherwise its simply a systemic # issue. split_gids['vali'] = sorted( set(split_gids['vali']) - set(split_gids['train'])) split_gids['test'] = sorted( set(split_gids['test']) - set(split_gids['train'])) split_gids['test'] = sorted( set(split_gids['test']) - set(split_gids['vali'])) if __debug__: import itertools as it for a, b in it.combinations(split_gids.values(), 2): if (set(a) & set(b)): print('split_gids = {!r}'.format(split_gids)) assert False return split_gids
def _cm_training_pairs(infr, qreq_=None, cm_list=None, top_gt=2, mid_gt=2, bot_gt=2, top_gf=2, mid_gf=2, bot_gf=2, rand_gt=2, rand_gf=2, rng=None): """ Constructs training data for a pairwise classifier Example: >>> # ENABLE_DOCTEST >>> infr = testdata_infr('PZ_MTEST') >>> infr.exec_matching(cfgdict={ >>> 'can_match_samename': True, >>> 'K': 4, >>> 'Knorm': 1, >>> 'prescore_method': 'csum', >>> 'score_method': 'csum' >>> }) >>> exec(ut.execstr_funckw(infr._cm_training_pairs)) >>> rng = np.random.RandomState(42) >>> aid_pairs = np.array(infr._cm_training_pairs(rng=rng)) >>> print(len(aid_pairs)) >>> assert np.sum(aid_pairs.T[0] == aid_pairs.T[1]) == 0 """ if qreq_ is None: cm_list = infr.cm_list qreq_ = infr.qreq_ ibs = infr.ibs aid_pairs = [] dnids = qreq_.get_qreq_annot_nids(qreq_.daids) # dnids = qreq_.get_qreq_annot_nids(qreq_.daids) rng = util.ensure_rng(rng) for cm in ub.ProgIter(cm_list, desc='building pairs'): all_gt_aids = cm.get_top_gt_aids(ibs) all_gf_aids = cm.get_top_gf_aids(ibs) gt_aids = util.take_percentile_parts(all_gt_aids, top_gt, mid_gt, bot_gt) gf_aids = util.take_percentile_parts(all_gf_aids, top_gf, mid_gf, bot_gf) # get unscored examples unscored_gt_aids = [ aid for aid in qreq_.daids[cm.qnid == dnids] if aid not in cm.daid2_idx ] rand_gt_aids = util.random_sample(unscored_gt_aids, rand_gt, rng=rng) # gf_aids = cm.get_groundfalse_daids() _gf_aids = qreq_.daids[cm.qnid != dnids] _gf_aids = qreq_.daids.compress(cm.qnid != dnids) # gf_aids = ibs.get_annot_groundfalse(cm.qaid, daid_list=qreq_.daids) rand_gf_aids = util.random_sample(_gf_aids, rand_gf, rng=rng).tolist() chosen_daids = list( ub.unique(gt_aids + gf_aids + rand_gf_aids + rand_gt_aids)) aid_pairs.extend([(cm.qaid, aid) for aid in chosen_daids if cm.qaid != aid]) return aid_pairs
def ford_circles(): """ Draw Ford Circles This is a Ford Circle diagram of the Rationals and Float32 numbers. Only 163 of the 32608 rationals I generated can be exactly represented by a float32. [MF 14] [MF 95] [MF 14] https://www.youtube.com/watch?v=83ZjYvkdzYI&list=PL5A714C94D40392AB&index=14 [MF 95] https://www.youtube.com/watch?v=gATEJ3f3FBM&list=PL5A714C94D40392AB&index=95 Examples: import kwplot kwplot.autompl() """ import kwplot import ubelt as ub import matplotlib as mpl plt = kwplot.autoplt() sns = kwplot.autosns() # NOQA limit = 256 * 256 print('limit = {!r}'.format(limit)) rats_to_plot = set() maxx = 1 _iter = Rational.members(limit=limit) _genrat = set(ub.ProgIter(_iter, total=limit, desc='gen rats')) rats_to_plot |= _genrat rats_to_plot2 = sorted({Rational(r % maxx) for r in rats_to_plot} | {maxx}) floats = sorted( ub.unique(map(float, rats_to_plot2), key=lambda f: f.as_integer_ratio())) print(f'{len(rats_to_plot) = }') print(f'{len(rats_to_plot2) = }') print(f'{len(floats) = }') import numpy as np ax = kwplot.figure(fnum=1, doclf=True).gca() prog = ub.ProgIter(sorted(rats_to_plot2), verbose=1) dtype = np.float32 patches = ub.ddict(list) errors = [] for rat in prog: denominator = rat.denominator radius = 1 / (2 * (denominator * denominator)) point = (rat, radius) flt = dtype(rat) a, b = flt.as_integer_ratio() flt_as_rat = Rational(a, b) error = abs(rat - flt_as_rat) if error == 0: new_circle = plt.Circle(point, radius, facecolor='dodgerblue', edgecolor='none', linewidth=0, alpha=0.5) patches['good'].append(new_circle) else: errors.append(error) # Plot a line for error new_circle = plt.Circle(point, radius, facecolor='orangered', edgecolor='none', linewidth=0, alpha=0.5) patches['bad'].append(new_circle) ax.plot((rat - error, rat + error), (radius, radius), 'x-', color='darkgray') print(ub.map_vals(len, patches)) total = float(sum(errors)) print('total = {!r}'.format(total)) print(max(errors)) print(min(errors)) for v in patches.values(): first = ub.peek(v) prop = ub.dict_isect(first.properties(), ['facecolor', 'linewidth', 'alpha', 'edgecolor']) col = mpl.collections.PatchCollection(v, **prop) ax.add_collection(col) # Lets look for the holes in IEEE float # for flt in ub.ProgIter(sorted(floats), verbose=1): kwplot.phantom_legend({ f'rationals without a {dtype}': 'orangered', f'rationals with a {dtype}': 'dodgerblue', f'x-x indicates {dtype} approximation error': 'darkgray', }) ax.set_title('Holes in IEEE 754 Float64') ax.set_xlabel('A rational number') ax.set_ylabel('The squared rational denominator') # import numpy as np # points = np.array([c.center for c in _circles]) # maxx, maxy = points.max(axis=0) # print('maxx = {!r}'.format(maxx)) # print('maxy = {!r}'.format(maxy)) # maxx, maxy = maxx // 2, maxy // 2 # ax.set_xlim(0, np.sqrt(int(maxx))) # ax.set_ylim(0, np.sqrt(int(maxy))) # ax.set_aspect('equal') # ax.set_xlim(0.2, 0.22) ax.set_xlim(0, 1) ax.set_ylim(0, 0.1)
def simple_munkres(part_oldnames): """ Defines a munkres problem to solve name rectification. Notes: We create a matrix where each rows represents a group of annotations in the same PCC and each column represents an original name. If there are more PCCs than original names the columns are padded with extra values. The matrix is first initialized to be negative infinity representing impossible assignments. Then for each column representing a padded name, we set we its value to $1$ indicating that each new name could be assigned to a padded name for some small profit. Finally, let $f_{rc}$ be the the number of annotations in row $r$ with an original name of $c$. Each matrix value $(r, c)$ is set to $f_{rc} + 1$ if $f_{rc} > 0$, to represent how much each name ``wants'' to be labeled with a particular original name, and the extra one ensures that these original names are always preferred over padded names. Example: >>> part_oldnames = [['a', 'b'], ['b', 'c'], ['c', 'a', 'a']] >>> new_names = simple_munkres(part_oldnames) >>> result = ub.repr2(new_names) >>> print(new_names) ['b', 'c', 'a'] Example: >>> part_oldnames = [[], ['a', 'a'], [], >>> ['a', 'a', 'a', 'a', 'a', 'a', 'a', 'b'], ['a']] >>> new_names = simple_munkres(part_oldnames) >>> result = ub.repr2(new_names) >>> print(new_names) [None, 'a', None, 'b', None] Example: >>> part_oldnames = [[], ['b'], ['a', 'b', 'c'], ['b', 'c'], ['c', 'e', 'e']] >>> new_names = find_consistent_labeling(part_oldnames) >>> result = ub.repr2(new_names) >>> print(new_names) ['_extra_name0', 'b', 'a', 'c', 'e'] Profit Matrix b a c e _0 0 -10 -10 -10 -10 1 1 2 -10 -10 -10 1 2 2 2 2 -10 1 3 2 -10 2 -10 1 4 -10 -10 2 3 1 """ unique_old_names = list(ub.unique(ub.flatten(part_oldnames))) num_new_names = len(part_oldnames) num_old_names = len(unique_old_names) # Create padded dummy values. This accounts for the case where it is # impossible to uniquely map to the old db num_pad = max(num_new_names - num_old_names, 0) total = num_old_names + num_pad shape = (total, total) # Allocate assignment matrix. # rows are new-names and cols are old-names. # Initially the profit of any assignment is effectively -inf # This effectively marks all assignments as invalid profit_matrix = np.full(shape, -2 * total, dtype=np.int) # Overwrite valid assignments with positive profits from graphid import util oldname2_idx = util.make_index_lookup(unique_old_names) name_freq_list = [ub.dict_hist(names) for names in part_oldnames] # Initialize profit of a valid assignment as 1 + freq # This incentivizes using a previously used name for rowx, name_freq in enumerate(name_freq_list): for name, freq in name_freq.items(): colx = oldname2_idx[name] profit_matrix[rowx, colx] = freq + 1 # Set a much smaller profit for using an extra name # This allows the solution to always exist profit_matrix[:, num_old_names:total] = 1 # Convert to minimization problem big_value = (profit_matrix.max()) - (profit_matrix.min()) cost_matrix = big_value - profit_matrix # Use scipy implementation of munkres algorithm. rx2_cx = dict(zip(*scipy.optimize.linear_sum_assignment(cost_matrix))) # Each row (new-name) has now been assigned a column (old-name) # Map this back to the input-space (using None to indicate extras) cx2_name = dict(enumerate(unique_old_names)) if False: import pandas as pd columns = unique_old_names + ['_%r' % x for x in range(num_pad)] print('Profit Matrix') print(pd.DataFrame(profit_matrix, columns=columns)) print('Cost Matrix') print(pd.DataFrame(cost_matrix, columns=columns)) assignment_ = [cx2_name.get(rx2_cx[rx], None) for rx in range(num_new_names)] return assignment_
def find_consistent_labeling(grouped_oldnames, extra_prefix='_extra_name', verbose=False): """ Solves a a maximum bipirtite matching problem to find a consistent name assignment that minimizes the number of annotations with different names. For each new grouping of annotations we assign For each group of annotations we must assign them all the same name, either from To reduce the running time Args: gropued_oldnames (list): A group of old names where the grouping is based on new names. For instance: Given: aids = [1, 2, 3, 4, 5] old_names = [0, 1, 1, 1, 0] new_names = [0, 0, 1, 1, 0] The grouping is [[0, 1, 0], [1, 1]] This lets us keep the old names in a split case and re-use exising names and make minimal changes to current annotation names while still being consistent with the new and improved grouping. The output will be: [0, 1] Meaning that all annots in the first group are assigned the name 0 and all annots in the second group are assigned the name 1. References: http://stackoverflow.com/questions/1398822/assignment-problem-numpy Example: >>> grouped_oldnames = demodata_oldnames(25, 15, 5, n_per_incon=5) >>> new_names = find_consistent_labeling(grouped_oldnames, verbose=1) >>> grouped_oldnames = demodata_oldnames(0, 15, 5, n_per_incon=1) >>> new_names = find_consistent_labeling(grouped_oldnames, verbose=1) >>> grouped_oldnames = demodata_oldnames(0, 0, 0, n_per_incon=1) >>> new_names = find_consistent_labeling(grouped_oldnames, verbose=1) Example: >>> ydata = [] >>> xdata = list(range(10, 150, 50)) >>> for x in xdata: >>> print('x = %r' % (x,)) >>> grouped_oldnames = demodata_oldnames(x, 15, 5, n_per_incon=5) >>> t = ub.Timerit(3, verbose=1) >>> for timer in t: >>> with timer: >>> new_names = find_consistent_labeling(grouped_oldnames) >>> ydata.append(t.min()) >>> # xdoc: +REQUIRES(--show) >>> import plottool as pt >>> pt.qtensure() >>> pt.multi_plot(xdata, [ydata]) >>> util.show_if_requested() Example: >>> grouped_oldnames = [['a', 'b', 'c'], ['b', 'c'], ['c', 'e', 'e']] >>> new_names = find_consistent_labeling(grouped_oldnames, verbose=1) >>> result = ub.repr2(new_names) >>> print(new_names) ['a', 'b', 'e'] Example: >>> grouped_oldnames = [['a', 'b'], ['a', 'a', 'b'], ['a']] >>> new_names = find_consistent_labeling(grouped_oldnames) >>> result = ub.repr2(new_names) >>> print(new_names) ['b', 'a', '_extra_name0'] Example: >>> grouped_oldnames = [['a', 'b'], ['e'], ['a', 'a', 'b'], [], ['a'], ['d']] >>> new_names = find_consistent_labeling(grouped_oldnames) >>> result = ub.repr2(new_names) >>> print(new_names) ['b', 'e', 'a', '_extra_name0', '_extra_name1', 'd'] Example: >>> grouped_oldnames = [[], ['a', 'a'], [], >>> ['a', 'a', 'a', 'a', 'a', 'a', 'a', 'b'], ['a']] >>> new_names = find_consistent_labeling(grouped_oldnames) >>> result = ub.repr2(new_names) >>> print(new_names) ['_extra_name0', 'a', '_extra_name1', 'b', '_extra_name2'] """ unique_old_names = list(ub.unique(ub.flatten(grouped_oldnames))) n_old_names = len(unique_old_names) n_new_names = len(grouped_oldnames) # Initialize assignment to all Nones assignment = [None for _ in range(n_new_names)] if verbose: print('finding maximally consistent labeling') print('n_old_names = %r' % (n_old_names,)) print('n_new_names = %r' % (n_new_names,)) # For each old_name, determine now many new_names use it. oldname_sets = list(map(set, grouped_oldnames)) oldname_usage = ub.dict_hist(ub.flatten(oldname_sets)) # Any name used more than once is a conflict and must be resolved conflict_oldnames = {k for k, v in oldname_usage.items() if v > 1} # Partition into trivial and non-trivial cases nontrivial_oldnames = [] nontrivial_new_idxs = [] trivial_oldnames = [] trivial_new_idxs = [] for new_idx, group in enumerate(grouped_oldnames): if set(group).intersection(conflict_oldnames): nontrivial_oldnames.append(group) nontrivial_new_idxs.append(new_idx) else: trivial_oldnames.append(group) trivial_new_idxs.append(new_idx) # Rectify trivial cases # Any new-name that does not share any of its old-names with other # new-names can be resolved trivially n_trivial_unchanged = 0 n_trivial_ignored = 0 n_trivial_merges = 0 for group, new_idx in zip(trivial_oldnames, trivial_new_idxs): if len(group) > 0: # new-names that use more than one old-name are simple merges h = ub.dict_hist(group) if len(h) > 1: n_trivial_merges += 1 else: n_trivial_unchanged += 1 hitems = list(h.items()) hvals = [i[1] for i in hitems] maxval = max(hvals) g = min([k for k, v in hitems if v == maxval]) assignment[new_idx] = g else: # new-names that use no old-names can be ignored n_trivial_ignored += 1 if verbose: n_trivial = len(trivial_oldnames) n_nontrivial = len(nontrivial_oldnames) print('rectify %d trivial groups' % (n_trivial,)) print(' * n_trivial_unchanged = %r' % (n_trivial_unchanged,)) print(' * n_trivial_merges = %r' % (n_trivial_merges,)) print(' * n_trivial_ignored = %r' % (n_trivial_ignored,)) print('rectify %d non-trivial groups' % (n_nontrivial,)) # Partition nontrivial_oldnames into smaller disjoint sets nontrivial_oldnames_sets = list(map(set, nontrivial_oldnames)) import networkx as nx g = nx.Graph() g.add_nodes_from(range(len(nontrivial_oldnames_sets))) for u, group1 in enumerate(nontrivial_oldnames_sets): rest = nontrivial_oldnames_sets[u + 1:] for v, group2 in enumerate(rest, start=u + 1): if group1.intersection(group2): g.add_edge(u, v) nontrivial_partition = list(nx.connected_components(g)) if verbose: print(' * partitioned non-trivial into %d subgroups' % (len(nontrivial_partition))) from graphid import util part_size_stats = util.stats_dict(map(len, nontrivial_partition)) stats_str = ub.repr2(part_size_stats, precision=2, strkeys=True) print(' * partition size stats = %s' % (stats_str,)) # Rectify nontrivial cases for part_idxs in ub.ProgIter(nontrivial_partition, desc='rectify parts', enabled=verbose): part_oldnames = list(ub.take(nontrivial_oldnames, part_idxs)) part_newidxs = list(ub.take(nontrivial_new_idxs, part_idxs)) # Rectify this part assignment_ = simple_munkres(part_oldnames) for new_idx, new_name in zip(part_newidxs, assignment_): assignment[new_idx] = new_name # Any unassigned name is now given a new unique label with a prefix if extra_prefix is not None: num_extra = 0 for idx, val in enumerate(assignment): if val is None: assignment[idx] = '%s%d' % (extra_prefix, num_extra,) num_extra += 1 return assignment
def _squash_between(repo, start, stop, dry=False, verbose=True): """ inplace squash between, use external function that sets up temp branches to use this directly from the commandline. """ if len(start.parents) != 1: raise AssertionError('cant handle') # assert start.authored_datetime < stop.authored_datetime if not repo.is_ancestor(ancestor_rev=start, rev=stop): raise AssertionError('cant handle') # Do RFC2822 # ISO_8601 = '%Y-%m-%d %H:%M:%S %z' # NOQA # ts_start = start.authored_datetime.strftime(ISO_8601) # ts_stop = stop.authored_datetime.strftime(ISO_8601) ts_start = email.utils.format_datetime(start.authored_datetime) ts_stop = email.utils.format_datetime(stop.authored_datetime) # if ts_start.split()[0:4] == ts_stop.split()[0:4]: if start.authored_datetime.date() == stop.authored_datetime.date(): ts_stop_short = ' '.join(ts_stop.split()[4:]) else: ts_stop_short = ts_stop # Construct a new message commits = commits_between(repo, start, stop) messages = [commit.message for commit in commits] # messages = [commit.message for commit in streak._streak] unique_messages = ub.unique(messages) summary = '\n'.join(unique_messages) if summary == 'wip\n': summary = summary.strip('\n') new_msg = '{} - Squashed {} commits from <{}> to <{}>\n'.format( summary, len(commits), ts_start, ts_stop_short) if verbose: print(' * Creating new commit with message:') print(new_msg) old_head = repo.commit('HEAD') if (stop != old_head and not repo.is_ancestor(ancestor_rev=stop, rev=old_head)): raise Exception('stop={} is not an ancestor of old_head={}'.format( stop, old_head)) if not dry: # ------------------ # MODIFICATION LOGIC # ------------------ # Go back in time to the sequence stopping point repo.git.reset(stop.hexsha, hard=True) # Undo commits from start to stop by softly reseting to just before the start before_start = start.parents[0] if verbose: print(' * reseting to before <start>') repo.git.reset(before_start.hexsha, soft=True) # Commit the changes in a new squashed commit and presever authored date if verbose: print(' * creating one commit with all modifications up to <stop>') repo.index.commit(new_msg, author_date=ts_stop) # If <stop> was not the most recent commit, we need to take those back on if stop != old_head: # Copy commits following the end of the streak in front of our new commit if verbose: print(' * fixing up the head') try: # above_commits = commits_between(repo, stop, old_head) # print('above_commits = {}'.format(ub.repr2(above_commits, si=True))) above = stop.hexsha + '..' + old_head.hexsha # above = streak.child.hexsha + '..' + old_head repo.git.cherry_pick(above, allow_empty=True) except git.GitCommandError: print('ERROR: need to roll back') raise else: if verbose: print(' * already at the head, no need to fix')