def ancestral_sample(num_samples: int, forest: CFG, tsort: list, edge_weights: dict, inside: dict, root: Symbol) -> dict: """Returns the viterbi decoding of hypergraph""" samples = list() for i in range(num_samples): Q = deque([root]) S = list() while Q: symbol = Q.popleft() incoming = forest.get(symbol) weights = [0.0] * len(incoming) for i, edge in enumerate(incoming): weights[i] = edge_weights[edge] for u in edge.rhs: # u in tail(e) weights[i] *= inside[u] # TODO: change to log-sum-exp probs = np.array(weights) / sum(weights) index = np.argmax(np.random.multinomial(1, probs)) selected = incoming[index] for sym in selected.rhs: if not sym.is_terminal(): Q.append(sym) S.append(selected) samples.append(S) # hack since list is unhashable type, so we cannot use Counter (bummer) ys = [write_derrivation(d).pop() for d in samples] most_y, counts = Counter(ys).most_common(1)[0] dic = {y: d for y, d in zip(ys, samples)} most_sampled = dic[most_y] return most_sampled, counts
def predict(cfg: CFG, item: Item) -> list: """ Prediction for Earley. Inference rule: [X -> alpha * Y beta, [r, ..., s]] -------------------- (Y -> gamma) \in R [Y -> * gamma, [s]] R is the ruleset of the grammar. :param item: an active Item :returns: a list of predicted Items or None """ return [Item(rule, [item.dot]) for rule in cfg.get(item.next)]
def outside_algorithm(forest: CFG, tsort: list, edge_weights: dict, inside: dict, root: Symbol) -> dict: """Returns the outside weight of each node""" O = dict() for symbol in tsort: O[symbol] = 0.0 O[root] = 1.0 for symbol in reversed(tsort): incoming = forest.get(symbol) for edge in incoming: for u in edge.rhs: # u in tail(e) k = edge_weights[edge] * O[symbol] for s in edge.rhs: if not u == s: k *= inside[s] # TODO: change to log-sum-exp O[u] += k return O
def top_sort(forest: CFG) -> list: """Returns ordered list of nodes according to topsort order in an acyclic forest""" S = {symbol for symbol in forest.terminals } # (Copy!) only terminals have no dependecies D = {symbol: {child for rule in forest.get(symbol) for child in rule.rhs}\ for symbol in forest.nonterminals|forest.terminals} # forest.nonterminals|forest.terminals = V L = list() while S: # while S nonempty u = S.pop() L.append(u) outgoing = [e for e in forest if u in e.rhs] # outgoing = FS(u) for rule in outgoing: v = rule.lhs D[v] = D[v] - {u} if len(D[v]) == 0: S = S | {v} return L
def inside_algorithm(forest: CFG, tsort: list, edge_weights: dict) -> dict: """Returns the inside weight of each node""" I = dict() for symbol in tsort: # symbol is v incoming = forest.get( symbol ) # BS(v) - gets all the incoming nodes, i.e. all rules where symbol is lhs if len(incoming) == 0: I[symbol] = 1.0 # leaves else: w = 0.0 for edge in incoming: # edge is of type Rule k = edge_weights[edge] for child in edge.rhs: # chid in tail(e) k *= I[child] # TODO: change to log-sum-exp w += k I[symbol] = w return I
def viterbi_log(forest: CFG, tsort: list, edge_weights: dict, inside: dict, root: Symbol) -> dict: """Returns the viterbi decoding of hypergraph""" Q = deque([root]) V = list() while Q: symbol = Q.popleft() incoming = forest.get(symbol) weights = [1.0] * len(incoming) for i, edge in enumerate(incoming): weights[i] = np.exp(edge_weights[edge]) for u in edge.rhs: # u in tail(e) weights[i] += inside[u] # TODO: change to log-sum-exp weight, selected = max(zip(weights, incoming), key=lambda xy: xy[0]) for sym in selected.rhs: if not sym.is_terminal(): Q.append(sym) V.append(selected) return V
def axioms(cfg: CFG, fsa: FSA, s: Symbol) -> list: """ Axioms for Earley. Inference rule: -------------------- (S -> alpha) \in R and q0 \in I [S -> * alpha, [q0]] R is the rule set of the grammar. I is the set of initial states of the automaton. :param cfg: a CFG :param fsa: an FSA :param s: the CFG's start symbol (S) :returns: a list of items that are Earley axioms """ items = [] for q0 in fsa.iterinitial(): for rule in cfg.get(s): items.append(Item(rule, [q0])) return items
def inside_algorithm_log(forest: CFG, tsort: list, edge_weights: dict) -> dict: """Returns the inside weight of each node""" I = dict() for symbol in tsort: # symbol is v incoming = forest.get( symbol ) # BS(v) - gets all the incoming nodes, i.e. all rules where symbol is lhs if len(incoming) == 0: I[symbol] = 0.0 # leaves else: # w = 0.0 w = -np.inf parts = [] for edge in incoming: # edge is of type Rule k = edge_weights[edge] for child in edge.rhs: # chid in tail(e) k += I[child] # w = np.log(np.exp(w) + np.exp(k)) w = np.logaddexp(w, k) #total = parts[0] + reduce(sum, parts[1:]) I[symbol] = w return I
def write_derrivation(d): derivation_as_fsa = libitg.forest_to_fsa(CFG(d), d[0].lhs) candidates = libitg.enumerate_paths_in_fsa(derivation_as_fsa) return candidates
def earley(cfg: CFG, fsa: FSA, start_symbol: Symbol, sprime_symbol=None, eps_symbol=Terminal('-EPS-'), clean=True): """ Earley intersection between a CFG and an FSA. :param cfg: a grammar or forest :param fsa: an acyclic FSA :param start_symbol: the grammar/forest start symbol :param sprime_symbol: if specified, the resulting forest will have sprime_symbol as its starting symbol :param eps_symbol: if not None, the parser will support epsilon rules :param clean: if True, returns a forest without dead edges. :returns: a CFG object representing the intersection between the cfg and the fsa """ # start an agenda of items A = Agenda() # this is used to avoid a bit of spurious computation have_predicted = set() # populate the agenda with axioms for item in axioms(cfg, fsa, start_symbol): A.push(item) # call inference rules for as long as we have active items in the agenda while len(A) > 0: antecedent = A.pop() consequents = [] if antecedent.is_complete(): # dot at the end of rule # try to complete other items consequents = complete(A, antecedent) else: if antecedent.next.is_terminal(): # dot before a terminal consequents = scan(fsa, antecedent, eps_symbol=eps_symbol) else: # dot before a nonterminal if (antecedent.next, antecedent.dot ) not in have_predicted: # test for spurious computation consequents = predict(cfg, antecedent) # attempt prediction have_predicted.add((antecedent.next, antecedent.dot)) else: # we have already predicted in this context, let's attempt completion consequents = complete(A, antecedent) for item in consequents: A.push(item) # mark this antecedent as processed A.make_passive(antecedent) def iter_intersected_rules(): """ Here we convert complete items into CFG rules. This is a top-down process where we visit complete items at most once. """ # in the agenda, items are organised by "context" where a context is a tuple (LHS, start state) to_do = deque() # contexts to be processed discovered_set = set() # contexts discovered top_symbols = [ ] # here we store tuples of the kind (start_symbol, initial state, final state) # we start with items that rewrite the start_symbol from an initial FSA state for q0 in fsa.iterinitial(): to_do.append((start_symbol, q0)) # let's mark these as discovered discovered_set.add((start_symbol, q0)) # for as long as there are rules to be discovered while to_do: nonterminal, start = to_do.popleft() # give every complete item matching the context above a chance to yield a rule for end, items in A.complete(nonterminal, start): for item in items: # create a new LHS symbol based on intersected states lhs = Span(item.lhs, item.start, item.dot) # if LHS is the start_symbol, then we must respect FSA initial/final states # also, we must remember to add a goal rule for this if item.lhs == start_symbol: if not (fsa.is_initial(start) and fsa.is_final(item.dot)): continue # we discard this item because S can only span from initial to final in FSA else: top_symbols.append(lhs) # create new RHS symbols based on intersected states # and update discovered set rhs = [] for i, sym in enumerate(item.rule.rhs): context = (sym, item.state(i)) if not sym.is_terminal( ) and context not in discovered_set: to_do.append( context) # book this nonterminal context discovered_set.add(context) # mark as discovered # create a new RHS symbol based on intersected states rhs.append(Span(sym, item.state(i), item.state(i + 1))) yield Rule(lhs, rhs) if sprime_symbol: for lhs in top_symbols: yield Rule(sprime_symbol, [lhs]) # return the intersected CFG :) out_forest = CFG(iter_intersected_rules()) if clean: # possibly cleaning it first out_forest = cleanup_forest(out_forest, sprime_symbol) return out_forest
def cleanup_forest(forest: CFG, root: Symbol) -> CFG: """This wraps iter_useful_edges and return a clean CFG where every edge is useful""" return CFG(iter_useful_edges(forest, root))