def __init__(self, grammar=grammar_zinc_new, checks=False): # self.mask_gen = get_mask_gen() # self.mask_gen.do_terminal_mask = False self.term_dist = {} self.d_term_dist = {} self.grammar = grammar self.GCFG = self.grammar.GCFG self.checks = checks for p in self.GCFG.productions(): for s in p.rhs(): if is_terminal(s): # terminals have term distance 0 self.term_dist[frozendict({'token': s})] = 0 self.term_dist[frozendict({'token': Nonterminal('None')})] = 0 # seed the search with the root symbol self.term_dist[frozendict({'token': Nonterminal('smiles')})] = float('inf') while True: # iterate to convergence # print('*** and one more pass... ***') last_term_dist = copy.copy(self.term_dist) for sym in last_term_dist.keys(): if is_terminal(sym['token']): self.term_dist[sym] = 0 if self.term_dist[sym] > 0: mask = self.get_mask_from_token(sym) # [p for ip, p in enumerate(self.GCFG.productions()) if mask[ip]] if self.checks: assert (not all([x == 0 for x in mask])) for ip, p in enumerate(self.GCFG.productions()): if mask[ip]: # print('trying', sym, p) this_exp = apply_rule([sym], 0, p, None, self.checks) this_term_dist = 1 for this_sym in this_exp: if frozendict(this_sym) not in self.term_dist: self.term_dist[frozendict(this_sym)] = float('inf') print('added ', this_sym, 'from', sym, 'via', p) # if 'ring_size' in sym and sym['ring_size'] > 6: # print('aaa') this_term_dist += self.term_dist[frozendict(this_sym)] if this_term_dist < self.term_dist[frozendict(sym)]: # if 'ring_size' in sym and sym['ring_size'] > 6: # print('aaa') print('improving:', p, self.term_dist[frozendict(sym)], this_term_dist, [self.term_dist[frozendict(this_sym)] for this_sym in this_exp]) self.term_dist[frozendict(sym)] = this_term_dist if last_term_dist == self.term_dist: break
def format_grammar(task, primitives, encoders=[]): grammar = create_completegrammar(primitives) grammar = create_taskgrammar(grammar, task, encoders) formatted_grammar = {'NON_TERMINALS': {}, 'TERMINALS': {}, 'RULES': {}, 'RULES_LOOKUP': {}} formatted_grammar['START'] = grammar.start().symbol() terminals = [] logger.info('Formating grammar to style of pipeline game') for production in grammar.productions(): non_terminal = production.lhs().symbol() production_str = str(production).replace('\'', '') formatted_grammar['RULES'][production_str] = len(formatted_grammar['RULES']) + 1 if non_terminal not in formatted_grammar['NON_TERMINALS']: formatted_grammar['NON_TERMINALS'][non_terminal] = len(formatted_grammar['NON_TERMINALS']) + 1 if non_terminal not in formatted_grammar['RULES_LOOKUP']: formatted_grammar['RULES_LOOKUP'][non_terminal] = [] formatted_grammar['RULES_LOOKUP'][non_terminal].append(production_str) for token in production.rhs(): if is_terminal(token) and token != 'E' and token not in terminals: terminals.append(token) formatted_grammar['TERMINALS'] = {t: i+len(formatted_grammar['NON_TERMINALS']) for i, t in enumerate(terminals, 1)} formatted_grammar['TERMINALS']['E'] = 0 # Special case for the empty symbol return formatted_grammar
def main(args): sentence = args.sentence.lower() args.sentence = sentence tokens = sentence.split() grammar = loadGrammar(args) nonterm = getnonterm(grammar) terminalProductionRules = getTerminalProbability(args, grammar, nonterm) HSrules = grammar.productions(Nonterminal('HS')) for rule in HSrules: grammar.productions().remove(rule) ESrules = grammar.productions(Nonterminal('ES')) for rule in ESrules: grammar.productions().remove(rule) grammar.productions().extend(terminalProductionRules) for token in tokens: grammar.productions().append( ProbabilisticProduction(Nonterminal(token.upper()), [unicode(token)], prob=1)) #print "Grammars" grammarlist = str(grammar).split('\n')[1:] #print "Transfered" strgrammar = '' for p in grammar.productions(): rhs = p.rhs() rhsstr = '' for r in rhs: if is_terminal(r): rhsstr += '\'' + str(r) + '\' ' else: rhsstr += str(r) + ' ' strgrammar += str(p.lhs()) + ' -> ' + rhsstr + ' [' + '{0:.8f}'.format( p.prob()) + ']\n' #print strgrammar grammar = PCFG.fromstring(strgrammar.split('\n')) #''' #grammar = loadGrammar(args) #tokens = args.sentence.lower().split() #nonterm = getnonterm(grammar) CYK(tokens, nonterm, grammar) #with open(args.grammar_file, 'r') as f: # content = f.read() #trees = corpus2trees(content) #productions = trees2productions(trees) #listnonterm = [] #grammar = nltk.grammar.induce_pcfg(nltk.grammar.Nonterminal('SS'), productions) #print grammar #''' '''
def convert_hybrid(grammar): ''' Convert rules in the form of [A -> 'b' C] where the rhs has both non-terminals and terminals into rules in the form of [A -> B C] & [B -> 'b'] with a dummy non-terminal B ''' rules = grammar.productions() new_rules = [] for rule in rules: lhs = rule.lhs() rhs = rule.rhs() # check for hybrid rules if rule.is_lexical() and len(rhs) > 1: new_rhs = [] for item in rule.rhs(): if is_terminal(item): new_sym = Nonterminal(item) new_rhs.append(new_sym) # add new lexical rule with dummy lhs nonterminal new_rules.append(Production(new_sym, (item, ))) else: new_rhs.append(item) # add converted mixed rule with only non-terminals on rhs new_rules.append(Production(lhs, tuple(new_rhs))) else: new_rules.append(rule) new_grammar = CFG(grammar.start(), new_rules) return new_grammar
def apply(self, chart, grammar, edge): if edge.is_complete(): return nextsym, index = edge.nextsym(), edge.end() if not is_nonterminal(nextsym): return # If we've already applied this rule to an edge with the same # next & end, and the chart & grammar have not changed, then # just return (no new edges to add). nextsym_with_bindings = edge.next_with_bindings() done = self._done.get((nextsym_with_bindings, index), (None, None)) if done[0] is chart and done[1] is grammar: return for prod in grammar.productions(lhs=nextsym): # If the left corner in the predicted production is # leaf, it must match with the input. if prod.rhs(): first = prod.rhs()[0] if is_terminal(first): if index >= chart.num_leaves(): continue if first != chart.leaf(index): continue # We rename vars here, because we don't want variables # from the two different productions to match. if unify(prod.lhs(), nextsym_with_bindings, rename_vars=True): new_edge = FeatureTreeEdge.from_production(prod, edge.end()) if chart.insert(new_edge, ()): yield new_edge # Record the fact that we've applied this rule. self._done[nextsym_with_bindings, index] = (chart, grammar)
def process_hybrid_productions(productions): new_productions_list = [] # list of new productions to_remove_list = [] # Hybrid production for p in productions: is_hybrid = 0 # flag that indicates if current production is hybrid if len(p.rhs() ) > 1: # more than one symbols are on the right hand side rh_list = [] # new list for right hand symbols for r_symbol in p.rhs(): if is_terminal(r_symbol): # for terminal symbol dummy_symbol = Nonterminal( r_symbol) # create dummy nonterminal new_productions_list.append( Production(dummy_symbol, [r_symbol])) # new unit production rh_list.append(dummy_symbol) is_hybrid = 1 # hybrid production confirmed else: # for nonterminal symbol rh_list.append(r_symbol) if is_hybrid: # need to remove original production and add some productions # in the loop, we won't change the list. Store them first. new_productions_list.append(Production( p.lhs(), rh_list)) # new production with dummy symbol to_remove_list.append(p) return to_remove_list, new_productions_list
def sample(self, max_seq_len): """ Sample a derivation from the grammar """ q = deque() # queue d = {} # derivation i = 0 # lhs index t_count = 0 # number of generated terminals (so far) q.append((Nonterminal(args.start), i)) # append start symbol i += 1 while len(q) > 0: # stop if this sequence is going to be longer than the # requested sequence length (we discard it anyway) if t_count > max_seq_len: return None lhs, lhs_id = q.popleft() # Nonterminal, ID # print("processing: %s (%d)" % (lhs, lhs_id)) if not is_terminal(lhs): # choose a production with u as lhs r = self.sample_production(lhs) # count number of terminals for item in r.irhs(): if is_terminal(item): t_count += 1 # create a production-rhs where the symbols all get a unique ID # so we can reconstruct the derivation later irhs = list(zip(r.irhs(), range(i, i + len(r)))) i += len(r) # add the non-terminals to the queue q.extend(irhs) # save to the derivation d[(lhs, lhs_id)] = tuple([irhs, r.orhs()]) return d
def _get_terminal_symbols(cfg): """ Returns a set of all the terminal symbols used in a nltk context-free grammar. """ terminal_symbols = set() for prod in cfg.productions(): terminal_symbols.update(list(filter(lambda x: grammar.is_terminal(x), prod.rhs()))) return terminal_symbols
def rule_adds_atom(p): atoms = ['c', 'n', 'o', 's', 'f', 'cl', 'br', 'i'] if any([x.lower() in atoms for x in p.rhs() if is_terminal(x)]) or \ any(['valence' in x._symbol for x in p.rhs() if is_nonterminal(x)]): return 1 elif any(['segment' in x._symbol for x in p.rhs() if is_nonterminal(x)]): return 2 else: return 0
def preprocessingGrammar(grammar): parent = defaultdict(list) probdict = {} for prod in grammar.productions(): if is_terminal(prod.rhs()[0]): #print prod.rhs()[0] parent[str(prod.rhs()[0])].append(str(prod.lhs())) probdict[str(prod.lhs()) + ' -> ' + ' '.join( [str(x) for x in prod.rhs() if len(str(x)) > 0])] = prod.prob() return (parent, probdict)
def is_cnf(production): rhs = production.rhs() if len(rhs) == 1: return grammar.is_terminal(rhs[0]) elif len(rhs) == 2: return (grammar.is_nonterminal(rhs[0]) and grammar.is_nonterminal(rhs[1])) else: return False
def get_lhs_terminal(grammar=load(grammar_url)): """ Return a production list of lhs(left hand side) that are terminal :param grammar: :return: """ lhs_list = [] for p in grammar.productions(): if p.lhs() not in lhs_list and is_terminal(p.rhs()[0]): lhs_list.append(p.lhs()) return lhs_list
def _LengthVector(pcfg,character=None): nonterminals = NonTerminalsPCFG(pcfg) if character == None: ruleset_lens = matrix([[sum([r.prob()*len([s for s in r.rhs() if is_terminal(s)])\ for r in pcfg.productions(lhs=nt)])] \ for nt in nonterminals]) else: ruleset_lens = matrix([[sum([r.prob()*len([s for s in r.rhs() if s==character])\ for r in pcfg.productions(lhs=nt)])] \ for nt in nonterminals]) return ruleset_lens
def derivation_to_tree_orhs(self, d, node, node_id): """ Convert a derivation map to an NLTK tree - orhs version """ irhs_with_id, orhs = d[(node, node_id)] children = [] for child_rhs in orhs: if isinstance(child_rhs, int): child, child_id = list( filter(lambda x: not is_terminal(x[0]), irhs_with_id))[child_rhs - 1] children.append( self.derivation_to_tree_orhs(d, child, child_id)) elif is_terminal(child_rhs): children.append(child_rhs) t = Tree(node, children) return t
def is_transition(p): """ Checks to see if a Production object is a transition. A transition is a Production in which the right-hand side must begin with a terminal. See BUTA for the interpretation of a transition. :type p: gr.Production :param p: A production :rtype: bool :return: True if p is a transition, False otherwise """ check_type(p, gr.Production) return len(p.rhs()) > 0 and gr.is_terminal(p.rhs()[0])
def get_cfg(filename): global N, sigma, R my_grammar = nltk.data.load(filename) for p in my_grammar.productions(): N.append(p.lhs().symbol()) #add nonterminal for element in p.rhs(): if grammar.is_terminal(element): sigma.append(element) #add rhs terminal(s) else: N.append(element.symbol()) #add rhs nonterminal(s) R.append(p) #add rule #remove duplicates N = list(set(N)) sigma = list(set(sigma))
def remove_rhs_terminals(production): rhs = production.rhs() if len(rhs) > 1: new_rhs = () for element in rhs: if grammar.is_terminal(element): #create dummy nonterminal new_nt = grammar.Nonterminal(create_nonterminal()) new_rhs = new_rhs + (new_nt, ) #define dummy nonterminal R.append(grammar.Production(new_nt, (element, ))) else: new_rhs = new_rhs + (element, ) #replace rule R[R.index(production)] = grammar.Production(production.lhs(), new_rhs)
def find_free_numerals(S, this_index, grammar, reuse_numerals=True): # collect all the un-paired numeral terminals before current token used_tokens = set() for j in range(this_index): # up to, and excluding, this_index current_token = S[j]['token'] assert(is_terminal(current_token)) # we assume the token we want to expand now is the leftmost nontermonal # the second check is to exclude numerals that describe charge if current_token in grammar.numeric_tokens and 'is_cycle_numeral' in S[j]: if current_token in used_tokens and reuse_numerals: #this cycle has been closed, can reuse the numeral used_tokens.remove(current_token) else: used_tokens.add(current_token) # find the first unused numeral free_numerals = [nt for nt in grammar.numeric_tokens if not nt in used_tokens] if not free_numerals: raise ValueError("Too many nested cycles - can't find a valid numeral") else: return free_numerals
def derivation_to_tree(self, d, node, node_id): """ Convert a derivation map to an NLTK tree """ irhs_with_id, orhs = d[(node, node_id)] children = [] for x in irhs_with_id: child, child_id = x # print("process_der:", child, child_id) if is_terminal(child): children.append(child) else: children.append(self.derivation_to_tree(d, child, child_id)) t = Tree(node, children) return t
def _get_code_for(self, null): """ Creates an encoding of a CFG's terminal symbols as numbers. :type null: unicode :param null: A string representing "null" :rtype: dict :return: A dict associating each terminal of the grammar with a unique number. The highest number represents "null" """ rhss = [r.rhs() for r in self.grammar.productions()] rhs_symbols = set() rhs_symbols.update(*rhss) rhs_symbols = set(x for x in rhs_symbols if gr.is_terminal(x)) code_for = {x: i for i, x in enumerate(rhs_symbols)} code_for[null] = len(code_for) return code_for
def convert_grammar(cfg_grammar): """ Converts to Chomsky_Normal_form """ if cfg_grammar.is_chomsky_normal_form(): return cfg_grammar # Go through every rule, and do the following conversions: # - remove terminals in non-solitary rules # - break up greater-than-2 rules # Notice that this loop-through will blissfully ignore small productions new_productions = [] for production in cfg_grammar.productions(): rhs_size = len(production) lhs = production.lhs() rhs = production.rhs() if rhs_size < 2: new_productions += [Production(lhs,rhs)] else: # Go through removing terminals term_rules = [] for i in range(0, rhs_size): if is_terminal(rhs[i]): newnonterm = Nonterminal(rhs[i]) term_rules += Production(newnonterm, rhs) rhs[i] = newnonterm new_productions += term_rules # Now break up large groups new_productions += break_large_rhs(lhs, rhs) # Reset for next loop through new_cfg = CFG(cfg_grammar.start(), new_productions) assert(new_cfg.is_binarised()) # Remove empty productions new_cfg = remove_empty_productions(new_cfg) # Go through the rules again, removing non-terminals in solitary rules new_cfg = remove_unitary_productions(new_cfg) assert(new_cfg.is_chomsky_normal_form()) return(new_cfg)