Exemplo n.º 1
0
def on_pick(event, infr=None):
    import wbia.plottool as pt

    logger.info('ON PICK: %r' % (event, ))
    artist = event.artist
    plotdat = pt.get_plotdat_dict(artist)
    if plotdat:
        if 'node' in plotdat:
            all_node_data = ut.sort_dict(plotdat['node_data'].copy())
            visual_node_data = ut.dict_subset(all_node_data,
                                              infr.visual_node_attrs, None)
            node_data = ut.delete_dict_keys(all_node_data,
                                            infr.visual_node_attrs)
            node = plotdat['node']
            node_data['degree'] = infr.graph.degree(node)
            node_label = infr.pos_graph.node_label(node)
            logger.info('visual_node_data: ' +
                        ut.repr2(visual_node_data, nl=1))
            logger.info('node_data: ' + ut.repr2(node_data, nl=1))
            ut.cprint('node: ' + ut.repr2(plotdat['node']), 'blue')
            logger.info('(pcc) node_label = %r' % (node_label, ))
            logger.info('artist = %r' % (artist, ))
        elif 'edge' in plotdat:
            all_edge_data = ut.sort_dict(plotdat['edge_data'].copy())
            logger.info(infr.repr_edge_data(all_edge_data))
            ut.cprint('edge: ' + ut.repr2(plotdat['edge']), 'blue')
            logger.info('artist = %r' % (artist, ))
        else:
            logger.info('???: ' + ut.repr2(plotdat))
    logger.info(ut.get_timestamp())
Exemplo n.º 2
0
def case_redo_incon():
    """
    CommandLine:
        python -m wbia.algo.graph.tests.dyn_cases case_redo_incon --show

    Example:
        >>> # ENABLE_DOCTEST
        >>> from wbia.algo.graph.tests.dyn_cases import *  # NOQA
        >>> case_redo_incon()
    """
    ccs = [[1, 2], [3, 4]]  # [6, 7]]
    edges = [
        (2, 3, {
            'evidence_decision': NEGTV
        }),
        (1, 4, {
            'evidence_decision': NEGTV
        }),
    ]
    edges += []
    new_edges = [(2, 3, {'evidence_decision': POSTV})]
    infr1, infr2, check = do_infr_test(ccs, edges, new_edges)

    maybe_splits = infr2.get_edge_attrs('maybe_error', default=None)
    logger.info('maybe_splits = %r' % (maybe_splits, ))
    if not any(maybe_splits.values()):
        ut.cprint('FAILURE', 'red')
        logger.info('At least one edge should be marked as a split')

    check.after()
Exemplo n.º 3
0
 def print_acronymn_def(english_line):
     words = re.split('[~\s]', english_line.rstrip('.'))
     words = [w.rstrip(',').rstrip('.') for w in words]
     flag = 0
     for count, word in enumerate(words):
         if re.match('\\([A-Z]+\\)', word):
             ut.cprint(word, 'blue')
             flag = True
     if flag:
         print(re.sub('\\\\cite{[^}]*}', '', line))
Exemplo n.º 4
0
def update_all(repo, master, mixins):
    ut.cprint('--- UPDATE ALL ---', 'blue')
    repo.checkout2(master)
    repo.pull2()
    repo.issue('git fetch --all')

    for branch in mixins:
        repo.checkout2(branch)
        # repo.issue('git checkout ' + branch)
        # gitrepo = repo.as_gitpython()  # NOQA
        repo.reset_branch_to_remote(branch)
        repo.issue('git pull')
Exemplo n.º 5
0
def main():
    tests_ = tests
    subset = ['consistent_info', 'inconsistent_info']
    subset = ['chain1', 'chain2', 'chain3']
    subset += ['triangle1', 'triangle2', 'triangle3']
    # subset = ['inconsistent_info']
    tests_ = ut.dict_subset(tests, subset)

    for name, func in tests_.items():
        logger.info('\n==============')
        ut.cprint('name = %r' % (name, ), 'yellow')
        uvw_list, pass_values, fail_values = func()
        G = build_graph(uvw_list)

        nodes = sorted(G.nodes())
        edges = [tuple(sorted(e)) for e in G.edges()]
        edges = ut.sortedby2(edges, edges)

        n_annots = len(nodes)
        n_names = n_annots

        annot_idxs = list(range(n_annots))
        lookup_annot_idx = ut.dzip(nodes, annot_idxs)
        nx.set_node_attributes(G, name='annot_idx', values=lookup_annot_idx)

        edge_probs = np.array([
            get_edge_id_probs(G, aid1, aid2, n_names) for aid1, aid2 in edges
        ])

        logger.info('nodes = %r' % (nodes, ))
        # logger.info('edges = %r' % (edges,))
        logger.info('Noisy Observations')
        logger.info(
            pd.DataFrame(edge_probs,
                         columns=['same', 'diff'],
                         index=pd.Series(edges)))
        edge_probs = None

        cut_step(
            G,
            nodes,
            edges,
            n_annots,
            n_names,
            lookup_annot_idx,
            edge_probs,
            pass_values,
            fail_values,
        )

        edge_probs = bp_step(G, nodes, edges, n_annots, n_names,
                             lookup_annot_idx)
Exemplo n.º 6
0
    def ensure_results(self, expt_name=None, nocompute=None):
        """
        Subclasses must obey the measure_<expt_name>, draw_<expt_name> contract
        """
        if nocompute is None:
            nocompute = ut.get_argflag('--nocompute')

        if expt_name is None and exists(self.dpath):
            # Load all
            fpaths = ut.glob(str(self.dpath), '*.pkl')
            expt_names = [splitext(basename(fpath))[0] for fpath in fpaths]
            for fpath, expt_name in zip(fpaths, expt_names):
                self.expt_results[expt_name] = ut.load_data(fpath)
        else:
            # expt_name = splitext(basename(fpath))[0]
            fpath = join(str(self.dpath), expt_name + '.pkl')
            # fpath = ut.truepath(fpath)
            if not exists(fpath):
                ut.cprint(
                    'Experiment results {} do not exist'.format(expt_name),
                    'red')
                ut.cprint('First re-setup to check if it is a path issue',
                          'red')
                if nocompute:
                    raise Exception(
                        str(expt_name) + ' does not exist for ' +
                        str(self.dbname))

                if self.ibs is None:
                    self._precollect()
                ut.cprint('Checking new fpath', 'yellow')
                fpath = join(str(self.dpath), expt_name + '.pkl')
                logger.info('fpath = %r' % (fpath, ))
                if not exists(fpath):
                    ut.cprint('Results still missing need to re-measure',
                              'red')
                    # assert False
                    # self._setup()
                    getattr(self, 'measure_' + expt_name)()
                else:
                    ut.cprint('Re-setup fixed it', 'green')
            else:
                logger.info('Experiment results {} exist'.format(expt_name))
            self.expt_results[expt_name] = ut.load_data(fpath)
            return self.expt_results[expt_name]
Exemplo n.º 7
0
        def after(self, errors=[]):
            """
            Delays error reporting until after visualization

            prints errors, then shows you the graph, then
            finally if any errors were discovered they are raised
            """

            errors = errors + self._errors
            if errors:
                ut.cprint('PRINTING %d FAILURE' % (len(errors)), 'red')
                for msg in errors:
                    logger.info(msg)
                ut.cprint('HAD %d FAILURE' % (len(errors)), 'red')
            if ut.show_was_requested():
                pt.all_figures_tile(percent_w=0.5)
                ut.show_if_requested()
            if errors:
                raise AssertionError('There were errors')
Exemplo n.º 8
0
def cut_step(G, nodes, edges, n_annots, n_names, lookup_annot_idx, edge_probs, pass_values, fail_values):
    # Create nodes in the graphical model.  In this case there are <num_vars>
    # nodes and each node can be assigned to one of <num_vars> possible labels
    space = np.full((n_annots,), fill_value=n_names, dtype=opengm.index_type)
    gm = opengm.gm(space, operator='adder')

    # Use one potts function for each edge
    gm = build_factor_graph(G, nodes, edges , n_annots, n_names,
                            lookup_annot_idx, use_unaries=False,
                            edge_probs=edge_probs, operator='adder')

    with ut.Indenter('[CUTS]'):
        ut.cprint('Brute Force Labels: (energy minimization)', 'blue')
        infr = opengm.inference.Bruteforce(gm, accumulator='minimizer')
        infr.infer()
        labels = rectify_labels(G, infr.arg())
        print(pd.DataFrame(labels, columns=['nid'], index=pd.Series(nodes)).T)
        print('value = %r' % (infr.value(),))

        mc_params = opengm.InfParam(maximalNumberOfConstraintsPerRound=1000000,
                                    initializeWith3Cycles=True,
                                    edgeRoundingValue=1e-08, timeOut=36000000.0,
                                    cutUp=1e+75, reductionMode=3, numThreads=0,
                                    # allowCutsWithin=?
                                    # workflow=workflow
                                    verbose=False, verboseCPLEX=False)
        infr = opengm.inference.Multicut(gm, parameter=mc_params,
                                         accumulator='minimizer')

        infr.infer()
        labels = infr.arg()
        labels = rectify_labels(G, infr.arg())

        ut.cprint('Multicut Labels: (energy minimization)', 'blue')
        print(pd.DataFrame(labels, columns=['nid'], index=pd.Series(nodes)).T)
        print('value = %r' % (infr.value(),))

        if pass_values is not None:
            gotany = False
            for pval in pass_values:
                if all(labels == pval):
                    gotany = True
                    break
            if not gotany:
                ut.cprint('INCORRECT DID NOT GET PASS VALUES', 'red')
                print('pass_values = %r' % (pass_values,))

        if fail_values is not None:
            for fail in fail_values:
                if all(labels == fail):
                    ut.cprint('INCORRECT', 'red')
Exemplo n.º 9
0
def apply_species_with_detector_hack(ibs, cfgdict, qaids, daids,
                                     verbose=None):
    """
    HACK turns of featweights if they cannot be applied
    """
    if verbose is None:
        verbose = VERBOSE_QREQ
    if True:
        # Hack for test speed
        if ibs.dbname in {'PZ_MTEST', 'GZ_Master1', 'PZ_Master1'}:
            return True
    # Only apply the hack with repsect to the queried annotations
    aid_list = set(it.chain(qaids, daids))
    unique_species = ibs.get_database_species(aid_list)
    # turn off featureweights when not absolutely sure they are ok to us,)
    candetect = (len(unique_species) == 1 and
                 ibs.has_species_detector(unique_species[0]))
    if not candetect:
        if ut.NOT_QUIET:
            ut.cprint(
                '[qreq] HACKING FG_WEIGHT OFF (db species is not supported)',
                'yellow')
            if verbose > 1:
                if len(unique_species) != 1:
                    print('[qreq]  * len(unique_species) = %r' % len(unique_species))
                else:
                    print('[qreq]  * unique_species = %r' % (unique_species,))
        #print('[qreq]  * valid species = %r' % (
        #    ibs.get_species_with_detectors(),))
        #cfg._featweight_cfg.featweight_enabled = 'ERR'
        cfgdict['featweight_enabled'] = False  # 'ERR'
        cfgdict['fg_on'] = False
    else:
        #print(ibs.get_annot_species_texts(aid_list))
        if verbose:
            print('[qreq] Using fgweights of unique_species=%r' % (
                unique_species,))
    return unique_species
Exemplo n.º 10
0
    def ensure_data(qreq_):
        """
            >>> import wbia
            qreq_ = wbia.testdata_qreq_(
                defaultdb='Oxford', a='oxford',
                p='default:proot=smk,nAssign=1,num_words=64000,SV=False,can_match_sameimg=True,dim_size=None')
        """
        logger.info('Ensure data for %s' % (qreq_, ))

        # qreq_.cachedir = ut.ensuredir((ibs.cachedir, 'smk'))
        qreq_.ensure_nids()

        def make_cacher(name, cfgstr=None):
            if cfgstr is None:
                cfgstr = ut.hashstr27(qreq_.get_cfgstr())
            if False and ut.is_developer():
                return ut.Cacher(
                    fname=name + '_' + qreq_.ibs.get_dbname(),
                    cfgstr=cfgstr,
                    cache_dir=ut.ensuredir(ut.truepath('~/Desktop/smkcache')),
                )
            else:
                wrp = ut.DynStruct()

                def ensure(func):
                    return func()

                wrp.ensure = ensure
                return wrp

        import copy

        dconfig = copy.deepcopy(qreq_.qparams)
        qconfig = qreq_.qparams
        if qreq_.qparams['data_ma']:
            # Disable database-dise multi-assignment
            dconfig['nAssign'] = 1
        wwm = qreq_.qparams['word_weight_method']

        depc = qreq_.ibs.depc
        vocab_aids = qreq_.daids

        cheat = False
        if cheat:
            import wbia

            ut.cprint('CHEATING', 'red')
            vocab_aids = wbia.init.filter_annots.sample_annots_wrt_ref(
                qreq_.ibs,
                qreq_.daids,
                {'exclude_ref_contact': True},
                qreq_.qaids,
                verbose=1,
            )
            vocab_rowid = depc.get_rowids('vocab', (vocab_aids, ),
                                          config=dconfig,
                                          ensure=False)[0]
            assert vocab_rowid is not None

        depc = qreq_.ibs.depc
        dinva_pcfgstr = depc.stacked_config(None,
                                            'inverted_agg_assign',
                                            config=dconfig).get_cfgstr()
        qinva_pcfgstr = depc.stacked_config(None,
                                            'inverted_agg_assign',
                                            config=qconfig).get_cfgstr()
        dannot_vuuid = qreq_.ibs.get_annot_hashid_visual_uuid(
            qreq_.daids).strip('_')
        qannot_vuuid = qreq_.ibs.get_annot_hashid_visual_uuid(
            qreq_.qaids).strip('_')
        tannot_vuuid = dannot_vuuid
        dannot_suuid = qreq_.ibs.get_annot_hashid_semantic_uuid(
            qreq_.daids).strip('_')
        qannot_suuid = qreq_.ibs.get_annot_hashid_semantic_uuid(
            qreq_.qaids).strip('_')

        dinva_phashid = ut.hashstr27(dinva_pcfgstr + tannot_vuuid)
        qinva_phashid = ut.hashstr27(qinva_pcfgstr + tannot_vuuid)
        dinva_cfgstr = '_'.join([dannot_vuuid, dinva_phashid])
        qinva_cfgstr = '_'.join([qannot_vuuid, qinva_phashid])

        # vocab = inverted_index.new_load_vocab(ibs, qreq_.daids, config)
        dinva_cacher = make_cacher('inva', dinva_cfgstr)
        qinva_cacher = make_cacher('inva', qinva_cfgstr)
        dwwm_cacher = make_cacher('word_weight', wwm + dinva_cfgstr)

        gamma_phashid = ut.hashstr27(qreq_.get_pipe_cfgstr() + tannot_vuuid)
        dgamma_cfgstr = '_'.join([dannot_suuid, gamma_phashid])
        qgamma_cfgstr = '_'.join([qannot_suuid, gamma_phashid])
        dgamma_cacher = make_cacher('dgamma', cfgstr=dgamma_cfgstr)
        qgamma_cacher = make_cacher('qgamma', cfgstr=qgamma_cfgstr)

        dinva = dinva_cacher.ensure(
            lambda: inverted_index.InvertedAnnots.from_depc(
                depc, qreq_.daids, vocab_aids, dconfig))

        qinva = qinva_cacher.ensure(
            lambda: inverted_index.InvertedAnnots.from_depc(
                depc, qreq_.qaids, vocab_aids, qconfig))

        dinva.wx_to_aids = dinva.compute_inverted_list()

        wx_to_weight = dwwm_cacher.ensure(
            lambda: dinva.compute_word_weights(wwm))
        dinva.wx_to_weight = wx_to_weight
        qinva.wx_to_weight = wx_to_weight

        thresh = qreq_.qparams['smk_thresh']
        alpha = qreq_.qparams['smk_alpha']

        dinva.gamma_list = dgamma_cacher.ensure(
            lambda: dinva.compute_gammas(alpha, thresh))

        qinva.gamma_list = qgamma_cacher.ensure(
            lambda: qinva.compute_gammas(alpha, thresh))

        qreq_.qinva = qinva
        qreq_.dinva = dinva

        logger.info('loading keypoints')
        if qreq_.qparams.sv_on:
            qreq_.data_kpts = qreq_.ibs.get_annot_kpts(
                qreq_.daids, config2_=qreq_.extern_data_config2)

        logger.info('building aid index')
        qreq_.daid_to_didx = ut.make_index_lookup(qreq_.daids)
Exemplo n.º 11
0
def main():
    target = 'dev_combo'
    master = 'master'
    mixins = [
        # 'mbkm_fixup',
        # 'progiter',
        # 'multiclass_mcc',
        'missing_values_rf',
    ]
    ut.cprint('--- OPEN REPO ---', 'blue')
    # dpath = os.getcwd()
    dpath = ut.truepath('~/code/scikit-learn')
    repo = ut.Repo(dpath=dpath, url='[email protected]:Erotemic/scikit-learn.git')

    if not repo.is_cloned():
        repo.clone()
        # repo.issue('pip install -e .')

    # Make sure remotes are properly setup
    repo._ensure_remote_exists(
        'source', 'https://github.com/scikit-learn/scikit-learn.git')
    repo._ensure_remote_exists('raghavrv',
                               'https://github.com/raghavrv/scikit-learn.git')

    # Master should point to the scikit-learn official repo
    if repo.get_branch_remote('master') != 'source':
        repo.set_branch_remote('master', 'source')

    update_all(repo, master, mixins)

    REBASE_VERSION = True
    if REBASE_VERSION:
        ut.cprint('--- REBASE BRANCHES ON MASTER ---', 'blue')
        rebase_mixins = []
        for branch in mixins:
            new_branch = make_dev_rebased_mixin(repo, master, branch)
            rebase_mixins.append(new_branch)
        ut.cprint('--- CHECKOUT DEV MASTER --- ', 'blue')
        reset_dev_branch(repo, master, target)
        ut.cprint('--- MERGE INTO DEV MASTER --- ', 'blue')
        for branch in rebase_mixins:
            repo.issue('git merge --no-edit -s recursive ' + branch)
        # repo.issue('git merge --no-edit -s recursive -Xours ' + branch)
    else:
        # Attempt to automerge taking whatever is in the mixin branches as de-facto
        ut.cprint('--- CHECKOUT DEV MASTER --- ', 'blue')
        reset_dev_branch(repo, master, target)
        ut.cprint('--- MERGE INTO DEV MASTER --- ', 'blue')
        for branch in mixins:
            repo.issue('git merge --no-edit -s recursive -Xtheirs ' + branch)
        # cleanup because we didn't rebase
        fpath = join(repo.dpath, 'sklearn/utils/validation.py')
        ut.sedfile(fpath,
                   'accept_sparse=None',
                   'accept_sparse=False',
                   force=True)
        repo.issue('git commit -am "quick fix of known merge issue"')

    # # Recompile the
    if True:
        repo.issue('python setup.py clean')
        repo.issue('python setup.py build -j9')
        repo.issue('pip install -e .')
Exemplo n.º 12
0
    def pair_connection_info(infr, aid1, aid2):
        """
        Helps debugging when ibs.nids has info that annotmatch/staging do not

        Examples:
            >>> # # FIXME failing-test (22-Jul-2020) GZ_Master1 doesn't exist
            >>> # xdoctest: +SKIP
            >>> from wbia.algo.graph.mixin_helpers import *  # NOQA
            >>> import wbia
            >>> ibs = wbia.opendb(defaultdb='GZ_Master1')
            >>> infr = wbia.AnnotInference(ibs, 'all', autoinit=True)
            >>> infr.reset_feedback('staging', apply=True)
            >>> infr.relabel_using_reviews(rectify=False)
            >>> aid1, aid2 = 1349, 3087
            >>> aid1, aid2 = 1535, 2549
            >>> infr.pair_connection_info(aid1, aid2)


            >>> aid1, aid2 = 4055, 4286
            >>> aid1, aid2 = 6555, 6882
            >>> aid1, aid2 = 712, 803
            >>> aid1, aid2 = 3883, 4220
            >>> infr.pair_connection_info(aid1, aid2)
        """

        nid1, nid2 = infr.pos_graph.node_labels(aid1, aid2)
        cc1 = infr.pos_graph.connected_to(aid1)
        cc2 = infr.pos_graph.connected_to(aid2)
        ibs = infr.ibs

        # First check directly relationships

        def get_aug_df(edges):
            df = infr.get_edge_dataframe(edges)
            if len(df):
                df.index.names = ('aid1', 'aid2')
                nids = np.array(
                    [infr.pos_graph.node_labels(u, v) for u, v in list(df.index)]
                )
                df = df.assign(nid1=nids.T[0], nid2=nids.T[1])
                part = ['nid1', 'nid2', 'evidence_decision', 'tags', 'user_id']
                neworder = ut.partial_order(df.columns, part)
                df = df.reindex(neworder, axis=1)
                df = df.drop(['review_id', 'timestamp'], axis=1)
            return df

        def print_df(df, lbl):
            df_str = df.to_string()
            df_str = ut.highlight_regex(df_str, ut.regex_word(str(aid1)), color='blue')
            df_str = ut.highlight_regex(df_str, ut.regex_word(str(aid2)), color='red')
            if nid1 not in {aid1, aid2}:
                df_str = ut.highlight_regex(
                    df_str, ut.regex_word(str(nid1)), color='blue'
                )
            if nid2 not in {aid1, aid2}:
                df_str = ut.highlight_regex(df_str, ut.regex_word(str(nid2)), color='red')
            logger.info('\n\n=====')
            logger.info(lbl)
            logger.info('=====')
            logger.info(df_str)

        logger.info('================')
        logger.info('Pair Connection Info')
        logger.info('================')

        nid1_, nid2_ = ibs.get_annot_nids([aid1, aid2])
        logger.info('AIDS        aid1, aid2 = %r, %r' % (aid1, aid2))
        logger.info('INFR NAMES: nid1, nid2 = %r, %r' % (nid1, nid2))
        if nid1 == nid2:
            logger.info('INFR cc = %r' % (sorted(cc1),))
        else:
            logger.info('INFR cc1 = %r' % (sorted(cc1),))
            logger.info('INFR cc2 = %r' % (sorted(cc2),))

        if (nid1 == nid2) != (nid1_ == nid2_):
            ut.cprint('DISAGREEMENT IN GRAPH AND DB', 'red')
        else:
            ut.cprint('GRAPH AND DB AGREE', 'green')

        logger.info('IBS  NAMES: nid1, nid2 = %r, %r' % (nid1_, nid2_))
        if nid1_ == nid2_:
            logger.info('IBS CC: %r' % (sorted(ibs.get_name_aids(nid1_)),))
        else:
            logger.info('IBS CC1: %r' % (sorted(ibs.get_name_aids(nid1_)),))
            logger.info('IBS CC2: %r' % (sorted(ibs.get_name_aids(nid2_)),))

        # Does this exist in annotmatch?
        in_am = ibs.get_annotmatch_rowid_from_undirected_superkey([aid1], [aid2])
        logger.info('in_am = %r' % (in_am,))

        # Does this exist in staging?
        staging_rowids = ibs.get_review_rowids_from_edges([(aid1, aid2)])[0]
        logger.info('staging_rowids = %r' % (staging_rowids,))

        if False:
            # Make absolutely sure
            stagedf = ibs.staging.get_table_as_pandas('reviews')
            aid_cols = ['annot_1_rowid', 'annot_2_rowid']
            has_aid1 = (stagedf[aid_cols] == aid1).any(axis=1)
            from_aid1 = stagedf[has_aid1]
            conn_aid2 = (from_aid1[aid_cols] == aid2).any(axis=1)
            logger.info('# connections = %r' % (conn_aid2.sum(),))

        # Next check indirect relationships
        graph = infr.graph
        if cc1 != cc2:
            edge_df1 = get_aug_df(nxu.edges_between(graph, cc1))
            edge_df2 = get_aug_df(nxu.edges_between(graph, cc2))
            print_df(edge_df1, 'Inside1')

            print_df(edge_df2, 'Inside1')

            out_df1 = get_aug_df(nxu.edges_outgoing(graph, cc1))
            print_df(out_df1, 'Outgoing1')

            out_df2 = get_aug_df(nxu.edges_outgoing(graph, cc2))
            print_df(out_df2, 'Outgoing2')
        else:
            subgraph = infr.pos_graph.subgraph(cc1)
            logger.info('Shortest path between endpoints')
            logger.info(nx.shortest_path(subgraph, aid1, aid2))

        edge_df3 = get_aug_df(nxu.edges_between(graph, cc1, cc2))
        print_df(edge_df3, 'Between')
Exemplo n.º 13
0
def bp_step(G, nodes, edges, n_annots, n_names, lookup_annot_idx):
    gm = build_factor_graph(G,
                            nodes,
                            edges,
                            n_annots,
                            n_names,
                            lookup_annot_idx,
                            use_unaries=False,
                            edge_probs=None,
                            operator='multiplier')

    with ut.Indenter('[BELIEF]'):
        ut.cprint('Brute Force Labels: (probability maximization)', 'blue')
        infr = opengm.inference.Bruteforce(gm, accumulator='maximizer')
        infr.infer()
        labels = rectify_labels(G, infr.arg())
        print(pd.DataFrame(labels, columns=['nid'], index=pd.Series(nodes)).T)
        print('value = %r' % (infr.value(), ))

        lpb_parmas = opengm.InfParam(
            damping=0.00,
            steps=10000,
            # convergenceBound=0,
            isAcyclic=False)
        # http://www.andres.sc/publications/opengm-2.0.2-beta-manual.pdf
        # I believe multiplier + integrator = marginalization
        # Manual says multiplier + adder = marginalization
        # Manual says multiplier + maximizer = probability maximization
        # infr = opengm.inference.TreeReweightedBp(
        LBP_algorithm = opengm.inference.BeliefPropagation
        # LBP_algorithm = opengm.inference.TreeReweightedBp

        ut.cprint('Belief Propogation (maximization)', 'blue')
        infr = LBP_algorithm(gm, parameter=lpb_parmas, accumulator='maximizer')
        infr.infer()
        labels = rectify_labels(G, infr.arg())
        pairwise_factor_idxs = gm.pairwise_factor_idxs
        factor_marginals = infr.factorMarginals(pairwise_factor_idxs)
        # print('factor_marginals =\n%r' % (factor_marginals,))
        edge_marginals_same_diff_ = [(np.diag(f).sum(),
                                      f[~np.eye(f.shape[0], dtype=bool)].sum())
                                     for f in factor_marginals]
        edge_marginals_same_diff_ = np.array(edge_marginals_same_diff_)
        edge_marginals_same_diff = edge_marginals_same_diff_.copy()
        edge_marginals_same_diff /= edge_marginals_same_diff.sum(axis=1,
                                                                 keepdims=True)
        print('Unnormalized Edge Marginals:')
        print(
            pd.DataFrame(edge_marginals_same_diff,
                         columns=['same', 'diff'],
                         index=pd.Series(edges)))
        # print('Edge marginals after Belief Propogation')
        # print(pd.DataFrame(edge_marginals_same_diff, columns=['same', 'diff'], index=pd.Series(edges)))
        print('Labels:')
        print(pd.DataFrame(labels, columns=['nid'], index=pd.Series(nodes)).T)
        print('value = %r' % (infr.value(), ))

        ut.cprint('Belief Propogation (marginalization)', 'blue')
        infr = LBP_algorithm(gm,
                             parameter=lpb_parmas,
                             accumulator='integrator')
        infr.infer()
        labels = rectify_labels(G, infr.arg())
        pairwise_factor_idxs = gm.pairwise_factor_idxs
        factor_marginals = infr.factorMarginals(pairwise_factor_idxs)
        # print('factor_marginals =\n%r' % (factor_marginals,))
        edge_marginals_same_diff_ = [(np.diag(f).sum(),
                                      f[~np.eye(f.shape[0], dtype=bool)].sum())
                                     for f in factor_marginals]
        edge_marginals_same_diff_ = np.array(edge_marginals_same_diff_)
        edge_marginals_same_diff = edge_marginals_same_diff_.copy()
        edge_marginals_same_diff /= edge_marginals_same_diff.sum(axis=1,
                                                                 keepdims=True)
        print('Unnormalized Edge Marginals:')
        print(
            pd.DataFrame(edge_marginals_same_diff,
                         columns=['same', 'diff'],
                         index=pd.Series(edges)))
        # print('Edge marginals after Belief Propogation')
        # print(pd.DataFrame(edge_marginals_same_diff, columns=['same', 'diff'], index=pd.Series(edges)))
        print('Labels:')
        print(pd.DataFrame(labels, columns=['nid'], index=pd.Series(nodes)).T)
        print('value = %r' % (infr.value(), ))

    # import plottool as pt
    # viz_factor_graph(gm)
    # # _ = pt.show_nx(G)
    # print("SHOW")
    # pt.plt.show()

    # marginals = infr.marginals(annot_idxs)
    # print('node marginals are')
    # print(pd.DataFrame(marginals, index=pd.Series(nodes)))
    return edge_marginals_same_diff
Exemplo n.º 14
0
def main(bib_fpath=None):
    r"""
    intro point to fixbib script

    CommmandLine:
        fixbib
        python -m fixtex bib
        python -m fixtex bib --dryrun
        python -m fixtex bib --dryrun --debug
    """

    if bib_fpath is None:
        bib_fpath = 'My Library.bib'

    # DEBUG = ub.argflag('--debug')
    # Read in text and ensure ascii format
    dirty_text = ut.readfrom(bib_fpath)

    from fixtex.fix_tex import find_used_citations, testdata_fpaths

    if exists('custom_extra.bib'):
        extra_parser = bparser.BibTexParser(ignore_nonstandard_types=False)
        parser = bparser.BibTexParser()
        ut.delete_keys(parser.alt_dict, ['url', 'urls'])
        print('Parsing extra bibtex file')
        extra_text = ut.readfrom('custom_extra.bib')
        extra_database = extra_parser.parse(extra_text, partial=False)
        print('Finished parsing extra')
        extra_dict = extra_database.get_entry_dict()
    else:
        extra_dict = None

    #udata = dirty_text.decode("utf-8")
    #dirty_text = udata.encode("ascii", "ignore")
    #dirty_text = udata

    # parser = bparser.BibTexParser()
    # bib_database = parser.parse(dirty_text)
    # d = bib_database.get_entry_dict()

    print('BIBTEXPARSER LOAD')
    parser = bparser.BibTexParser(ignore_nonstandard_types=False,
                                  common_strings=True)
    ut.delete_keys(parser.alt_dict, ['url', 'urls'])
    print('Parsing bibtex file')
    bib_database = parser.parse(dirty_text, partial=False)
    print('Finished parsing')

    bibtex_dict = bib_database.get_entry_dict()
    old_keys = list(bibtex_dict.keys())
    new_keys = []
    for key in ub.ProgIter(old_keys, label='fixing keys'):
        new_key = key
        new_key = new_key.replace(':', '')
        new_key = new_key.replace('-', '_')
        new_key = re.sub('__*', '_', new_key)
        new_keys.append(new_key)

    # assert len(ut.find_duplicate_items(new_keys)) == 0, 'new keys created conflict'
    assert len(ub.find_duplicates(new_keys)) == 0, 'new keys created conflict'

    for key, new_key in zip(old_keys, new_keys):
        if key != new_key:
            entry = bibtex_dict[key]
            entry['ID'] = new_key
            bibtex_dict[new_key] = entry
            del bibtex_dict[key]

    # The bibtext is now clean. Print it to stdout
    #print(clean_text)
    verbose = None
    if verbose is None:
        verbose = 1

    # Find citations from the tex documents
    key_list = None
    if key_list is None:
        cacher = ub.Cacher('texcite1', enabled=0)
        data = cacher.tryload()
        if data is None:
            fpaths = testdata_fpaths()
            key_list, inverse = find_used_citations(fpaths,
                                                    return_inverse=True)
            # ignore = ['JP', '?', 'hendrick']
            # for item in ignore:
            #     try:
            #         key_list.remove(item)
            #     except ValueError:
            #         pass
            if verbose:
                print('Found %d citations used in the document' %
                      (len(key_list), ))
            data = key_list, inverse
            cacher.save(data)
        key_list, inverse = data

    # else:
    #     key_list = None

    unknown_pubkeys = []
    debug_author = ub.argval('--debug-author', default=None)
    # ./fix_bib.py --debug_author=Kappes

    if verbose:
        print('Fixing %d/%d bibtex entries' %
              (len(key_list), len(bibtex_dict)))

    # debug = True
    debug = False
    if debug_author is not None:
        debug = False

    known_keys = list(bibtex_dict.keys())
    missing_keys = set(key_list) - set(known_keys)
    if extra_dict is not None:
        missing_keys.difference_update(set(extra_dict.keys()))

    if missing_keys:
        print('The library is missing keys found in tex files %s' %
              (ub.repr2(missing_keys), ))

    # Search for possible typos:
    candidate_typos = {}
    sedlines = []
    for key in missing_keys:
        candidates = ut.closet_words(key, known_keys, num=3, subset=True)
        if len(candidates) > 1:
            top = candidates[0]
            if ut.edit_distance(key, top) == 1:
                # "sed -i -e 's/{}/{}/g' *.tex".format(key, top)
                import os
                replpaths = ' '.join(
                    [relpath(p, os.getcwd()) for p in inverse[key]])
                sedlines.append("sed -i -e 's/{}/{}/g' {}".format(
                    key, top, replpaths))
        candidate_typos[key] = candidates
        print('Cannot find key = %r' % (key, ))
        print('Did you mean? %r' % (candidates, ))

    print('Quick fixes')
    print('\n'.join(sedlines))

    # group by file
    just = max([0] + list(map(len, missing_keys)))
    missing_fpaths = [inverse[key] for key in missing_keys]
    for fpath in sorted(set(ub.flatten(missing_fpaths))):
        # ut.fix_embed_globals()
        subkeys = [k for k in missing_keys if fpath in inverse[k]]
        print('')
        ut.cprint('--- Missing Keys ---', 'blue')
        ut.cprint('fpath = %r' % (fpath, ), 'blue')
        ut.cprint('{} | {}'.format('Missing'.ljust(just), 'Did you mean?'),
                  'blue')
        for key in subkeys:
            print('{} | {}'.format(ut.highlight_text(key.ljust(just), 'red'),
                                   ' '.join(candidate_typos[key])))

    # for key in list(bibtex_dict.keys()):

    if extra_dict is not None:
        # Extra database takes precidence over regular
        key_list = list(ut.unique(key_list + list(extra_dict.keys())))
        for k, v in extra_dict.items():
            bibtex_dict[k] = v

    full = ub.argflag('--full')

    for key in key_list:
        try:
            entry = bibtex_dict[key]
        except KeyError:
            continue
        self = BibTexCleaner(key, entry, full=full)

        if debug_author is not None:
            debug = debug_author in entry.get('author', '')

        if debug:
            ut.cprint(' --- ENTRY ---', 'yellow')
            print(ub.repr2(entry, nl=1))

        entry = self.fix()
        # self.clip_abstract()
        # self.shorten_keys()
        # self.fix_authors()
        # self.fix_year()
        # old_pubval = self.fix_pubkey()
        # if old_pubval:
        #     unknown_pubkeys.append(old_pubval)
        # self.fix_arxiv()
        # self.fix_general()
        # self.fix_paper_types()

        if debug:
            print(ub.repr2(entry, nl=1))
            ut.cprint(' --- END ENTRY ---', 'yellow')
        bibtex_dict[key] = entry

    unwanted_keys = set(bibtex_dict.keys()) - set(key_list)
    if verbose:
        print('Removing unwanted %d entries' % (len(unwanted_keys)))
    ut.delete_dict_keys(bibtex_dict, unwanted_keys)

    if 0:
        d1 = bibtex_dict.copy()
        full = True
        for key, entry in d1.items():
            self = BibTexCleaner(key, entry, full=full)
            pub = self.publication()
            if pub is None:
                print(self.entry['ENTRYTYPE'])

            old = self.fix_pubkey()
            x1 = self._pubval()
            x2 = self.standard_pubval(full=full)
            # if x2 is not None and len(x2) > 5:
            #     print(ub.repr2(self.entry))

            if x1 != x2:
                print('x2 = %r' % (x2, ))
                print('x1 = %r' % (x1, ))
                print(ub.repr2(self.entry))

            # if 'CVPR' in self.entry.get('booktitle', ''):
            #     if 'CVPR' != self.entry.get('booktitle', ''):
            #         break
            if old:
                print('old = %r' % (old, ))
            d1[key] = self.entry

    if full:
        d1 = bibtex_dict.copy()

        import numpy as np
        import pandas as pd
        df = pd.DataFrame.from_dict(d1, orient='index')

        paged_items = df[~pd.isnull(df['pub_accro'])]
        has_pages = ~pd.isnull(paged_items['pages'])
        print('have pages {} / {}'.format(has_pages.sum(), len(has_pages)))
        print(ub.repr2(paged_items[~has_pages]['title'].values.tolist()))

        entrytypes = dict(list(df.groupby('pub_type')))
        if False:
            # entrytypes['misc']
            g = entrytypes['online']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            entrytypes['book']
            entrytypes['thesis']
            g = entrytypes['article']
            g = entrytypes['incollection']
            g = entrytypes['conference']

        def lookup_pub(e):
            if e == 'article':
                return 'journal', 'journal'
            elif e == 'incollection':
                return 'booksection', 'booktitle'
            elif e == 'conference':
                return 'conference', 'booktitle'
            return None, None

        for e, g in entrytypes.items():
            print('e = %r' % (e, ))
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]
            if 'pub_full' in g.columns:
                place_title = g['pub_full'].tolist()
                print(ub.repr2(ub.dict_hist(place_title)))
            else:
                print('Unknown publications')

        if 'report' in entrytypes:
            g = entrytypes['report']
            missing = g[pd.isnull(g['title'])]
            if len(missing):
                print('Missing Title')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'journal' in entrytypes:
            g = entrytypes['journal']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            missing = g[pd.isnull(g['journal'])]
            if len(missing):
                print('Missing Journal')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'conference' in entrytypes:
            g = entrytypes['conference']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            missing = g[pd.isnull(g['booktitle'])]
            if len(missing):
                print('Missing Booktitle')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'incollection' in entrytypes:
            g = entrytypes['incollection']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]

            missing = g[pd.isnull(g['booktitle'])]
            if len(missing):
                print('Missing Booktitle')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        if 'thesis' in entrytypes:
            g = entrytypes['thesis']
            g = g[g.columns[~np.all(pd.isnull(g), axis=0)]]
            missing = g[pd.isnull(g['institution'])]
            if len(missing):
                print('Missing Institution')
                print(ub.repr2(missing[['title', 'author']].values.tolist()))

        # import utool
        # utool.embed()

    # Overwrite BibDatabase structure
    bib_database._entries_dict = bibtex_dict
    bib_database.entries = list(bibtex_dict.values())

    #conftitle_to_types_set_hist = {key: set(val) for key, val in conftitle_to_types_hist.items()}
    #print(ub.repr2(conftitle_to_types_set_hist))

    print('Unknown conference keys:')
    print(ub.repr2(sorted(unknown_pubkeys)))
    print('len(unknown_pubkeys) = %r' % (len(unknown_pubkeys), ))

    writer = BibTexWriter()
    writer.contents = ['comments', 'entries']
    writer.indent = '  '
    writer.order_entries_by = ('type', 'author', 'year')

    new_bibtex_str = bibtexparser.dumps(bib_database, writer)

    # Need to check
    #jegou_aggregating_2012

    # Fix the Journal Abreviations
    # References:
    # https://www.ieee.org/documents/trans_journal_names.pdf

    # Write out clean bibfile in ascii format
    clean_bib_fpath = ub.augpath(bib_fpath.replace(' ', '_'), suffix='_clean')

    if not ub.argflag('--dryrun'):
        ut.writeto(clean_bib_fpath, new_bibtex_str)