def induce_structure(self, sentences): sentences = [[c for c in s] for s in sentences] start_symbols = set() productions = [] prod_table = {} # group all digits together digit_terminals = set([str(i) for i in range(10)]) # unary rules terminals = set() for s in sentences: terminals.update(s) for t in terminals: if t in digit_terminals: nt = nltk.Nonterminal("Digit") else: nt = nltk.Nonterminal("Unary%s" % self.gen_nt()) p = Production(nt, [t]) productions.append(p) prod_table[tuple(p.rhs())] = p.lhs() sentences = self.apply_unary_prod(sentences, prod_table) while len(sentences) > 0: if self.has_recursion(sentences): p = self.generate_recursive_prod(sentences) else: p = self.generate_most_frequent_prod(sentences) productions.append(p) prod_table[tuple(p.rhs())] = p.lhs() sentences = self.update_with_prod(sentences, prod_table) new_sentences = [] for s in sentences: if len(s) == 1: start_symbols.add(s[0]) else: new_sentences.append(s) sentences = new_sentences # generate the start productions for symbol in start_symbols: for p in productions: if p.lhs() == symbol: productions.append(Production(self.start, p.rhs())) self.grammar = nltk.induce_pcfg(self.start, productions)
def train(): print("Collecting sub-corpus from Penn Treebank (nltk.corpus)") # prepare parsing trees, extrated from treebank tbank_trees = [] for sent in treebank.parsed_sents(): sent.chomsky_normal_form() tbank_trees.append(sent) # build vocabulary list, extracted from treebank vocab_size = 10000 # set vocabulary size to 10000 words = [wrd.lower() for wrd in treebank.words()] vocab = [wrd for wrd,freq in Counter(treebank.words()).most_common(vocab_size)] # generate grammar rules list, extracted from treebank. and calculate their probablity based their frequency tbank_productions = set(production for tree in tbank_trees for production in tree.productions()) tbank_grammar = CFG(Nonterminal('S'), list(tbank_productions)) production_rules = tbank_grammar.productions() rules_to_prob = defaultdict(int) nonterm_occurrence = defaultdict(int) #calculate probablity for rules for sent in tbank_trees: for production in sent.productions(): if len(production.rhs()) == 1 and not isinstance(production.rhs()[0], Nonterminal): production = Production(production.lhs(), [production.rhs()[0].lower()]) nonterm_occurrence[production.lhs()] += 1 rules_to_prob[production] += 1 for rule in rules_to_prob: rules_to_prob[rule] /= nonterm_occurrence[rule.lhs()] # use Katz smoothing rules_to_prob, vocab = katz_smooth(rules_to_prob, vocab) rules = list(rules_to_prob.keys()) rules_reverse_dict = dict((j,i) for i, j in enumerate(rules)) left_rules = defaultdict(set) right_rules = defaultdict(set) unary_rules = defaultdict(set) # classify left, right rules for rule in rules: if len(rule.rhs()) > 1: left_rules[rule.rhs()[0]].add(rule) right_rules[rule.rhs()[1]].add(rule) else: unary_rules[rule.rhs()[0]].add(rule) terminal_nonterms_rules = set(rule for rule in rules_to_prob if len(rule.rhs()) == 1 and isinstance(rule.rhs()[0], str)) terminal_nonterms = defaultdict(int) for rule in terminal_nonterms_rules: terminal_nonterms[rule.lhs()] += 1 pcfg_parser = { 'vocab': vocab, 'left_rules': left_rules, 'right_rules': right_rules, 'unary_rules': unary_rules, 'rules_to_prob': rules_to_prob, 'terminal_nonterms': terminal_nonterms } return pcfg_parser
def _remove_empty_productions(input_productions, letters): """Remove productions with empty right hand sides.""" copied_prods = deepcopy(input_productions) # # Find all nonterminals that generate the emptry string. # # Basis: A nonterminal generates the empty string if it is the LHS of a # production thats RHS is empty. gen_empty = [prod.lhs() for prod in copied_prods if len(prod.rhs()) == 0] N = len(gen_empty) # Induction: while True: for nonterm in gen_empty: for prod in copied_prods: if nonterm in prod.rhs(): better = list(prod.rhs()) better.remove(nonterm) prod._rhs = tuple(better) gen_empty[:] = [prod.lhs() for prod in copied_prods if len(prod.rhs()) == 0] new_len = len(gen_empty) if new_len == N: break N = new_len print 'gen_empty', gen_empty # ADD NEW RULES new_prods = [] productions = deepcopy(input_productions) for nonterm in gen_empty: prods = [prod for prod in productions if len(prod.rhs()) == 2 and nonterm in prod.rhs()] for prod in prods: rhs = list(prod.rhs()) while nonterm in rhs: lhs = prod.lhs() rhs.remove(nonterm) p = Production(lhs, tuple(rhs)) new_prods.append(p) productions += new_prods productions[:] = [p for p in productions if p.rhs()] return productions