def rulebreakdowns(self, limit=10): """Print breakdowns for the most frequent rule mismatches.""" acc = self.acc # NB: unary nodes not handled properly gmismatch = {(n, indices): rule for n, indices, rule in acc.goldrule - acc.candrule} wrong = multiset((rule, gmismatch[n, indices]) for n, indices, rule in acc.candrule - acc.goldrule if len(indices) > 1 and (n, indices) in gmismatch) print('\n Rewrite rule mismatches (for given span)') print(' count cand / gold rules') for (crule, grule), cnt in wrong.most_common(limit): print(' %7d %s' % (cnt, grammar.printrule(*crule))) print(' %7s %s' % (' ', grammar.printrule(*grule))) gspans = {(n, indices) for n, indices, _ in acc.goldrule} wrong = multiset(rule for n, indices, rule in acc.candrule - acc.goldrule if len(indices) > 1 and (n, indices) not in gspans) print('\n Rewrite rules (span not in gold trees)') print(' count rule in candidate parses') for crule, cnt in wrong.most_common(limit): print(' %7d %s' % (cnt, grammar.printrule(*crule))) cspans = {(n, indices) for n, indices, _ in acc.candrule} wrong = multiset(rule for n, indices, rule in acc.goldrule - acc.candrule if len(indices) > 1 and (n, indices) not in cspans) print('\n Rewrite rules (span missing from candidate parses)') print(' count rule in gold standard set') for grule, cnt in wrong.most_common(limit): print(' %7d %s' % (cnt, grammar.printrule(*grule)))
def tagbreakdown(self, limit=10): """Print breakdowns for the most frequent tags.""" acc = self.acc print('\n Tag Statistics (%s tags / errors)' % ( ('%d most frequent ' % limit) if limit else 'all'), end='') print('\n tag % gold recall prec. F1', ' cand gold count') print(' ' + 38 * '_' + 7 * ' ' + 20 * '_') tags = multiset(tag for _, tag in acc.goldpos) wrong = multiset((c, g) for (_, g), (_, c) in zip(acc.goldpos, acc.candpos) if g != c) for tag, mismatch in zip_longest(tags.most_common(limit), wrong.most_common(limit)): if tag is None: print(''.rjust(40), end='') else: goldtag = multiset(n for n, (w, t) in enumerate(acc.goldpos) if t == tag[0]) candtag = multiset(n for n, (w, t) in enumerate(acc.candpos) if t == tag[0]) print('%s %6.2f %6.2f %6.2f %6.2f' % ( tag[0].rjust(7), 100 * len(goldtag) / len(acc.goldpos), 100 * recall(goldtag, candtag), 100 * precision(goldtag, candtag), 100 * f_measure(goldtag, candtag)), end='') if mismatch is not None: print(' %s %7d' % (' '.join((mismatch[0][0].rjust(8), mismatch[0][1].ljust(8))).rjust(12), mismatch[1]), end='') print()
def __init__(self, n, gtree, gsent, ctree, csent, param): """Construct a pair of gold and candidate trees for evaluation.""" self.n = n self.param = param self.gtree, self.ctree = gtree, ctree self.csentorig, self.gsentorig = csent, gsent self.csent, self.gsent = csent[:], gsent[:] self.cpos, self.gpos = sorted(ctree.pos()), sorted(gtree.pos()) self.lencpos = sum(1 for _, b in self.cpos if b not in self.param['DELETE_LABEL_FOR_LENGTH']) self.lengpos = sum(1 for _, b in self.gpos if b not in self.param['DELETE_LABEL_FOR_LENGTH']) if self.lencpos != self.lengpos: raise ValueError('sentence length mismatch. sents:\n%s\n%s' % ( ' '.join(self.csent), ' '.join(self.gsent))) grootpos = {child[0] for child in gtree if isinstance(child[0], int)} # massage the data (in-place modifications) transform(self.ctree, self.csent, self.cpos, dict(self.gpos), self.param, grootpos) transform(self.gtree, self.gsent, self.gpos, dict(self.gpos), self.param, grootpos) # if not gtree or not ctree: # return dict(LP=0, LR=0, LF=0) if self.csent != self.gsent: raise ValueError('candidate & gold sentences do not match:\n' '%r // %r' % (' '.join(csent), ' '.join(gsent))) self.cbrack = bracketings(ctree, self.param['LABELED'], self.param['DELETE_LABEL'], self.param['DISC_ONLY']) self.gbrack = bracketings(gtree, self.param['LABELED'], self.param['DELETE_LABEL'], self.param['DISC_ONLY']) self.lascore = self.ted = self.denom = Decimal('nan') self.cdep = self.gdep = () self.pgbrack = self.pcbrack = self.grule = self.crule = () if not self.gpos: return # avoid 'sentences' with only punctuation. self.lascore = leafancestor(self.gtree, self.ctree, self.param['DELETE_LABEL']) if self.param['TED']: self.ted, self.denom = treedisteval(self.gtree, self.ctree, includeroot=self.gtree.label not in self.param['DELETE_LABEL']) if self.param['DEP']: self.cdep = dependencies(self.ctree, self.param['HEADRULES']) self.gdep = dependencies(self.gtree, self.param['HEADRULES']) assert self.lascore != 1 or self.gbrack == self.cbrack, ( 'leaf ancestor score 1.0 but no exact match: (bug?)') self.pgbrack = parentedbracketings(self.gtree, labeled=True, dellabel=self.param['DELETE_LABEL'], disconly=self.param['DISC_ONLY']) self.pcbrack = parentedbracketings(self.ctree, labeled=True, dellabel=self.param['DELETE_LABEL'], disconly=self.param['DISC_ONLY']) self.grule = multiset((node.indices, rule) for node, rule in zip(self.gtree.subtrees(), grammar.lcfrsproductions(self.gtree, self.gsent))) self.crule = multiset((node.indices, rule) for node, rule in zip(self.ctree.subtrees(), grammar.lcfrsproductions(self.ctree, self.csent)))
def getunknownwordmodel(tagged_sents, unknownword, unknownthreshold, openclassthreshold): """ Compute an unknown word model that smooths lexical probabilities for unknown & rare words. :param tagged_sents: the sentences from the training set with the gold POS tags from the treebank. :param unknownword: a function that returns a signature for a given word; e.g., "eschewed" => "_UNK-L-d". :param unknownthreshold: words with frequency lower than or equal to this are replaced by their signature. :param openclassthreshold: tags that rewrite to at least this much word types are considered to be open class categories. """ wordsfortag = defaultdict(set) tags = multiset() wordtags = multiset() sigs = multiset() sigtag = multiset() words = multiset(word for sent in tagged_sents for word, tag in sent) lexicon = {word for word, freq in words.items() if freq > unknownthreshold} wordsig = {} for sent in tagged_sents: for n, (word, tag) in enumerate(sent): wordsfortag[tag].add(word) tags[tag] += 1 wordtags[word, tag] += 1 sig = unknownword(word, n, lexicon) wordsig[word] = sig # NB: sig may also depend on n and lexicon sigtag[sig, tag] += 1 if openclassthreshold: openclasstags = {tag: len({w.lower() for w in ws}) for tag, ws in wordsfortag.items() if len({w.lower() for w in ws}) >= openclassthreshold} closedclasstags = set(tags) - set(openclasstags) closedclasswords = {word for tag in closedclasstags for word in wordsfortag[tag]} openclasswords = lexicon - closedclasswords # add rare closed-class words back to lexicon lexicon.update(closedclasswords) else: openclasstags = {} openclasswords = {} for sent in tagged_sents: for n, (word, _) in enumerate(sent): if word not in lexicon: sig = unknownword(word, n, lexicon) sigs[sig] += 1 msg = "known words: %d, signature types seen: %d\n" % ( len(lexicon), len(sigs)) msg += "open class tags: {%s}\n" % ", ".join( "%s:%d" % a for a in openclasstags.items()) msg += "closed class tags: {%s}" % ", ".join(a for a in closedclasstags) return (sigs, words, lexicon, wordsfortag, openclasstags, openclasswords, tags, wordtags, wordsig, sigtag), msg
def treebankgrammar(trees, sents): """ Induce a probabilistic LCFRS with relative frequencies of productions. When trees contain no discontinuities, the result is equivalent to a treebank PCFG. """ grammar = multiset(rule for tree, sent in zip(trees, sents) for rule in lcfrs_productions(tree, sent)) lhsfd = multiset() for rule, freq in grammar.items(): lhsfd[rule[0][0]] += freq for rule, freq in grammar.items(): grammar[rule] = Fraction(freq, lhsfd[rule[0][0]]) return sortgrammar(grammar.items())
def __init__(self, disconly=False): """:param disconly: if True, only collect discontinuous bracketings.""" self.disconly = disconly self.maxlenseen, self.sentcount = Decimal(0), Decimal(0) self.exact = Decimal(0) self.dicenoms, self.dicedenoms = Decimal(0), Decimal(0) self.goldb, self.candb = multiset(), multiset() # all brackets self.lascores = [] self.golddep, self.canddep = [], [] self.goldpos, self.candpos = [], [] # extra accounting for breakdowns: self.goldbcat = defaultdict(multiset) # brackets per category self.candbcat = defaultdict(multiset) self.goldbatt, self.candbatt = set(), set() # attachments per category self.goldrule, self.candrule = multiset(), multiset()
def bracketings(tree, labeled=True, dellabel=(), disconly=False): """Return the labeled set of bracketings for a tree. For each nonterminal node, the set will contain a tuple with the label and the set of terminals which it dominates. ``tree`` must have been processed by ``transform()``. The argument ``dellabel`` is only used to exclude the ROOT node from the results (because it cannot be deleted by ``transform()`` when non-unary). >>> tree = Tree.parse('(S (NP 1) (VP (VB 0) (JJ 2)))', parse_leaf=int) >>> params = {'DELETE_LABEL': set(), 'DELETE_WORD': set(), ... 'EQ_LABEL': {}, 'EQ_WORD': {}, ... 'DELETE_ROOT_PRETERMS': 0} >>> transform(tree, tree.leaves(), tree.pos(), dict(tree.pos()), ... params, set()) >>> sorted(bracketings(tree).items()) [(('S', (0, 1, 2)), 1), (('VP', (0, 2)), 1)] >>> tree = Tree.parse('(S (NP 1) (VP (VB 0) (JJ 2)))', parse_leaf=int) >>> params['DELETE_LABEL'] = {'VP'} >>> transform(tree, tree.leaves(), tree.pos(), dict(tree.pos()), ... params, set()) >>> bracketings(tree) Counter({('S', (0, 1, 2)): 1})""" return multiset(bracketing(a, labeled) for a in tree.subtrees() if a and isinstance(a[0], Tree) # nonempty, not a preterminal and a.label not in dellabel and (not disconly or disc(a)))
def catbreakdown(self, limit=10): """Print breakdowns for the most frequent labels.""" acc = self.acc print('\n Attachment errors (correct labeled bracketing, wrong parent)') print(' label cand gold count') print(' ' + 33 * '_') gmismatch = dict(acc.goldbatt - acc.candbatt) wrong = multiset((label, cparent, gmismatch[n, label, indices]) for (n, label, indices), cparent in acc.candbatt - acc.goldbatt if (n, label, indices) in gmismatch) for (cat, gparent, cparent), cnt in wrong.most_common(limit): print('%s %s %s %7d' % (cat.rjust(7), gparent.rjust(7), cparent.rjust(7), cnt)) print('\n Category Statistics (%s categories / errors)' % ( ('%d most frequent ' % limit) if limit else 'all')) print(' label % gold recall prec. F1', ' cand gold count') print(' ' + 38 * '_' + 7 * ' ' + 24 * '_') gmismatch = {(n, indices): label for n, (label, indices) in acc.goldb - acc.candb} wrong = multiset((label, gmismatch[n, indices]) for n, (label, indices) in acc.candb - acc.goldb if (n, indices) in gmismatch) freqcats = sorted(set(acc.goldbcat) | set(acc.candbcat), key=lambda x: len(acc.goldbcat[x]), reverse=True) for cat, mismatch in zip_longest(freqcats[:limit], wrong.most_common(limit)): if cat is None: print(39 * ' ', end='') else: print('%s %6.2f %s %s %s' % ( cat.rjust(7), 100 * sum(acc.goldbcat[cat].values()) / len(acc.goldb), nozerodiv(lambda: recall( acc.goldbcat[cat], acc.candbcat[cat])), nozerodiv(lambda: precision( acc.goldbcat[cat], acc.candbcat[cat])), nozerodiv(lambda: f_measure( acc.goldbcat[cat], acc.candbcat[cat])), ), end='') if mismatch is not None: print(' %s %7d' % (' '.join((mismatch[0][0].rjust(8), mismatch[0][1].ljust(8))), mismatch[1]), end='') print()
def treebankgrammar(trees, sents, extrarules=None): """Induce a probabilistic LCFRS with relative frequencies of productions. When trees contain no discontinuities, the result is equivalent to a treebank PCFG. :param extarules: A dictionary of productions that will be merged with the grammar, with (pseudo)frequencies as values.""" grammar = multiset(rule for tree, sent in zip(trees, sents) for rule in lcfrsproductions(tree, sent)) if extrarules is not None: for rule in extrarules: grammar[rule] += extrarules[rule] lhsfd = multiset() for rule, freq in grammar.items(): lhsfd[rule[0][0]] += freq return sortgrammar((rule, (freq, lhsfd[rule[0][0]])) for rule, freq in grammar.items())
def read_from_export(filename, param, delete_pos, encoding): """Read a signature from an export-format file. Returns a dict which maps sentence numbers to lists of bracketings.""" # will be returned signatures = {} pos_tags = {} # for reading export data within_sentence = False sentence = [] # stores tuples of terminals dominated by node tuples_by_nodenum = {} for line in io.open(filename, 'r', encoding=encoding): line = line.strip() if not within_sentence: if line.startswith("#BOS"): within_sentence = True sentence.append(line) else: sentence.append(line) if line.startswith("#EOS"): # complete sentence collected, process it within_sentence = False # get the sentence number from the EOS line sent_num = int(line.split(None, 1)[1]) # remove BOS and EOS lines, split lines into fields sentence = [export_split(a) for a in sentence[1:-1]] # extract all non-terminal labels labels_by_nodenum = {int(fields[NODE][1:]): fields[TAG] for fields in sentence if fields[NODE].startswith('#')} labels_by_nodenum[0] = u"VROOT" # intialize bracketing store for this sentence signatures[sent_num] = multiset() pos_tags[sent_num] = set() # get the non-terminal labels and the terminals which the # corresponding nodes dominate export_process_sentence(sentence, labels_by_nodenum, tuples_by_nodenum, pos_tags[sent_num], 0, param, delete_pos) for nodenum in tuples_by_nodenum: # the label of a nonterminal label = tuples_by_nodenum[nodenum][0] # the terminals dominated by this nonterminal terminals = tuples_by_nodenum[nodenum][1] # only add non-deleted, non-empty bracketings if (label not in param.get('DELETE_LABEL', ()) and terminals): bracketing = Bracketing( label if param['LABELED'] else 'X', terminals) signatures[sent_num][bracketing] += 1 # reset sentence = [] tuples_by_nodenum = {} return signatures, pos_tags
def parentedbracketings(tree, labeled=True, dellabel=(), disconly=False): """Return the labeled bracketings with parents for a tree. :returns: multiset with items of the form ``((label, indices), parentlabel)`` """ return multiset((bracketing(a, labeled), getattr(a.parent, 'label', '')) for a in tree.subtrees() if a and isinstance(a[0], Tree) # nonempty, not a preterminal and a.label not in dellabel and (not disconly or disc(a)))
def grammarinfo(grammar, dump=None): """ print(some statistics on a grammar, before it goes through Grammar().) :param dump: if given a filename, will dump distribution of parsing complexity to a file (i.e., p.c. 3 occurs 234 times, 4 occurs 120 times, etc.) """ from discodop.eval import mean lhs = {rule[0] for (rule, yf), w in grammar} l = len(grammar) result = "labels: %d" % len({rule[a] for (rule, yf), w in grammar for a in range(3) if len(rule) > a}) result += " of which preterminals: %d\n" % ( len({rule[0] for (rule, yf), w in grammar if rule[1] == 'Epsilon'}) or len({rule[a] for (rule, yf), w in grammar for a in range(1, 3) if len(rule) > a and rule[a] not in lhs})) ll = sum(1 for (rule, yf), w in grammar if rule[1] == 'Epsilon') result += "clauses: %d lexical clauses: %d" % (l, ll) result += " non-lexical clauses: %d\n" % (l - ll) n, r, yf, w = max((len(yf), rule, yf, w) for (rule, yf), w in grammar) result += "max fan-out: %d in " % n result += printrule(r, yf, w) result += " average: %g\n" % mean([len(yf) for (_, yf), _, in grammar]) n, r, yf, w = max((sum(map(len, yf)), rule, yf, w) for (rule, yf), w in grammar if rule[1] != 'Epsilon') result += "max variables: %d in %s\n" % (n, printrule(r, yf, w)) def parsingcomplexity(yf): """ this sums the fanouts of LHS & RHS """ if isinstance(yf[0], tuple): return len(yf) + sum(map(len, yf)) return 1 # NB: a lexical production has complexity 1 pc = {(rule, yf, w): parsingcomplexity(yf) for (rule, yf), w in grammar} r, yf, w = max(pc, key=pc.get) result += "max parsing complexity: %d in %s" % ( pc[r, yf, w], printrule(r, yf, w)) result += " average %g" % mean(pc.values()) if dump: pcdist = multiset(pc.values()) open(dump, "w").writelines("%d\t%d\n" % x for x in pcdist.items()) return result
def entrypoint(this_path, json_path): t1 = time.time() ds = routines.dicts_from_json_path(json_path) authors = multiset() for d in ds: if "author" in d: authors.update([d["author"]]) else: authors.update(None) print("AUTHORS:") maxlen = 0 for a in authors.keys(): l = len(a) if l > maxlen: maxlen = l lst = list(authors.items()) lst.sort(key=lambda p: p[1], reverse=True) assert lst != None for a, n in lst: a = f'"{a}"' print(f" {a.ljust(maxlen + 2)} : {n}") print("Total execution time: {:.3f} seconds.".format(time.time() - t1))
def evaluate(key, answer, param, encoding): """Initiate evaluation of answer file against key file (gold).""" # read signature from key file key_sig, key_tags = read_from_export(key, param, "DELETE_LABEL_FOR_LENGTH", encoding) # read signature from answer file answer_sig, answer_tags = read_from_export(answer, param, "DELETE_LABEL", encoding) if not key_sig: raise ValueError("no sentences in key") if len(answer_sig) > len(key_sig): raise ValueError("more sentences in answer than key") print("""\ sent. prec. rec. F1 match gold test words matched tags ====================================================================""") # missing sentences missing = 0 # number of matching bracketings (labeled) total_match = 0 # total number of brackets in key total_key = 0 # total number of brackets in answer total_answer = 0 total_exact = 0 total_words = total_matched_pos = total_sents = 0 # get all sentence numbers from gold for sent_num in sorted(key_sig): if len(key_tags[sent_num]) > param['CUTOFF_LEN']: continue total_sents += 1 sent_match = 0 # get bracketings for key # there must be something for every sentence in key if sent_num not in key_sig: raise ValueError("no data for sent. %d in key" % sent_num) key_sent_sig = key_sig[sent_num] # get bracketings for answer answer_sent_sig = None if sent_num in answer_sig: answer_sent_sig = answer_sig[sent_num] else: answer_sent_sig = multiset() missing += 1 # compute matching brackets sent_match = sum((key_sent_sig & answer_sent_sig).values()) if key_sent_sig == answer_sent_sig: total_exact += 1 sent_prec = 0.0 if len(answer_sent_sig) > 0: sent_prec = 100 * sent_match / len(answer_sent_sig) sent_rec = 0.0 if len(key_sent_sig) > 0: sent_rec = 100 * sent_match / len(key_sent_sig) sent_fb1 = 0.0 if sent_prec + sent_rec > 0: sent_fb1 = 2 * sent_prec * sent_rec / (sent_prec + sent_rec) tag_match = len(key_tags[sent_num] & answer_tags[sent_num]) print("%4d %6.2f %6.2f %6.2f %3d %3d %3d %3d %3d" % ( sent_num, sent_prec, sent_rec, sent_fb1, sent_match, len(key_sent_sig), len(answer_sent_sig), len(answer_tags[sent_num]), tag_match)) total_match += sent_match total_key += len(key_sent_sig) total_answer += len(answer_sent_sig) total_matched_pos += tag_match total_words += len(key_tags[sent_num]) prec = 0.0 if total_answer > 0: prec = 100 * total_match / total_answer rec = 0.0 if total_key > 0: rec = 100 * total_match / total_key fb1 = 0.0 if prec + rec > 0: fb1 = 2 * prec * rec / (prec + rec) labeled = ('unlabeled', 'labeled')[param['LABELED']] print("===========================================================") print() print() print("Summary (%s, <= %d):" % (labeled, param['CUTOFF_LEN'])) print("===========================================================") print() print("Sentences in key".ljust(30), ":", total_sents) print("Sentences missing in answer".ljust(30), ":", missing) print() print("Total edges in key".ljust(30), ":", total_key) print("Total edges in answer".ljust(30), ":", total_answer) print("Total matching edges".ljust(30), ":", total_match) print() print("POS : %6.2f %%" % (100 * total_matched_pos / total_words)) print("%sP : %6.2f %%" % (labeled[0].upper(), prec)) print("%sR : %6.2f %%" % (labeled[0].upper(), rec)) print("%sF1 : %6.2f %%" % (labeled[0].upper(), fb1)) print("EX : %6.2f %%" % (100 * total_exact / total_sents))
def doparsing(**kwds): """ Parse a set of sentences using worker processes. """ params = DictObj(usetags=True, numproc=None, tailmarker='', category=None, deletelabel=(), deleteword=(), corpusfmt='export') params.update(kwds) goldbrackets = multiset() totaltokens = 0 results = [DictObj(name=stage.name) for stage in params.parser.stages] for result in results: result.update(elapsedtime=dict.fromkeys(params.testset), parsetrees=dict.fromkeys(params.testset), brackets=multiset(), tagscorrect=0, exact=0, noparse=0) if params.numproc == 1: initworker(params) dowork = (worker(a) for a in params.testset.items()) else: pool = multiprocessing.Pool(processes=params.numproc, initializer=initworker, initargs=(params,)) dowork = pool.imap_unordered(worker, params.testset.items()) logging.info('going to parse %d sentences.', len(params.testset)) # main parse loop over each sentence in test corpus for nsent, data in enumerate(dowork, 1): sentid, msg, sentresults = data sent, goldtree, goldsent, _ = params.testset[sentid] logging.debug('%d/%d (%s). [len=%d] %s\n%s', nsent, len(params.testset), sentid, len(sent), ' '.join(a[0] for a in goldsent), msg) evaltree = goldtree.copy(True) evalmod.transform(evaltree, [w for w, _ in sent], evaltree.pos(), dict(evaltree.pos()), params.deletelabel, params.deleteword, {}, {}) goldb = evalmod.bracketings(evaltree, dellabel=params.deletelabel) goldbrackets.update((sentid, (label, span)) for label, span in goldb.elements()) totaltokens += sum(1 for _, t in goldsent if t not in params.deletelabel) for n, result in enumerate(sentresults): results[n].brackets.update((sentid, (label, span)) for label, span in result.candb.elements()) assert (results[n].parsetrees[sentid] is None and results[n].elapsedtime[sentid] is None) results[n].parsetrees[sentid] = result.parsetree results[n].elapsedtime[sentid] = result.elapsedtime if result.noparse: results[n].noparse += 1 if result.exact: results[n].exact += 1 results[n].tagscorrect += sum(1 for (_, a), (_, b) in zip(goldsent, sorted(result.parsetree.pos())) if b not in params.deletelabel and a == b) logging.debug( '%s cov %5.2f tag %5.2f ex %5.2f lp %5.2f lr %5.2f lf %5.2f%s', result.name.ljust(7), 100 * (1 - results[n].noparse / nsent), 100 * (results[n].tagscorrect / totaltokens), 100 * (results[n].exact / nsent), 100 * evalmod.precision(goldbrackets, results[n].brackets), 100 * evalmod.recall(goldbrackets, results[n].brackets), 100 * evalmod.f_measure(goldbrackets, results[n].brackets), ('' if n + 1 < len(sentresults) else '\n')) if params.numproc != 1: pool.terminate() pool.join() del dowork, pool writeresults(results, params) return results, goldbrackets
def breakdowns(param, goldb, candb, goldpos, candpos, goldbcat, candbcat, maxlenseen): """ Print breakdowns for the most frequent labels / tags. """ if param['LABELED'] and param['DEBUG'] != -1: print() print(' Category Statistics (10 most frequent categories / errors)', end='') if maxlenseen > param['CUTOFF_LEN']: print(' for length <= %d' % param['CUTOFF_LEN'], end='') print() print(' label % gold recall prec. F1', ' test/gold count') print('_______________________________________', ' ____________________') gmismatch = {(n, indices): label for n, (label, indices) in (goldb - candb)} wrong = multiset((label, gmismatch[n, indices]) for n, (label, indices) in (candb - goldb) if (n, indices) in gmismatch) freqcats = sorted(set(goldbcat) | set(candbcat), key=lambda x: len(goldbcat[x]), reverse=True) for cat, mismatch in zip_longest(freqcats[:10], wrong.most_common(10)): if cat is None: print(' ', end='') else: print('%s %6.2f %s %s %s' % ( cat.rjust(7), 100 * sum(goldbcat[cat].values()) / len(goldb), nozerodiv(lambda: recall(goldbcat[cat], candbcat[cat])), nozerodiv(lambda: precision(goldbcat[cat], candbcat[cat])), nozerodiv(lambda: f_measure(goldbcat[cat], candbcat[cat])), ), end='') if mismatch is not None: print(' %s %7d' % ( '/'.join(mismatch[0]).rjust(12), mismatch[1]), end='') print() if accuracy(goldpos, candpos) != 1: print() print(' Tag Statistics (10 most frequent tags / errors)', end='') if maxlenseen > param['CUTOFF_LEN']: print(' for length <= %d' % param['CUTOFF_LEN'], end='') print('\n tag % gold recall prec. F1', ' test/gold count') print('_______________________________________', ' ____________________') tags = multiset(tag for _, tag in goldpos) wrong = multiset((c, g) for (_, g), (_, c) in zip(goldpos, candpos) if g != c) for tag, mismatch in zip_longest(tags.most_common(10), wrong.most_common(10)): if tag is None: print(''.rjust(40), end='') else: goldtag = multiset(n for n, (w, t) in enumerate(goldpos) if t == tag[0]) candtag = multiset(n for n, (w, t) in enumerate(candpos) if t == tag[0]) print('%s %6.2f %6.2f %6.2f %6.2f' % ( tag[0].rjust(7), 100 * len(goldtag) / len(goldpos), 100 * recall(goldtag, candtag), 100 * precision(goldtag, candtag), 100 * f_measure(goldtag, candtag)), end='') if mismatch is not None: print(' %s %7d' % ( '/'.join(mismatch[0]).rjust(12), mismatch[1]), end='') print() print()
def parsetepacoc( stages=(dict(mode='pcfg', split=True, markorigin=True), dict(mode='plcfrs', prune=True, k=10000, splitprune=True), dict(mode='plcfrs', prune=True, k=5000, dop=True, usedoubledop=True, estimator='dop1', objective='mpp', sample=False, kbest=True)), trainmaxwords=999, trainnumsents=25005, testmaxwords=999, bintype='binarize', h=1, v=1, factor='right', tailmarker='', markhead=False, revmarkov=False, pospa=False, leftmostunary=True, rightmostunary=True, fanout_marks_before_bin=False, transformations=None, usetagger='stanford', resultdir='tepacoc', numproc=1): """ Parse the tepacoc test set. """ for stage in stages: for key in stage: assert key in DEFAULTSTAGE, 'unrecognized option: %r' % key stages = [DictObj({k: stage.get(k, v) for k, v in DEFAULTSTAGE.items()}) for stage in stages] os.mkdir(resultdir) # Log everything, and send it to stderr, in a format with just the message. formatstr = '%(message)s' logging.basicConfig(level=logging.DEBUG, format=formatstr) # log up to INFO to a results log file fileobj = logging.FileHandler(filename='%s/output.log' % resultdir) fileobj.setLevel(logging.INFO) fileobj.setFormatter(logging.Formatter(formatstr)) logging.getLogger('').addHandler(fileobj) tepacocids, tepacocsents = readtepacoc() try: (corpus_sents, corpus_taggedsents, corpus_trees, corpus_blocks) = pickle.load( gzip.open('tiger.pickle.gz', 'rb')) except IOError: # file not found corpus = getreader('export')('../tiger/corpus', 'tiger_release_aug07.export', headrules='negra.headrules' if bintype == 'binarize' else None, headfinal=True, headreverse=False, punct='move', encoding='iso-8859-1') corpus_sents = list(corpus.sents().values()) corpus_taggedsents = list(corpus.tagged_sents().values()) corpus_trees = list(corpus.parsed_sents().values()) if transformations: corpus_trees = [transform(tree, sent, transformations) for tree, sent in zip(corpus_trees, corpus_sents)] corpus_blocks = list(corpus.blocks().values()) pickle.dump((corpus_sents, corpus_taggedsents, corpus_trees, corpus_blocks), gzip.open('tiger.pickle.gz', 'wb'), protocol=-1) # test sets (one for each category) testsets = {} allsents = [] for cat, catsents in tepacocsents.items(): testset = sents, trees, goldsents, blocks = [], [], [], [] for n, sent in catsents: if sent != corpus_sents[n]: logging.error( 'mismatch. sent %d:\n%r\n%r\n' 'not in corpus %r\nnot in tepacoc %r', n + 1, sent, corpus_sents[n], [a for a, b in zip_longest(sent, corpus_sents[n]) if a and a != b], [b for a, b in zip_longest(sent, corpus_sents[n]) if b and a != b]) elif len(corpus_sents[n]) <= testmaxwords: sents.append(corpus_taggedsents[n]) trees.append(corpus_trees[n]) goldsents.append(corpus_taggedsents[n]) blocks.append(corpus_blocks[n]) allsents.extend(sents) logging.info('category: %s, %d of %d sentences', cat, len(testset[0]), len(catsents)) testsets[cat] = testset testsets['baseline'] = zip(*[sent for n, sent in enumerate(zip(corpus_taggedsents, corpus_trees, corpus_taggedsents, corpus_blocks)) if len(sent[1]) <= trainmaxwords and n not in tepacocids][trainnumsents:trainnumsents + 2000]) allsents.extend(testsets['baseline'][0]) if usetagger: overridetags = ('PTKANT', 'VAIMP') taglex = defaultdict(set) for sent in corpus_taggedsents[:trainnumsents]: for word, tag in sent: taglex[word].add(tag) overridetagdict = {tag: {word for word, tags in taglex.items() if tags == {tag}} for tag in overridetags} tagmap = {'$(': '$[', 'PAV': 'PROAV', 'PIDAT': 'PIAT'} # the sentences in the list allsents are modified in-place so that # the relevant copy in testsets[cat][0] is updated as well. externaltagging(usetagger, '', allsents, overridetagdict, tagmap) # training set trees, sents, blocks = zip(*[sent for n, sent in enumerate(zip(corpus_trees, corpus_sents, corpus_blocks)) if len(sent[1]) <= trainmaxwords and n not in tepacocids][:trainnumsents]) getgrammars(trees, sents, stages, bintype, h, v, factor, tailmarker, revmarkov, leftmostunary, rightmostunary, pospa, markhead, fanout_marks_before_bin, testmaxwords, resultdir, numproc, None, False, trees[0].label, None) del corpus_sents, corpus_taggedsents, corpus_trees, corpus_blocks results = {} cnt = 0 parser = Parser(stages, tailmarker=tailmarker, transformations=transformations) for cat, testset in sorted(testsets.items()): if cat == 'baseline': continue logging.info('category: %s', cat) begin = time.clock() results[cat] = doparsing(parser=parser, testset=testset, resultdir=resultdir, usetags=True, numproc=numproc, category=cat) cnt += len(testset[0]) if numproc == 1: logging.info('time elapsed during parsing: %g', time.clock() - begin) #else: # wall clock time here goldbrackets = multiset() totresults = [DictObj(name=stage.name) for stage in stages] for result in totresults: result.elapsedtime = [None] * cnt result.parsetrees = [None] * cnt result.brackets = multiset() result.exact = result.noparse = 0 goldblocks = [] goldsents = [] for cat, res in results.items(): logging.info('category: %s', cat) goldbrackets |= res[2] goldblocks.extend(res[3]) goldsents.extend(res[4]) for result, totresult in zip(res[0], totresults): totresult.exact += result.exact totresult.noparse += result.noparse totresult.brackets |= result.brackets totresult.elapsedtime.extend(result.elapsedtime) oldeval(*res) logging.info('TOTAL') oldeval(totresults, goldbrackets) # write TOTAL results file with all tepacoc sentences (not the baseline) for stage in stages: open('TOTAL.%s.export' % stage.name, 'w').writelines( open('%s.%s.export' % (cat, stage.name)).read() for cat in list(results) + ['gold']) # do baseline separately because it shouldn't count towards the total score cat = 'baseline' logging.info('category: %s', cat) oldeval(*doparsing(parser=parser, testset=testsets[cat], resultdir=resultdir, usetags=True, numproc=numproc, category=cat))
def doubledop(trees, fragments, debug=False): """ Extract a Double-DOP grammar from a treebank. That is, a fragment grammar containing all fragments that occur at least twice, plus all individual productions needed to obtain full coverage. Input trees need to be binarized. A second level of binarization (a normal form) is needed when fragments are converted to individual grammar rules, which occurs through the removal of internal nodes. The binarization adds unique identifiers so that each grammar rule can be mapped back to its fragment. In fragments with terminals, we replace their POS tags with a tag uniquely identifying that terminal and tag: tag@word. :returns: a tuple (grammar, altweights, backtransform) altweights is a dictionary containing alternate weights. """ def getweight(frag, terminals): """ :returns: frequency, EWE, and other weights for fragment. """ freq = sum(fragments[frag, terminals].values()) root = frag[1:frag.index(' ')] nonterms = frag.count('(') - 1 # Sangati & Zuidema (2011, eq. 5) # FIXME: verify that this formula is equivalent to Bod (2003). ewe = sum(Fraction(v, fragmentcount[k]) for k, v in fragments[frag, terminals].items()) # Bonnema (2003, p. 34) bon = 2 ** -nonterms * (freq / ntfd[root]) short = 0.5 return freq, ewe, bon, short uniformweight = (1, 1, 1, 1) grammar = {} backtransform = {} ids = UniqueIDs() # build index of the number of fragments extracted from a tree for ewe fragmentcount = defaultdict(int) for indices in fragments.values(): for index, cnt in indices.items(): fragmentcount[index] += cnt # ntfd: frequency of a non-terminal node in treebank ntfd = multiset(node.label for tree in trees for node in tree.subtrees()) # binarize, turn to lcfrs productions # use artificial markers of binarization as disambiguation, # construct a mapping of productions to fragments for frag, terminals in fragments: prods, newfrag = flatten(frag, terminals, ids) prod = prods[0] if prod[0][1] == 'Epsilon': # lexical production grammar[prod] = getweight(frag, terminals) continue elif prod in backtransform: # normally, rules of fragments are disambiguated by binarization IDs # in case there's a fragment with only one or two frontier nodes, # we add an artficial node. newlabel = "%s}<%d>%s" % (prod[0][0], next(ids), '' if len(prod[1]) == 1 else '_%d' % len(prod[1])) prod1 = ((prod[0][0], newlabel) + prod[0][2:], prod[1]) # we have to determine fanout of the first nonterminal # on the right hand side prod2 = ((newlabel, prod[0][1]), tuple((0,) for component in prod[1] for a in component if a == 0)) prods[:1] = [prod1, prod2] # first binarized production gets prob. mass grammar[prod] = getweight(frag, terminals) grammar.update(zip(prods[1:], repeat(uniformweight))) # & becomes key in backtransform backtransform[prod] = newfrag if debug: ids = count() flatfrags = [flatten(frag, terminals, ids) for frag, terminals in fragments] print("recurring fragments:") for a, b in zip(flatfrags, fragments): print("fragment: %s\nprod: %s" % (b[0], "\n\t".join( printrule(r, yf, 0) for r, yf in a[0]))) print("template: %s\nfreq: %2d sent: %s\n" % ( a[1], len(fragments[b]), ' '.join('_' if x is None else quotelabel(x) for x in b[1]))) print("backtransform:") for a, b in backtransform.items(): print(a, b) # fix order of grammar rules; backtransform will mirror this order grammar = sortgrammar(grammar.items()) # replace keys with numeric ids of rules, drop terminals. backtransform = [backtransform[r] for r, _ in grammar if r in backtransform] # relative frequences as probabilities (don't normalize shortest & bon) ntsums = defaultdict(int) ntsumsewe = defaultdict(int) for rule, (freq, ewe, _, _) in grammar: ntsums[rule[0][0]] += freq ntsumsewe[rule[0][0]] += ewe eweweights = [float(ewe) / ntsumsewe[rule[0][0]] for rule, (_, ewe, _, _) in grammar] bonweights = [bon for rule, (_, _, bon, _) in grammar] shortest = [s for rule, (_, _, _, s) in grammar] grammar = [(rule, Fraction(freq, ntsums[rule[0][0]])) for rule, (freq, _, _, _) in grammar] return grammar, backtransform, dict( ewe=eweweights, bon=bonweights, shortest=shortest)
def doeval(gold_trees, gold_sents, cand_trees, cand_sents, param): """ Do the actual evaluation on given parse trees and parameters. Results are printed to standard output. """ assert gold_trees, 'no trees in gold file' assert cand_trees, 'no trees in parses file' keylen = max(len(str(x)) for x in cand_trees) if param['DEBUG'] == 1: print('Parameters:') for a in param: print('%s\t%s' % (a, param[a])) for a in HEADER: print(' ' * (keylen - 4) + a) print('', '_' * ((keylen - 5) + len(HEADER[-1]))) # the suffix '40' is for the length restricted results maxlenseen = sentcount = maxlenseen40 = sentcount40 = 0 goldb = multiset() candb = multiset() goldb40 = multiset() candb40 = multiset() goldbcat = defaultdict(multiset) candbcat = defaultdict(multiset) goldbcat40 = defaultdict(multiset) candbcat40 = defaultdict(multiset) lascores = [] lascores40 = [] golddep = [] canddep = [] golddep40 = [] canddep40 = [] goldpos = [] candpos = [] goldpos40 = [] candpos40 = [] exact = exact40 = 0.0 dicenoms = dicedenoms = dicenoms40 = dicedenoms40 = 0 for n, ctree in cand_trees.items(): gtree = gold_trees[n] cpos = sorted(ctree.pos()) gpos = sorted(gtree.pos()) csent = [w for w, _ in cand_sents[n]] gsent = [w for w, _ in gold_sents[n]] lencpos = sum(1 for _, b in cpos if b not in param['DELETE_LABEL_FOR_LENGTH']) lengpos = sum(1 for _, b in gpos if b not in param['DELETE_LABEL_FOR_LENGTH']) assert lencpos == lengpos, ('sentence length mismatch. ' 'sents:\n%s\n%s' % (' '.join(csent), ' '.join(gsent))) # massage the data (in-place modifications) transform(ctree, csent, cpos, dict(gpos), param['DELETE_LABEL'], param['DELETE_WORD'], param['EQ_LABEL'], param['EQ_WORD']) transform(gtree, gsent, gpos, dict(gpos), param['DELETE_LABEL'], param['DELETE_WORD'], param['EQ_LABEL'], param['EQ_WORD']) #if not gtree or not ctree: # continue assert csent == gsent, ('candidate & gold sentences do not match:\n' '%r // %r' % (' '.join(csent), ' '.join(gsent))) cbrack = bracketings(ctree, param['LABELED'], param['DELETE_LABEL'], param['DISC_ONLY']) gbrack = bracketings(gtree, param['LABELED'], param['DELETE_LABEL'], param['DISC_ONLY']) if not param['DISC_ONLY'] or cbrack or gbrack: sentcount += 1 # this is to deal with 'sentences' with only a single punctuation mark. if not gpos: continue if maxlenseen < lencpos: maxlenseen = lencpos if cbrack == gbrack: if not param['DISC_ONLY'] or cbrack or gbrack: exact += 1 candb.update((n, a) for a in cbrack.elements()) goldb.update((n, a) for a in gbrack.elements()) for a in gbrack: goldbcat[a[0]][(n, a)] += 1 for a in cbrack: candbcat[a[0]][(n, a)] += 1 goldpos.extend(gpos) candpos.extend(cpos) lascores.append(leafancestor(gtree, ctree, param['DELETE_LABEL'])) if param['TED']: ted, denom = treedisteval(gtree, ctree, includeroot=gtree.label not in param['DELETE_LABEL']) dicenoms += ted dicedenoms += denom if param['DEP']: cdep = dependencies(ctree, param['HEADRULES']) gdep = dependencies(gtree, param['HEADRULES']) canddep.extend(cdep) golddep.extend(gdep) if lencpos <= param['CUTOFF_LEN']: if not param['DISC_ONLY'] or cbrack or gbrack: sentcount40 += 1 if maxlenseen40 < lencpos: maxlenseen40 = lencpos candb40.update((n, a) for a in cbrack.elements()) goldb40.update((n, a) for a in gbrack.elements()) for a in gbrack: goldbcat40[a[0]][(n, a)] += 1 for a in cbrack: candbcat40[a[0]][(n, a)] += 1 if cbrack == gbrack: if not param['DISC_ONLY'] or cbrack or gbrack: exact40 += 1 goldpos40.extend(gpos) candpos40.extend(cpos) if lascores[-1] is not None: lascores40.append(lascores[-1]) if param['TED']: dicenoms40 += ted dicedenoms40 += denom if param['DEP']: canddep40.extend(cdep) golddep40.extend(gdep) assert lascores[-1] != 1 or gbrack == cbrack, ( 'leaf ancestor score 1.0 but no exact match: (bug?)') if lascores[-1] is None: del lascores[-1] if param['DEBUG'] <= 0: continue if param['DEBUG'] > 1: for a in HEADER: print(' ' * (keylen - 4) + a) print('', '_' * ((keylen - 5) + len(HEADER[-1]))) print(('%' + str(keylen) + 's %5d %s %s %5d %5d %5d %5d %4d %s %6.2f%s%s') % ( n, lengpos, nozerodiv(lambda: recall(gbrack, cbrack)), nozerodiv(lambda: precision(gbrack, cbrack)), sum((gbrack & cbrack).values()), sum(gbrack.values()), sum(cbrack.values()), len(gpos), sum(1 for a, b in zip(gpos, cpos) if a == b), nozerodiv(lambda: accuracy(gpos, cpos)), 100 * lascores[-1], str(ted).rjust(3) if param['TED'] else '', nozerodiv(lambda: accuracy(gdep, cdep)) if param['DEP'] else '')) if param['DEBUG'] > 1: print('Sentence:', ' '.join(gsent)) print('Gold tree:\n%s\nCandidate tree:\n%s' % ( DrawTree(gtree, gsent, abbr=True).text( unicodelines=True, ansi=True), DrawTree(ctree, csent, abbr=True).text( unicodelines=True, ansi=True))) print('Gold brackets: %s\nCandidate brackets: %s' % ( strbracketings(gbrack), strbracketings(cbrack))) print('Matched brackets: %s\nUnmatched brackets: %s' % ( strbracketings(gbrack & cbrack), strbracketings((cbrack - gbrack) | (gbrack - cbrack)))) goldpaths = leafancestorpaths(gtree, param['DELETE_LABEL']) candpaths = leafancestorpaths(ctree, param['DELETE_LABEL']) for leaf in goldpaths: print('%6.3g %s %s : %s' % ( pathscore(goldpaths[leaf], candpaths[leaf]), gsent[leaf].ljust(15), ' '.join(goldpaths[leaf][::-1]).rjust(20), ' '.join(candpaths[leaf][::-1]))) print('%6.3g average = leaf-ancestor score' % lascores[-1]) print('POS: ', ' '.join('%s/%s' % (a[1], b[1]) for a, b in zip(cpos, gpos))) if param['TED']: print('Tree-dist: %g / %g = %g' % ( ted, denom, 1 - ted / denom)) newtreedist(gtree, ctree, True) if param['DEP']: print('Sentence:', ' '.join(gsent)) print('dependencies gold', ' ' * 35, 'cand') for (_, a, b), (_, c, d) in zip(gdep, cdep): # use original sentences because we don't delete # punctuation for dependency evaluation print('%15s -> %15s %15s -> %15s' % ( gold_sents[n][a - 1][0], gold_sents[n][b - 1][0], cand_sents[n][c - 1][0], cand_sents[n][d - 1][0])) print() breakdowns(param, goldb40, candb40, goldpos40, candpos40, goldbcat40, candbcat40, maxlenseen) msg = summary(param, goldb, candb, goldpos, candpos, sentcount, maxlenseen, exact, lascores, dicenoms, dicedenoms, golddep, canddep, goldb40, candb40, goldpos40, candpos40, sentcount40, maxlenseen40, exact40, lascores40, dicenoms40, dicedenoms40, golddep40, canddep40) return msg
def doubledop(trees, sents, debug=False, binarized=True, complement=False, iterate=False, numproc=None, extrarules=None): """Extract a Double-DOP grammar from a treebank. That is, a fragment grammar containing fragments that occur at least twice, plus all individual productions needed to obtain full coverage. Input trees need to be binarized. A second level of binarization (a normal form) is needed when fragments are converted to individual grammar rules, which occurs through the removal of internal nodes. The binarization adds unique identifiers so that each grammar rule can be mapped back to its fragment. In fragments with terminals, we replace their POS tags with a tag uniquely identifying that terminal and tag: ``tag@word``. :param binarized: Whether the resulting grammar should be binarized. :param iterate, complement, numproc: cf. fragments.getfragments() :returns: a tuple (grammar, altweights, backtransform) altweights is a dictionary containing alternate weights.""" def getweight(frag, terminals): """:returns: frequency, EWE, and other weights for fragment.""" freq = sum(fragments[frag, terminals].values()) root = frag[1:frag.index(' ')] nonterms = frag.count('(') - 1 # Sangati & Zuidema (2011, eq. 5) # FIXME: verify that this formula is equivalent to Bod (2003). ewe = sum(v / fragmentcount[k] for k, v in fragments[frag, terminals].items()) # Bonnema (2003, p. 34) bon = 2 ** -nonterms * (freq / ntfd[root]) short = 0.5 return freq, ewe, bon, short from discodop.fragments import getfragments uniformweight = (1, 1, 1, 1) grammar = {} backtransform = {} ids = UniqueIDs() fragments = getfragments(trees, sents, numproc, iterate=iterate, complement=complement) # build index of the number of fragments extracted from a tree for ewe fragmentcount = defaultdict(int) for indices in fragments.values(): for index, cnt in indices.items(): fragmentcount[index] += cnt # ntfd: frequency of a non-terminal node in treebank ntfd = multiset(node.label for tree in trees for node in tree.subtrees()) # binarize, turn into LCFRS productions # use artificial markers of binarization as disambiguation, # construct a mapping of productions to fragments for origfrag in fragments: frag, terminals = origfrag prods, newfrag = flatten(frag, terminals, ids, backtransform, binarized) prod = prods[0] if prod[0][1] == 'Epsilon': # lexical production grammar[prod] = getweight(frag, terminals) continue # first binarized production gets prob. mass grammar[prod] = getweight(frag, terminals) grammar.update(zip(prods[1:], repeat(uniformweight))) # & becomes key in backtransform backtransform[prod] = origfrag, newfrag if debug: ids = count() flatfrags = [flatten(frag, terminals, ids, {}, binarized) for frag, terminals in fragments] print("recurring fragments:") for a, b in zip(flatfrags, fragments): print("fragment: %s\nprod: %s" % (b[0], "\n\t".join( printrule(r, yf, 0) for r, yf in a[0]))) print("template: %s\nfreq: %2d sent: %s\n" % ( a[1], len(fragments[b]), ' '.join('_' if x is None else quotelabel(x) for x in b[1]))) print("backtransform:") for a, b in backtransform.items(): print(a, b) if extrarules is not None: for rule in extrarules: x = extrarules[rule] a = b = c = 0 if rule in grammar: a, b, c, _ = grammar[rule] grammar[rule] = (a + x, b + x, c + x, 0.5) # fix order of grammar rules grammar = sortgrammar(grammar.items()) # align fragments and backtransform with corresponding grammar rules fragments = OrderedDict((frag, fragments[frag]) for frag in (backtransform[rule][0] for rule, _ in grammar if rule in backtransform)) backtransform = [backtransform[rule][1] for rule, _ in grammar if rule in backtransform] # relative frequences as probabilities (don't normalize shortest & bon) ntsums = defaultdict(int) ntsumsewe = defaultdict(int) for rule, (freq, ewe, _, _) in grammar: ntsums[rule[0][0]] += freq ntsumsewe[rule[0][0]] += ewe eweweights = [float(ewe) / ntsumsewe[rule[0][0]] for rule, (_, ewe, _, _) in grammar] bonweights = [bon for rule, (_, _, bon, _) in grammar] shortest = [s for rule, (_, _, _, s) in grammar] grammar = [(rule, (freq, ntsums[rule[0][0]])) for rule, (freq, _, _, _) in grammar] return grammar, backtransform, dict( ewe=eweweights, bon=bonweights, shortest=shortest), fragments
def apply_rule(self, obs1, obs2, obs3): label, count = multiset(obs1.elements + obs2.elements + obs3.elements).most_common()[0] return Observation(label, self.confidence_boost(obs1, obs2, obs3), number_observations=count)
#!/usr/bin/env python3 from itertools import product, combinations, starmap from collections import Counter as multiset from tabulate import tabulate # Configure MIN_N, MAX_N = 1, 50 table = [["N", "ob(N)", "sb(N)", "quotient", "branch lengths", "orbits", "fundamental", "err"]] for N, odd_N in [(N, N%2) for N in range(MIN_N, MAX_N+1)]: ob_branches, sb_branches, lengths = multiset(), set(), multiset() # Identify the middle rows and columns of the board, based on natural 1..N coordinate-indices indices = tuple(range(1,N+1)) middle = tuple(indices[(N-1)//2 : N//2+1]) # middle/median indices (single if odd_n else pair) intersection = frozenset(product(middle, middle)) rows = (product(indices, middle),) if odd_N else (product(indices, middle[:1]), product(indices, middle[1:])) cols = (product(middle, indices),) if odd_N else (product(middle[:1], indices), product(middle[1:], indices)) numrc = len(rows) + len(cols) # Key Functions def legal(branch): return ((len(branch) == numrc or (len(branch) == numrc-1 and len(branch & intersection))) and all((x,y)==(a,b) or (x!=a and y!=b and x+y!=a+b and x-y!=a-b) for (x,y),(a,b) in combinations(branch, 2))) def symmetries(squares): mirror, rotate = lambda x,y: (N-x+1, y), lambda x,y: (N-y+1, x) mirrors, rotates = lambda squares: starmap(mirror, squares), lambda squares: starmap(rotate, squares) return { frozenset(squares), frozenset(mirrors(squares)), frozenset(mirrors(rotates(squares))), frozenset(mirrors(rotates(rotates(squares)))), frozenset(mirrors(rotates(rotates(rotates(squares))))), frozenset(rotates(squares)), frozenset(rotates(rotates(squares))), frozenset(rotates(rotates(rotates(squares)))) } def lexo(branch): return tuple(sorted(map(sorted,symmetries(branch)))[0])
def test_condition(self, obs1, obs2, obs3): mset1 = multiset(obs1.elements + obs2.elements) mset2 = multiset(obs2.elements + obs3.elements) return self._compute_jaccard_similarity(mset1, mset2) > self.similarity_distance_threshold