def outputVars(self, head): if head.is_atom: head_atom = head.fact else: head_atom = head.fact.facts[0] head_args = [head_atom.loc] + head_atom.fact.terms output_vars = [] seen = {} for i in range(0,len(self.pred_args)): pred_arg = self.pred_args[i] if isinstance(pred_arg, str) and pred_arg == INPUT: pass elif isinstance(pred_arg, str) and pred_arg == OUTPUT: output_vars.append( head_args[i] ) else: if pred_arg.rule_idx not in seen: output_vars.append( head_args[i] ) seen[pred_arg.rule_idx] = () inspect = Inspector() this_args = [] for assoc_guard in self.assoc_guards: this_args += inspect.free_vars( assoc_guard.term1 ) + inspect.free_vars( assoc_guard.term2 ) for i in range(0,len(self.guard_args)): guard_arg = self.guard_args[i] if isinstance(guard_arg, str) and guard_arg == INPUT: pass elif isinstance(guard_arg, str) and guard_arg == OUTPUT: output_vars.append( this_args[i] ) elif guard_arg.rule_idx not in seen: output_vars.append( this_args[i] ) seen[guard_arg.rule_idx] = () return output_vars
def __init__(self, head, head_idx): self.initialize(head.fact) self.head_idx = head_idx self.fact = head.fact.facts[0] self.fact_pat = head.fact_pat self.head = head self.compre = head.fact inspect = Inspector() self.output_vars = inspect.free_vars( self.fact ) self.compre_binders = inspect.free_vars( head.fact, loc=False, args=False, compre_binders=True )
def __init__(self, head, head_idx): self.initialize(head.fact) self.head_idx = head_idx self.fact = head.fact.facts[0] self.fact_pat = head.fact_pat self.head = head self.compre = head.fact inspect = Inspector() self.output_vars = inspect.free_vars(self.fact)
def __init__(self, head, head_idx): self.initialize(head.fact) self.head_idx = head_idx self.fact = head.fact.facts[0] self.fact_pat = head.fact_pat self.head = head self.compre = head.fact inspect = Inspector() self.output_vars = inspect.free_vars( self.fact )
class LHSCompre(Transformer): def __init__(self, decs): self.initialize(decs) self.inspect = Inspector() def transform(self): ensem_dec = self.inspect.filter_decs(self.decs, ensem=True)[0] rule_decs = self.inspect.filter_decs(ensem_dec.decs, rule=True) map( lambda rule_dec: self.transformRule(rule_dec), rule_decs) def transformRule(self, rule_dec): _,atoms,_,compres = self.inspect.partition_rule_heads( rule_dec.plhs + rule_dec.slhs ) # scope_vars = self.getVarsFilterByRuleIdx(atoms) map(lambda c: self.transformCompre(c, rule_dec), compres) def transformCompre(self, compre, rule_dec): if len(compre.comp_ranges) == 0 and compre.compre_mod in [ast.COMP_NONE_EXISTS,ast.COMP_ONE_OR_MORE]: term_vars = ast.TermLit(1,"int") next_rule_idx = rule_dec.next_rule_idx rule_dec.next_rule_idx += 1 term_range = ast.TermVar( "comp_mod_%s" % next_rule_idx ) term_range.rule_idx = next_rule_idx term_range.type = ast.TypeMSet( term_vars.type ) term_range.smt_type = tyMSet( tyInt ) compre.comp_ranges = [ ast.CompRange(term_vars, term_range) ] if compre.compre_mod == ast.COMP_NONE_EXISTS: term_range = compre.comp_ranges[0].term_range comp_none_exist_grd = ast.TermBinOp( ast.TermApp(ast.TermCons("size"), term_range), "==", ast.TermLit(0, "int")) rule_dec.grd += [comp_none_exist_grd] elif compre.compre_mod == ast.COMP_ONE_OR_MORE: term_range = compre.comp_ranges[0].term_range comp_one_or_more_grd = ast.TermBinOp( ast.TermApp(ast.TermCons("size"), term_range), ">", ast.TermLit(0, "int")) rule_dec.grd += [comp_one_or_more_grd] def getVarsFilterByRuleIdx(self, facts): fact_vars = self.getVars(facts) uniq_vars = [] idx_dict = {} for fact_var in fact_vars: if fact_var.rule_idx not in idx_dict: idx_dict[fact_var.rule_idx] = () uniq_vars.append( fact_var ) return uniq_vars @visit.on( 'fact' ) def getVars(self, fact): pass @visit.when( list ) def getVars(self, fact): return foldl(map(lambda f: self.getVars(f), fact), []) @visit.when( ast.FactLoc ) def getVars(self, fact): return self.inspect.free_vars(fact)
def __init__(self, compre_body, body_idx, compre_idx): inspect = Inspector() self.body = compre_body self.fact = compre_body.fact self.local = compre_body.local self.priority = compre_body.priority self.monotone = compre_body.monotone self.term_vars = inspect.free_vars(compre_body.fact.comp_ranges[0].term_vars) self.compre_dom = compre_body.fact.comp_ranges[0].term_range self.body_idx = body_idx self.compre_idx = compre_idx
def __init__(self, compre_body, body_idx, compre_idx): inspect = Inspector() self.body = compre_body self.fact = compre_body.fact self.local = compre_body.local self.priority = compre_body.priority self.monotone = compre_body.monotone self.term_vars = inspect.free_vars( compre_body.fact.comp_ranges[0].term_vars) self.compre_dom = compre_body.fact.comp_ranges[0].term_range self.body_idx = body_idx self.compre_idx = compre_idx
def __init__(self, head, head_idx, lookup, var_gen): self.initialize() inspect = Inspector() self.term_vars = inspect.free_vars(head.fact.comp_ranges[0].term_vars) self.compre_dom = head.fact.comp_ranges[0].term_range self.lookup = lookup self.input_vars = lookup.inputVars(head) # output_vars,dep_grds = lookup.outputVarsModuloDependencies(head, var_gen) self.output_vars = lookup.outputVars(head) # self.dep_grds = dep_grds self.head_idx = head_idx self.head = head self.fact = head.fact self.fact_pat = head.fact_pat
def __init__(self, fact_dir, loc_fact, mem_guard, var_ctxt): pred_idx,_ = fact_dir.getFactFromName( loc_fact.fact.name ) inspect = Inspector() fact_vars = inspect.free_var_idxs(loc_fact.loc) | inspect.free_var_idxs(loc_fact.fact) mem_vars = inspect.free_var_idxs(mem_guard.term1) common_vars = fact_vars & mem_vars self.degree_join = len(common_vars) degree_freedom = 0 # Existing dependencies from variable context self.var_dependencies = (fact_vars | mem_vars) & var_ctxt has_hash_index = False ppred_args = [] for pred_arg in [loc_fact.loc] + loc_fact.fact.terms: if pred_arg.rule_idx in self.var_dependencies: ppred_args.append( INPUT ) has_hash_index = True elif pred_arg.rule_idx in common_vars: ppred_args.append( pred_arg ) else: ppred_args.append( OUTPUT ) degree_freedom += 1 guard_args = [] for mem_arg in inspect.free_vars(mem_guard.term1): if pred_arg.rule_idx in self.var_dependencies: ppred_args.append( INPUT ) has_hash_index = True elif mem_arg.rule_idx in common_vars: guard_args.append( mem_arg ) else: guard_args.append( OUTPUT ) degree_freedom += 1 guard_args.append( INPUT ) guard_str = '(%s) in %s' % (','.join( map(lambda _: '%s',range(0,len(guard_args)-1)) ),'%s') self.degree_freedom = degree_freedom if has_hash_index: lk_name = "hash+mem" else: lk_name = "mem" self.initialize(MEM_LK, pred_idx, fact_dir, lk_name, pred_args=ppred_args, guard_args=guard_args, guard_str=guard_str ,assoc_guards=[mem_guard])
def check_int(self, ast_node): comp_ranges = ast_node.comp_ranges if len(comp_ranges) > 1: error_idx = self.declare_error("Unsupported LHS Pattern: Comprehension pattern with multiple comprehension range.") for comp_range in comp_ranges: self.extend_error(error_idx, comp_range) else: for comp_range in comp_ranges: self.check_int(comp_range) inspect = Inspector() fact_bases = inspect.get_base_facts( ast_node.facts ) if len(fact_bases) != 1: error_idx = self.declare_error("Unsupported LHS Pattern: Comprehension pattern with multiple fact patterns.") for f in ast_node.facts: self.extend_error(error_idx, f) else: self.check_int( ast_node.facts[0] ) if (len(ast_node.facts) == 1) and (len(comp_ranges) == 1): loc = ast_node.facts[0].loc if loc.name in map(lambda v: v.name,inspect.free_vars( comp_ranges[0].term_vars )): error_idx = self.declare_error("Unsupported LHS Pattern: Multi-location comprehension patterns.") self.extend_error(error_idx, loc)
def int_check(self, ast_node): inspect = Inspector() decs = ast_node.decs simplified_pred_names = {} non_local_pred_names = {} lhs_compre_pred_names = {} prioritized_pred_names = {} for rule_dec in inspect.filter_decs(decs, rule=True): rule_head_locs = {} simp_heads = rule_dec.slhs prop_heads = rule_dec.plhs rule_body = rule_dec.rhs # Scan for simplified predicate names for fact in inspect.get_base_facts(simp_heads): simplified_pred_names[fact.name] = () # Scan for non local predicate names # Annotates non local rule body facts as well. loc_var_terms = inspect.free_vars(simp_heads + prop_heads, args=False) loc_vars = map(lambda t: t.name, loc_var_terms) if len(set(loc_vars)) > 1: # Flag all body predicates as non local for fact in inspect.get_base_facts(rule_body): non_local_pred_names[fact.name] = () fact.local = False else: loc_var = loc_vars[0] (bfs, lfs, lfcs, comps) = inspect.partition_rule_heads(rule_body) for lf in lfs: if isinstance(lf.loc, ast.TermVar): if lf.loc.name != loc_var: non_local_pred_names[lf.fact.name] = () lf.fact.local = False else: # Location is not variable, hence treat as non-local non_local_pred_names[lf.fact.name] = () lf.fact.local = False for lfc in lfcs: if isinstance(lfc.loc, ast.TermVar): if lfc.loc.name != loc_var: for f in lfc.facts: non_local_pred_names[f.name] = () f.local = False else: # Location is not variable, hence treat as non-local for f in lfc.facts: non_local_pred_names[f.name] = () f.local = False for comp in comps: # Assumes that comprehension fact patterns are solo loc_fact = comp.facts[0] if loc_fact.loc.name != loc_var: non_local_pred_names[loc_fact.loc.name] = () loc_fact.fact.local = False else: if loc_var in map( lambda tv: tv.name, inspect.free_vars( comp.comp_ranges[0].term_vars)): non_local_pred_name[loc_fact.loc.name] = () loc_fact.fact.local = False # Scan for LHS comprehension predicate names (bfs, lfs, lfcs, comps) = inspect.partition_rule_heads(simp_heads + prop_heads) for comp in comps: loc_fact = comp.facts[0] lhs_compre_pred_names[loc_fact.fact.name] = () # Scan for non-unique rule heads rule_head_pred_names = {} for fact in inspect.get_base_facts(simp_heads + prop_heads): if fact.name not in rule_head_pred_names: rule_head_pred_names[fact.name] = [fact] else: rule_head_pred_names[fact.name].append(fact) self.rule_unique_heads[rule_dec.name] = [] collision_idx = 0 for name in rule_head_pred_names: facts = rule_head_pred_names[name] unique_head = len(facts) == 1 for fact in facts: fact.unique_head = unique_head fact.collision_idx = collision_idx collision_idx += 1 if unique_head: self.rule_unique_heads[rule_dec.name].append(name) # Scan for priorities self.rule_priority_body[rule_dec.name] = {} (bfs, lfs, lfcs, comps) = inspect.partition_rule_heads(rule_body) for bf in bfs: if bf.priority != None: prioritized_pred_names[bf.name] = () self.rule_priority_body[rule_dec.name][bf.name] = () for lf in lfs: if lf.priority != None: prioritized_pred_names[lf.fact.name] = () self.rule_priority_body[rule_dec.name][lf.fact.name] = () for lfc in lfcs: if lfc.priority != None: for f in lfc.facts: prioritized_pred_names[f.name] = () self.rule_priority_body[rule_dec.name][f.name] = () for comp in comps: if comp.priority != None: for f in comp.facts: prioritized_pred_names[f.name] = () self.rule_priority_body[rule_dec.name][f.name] = () # Annotate fact declaration nodes with relevant information fact_decs = inspect.filter_decs(decs, fact=True) for fact_dec in fact_decs: fact_dec.persistent = fact_dec.name not in simplified_pred_names fact_dec.local = fact_dec.name not in non_local_pred_names fact_dec.monotone = fact_dec.name not in lhs_compre_pred_names fact_dec.uses_priority = fact_dec.name in prioritized_pred_names self.fact_decs = fact_decs # Annotate rule declaration nodes with relevant information rule_decs = inspect.filter_decs(decs, rule=True) for rule_dec in rule_decs: rule_dec.unique_head_names = self.rule_unique_heads[rule_dec.name] rule_dec.rule_priority_body_names = self.rule_priority_body[ rule_dec.name].keys() # Annotate RHS constraints with monotonicity information for rule_dec in rule_decs: rule_body = rule_dec.rhs for fact in inspect.get_base_facts(rule_body): fact.monotone = fact.name not in lhs_compre_pred_names
def initialize(self, ast_node=None): self.id = get_next_matchtask_index() if ast_node != None: inspect = Inspector() self.free_vars = inspect.free_vars(ast_node)
def initialize(self, ast_node=None): self.id = get_next_matchtask_index() if ast_node != None: inspect = Inspector() self.free_vars = inspect.free_vars( ast_node )
class VarScopeChecker(Checker): def __init__(self, decs, source_text): self.inspect = Inspector() self.initialize(decs, source_text) self.curr_out_scopes = new_ctxt() self.curr_duplicates = { 'vars':{} } self.ensems = {} # Main checking operation def check(self): inspect = self.inspect ctxt = new_ctxt() # Check scoping of extern name declarations for extern in inspect.filter_decs(self.decs, extern=True): self.check_scope(extern, ctxt) for ensem_dec in inspect.filter_decs(self.decs, ensem=True): self.check_scope(ensem_dec, ctxt) for exec_dec in inspect.filter_decs(self.decs, execute=True): self.check_scope(exec_dec, ctxt) @visit.on('ast_node') def check_scope(self, ast_node, ctxt, lhs=False): pass @visit.when(ast.EnsemDec) def check_scope(self, ast_node, ctxt, lhs=False): old_ctxt = ctxt ctxt = copy_ctxt(ctxt) decs = ast_node.decs inspect = Inspector() dec_preds = {} # Check scoping of extern name declarations for extern in inspect.filter_decs(decs, extern=True): self.check_scope(extern, ctxt) # Check scoping of predicate declarations for dec in inspect.filter_decs(decs, fact=True): local_ctxt = self.check_scope(dec, ctxt) for pred in local_ctxt['preds']: if pred.name in dec_preds: dec_preds[pred.name].append(pred) else: dec_preds[pred.name] = [pred] extend_ctxt(ctxt, local_ctxt) self.compose_duplicate_error_reports("predicate", dec_preds) # Check scoping of rule declarations for dec in inspect.filter_decs(decs, rule=True): self.check_scope(dec, ctxt) ctxt['vars'] = [] self.ensems[ast_node.name] = ctxt return old_ctxt @visit.when(ast.ExecDec) def check_scope(self, ast_node, ctxt, lhs=False): if ast_node.name not in self.ensems: self.curr_out_scopes['ensem'].append( ast_node ) self.compose_out_scope_error_report(ctxt) else: ctxt = self.ensems[ast_node.name] for dec in ast_node.decs: self.check_scope(dec, ctxt) self.compose_duplicate_error_reports("variables", self.curr_duplicates['vars']) self.curr_duplicates['vars'] = {} self.compose_out_scope_error_report(ctxt) @visit.when(ast.ExistDec) def check_scope(self, ast_node, ctxt, lhs=False): for tvar in ast_node.exist_vars: if not lookup_var(ctxt, tvar): ctxt['vars'] += [tvar] else: self.record_duplicates(tvar, ctxt) @visit.when(ast.LocFactDec) def check_scope(self, ast_node, ctxt, lhs=False): for loc_fact in ast_node.loc_facts: self.check_scope(loc_fact, ctxt) @visit.when(ast.ExternDec) def check_scope(self, ast_node, ctxt, lhs=False): for type_sig in ast_node.type_sigs: self.check_scope(type_sig, ctxt) @visit.when(ast.ExternTypeSig) def check_scope(self, ast_node, ctxt, lhs=False): ctxt['cons'].append( ast_node ) @visit.when(ast.FactDec) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = new_ctxt() ctxt['preds'].append( ast_node ) # TODO: check types return ctxt @visit.when(ast.RuleDec) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) heads = ast_node.slhs + ast_node.plhs inspect = self.inspect # Extend location context with all rule head variables ''' for fact in map(inspect.get_fact, heads): terms = inspect.get_atoms( [inspect.get_loc(fact)] + inspect.get_args(fact) ) ctxt['vars'] += inspect.filter_atoms(terms, var=True) ''' ctxt['vars'] += self.get_rule_scope( heads, compre=False ) # Check scope of rule heads. This step checks consistency of constant names and # scoping of comprehension patterns. map(lambda h: self.check_scope(h, ctxt, lhs=True) , heads) map(lambda g: self.check_scope(g, ctxt, lhs=True) , ast_node.grd) ctxt['vars'] += self.get_rule_scope( heads, atoms=False) # Include exist variables scopes and check for overlaps with existing variables. # (We currently disallow them.) dup_vars = {} for v in ctxt['vars']: dup_vars[v.name] = [v] for ex_var in ast_node.exists: if ex_var.name in dup_vars: dup_vars[ex_var.name].append( ex_var ) else: dup_vars[ex_var.name] = [ex_var] ctxt['vars'] += ast_node.exists # Incremental include where assign statements for ass_stmt in ast_node.where: self.check_scope(ass_stmt.builtin_exp, ctxt) self.compose_out_scope_error_report(ctxt) a_vars = inspect.filter_atoms( inspect.get_atoms(ass_stmt.term_pat), var=True) for a_var in a_vars: if a_var.name in dup_vars: dup_vars[a_var.name].append( a_var ) else: dup_vars[a_var.name] = [a_var] ctxt['vars'] += a_vars self.compose_duplicate_error_reports("variables", dup_vars) map(lambda b: self.check_scope(b, ctxt) , ast_node.rhs) ''' for fact in map(inspect.get_fact, ast_node.rhs), fact_atoms=True ): loc = inspect.get_loc(fact) loc_key = loc.compare_value() args = inspect.get_args(fact) atoms = inspect.get_atoms(args) arg_map[loc_key] += map(lambda t: t.name,inspect.filter_atoms(atoms, var=True)) ''' self.compose_out_scope_error_report(ctxt) ''' @visit.when(ast.SetComprehension) def check_scope(self, ast_node, ctxt): inspect = Inspector() ctxt = copy_ctxt(ctxt) self.check_scope(ast_node.term_subj, ctxt) pat_vars = inspect.filter_atoms( inspect.get_atoms(ast_node.term_pat), var=True) ctxt['vars'] += pat_vars map(lambda fact: self.check_scope(fact, ctxt), ast_node.facts) self.compose_out_scope_error_report(ctxt) return ctxt ''' @visit.when(ast.FactBase) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) self.check_pred(ctxt, ast_node) # print ast_node map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.FactLoc) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) self.check_scope(ast_node.loc, ctxt) self.check_scope(ast_node.fact, ctxt) return ctxt @visit.when(ast.FactLocCluster) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) self.check_scope(ast_node.loc, ctxt) for fact in ast_node.facts: self.check_scope(fact, ctxt) return ctxt @visit.when(ast.FactCompre) def check_scope(self, ast_node, old_ctxt, lhs=False): ctxt = copy_ctxt(old_ctxt) comp_ranges = ast_node.comp_ranges # Check scope of comprehension ranges if not lhs: map(lambda comp_range: self.check_scope(comp_range, ctxt), comp_ranges) self.compose_out_scope_error_report(ctxt) # Extend variable context with comprehension binders ctxt['vars'] += self.inspect.free_vars( map(lambda cr: cr.term_vars, comp_ranges) ) # With extended variable context, check scopes of the fact pattern and guards for fact in ast_node.facts: self.check_scope( fact, ctxt ) for guard in ast_node.guards: self.check_scope( guard, ctxt ) self.compose_out_scope_error_report(ctxt) if lhs: old_ctxt['vars'] += self.inspect.free_vars( map(lambda cr: cr.term_range, comp_ranges) ) @visit.when(ast.CompRange) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope( ast_node.term_range, ctxt ) @visit.when(ast.TermCons) def check_scope(self, ast_node, ctxt, lhs=False): self.check_cons(ctxt, ast_node) return ctxt @visit.when(ast.TermVar) def check_scope(self, ast_node, ctxt, lhs=False): self.check_var(ctxt, ast_node) return ctxt @visit.when(ast.TermApp) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term1, ctxt) self.check_scope(ast_node.term2, ctxt) return ctxt @visit.when(ast.TermTuple) def check_scope(self, ast_node, ctxt, lhs=False): map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.TermList) def check_scope(self, ast_node, ctxt, lhs=False): map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.TermListCons) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term1, ctxt) self.check_scope(ast_node.term2, ctxt) return ctxt @visit.when(ast.TermMSet) def check_scope(self, ast_node, ctxt, lhs=False): map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.TermBinOp) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term1, ctxt) self.check_scope(ast_node.term2, ctxt) return ctxt @visit.when(ast.TermUnaryOp) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term, ctxt) return ctxt @visit.when(ast.TermLit) def check_scope(self, ast_node, ctxt, lhs=False): return ctxt @visit.when(ast.TermUnderscore) def check_scope(self, ast_node, ctxt, lhs=False): return ctxt # Error state operations def flush_error_ctxt(self): self.curr_out_scopes = new_ctxt() self.curr_duplicates = { 'vars':{} } def check_var(self, ctxt, var): if not lookup_var(ctxt, var): self.curr_out_scopes['vars'].append( var ) return False else: return True def check_pred(self, ctxt, pred): if not lookup_pred(ctxt, pred): self.curr_out_scopes['preds'].append( pred ) return False else: return True def check_cons(self, ctxt, cons): if not lookup_cons(ctxt, cons): self.curr_out_scopes['cons'].append( cons ) return False else: return True # Get rule scope @visit.on('ast_node') def get_rule_scope(self, ast_node, atoms=True, compre=True): pass @visit.when(list) def get_rule_scope(self, ast_node, atoms=True, compre=True): this_free_vars = [] for obj in ast_node: this_free_vars += self.get_rule_scope(obj, atoms=atoms, compre=compre) return this_free_vars @visit.when(ast.FactBase) def get_rule_scope(self, ast_node, atoms=True, compre=True): if atoms: return self.inspect.free_vars(ast_node) else: return [] @visit.when(ast.FactLoc) def get_rule_scope(self, ast_node, atoms=True, compre=True): if atoms: return self.inspect.free_vars(ast_node) else: return [] @visit.when(ast.FactLocCluster) def get_rule_scope(self, ast_node, atoms=True, compre=True): if atoms: return self.inspect.free_vars(ast_node) else: return [] @visit.when(ast.FactCompre) def get_rule_scope(self, ast_node, atoms=True, compre=True): if compre: comp_ranges = ast_node.comp_ranges if len(comp_ranges) == 1: return [comp_ranges[0].term_range] else: return [] else: return [] # Reporting def record_duplicates(self, tvar, ctxt): if tvar.name not in self.curr_duplicates['vars']: dups = [] for t in ctxt['vars']: if tvar.name == t.name: dups.append( t ) dups.append( tvar ) self.curr_duplicates['vars'][tvar.name] = dups else: self.curr_duplicates['vars'][tvar.name].append( tvar ) def compose_out_scope_error_report(self, ctxt): err = self.curr_out_scopes if len(err['vars']) > 0: legend = ("%s %s: Scope context variable(s).\n" % (terminal.T_GREEN_BACK,terminal.T_NORM)) + ("%s %s: Out of scope variable(s)." % (terminal.T_RED_BACK,terminal.T_NORM)) error_idx = self.declare_error("Variable(s) %s not in scope." % (','.join(set(map(lambda t: t.name,err['vars'])))), legend) map(lambda t: self.extend_error(error_idx,t), err['vars']) map(lambda t: self.extend_info(error_idx,t), ctxt['vars']) if len(err['preds']) > 0: legend = ("%s %s: Scope context predicate(s).\n" % (terminal.T_GREEN_BACK,terminal.T_NORM)) + ("%s %s: Out of scope predicate(s)." % (terminal.T_RED_BACK,terminal.T_NORM)) error_idx = self.declare_error("Predicate(s) %s not in scope." % (','.join(set(map(lambda t: t.name,err['preds'])))), legend) map(lambda t: self.extend_error(error_idx,t), err['preds']) map(lambda t: self.extend_info(error_idx,t), ctxt['preds']) if len(err['cons']) > 0: legend = ("%s %s: Scope context name(s).\n" % (terminal.T_GREEN_BACK,terminal.T_NORM)) + ("%s %s: Out of scope name(s)." % (terminal.T_RED_BACK,terminal.T_NORM)) error_idx = self.declare_error("Name(s) %s not in scope." % (','.join(set(map(lambda t: t.name,err['cons'])))), legend) map(lambda t: self.extend_error(error_idx,t), err['cons']) map(lambda t: self.extend_info(error_idx,t), ctxt['cons']) if len(err['ensem']) > 0: for exec_node in err['ensem']: error_idx = self.declare_error("Ensemble %s not in scope." % exec_node.name) self.extend_error(error_idx, exec_node) self.curr_out_scopes = new_ctxt() def compose_duplicate_error_reports(self, kind, dups): for name in dups: elems = dups[name] if len(elems) > 1: error_idx = self.declare_error("Duplicated declaration of %s %s." % (kind,name)) map(lambda p: self.extend_error(error_idx,p), elems)
def __init__(self, head, head_idx): inspect = Inspector() self.head = head self.head_idx = head_idx self.term_vars = inspect.free_vars(head.fact.comp_ranges[0].term_vars) self.compre_dom = head.fact.comp_ranges[0].term_range
class AlphaIndexer(Transformer): def __init__(self, decs): self.initialize(decs) self.inspect = Inspector() def transform(self): self.int_transform(self.decs) @visit.on('ast_node') def int_transform(self, ast_node, ctxt=None): pass @visit.when(list) def int_transform(self, ast_node, ctxt=None): for node in ast_node: self.int_transform(node, ctxt) @visit.when(ast.EnsemDec) def int_transform(self, ast_node, ctxt=None): rules = self.inspect.filter_decs(ast_node.decs, rule=True) for rule in rules: self.int_transform(rule) @visit.when(ast.RuleDec) def int_transform(self, ast_node, ctxt=None): ctxt = FramedCtxt() self.int_transform(ast_node.plhs, ctxt) self.int_transform(ast_node.slhs, ctxt) self.int_transform(ast_node.grd, ctxt) self.int_transform(ast_node.exists, ctxt) self.int_transform(ast_node.where, ctxt) self.int_transform(ast_node.rhs, ctxt) ast_node.next_rule_idx = ctxt.var_idx @visit.when(ast.AssignDec) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term_pat, ctxt) self.int_transform(ast_node.builtin_exp, ctxt) @visit.when(ast.FactBase) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when(ast.FactLoc) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.loc, ctxt) self.int_transform(ast_node.fact, ctxt) @visit.when(ast.FactLocCluster) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.loc, ctxt) for fact in ast_node.facts: self.int_transform(fact, ctxt) @visit.when(ast.FactCompre) def int_transform(self, ast_node, ctxt=None): for cr in ast_node.comp_ranges: self.int_transform(cr.term_range, ctxt) comp_binders = self.inspect.free_vars( map(lambda cr: cr.term_vars, ast_node.comp_ranges)) ctxt.push_frame(keys=set(map(lambda cb: cb.name, comp_binders))) self.int_transform(comp_binders, ctxt) self.int_transform(ast_node.facts, ctxt) self.int_transform(ast_node.guards, ctxt) ctxt.pop_frame() @visit.when(ast.TermVar) def int_transform(self, ast_node, ctxt=None): ast_node.rule_idx = ctxt.get_index(ast_node.name) @visit.when(ast.TermApp) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term1, ctxt) self.int_transform(ast_node.term2, ctxt) @visit.when(ast.TermTuple) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when(ast.TermList) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when(ast.TermListCons) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term1, ctxt) self.int_transform(ast_node.term2, ctxt) @visit.when(ast.TermMSet) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when(ast.TermCompre) def int_transform(self, ast_node, ctxt=None): for cr in ast_node.comp_ranges: self.int_transform(cr.term_range, ctxt) comp_binders = self.inspect.free_vars( map(lambda cr: cr.term_vars, ast_node.comp_ranges)) ctxt.push_frame(keys=set(map(lambda cb: cb.name, comp_binders))) self.int_transform(comp_binders, ctxt) self.int_transform(ast_node.term, ctxt) self.int_transform(ast_node.guards, ctxt) ctxt.pop_frame() @visit.when(ast.TermBinOp) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term1, ctxt) self.int_transform(ast_node.term2, ctxt) @visit.when(ast.TermUnaryOp) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term, ctxt)
def check_int(self, ast_node): inspect = Inspector() self.rule_free_vars = inspect.free_vars(ast_node.plhs + ast_node.slhs) for fact in ast_node.plhs + ast_node.slhs: self.check_int(fact) self.rule_free_vars = []
class VarScopeChecker(Checker): def __init__(self, decs, source_text): self.inspect = Inspector() self.initialize(decs, source_text) self.curr_out_scopes = new_ctxt() self.curr_duplicates = {'vars': {}} self.ensems = {} # Main checking operation def check(self): inspect = self.inspect ctxt = new_ctxt() # Check scoping of extern name declarations for extern in inspect.filter_decs(self.decs, extern=True): self.check_scope(extern, ctxt) for ensem_dec in inspect.filter_decs(self.decs, ensem=True): self.check_scope(ensem_dec, ctxt) for exec_dec in inspect.filter_decs(self.decs, execute=True): self.check_scope(exec_dec, ctxt) @visit.on('ast_node') def check_scope(self, ast_node, ctxt, lhs=False): pass @visit.when(ast.EnsemDec) def check_scope(self, ast_node, ctxt, lhs=False): old_ctxt = ctxt ctxt = copy_ctxt(ctxt) decs = ast_node.decs inspect = Inspector() dec_preds = {} # Check scoping of extern name declarations for extern in inspect.filter_decs(decs, extern=True): self.check_scope(extern, ctxt) # Check scoping of predicate declarations for dec in inspect.filter_decs(decs, fact=True): local_ctxt = self.check_scope(dec, ctxt) for pred in local_ctxt['preds']: if pred.name in dec_preds: dec_preds[pred.name].append(pred) else: dec_preds[pred.name] = [pred] extend_ctxt(ctxt, local_ctxt) self.compose_duplicate_error_reports("predicate", dec_preds) # Check scoping of rule declarations for dec in inspect.filter_decs(decs, rule=True): self.check_scope(dec, ctxt) ctxt['vars'] = [] self.ensems[ast_node.name] = ctxt return old_ctxt @visit.when(ast.ExecDec) def check_scope(self, ast_node, ctxt, lhs=False): if ast_node.name not in self.ensems: self.curr_out_scopes['ensem'].append(ast_node) self.compose_out_scope_error_report(ctxt) else: ctxt = self.ensems[ast_node.name] for dec in ast_node.decs: self.check_scope(dec, ctxt) self.compose_duplicate_error_reports("variables", self.curr_duplicates['vars']) self.curr_duplicates['vars'] = {} self.compose_out_scope_error_report(ctxt) @visit.when(ast.ExistDec) def check_scope(self, ast_node, ctxt, lhs=False): for tvar in ast_node.exist_vars: if not lookup_var(ctxt, tvar): ctxt['vars'] += [tvar] else: self.record_duplicates(tvar, ctxt) @visit.when(ast.LocFactDec) def check_scope(self, ast_node, ctxt, lhs=False): for loc_fact in ast_node.loc_facts: self.check_scope(loc_fact, ctxt) @visit.when(ast.ExternDec) def check_scope(self, ast_node, ctxt, lhs=False): for type_sig in ast_node.type_sigs: self.check_scope(type_sig, ctxt) @visit.when(ast.ExternTypeSig) def check_scope(self, ast_node, ctxt, lhs=False): ctxt['cons'].append(ast_node) @visit.when(ast.FactDec) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = new_ctxt() ctxt['preds'].append(ast_node) # TODO: check types return ctxt @visit.when(ast.RuleDec) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) heads = ast_node.slhs + ast_node.plhs inspect = self.inspect # Extend location context with all rule head variables ''' for fact in map(inspect.get_fact, heads): terms = inspect.get_atoms( [inspect.get_loc(fact)] + inspect.get_args(fact) ) ctxt['vars'] += inspect.filter_atoms(terms, var=True) ''' ctxt['vars'] += self.get_rule_scope(heads, compre=False) # Check scope of rule heads. This step checks consistency of constant names and # scoping of comprehension patterns. map(lambda h: self.check_scope(h, ctxt, lhs=True), heads) map(lambda g: self.check_scope(g, ctxt, lhs=True), ast_node.grd) ctxt['vars'] += self.get_rule_scope(heads, atoms=False) # Include exist variables scopes and check for overlaps with existing variables. # (We currently disallow them.) dup_vars = {} for v in ctxt['vars']: dup_vars[v.name] = [v] for ex_var in ast_node.exists: if ex_var.name in dup_vars: dup_vars[ex_var.name].append(ex_var) else: dup_vars[ex_var.name] = [ex_var] ctxt['vars'] += ast_node.exists # Incremental include where assign statements for ass_stmt in ast_node.where: self.check_scope(ass_stmt.builtin_exp, ctxt) self.compose_out_scope_error_report(ctxt) a_vars = inspect.filter_atoms(inspect.get_atoms(ass_stmt.term_pat), var=True) for a_var in a_vars: if a_var.name in dup_vars: dup_vars[a_var.name].append(a_var) else: dup_vars[a_var.name] = [a_var] ctxt['vars'] += a_vars self.compose_duplicate_error_reports("variables", dup_vars) map(lambda b: self.check_scope(b, ctxt), ast_node.rhs) ''' for fact in map(inspect.get_fact, ast_node.rhs), fact_atoms=True ): loc = inspect.get_loc(fact) loc_key = loc.compare_value() args = inspect.get_args(fact) atoms = inspect.get_atoms(args) arg_map[loc_key] += map(lambda t: t.name,inspect.filter_atoms(atoms, var=True)) ''' self.compose_out_scope_error_report(ctxt) ''' @visit.when(ast.SetComprehension) def check_scope(self, ast_node, ctxt): inspect = Inspector() ctxt = copy_ctxt(ctxt) self.check_scope(ast_node.term_subj, ctxt) pat_vars = inspect.filter_atoms( inspect.get_atoms(ast_node.term_pat), var=True) ctxt['vars'] += pat_vars map(lambda fact: self.check_scope(fact, ctxt), ast_node.facts) self.compose_out_scope_error_report(ctxt) return ctxt ''' @visit.when(ast.FactBase) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) self.check_pred(ctxt, ast_node) # print ast_node map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.FactLoc) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) self.check_scope(ast_node.loc, ctxt) self.check_scope(ast_node.fact, ctxt) return ctxt @visit.when(ast.FactLocCluster) def check_scope(self, ast_node, ctxt, lhs=False): ctxt = copy_ctxt(ctxt) self.check_scope(ast_node.loc, ctxt) for fact in ast_node.facts: self.check_scope(fact, ctxt) return ctxt @visit.when(ast.FactCompre) def check_scope(self, ast_node, old_ctxt, lhs=False): ctxt = copy_ctxt(old_ctxt) comp_ranges = ast_node.comp_ranges # Check scope of comprehension ranges if not lhs: map(lambda comp_range: self.check_scope(comp_range, ctxt), comp_ranges) self.compose_out_scope_error_report(ctxt) # Extend variable context with comprehension binders ctxt['vars'] += self.inspect.free_vars( map(lambda cr: cr.term_vars, comp_ranges)) # With extended variable context, check scopes of the fact pattern and guards for fact in ast_node.facts: self.check_scope(fact, ctxt) for guard in ast_node.guards: self.check_scope(guard, ctxt) self.compose_out_scope_error_report(ctxt) if lhs: old_ctxt['vars'] += self.inspect.free_vars( map(lambda cr: cr.term_range, comp_ranges)) @visit.when(ast.CompRange) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term_range, ctxt) @visit.when(ast.TermCons) def check_scope(self, ast_node, ctxt, lhs=False): self.check_cons(ctxt, ast_node) return ctxt @visit.when(ast.TermVar) def check_scope(self, ast_node, ctxt, lhs=False): self.check_var(ctxt, ast_node) return ctxt @visit.when(ast.TermApp) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term1, ctxt) self.check_scope(ast_node.term2, ctxt) return ctxt @visit.when(ast.TermTuple) def check_scope(self, ast_node, ctxt, lhs=False): map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.TermList) def check_scope(self, ast_node, ctxt, lhs=False): map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.TermListCons) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term1, ctxt) self.check_scope(ast_node.term2, ctxt) return ctxt @visit.when(ast.TermMSet) def check_scope(self, ast_node, ctxt, lhs=False): map(lambda t: self.check_scope(t, ctxt), ast_node.terms) return ctxt @visit.when(ast.TermBinOp) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term1, ctxt) self.check_scope(ast_node.term2, ctxt) return ctxt @visit.when(ast.TermUnaryOp) def check_scope(self, ast_node, ctxt, lhs=False): self.check_scope(ast_node.term, ctxt) return ctxt @visit.when(ast.TermLit) def check_scope(self, ast_node, ctxt, lhs=False): return ctxt @visit.when(ast.TermUnderscore) def check_scope(self, ast_node, ctxt, lhs=False): return ctxt # Error state operations def flush_error_ctxt(self): self.curr_out_scopes = new_ctxt() self.curr_duplicates = {'vars': {}} def check_var(self, ctxt, var): if not lookup_var(ctxt, var): self.curr_out_scopes['vars'].append(var) return False else: return True def check_pred(self, ctxt, pred): if not lookup_pred(ctxt, pred): self.curr_out_scopes['preds'].append(pred) return False else: return True def check_cons(self, ctxt, cons): if not lookup_cons(ctxt, cons): self.curr_out_scopes['cons'].append(cons) return False else: return True # Get rule scope @visit.on('ast_node') def get_rule_scope(self, ast_node, atoms=True, compre=True): pass @visit.when(list) def get_rule_scope(self, ast_node, atoms=True, compre=True): this_free_vars = [] for obj in ast_node: this_free_vars += self.get_rule_scope(obj, atoms=atoms, compre=compre) return this_free_vars @visit.when(ast.FactBase) def get_rule_scope(self, ast_node, atoms=True, compre=True): if atoms: return self.inspect.free_vars(ast_node) else: return [] @visit.when(ast.FactLoc) def get_rule_scope(self, ast_node, atoms=True, compre=True): if atoms: return self.inspect.free_vars(ast_node) else: return [] @visit.when(ast.FactLocCluster) def get_rule_scope(self, ast_node, atoms=True, compre=True): if atoms: return self.inspect.free_vars(ast_node) else: return [] @visit.when(ast.FactCompre) def get_rule_scope(self, ast_node, atoms=True, compre=True): if compre: comp_ranges = ast_node.comp_ranges if len(comp_ranges) == 1: return [comp_ranges[0].term_range] else: return [] else: return [] # Reporting def record_duplicates(self, tvar, ctxt): if tvar.name not in self.curr_duplicates['vars']: dups = [] for t in ctxt['vars']: if tvar.name == t.name: dups.append(t) dups.append(tvar) self.curr_duplicates['vars'][tvar.name] = dups else: self.curr_duplicates['vars'][tvar.name].append(tvar) def compose_out_scope_error_report(self, ctxt): err = self.curr_out_scopes if len(err['vars']) > 0: legend = ("%s %s: Scope context variable(s).\n" % (terminal.T_GREEN_BACK, terminal.T_NORM)) + ( "%s %s: Out of scope variable(s)." % (terminal.T_RED_BACK, terminal.T_NORM)) error_idx = self.declare_error( "Variable(s) %s not in scope." % (','.join(set(map(lambda t: t.name, err['vars'])))), legend) map(lambda t: self.extend_error(error_idx, t), err['vars']) map(lambda t: self.extend_info(error_idx, t), ctxt['vars']) if len(err['preds']) > 0: legend = ("%s %s: Scope context predicate(s).\n" % (terminal.T_GREEN_BACK, terminal.T_NORM)) + ( "%s %s: Out of scope predicate(s)." % (terminal.T_RED_BACK, terminal.T_NORM)) error_idx = self.declare_error( "Predicate(s) %s not in scope." % (','.join(set(map(lambda t: t.name, err['preds'])))), legend) map(lambda t: self.extend_error(error_idx, t), err['preds']) map(lambda t: self.extend_info(error_idx, t), ctxt['preds']) if len(err['cons']) > 0: legend = ("%s %s: Scope context name(s).\n" % (terminal.T_GREEN_BACK, terminal.T_NORM)) + ( "%s %s: Out of scope name(s)." % (terminal.T_RED_BACK, terminal.T_NORM)) error_idx = self.declare_error( "Name(s) %s not in scope." % (','.join(set(map(lambda t: t.name, err['cons'])))), legend) map(lambda t: self.extend_error(error_idx, t), err['cons']) map(lambda t: self.extend_info(error_idx, t), ctxt['cons']) if len(err['ensem']) > 0: for exec_node in err['ensem']: error_idx = self.declare_error("Ensemble %s not in scope." % exec_node.name) self.extend_error(error_idx, exec_node) self.curr_out_scopes = new_ctxt() def compose_duplicate_error_reports(self, kind, dups): for name in dups: elems = dups[name] if len(elems) > 1: error_idx = self.declare_error( "Duplicated declaration of %s %s." % (kind, name)) map(lambda p: self.extend_error(error_idx, p), elems)
class AlphaIndexer(Transformer): def __init__(self, decs): self.initialize(decs) self.inspect = Inspector() def transform(self): self.int_transform( self.decs ) @visit.on( 'ast_node' ) def int_transform(self, ast_node, ctxt=None): pass @visit.when( list ) def int_transform(self, ast_node, ctxt=None): for node in ast_node: self.int_transform( node, ctxt ) @visit.when( ast.EnsemDec ) def int_transform(self, ast_node, ctxt=None): rules = self.inspect.filter_decs( ast_node.decs, rule=True ) for rule in rules: self. int_transform( rule ) @visit.when( ast.RuleDec ) def int_transform(self, ast_node, ctxt=None): ctxt = FramedCtxt() self.int_transform(ast_node.plhs, ctxt) self.int_transform(ast_node.slhs, ctxt) self.int_transform(ast_node.grd, ctxt) self.int_transform(ast_node.exists, ctxt) self.int_transform(ast_node.where, ctxt) self.int_transform(ast_node.rhs, ctxt) ast_node.next_rule_idx = ctxt.var_idx @visit.when( ast.AssignDec ) def int_transform(self, ast_node, ctxt=None): self.int_transform( ast_node.term_pat, ctxt ) self.int_transform( ast_node.builtin_exp, ctxt ) @visit.when( ast.FactBase ) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform( term, ctxt ) @visit.when( ast.FactLoc ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.loc, ctxt) self.int_transform(ast_node.fact, ctxt) @visit.when( ast.FactLocCluster ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.loc, ctxt) for fact in ast_node.facts: self.int_transform(fact, ctxt) @visit.when( ast.FactCompre ) def int_transform(self, ast_node, ctxt=None): for cr in ast_node.comp_ranges: self.int_transform(cr.term_range, ctxt) comp_binders = self.inspect.free_vars( map(lambda cr: cr.term_vars, ast_node.comp_ranges) ) ctxt.push_frame( keys = set(map(lambda cb: cb.name, comp_binders)) ) self.int_transform(comp_binders, ctxt) self.int_transform(ast_node.facts, ctxt) self.int_transform(ast_node.guards, ctxt) ctxt.pop_frame() @visit.when( ast.TermVar ) def int_transform(self, ast_node, ctxt=None): ast_node.rule_idx = ctxt.get_index( ast_node.name ) @visit.when( ast.TermUnderscore ) def int_transform(self, ast_node, ctxt=None): ast_node.rule_idx = ctxt.new_index() @visit.when( ast.TermApp ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term1, ctxt) self.int_transform(ast_node.term2, ctxt) @visit.when( ast.TermTuple ) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when( ast.TermList ) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when( ast.TermListCons ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term1, ctxt) self.int_transform(ast_node.term2, ctxt) @visit.when( ast.TermMSet ) def int_transform(self, ast_node, ctxt=None): for term in ast_node.terms: self.int_transform(term, ctxt) @visit.when( ast.TermEnumMSet ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.texp1, ctxt) self.int_transform(ast_node.texp2, ctxt) @visit.when( ast.TermCompre ) def int_transform(self, ast_node, ctxt=None): for cr in ast_node.comp_ranges: self.int_transform(cr.term_range, ctxt) comp_binders = self.inspect.free_vars( map(lambda cr: cr.term_vars, ast_node.comp_ranges) ) ctxt.push_frame( keys = set(map(lambda cb: cb.name, comp_binders)) ) self.int_transform(comp_binders, ctxt) self.int_transform(ast_node.term, ctxt) self.int_transform(ast_node.guards, ctxt) ctxt.pop_frame() @visit.when( ast.TermBinOp ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term1, ctxt) self.int_transform(ast_node.term2, ctxt) @visit.when( ast.TermUnaryOp ) def int_transform(self, ast_node, ctxt=None): self.int_transform(ast_node.term, ctxt)
def int_check(self, ast_node): inspect = Inspector() decs = ast_node.decs simplified_pred_names = {} non_local_pred_names = {} lhs_compre_pred_names = {} prioritized_pred_names = {} for rule_dec in inspect.filter_decs(decs, rule=True): rule_head_locs = {} simp_heads = rule_dec.slhs prop_heads = rule_dec.plhs rule_body = rule_dec.rhs # Scan for simplified predicate names for fact in inspect.get_base_facts( simp_heads ): simplified_pred_names[ fact.name ] = () # Scan for non local predicate names # Annotates non local rule body facts as well. loc_var_terms = inspect.free_vars( simp_heads+prop_heads, args=False ) loc_vars = map(lambda t: t.name, loc_var_terms) if len(set(loc_vars)) > 1: # Flag all body predicates as non local for fact in inspect.get_base_facts( rule_body ): non_local_pred_names[ fact.name ] = () fact.local = False else: loc_var = loc_vars[0] (bfs,lfs,lfcs,comps) = inspect.partition_rule_heads( rule_body ) for lf in lfs: if isinstance(lf.loc, ast.TermVar): if lf.loc.name != loc_var: non_local_pred_names[ lf.fact.name ] = () lf.fact.local = False else: # Location is not variable, hence treat as non-local non_local_pred_names[ lf.fact.name ] = () lf.fact.local = False for lfc in lfcs: if isinstance(lfc.loc, ast.TermVar): if lfc.loc.name != loc_var: for f in lfc.facts: non_local_pred_names[ f.name ] = () f.local = False else: # Location is not variable, hence treat as non-local for f in lfc.facts: non_local_pred_names[ f.name ] = () f.local = False for comp in comps: # Assumes that comprehension fact patterns are solo loc_fact = comp.facts[0] if loc_fact.loc.name != loc_var: non_local_pred_names[ loc_fact.loc.name ] = () loc_fact.fact.local = False else: if loc_var in map(lambda tv: tv.name, inspect.free_vars( comp.comp_ranges[0].term_vars )): non_local_pred_name[ loc_fact.loc.name ] = () loc_fact.fact.local = False # Scan for LHS comprehension predicate names (bfs,lfs,lfcs,comps) = inspect.partition_rule_heads( simp_heads + prop_heads ) for comp in comps: loc_fact = comp.facts[0] lhs_compre_pred_names[ loc_fact.fact.name ] = () # Scan for non-unique rule heads rule_head_pred_names = {} for fact in inspect.get_base_facts( simp_heads + prop_heads ): if fact.name not in rule_head_pred_names: rule_head_pred_names[fact.name] = [fact] else: rule_head_pred_names[fact.name].append( fact ) self.rule_unique_heads[ rule_dec.name ] = [] collision_idx = 0 for name in rule_head_pred_names: facts = rule_head_pred_names[name] unique_head = len(facts) == 1 for fact in facts: fact.unique_head = unique_head fact.collision_idx = collision_idx collision_idx += 1 if unique_head: self.rule_unique_heads[rule_dec.name].append( name ) # Scan for priorities self.rule_priority_body[ rule_dec.name ] = {} (bfs,lfs,lfcs,comps) = inspect.partition_rule_heads( rule_body ) for bf in bfs: if bf.priority != None: prioritized_pred_names[ bf.name ] = () self.rule_priority_body[ rule_dec.name ][ bf.name ] = () for lf in lfs: if lf.priority != None: prioritized_pred_names[ lf.fact.name ] = () self.rule_priority_body[ rule_dec.name ][ lf.fact.name ] = () for lfc in lfcs: if lfc.priority != None: for f in lfc.facts: prioritized_pred_names[ f.name ] = () self.rule_priority_body[ rule_dec.name ][ f.name ] = () for comp in comps: if comp.priority != None: for f in comp.facts: prioritized_pred_names[ f.name ] = () self.rule_priority_body[ rule_dec.name ][ f.name ] = () # Annotate fact declaration nodes with relevant information fact_decs = inspect.filter_decs(decs, fact=True) for fact_dec in fact_decs: fact_dec.persistent = fact_dec.name not in simplified_pred_names fact_dec.local = fact_dec.name not in non_local_pred_names fact_dec.monotone = fact_dec.name not in lhs_compre_pred_names fact_dec.uses_priority = fact_dec.name in prioritized_pred_names self.fact_decs = fact_decs # Annotate rule declaration nodes with relevant information rule_decs = inspect.filter_decs(decs, rule=True) for rule_dec in rule_decs: rule_dec.unique_head_names = self.rule_unique_heads[ rule_dec.name ] rule_dec.rule_priority_body_names = self.rule_priority_body[ rule_dec.name ].keys() # Annotate RHS constraints with monotonicity information for rule_dec in rule_decs: rule_body = rule_dec.rhs for fact in inspect.get_base_facts( rule_body ): fact.monotone = fact.name not in lhs_compre_pred_names
class NeighborRestrictChecker(Checker): def __init__(self, decs, source_text, builtin_preds=[]): self.initialize(decs, source_text, builtin_preds=builtin_preds) self.inspect = Inspector() self.fact_dict = {} def check(self): for ensem_dec in self.inspect.filter_decs(self.decs, ensem=True): self.checkEnsem(ensem_dec) def checkEnsem(self, ensem_dec): self.fact_dict = {} for fact_dec in self.inspect.filter_decs(ensem_dec.decs, fact=True): self.fact_dict[fact_dec.name] = fact_dec max_nbr_level = -1 some_requires_sync = False for rule_dec in self.inspect.filter_decs(ensem_dec.decs, rule=True): nbr_level,requires_sync = self.checkRule(rule_dec) if max_nbr_level < nbr_level: max_nbr_level = nbr_level if requires_sync: some_requires_sync = True ensem_dec.max_nbr_level = max_nbr_level ensem_dec.requires_sync = some_requires_sync def checkRule(self, rule_dec): match_obligations = {} for fact in rule_dec.slhs: self.checkFact(fact, match_obligations, True) for fact in rule_dec.plhs: self.checkFact(fact, match_obligations, False) if len(match_obligations.keys()) > 1: rule_dec.is_system_centric = True rule_dec.match_obligations = match_obligations # TODO: Check neighbor relation and determine viable primary location nbr_options = [] for primary_loc in match_obligations: my_facts = retrieveAll( match_obligations[primary_loc] ) other_locs = [] other_facts = [] for loc in match_obligations: if primary_loc != loc: other_locs += [ loc ] other_facts += retrieveAll( match_obligations[loc] ) my_args = [] for fact in my_facts: my_args += map(lambda v: v.name, self.inspect.free_vars( fact, loc=False, args=True )) if subseteq(other_locs, my_args): other_fact_dict = {} for loc in other_locs: other_fact_dict[loc] = match_obligations[loc] has_trigger = False for fact in my_facts: if ast.TRIGGER_FACT == self.getFactRole( fact ): has_trigger = True primary_grds = [] other_grds = {} for loc in other_fact_dict: other_grds[loc] = [] non_iso_grds = [] for grd in rule_dec.grd: if subseteq(self.inspect.free_vars( grd ), my_args): primary_grds.append( grd ) else: grd_added = False for loc in other_fact_dict: other_facts = other_fact_dict[loc] other_args = map(lambda v: v.name, self.inspect.free_vars( other_facts, loc=False, args=True )) if subseteq(self.inspect.free_vars( grd ), my_args+other_args): other_grds[loc].append( grd ) grd_added = True if not grd_added: non_iso_grds.append( grd ) nbr_option = { 'primary_loc' : primary_loc , 'primary_obligation' : match_obligations[primary_loc] , 'primary_guards' : primary_grds , 'other_obligations' : other_fact_dict , 'other_guards' : other_grds , 'non_iso_guards' : non_iso_grds , 'primary_has_trigger' : has_trigger } if len(non_iso_grds) == 0: if has_trigger: nbr_options = [ nbr_option ] + nbr_options else: nbr_options.append( nbr_option ) if len(nbr_options) < 1: error_idx = self.declare_error( "System-centric rule is not neighbor-restricted.") self.extend_error( error_idx, rule_dec ) rule_dec.nbr_options = nbr_options rule_dec.nbr_level = len( other_locs ) # Currently always requires sync # TODO: Check LHS patterns and programmer pragmas rule_dec.requires_sync = True return (rule_dec.nbr_level,rule_dec.requires_sync) else: rule_dec.primary_loc = match_obligations.keys()[0] rule_dec.is_system_centric = False rule_dec.requires_sync = False return (0,False) @visit.on( 'fact' ) def checkFact(self, fact, match_obligations, is_simp): pass @visit.when( ast.FactLoc ) def checkFact(self, fact, match_obligations, is_simp): vs = self.inspect.free_vars( fact.loc ) role = self.getFactRole( fact ) extend(match_obligations, vs[0], fact, is_simp, role) @visit.when( ast.FactCompre ) def checkFact(self, fact, match_obligations, is_simp): vs = self.inspect.free_vars( fact.facts[0].loc ) role = self.getFactRole( fact ) extend(match_obligations, vs[0], fact, is_simp, role) @visit.on( 'fact' ) def getFactRole(self, fact): pass @visit.when( ast.FactLoc ) def getFactRole(self, fact): return self.fact_dict[fact.fact.name].fact_role @visit.when( ast.FactCompre ) def getFactRole(self, fact): return self.fact_dict[fact.facts[0].fact.name].fact_role
def check_int(self, ast_node): inspect = Inspector() self.rule_free_vars = inspect.free_vars( ast_node.plhs + ast_node.slhs ) for fact in ast_node.plhs + ast_node.slhs: self.check_int(fact) self.rule_free_vars = []