Esempio n. 1
0
    def check_scope(self, ast_node, ctxt, lhs=False):
        old_ctxt = ctxt
        ctxt = copy_ctxt(ctxt)
        decs = ast_node.decs
        inspect = Inspector()

        dec_preds = {}

        # Check scoping of extern name declarations
        for extern in inspect.filter_decs(decs, extern=True):
            self.check_scope(extern, ctxt)

        # Check scoping of predicate declarations
        for dec in inspect.filter_decs(decs, fact=True):
            local_ctxt = self.check_scope(dec, ctxt)
            for pred in local_ctxt['preds']:
                if pred.name in dec_preds:
                    dec_preds[pred.name].append(pred)
                else:
                    dec_preds[pred.name] = [pred]
            extend_ctxt(ctxt, local_ctxt)

        self.compose_duplicate_error_reports("predicate", dec_preds)

        # Check scoping of rule declarations
        for dec in inspect.filter_decs(decs, rule=True):
            self.check_scope(dec, ctxt)

        ctxt['vars'] = []
        self.ensems[ast_node.name] = ctxt

        return old_ctxt
Esempio n. 2
0
	def outputVars(self, head):
		if head.is_atom:
			head_atom = head.fact
		else:
			head_atom = head.fact.facts[0]
		head_args = [head_atom.loc] + head_atom.fact.terms
		output_vars = []
		seen = {}
		for i in range(0,len(self.pred_args)):
			pred_arg = self.pred_args[i]
			if isinstance(pred_arg, str) and pred_arg == INPUT:
				pass 
			elif isinstance(pred_arg, str) and pred_arg == OUTPUT:
				output_vars.append( head_args[i] )
			else:
				if pred_arg.rule_idx not in seen:
					output_vars.append( head_args[i] )
					seen[pred_arg.rule_idx] = ()

		inspect = Inspector()
		this_args = []
		for assoc_guard in self.assoc_guards:
			this_args += inspect.free_vars( assoc_guard.term1 ) + inspect.free_vars( assoc_guard.term2 )
		for i in range(0,len(self.guard_args)):
			guard_arg = self.guard_args[i]
			if isinstance(guard_arg, str) and guard_arg == INPUT:
				pass 
			elif isinstance(guard_arg, str) and guard_arg == OUTPUT:
				output_vars.append( this_args[i] )
			elif guard_arg.rule_idx not in seen:
				output_vars.append( this_args[i] )
				seen[guard_arg.rule_idx] = ()

		return output_vars
Esempio n. 3
0
	def check_scope(self, ast_node, ctxt, lhs=False):
		old_ctxt = ctxt
		ctxt = copy_ctxt(ctxt)
		decs = ast_node.decs
		inspect = Inspector()
	
		dec_preds = {}

		# Check scoping of extern name declarations
		for extern in inspect.filter_decs(decs, extern=True):
			self.check_scope(extern, ctxt)

		# Check scoping of predicate declarations
		for dec in inspect.filter_decs(decs, fact=True):
			local_ctxt = self.check_scope(dec, ctxt)
			for pred in local_ctxt['preds']:
				if pred.name in dec_preds:
					dec_preds[pred.name].append(pred)
				else:
					dec_preds[pred.name] = [pred]
			extend_ctxt(ctxt, local_ctxt)

		self.compose_duplicate_error_reports("predicate", dec_preds)

		# Check scoping of rule declarations
		for dec in inspect.filter_decs(decs, rule=True):
			self.check_scope(dec, ctxt)

		ctxt['vars'] = []
		self.ensems[ast_node.name] = ctxt

		return old_ctxt
Esempio n. 4
0
	def __init__(self, head, head_idx, init_binders=False):
		inspect = Inspector()
		self.head = head
		self.head_idx = head_idx
		# self.term_vars  = inspect.free_vars(head.fact.comp_ranges[0].term_vars)
		self.term_vars  = inspect.unfold_term_seq(head.fact.comp_ranges[0].term_vars)
		self.compre_dom = head.fact.comp_ranges[0].term_range
		self.init_binders = init_binders
Esempio n. 5
0
class LHSCompre(Transformer):

	def __init__(self, decs):
		self.initialize(decs)
		self.inspect = Inspector()

	def transform(self):
		ensem_dec = self.inspect.filter_decs(self.decs, ensem=True)[0]
		rule_decs = self.inspect.filter_decs(ensem_dec.decs, rule=True)
		map( lambda rule_dec: self.transformRule(rule_dec), rule_decs)

	def transformRule(self, rule_dec):
		_,atoms,_,compres = self.inspect.partition_rule_heads( rule_dec.plhs + rule_dec.slhs )
		# scope_vars = self.getVarsFilterByRuleIdx(atoms)
		map(lambda c: self.transformCompre(c, rule_dec), compres)

	def transformCompre(self, compre, rule_dec):
		if len(compre.comp_ranges) == 0 and compre.compre_mod in [ast.COMP_NONE_EXISTS,ast.COMP_ONE_OR_MORE]:			
			term_vars  = ast.TermLit(1,"int")
			next_rule_idx = rule_dec.next_rule_idx
			rule_dec.next_rule_idx += 1
			term_range = ast.TermVar( "comp_mod_%s" % next_rule_idx )
			term_range.rule_idx = next_rule_idx
			term_range.type     = ast.TypeMSet( term_vars.type )
			term_range.smt_type = tyMSet( tyInt )
			compre.comp_ranges  = [ ast.CompRange(term_vars, term_range) ]
		if compre.compre_mod == ast.COMP_NONE_EXISTS:
			term_range = compre.comp_ranges[0].term_range
			comp_none_exist_grd = ast.TermBinOp( ast.TermApp(ast.TermCons("size"), term_range), "==", ast.TermLit(0, "int"))
			rule_dec.grd += [comp_none_exist_grd]
		elif compre.compre_mod == ast.COMP_ONE_OR_MORE:
			term_range = compre.comp_ranges[0].term_range
			comp_one_or_more_grd = ast.TermBinOp( ast.TermApp(ast.TermCons("size"), term_range), ">", ast.TermLit(0, "int"))
			rule_dec.grd += [comp_one_or_more_grd]
				

	def getVarsFilterByRuleIdx(self, facts):
		fact_vars = self.getVars(facts)
		uniq_vars = []
		idx_dict  = {}
		for fact_var in fact_vars:
			if fact_var.rule_idx not in idx_dict:
				idx_dict[fact_var.rule_idx] = ()
				uniq_vars.append( fact_var )
		return uniq_vars

	@visit.on( 'fact' )
	def getVars(self, fact):
		pass

	@visit.when( list )
	def getVars(self, fact):
		return foldl(map(lambda f: self.getVars(f), fact), [])

	@visit.when( ast.FactLoc )
	def getVars(self, fact):
		return self.inspect.free_vars(fact)
Esempio n. 6
0
	def __init__(self, fact_dir):
		self.bs_eq_ctxt = emptyset()
		self.eq_ctxt  = emptyset()
		self.mem_grds = []
		self.ord_grds = []
		self.eq_grds  = []
		self.non_idx_grds = []
		self.inspect  = Inspector()
		self.fact_dir = fact_dir
Esempio n. 7
0
File: rules.py Progetto: sllam/chrcp
def process_rules(ensem_dec):
    inspect = Inspector()

    facts = inspect.filter_decs(ensem_dec, fact=True)
    rules = inspect.filter_decs(ensem_dec, rule=True)

    fact_dir = FactDirectory(facts)

    return (fact_dir, map(lambda r: Rule(r, fact_dir), rules))
Esempio n. 8
0
 def __init__(self, head, head_idx):
     self.initialize(head.fact)
     self.head_idx = head_idx
     self.fact = head.fact.facts[0]
     self.fact_pat = head.fact_pat
     self.head = head
     self.compre = head.fact
     inspect = Inspector()
     self.output_vars = inspect.free_vars(self.fact)
Esempio n. 9
0
	def __init__(self, head, head_idx):
		self.initialize(head.fact)
		self.head_idx = head_idx
		self.fact = head.fact.facts[0]
		self.fact_pat = head.fact_pat
		self.head = head
		self.compre = head.fact
		inspect = Inspector()
		self.output_vars =  inspect.free_vars( self.fact )
Esempio n. 10
0
File: rules.py Progetto: sllam/chrcp
def process_rules(ensem_dec):
	inspect = Inspector()

	facts   = inspect.filter_decs(ensem_dec, fact=True)
	rules   = inspect.filter_decs(ensem_dec, rule=True)

	fact_dir = FactDirectory( facts )

	return (fact_dir, map(lambda r: Rule(r, fact_dir),rules))
Esempio n. 11
0
	def __init__(self, head, head_idx):
		self.initialize(head.fact)
		self.head_idx = head_idx
		self.fact = head.fact.facts[0]
		self.fact_pat = head.fact_pat
		self.head = head
		self.compre = head.fact
		inspect = Inspector()
		self.output_vars =  inspect.free_vars( self.fact )
		self.compre_binders = inspect.free_vars( head.fact, loc=False, args=False, compre_binders=True )
Esempio n. 12
0
	def __init__(self, compre_body, body_idx, compre_idx):
		inspect = Inspector()
		self.body = compre_body
		self.fact = compre_body.fact
		self.local    = compre_body.local
		self.priority = compre_body.priority
		self.monotone = compre_body.monotone
		self.term_vars  = inspect.free_vars(compre_body.fact.comp_ranges[0].term_vars)
		self.compre_dom = compre_body.fact.comp_ranges[0].term_range
		self.body_idx   = body_idx
		self.compre_idx = compre_idx
Esempio n. 13
0
 def __init__(self, compre_body, body_idx, compre_idx):
     inspect = Inspector()
     self.body = compre_body
     self.fact = compre_body.fact
     self.local = compre_body.local
     self.priority = compre_body.priority
     self.monotone = compre_body.monotone
     self.term_vars = inspect.free_vars(
         compre_body.fact.comp_ranges[0].term_vars)
     self.compre_dom = compre_body.fact.comp_ranges[0].term_range
     self.body_idx = body_idx
     self.compre_idx = compre_idx
Esempio n. 14
0
def process_ensemble(ensem_dec):
	inspect = Inspector()
	facts   = inspect.filter_decs(ensem_dec.decs, fact=True)
	rules   = inspect.filter_decs(ensem_dec.decs, rule=True)
	externs = inspect.filter_decs(ensem_dec.decs, extern=True)

	# print facts
	# print rules

	fact_dir = FactDirectory( facts )

	return (fact_dir, externs, map(lambda r: Rule(r, fact_dir),rules))
Esempio n. 15
0
 def __init__(self, head, head_idx, lookup, var_gen):
     self.initialize()
     inspect = Inspector()
     self.term_vars = inspect.free_vars(head.fact.comp_ranges[0].term_vars)
     self.compre_dom = head.fact.comp_ranges[0].term_range
     self.lookup = lookup
     self.input_vars = lookup.inputVars(head)
     # output_vars,dep_grds = lookup.outputVarsModuloDependencies(head, var_gen)
     self.output_vars = lookup.outputVars(head)
     # self.dep_grds = dep_grds
     self.head_idx = head_idx
     self.head = head
     self.fact = head.fact
     self.fact_pat = head.fact_pat
Esempio n. 16
0
	def __init__(self, head, head_idx, lookup, var_gen):
		self.initialize()
		inspect = Inspector()
		self.term_vars  = inspect.free_vars(head.fact.comp_ranges[0].term_vars)
		self.compre_dom = head.fact.comp_ranges[0].term_range
		self.lookup = lookup
		self.input_vars  = lookup.inputVars(head)
		# output_vars,dep_grds = lookup.outputVarsModuloDependencies(head, var_gen)
		self.output_vars = lookup.outputVars(head)
		# self.dep_grds = dep_grds
		self.head_idx = head_idx
		self.head = head
		self.fact = head.fact
		self.fact_pat = head.fact_pat
Esempio n. 17
0
def process_prog( decs, prog_name ):
	# Currently assumes that there is exactly one ensemble dec and one exec dec for that emsemble.	
	inspect = Inspector()
	ensem_dec = inspect.filter_decs(decs, ensem=True)[0]
	exec_dec  = inspect.filter_decs(decs, execute=True)[0]

	# print ensem_dec
	# print exec_dec

	fact_dir, externs, rules = process_ensemble( ensem_dec )
	
	prog = ProgCompilation(ensem_dec, rules, fact_dir, externs, exec_dec, prog_name)

	return prog
Esempio n. 18
0
	def check_int(self, ast_node):
		comp_ranges = ast_node.comp_ranges
		if len(comp_ranges) > 1:
			error_idx = self.declare_error("Unsupported LHS Pattern: Comprehension pattern with multiple comprehension range.")
			for comp_range in comp_ranges:
				self.extend_error(error_idx, comp_range)
		else:
			for comp_range in comp_ranges:
				self.check_int(comp_range)
		inspect = Inspector()
		fact_bases = inspect.get_base_facts( ast_node.facts )
		if len(fact_bases) != 1:
			error_idx = self.declare_error("Unsupported LHS Pattern: Comprehension pattern with multiple fact patterns.")
			for f in ast_node.facts:
				self.extend_error(error_idx, f)
Esempio n. 19
0
 def initialize(self, fact, fact_name, fact_dir, is_compre):
     self.fact = fact
     self.id = get_next_fact_index()
     self.inspect = Inspector()
     idx, _ = fact_dir.getFactFromName(fact_name)
     self.fact_idx = idx
     self.is_compre = is_compre
     self.is_atom = not is_compre
Esempio n. 20
0
 def check_int(self, ast_node):
     comp_ranges = ast_node.comp_ranges
     if len(comp_ranges) > 1:
         error_idx = self.declare_error(
             "Unsupported LHS Pattern: Comprehension pattern with multiple comprehension range."
         )
         for comp_range in comp_ranges:
             self.extend_error(error_idx, comp_range)
     else:
         for comp_range in comp_ranges:
             self.check_int(comp_range)
     inspect = Inspector()
     fact_bases = inspect.get_base_facts(ast_node.facts)
     if len(fact_bases) != 1:
         error_idx = self.declare_error(
             "Unsupported LHS Pattern: Comprehension pattern with multiple fact patterns."
         )
         for f in ast_node.facts:
             self.extend_error(error_idx, f)
Esempio n. 21
0
	def __init__(self, fact_dir, loc_fact, ord_guard, is_left_input, var_ctxt):
		pred_idx,_ = fact_dir.getFactFromName( loc_fact.fact.name )
		inspect = Inspector()
		fact_vars = inspect.free_var_idxs(loc_fact.loc) | inspect.free_var_idxs(loc_fact.fact)
		self.degree_freedom = len(fact_vars) - 1
		self.degree_join    = 1
		ord_vars = inspect.free_var_idxs( ord_guard.term1 ) | inspect.free_var_idxs( ord_guard.term2 )
		self.var_dependencies = (fact_vars | ord_vars) & var_ctxt
		has_hash_index = False
		pred_args = []
		for pred_arg in [loc_fact.loc] + loc_fact.fact.terms:
			if pred_arg.rule_idx in self.var_dependencies:
				pred_args.append( INPUT )
				has_hash_index = True
			elif pred_arg.rule_idx in ord_vars:
				pred_args.append( pred_arg ) 
			else:
				pred_args.append( OUTPUT )
		if is_left_input:
			guard_args = [INPUT, ord_guard.term2]
		else:
			guard_args = [ord_guard.term1, INPUT]
		op = '<=' if ord_guard.include_eq else '<'
		guard_str = '%s %s %s' % ('%s',op,'%s')
		if has_hash_index:
			lk_name = "hash+ord"
		else:
			lk_name = "ord"
		self.initialize(ORD_LK, pred_idx, fact_dir, lk_name, pred_args=pred_args, guard_args=guard_args, guard_str=guard_str
                               ,assoc_guards=[ord_guard])
Esempio n. 22
0
	def __init__(self, ensem_dec, rules, fact_dir, extern_decs, exec_dec, prog_name, source_text="", origin_text=""):
		self.ensem_dec = ensem_dec
		self.prog_name  = prog_name
		self.ensem_name = ensem_dec.name
		self.fact_dir = fact_dir
		self.lookup_tables = LookupTables( fact_dir )
		rule_compilations = []
		for rule in rules:
			rule_compilations.append( RuleCompilation(rule, fact_dir, self.lookup_tables) )
		self.rule_compilations = rule_compilations
		self.lookup_tables.padWithLinearLookup()
		self.lookup_tables.padWithExportedLookup()
		self.rules = rules
		self.fact_dir = fact_dir

		self.pred_rule_compilations = defaultdict(list)
		for rule_comp in self.rule_compilations:
			for join_ordering in rule_comp.join_orderings:
				self.pred_rule_compilations[join_ordering.fact_idx].append( join_ordering )

		self.extern_decs = extern_decs
		self.exec_dec    = exec_dec

		self.source_text = source_text
		self.origin_text = origin_text
		if len(origin_text) > 0:
			self.transformed = True
		else:
			self.transformed = False

		inspect = Inspector()
		role_sig_decs = inspect.filter_decs(ensem_dec.decs, rolesig=True)
		role_def_decs = inspect.filter_decs(ensem_dec.decs, roledef=True)
		role_dict = {}
		for role_sig_dec in role_sig_decs:
			role_dict[ role_sig_dec.name ] = { 'sig':role_sig_dec }
		for role_def_dec in role_def_decs:
			role_dict[ role_def_dec.name ]['def'] = role_def_dec
		self.role_dict = role_dict
Esempio n. 23
0
	def inputVars(self, head):
		if head.is_atom:
			head_atom = head.fact
		else:
			head_atom = head.fact.facts[0]
		head_args = [head_atom.loc] + head_atom.fact.terms
		input_vars = []
		for i in range(0,len(self.pred_args)):
			pred_arg = self.pred_args[i]
			if isinstance(pred_arg, str) and pred_arg == INPUT:
				input_vars.append( head_args[i] )
	
		inspect = Inspector()
		this_args = []
		for assoc_guard in self.assoc_guards:
			# this_args += inspect.free_vars( assoc_guard.term1 ) + inspect.free_vars( assoc_guard.term2 )
			this_args += inspect.flatten_term( assoc_guard.term1 ) + inspect.flatten_term( assoc_guard.term2 )
			# this_args += inspect.flatten_term( assoc_guard )
		for i in range(0,len(self.guard_args)):
			guard_arg = self.guard_args[i]
			if isinstance(guard_arg, str) and guard_arg == INPUT:
				input_vars.append( this_args[i] )

		return input_vars
Esempio n. 24
0
	def check_int(self, ast_node):
		comp_ranges = ast_node.comp_ranges
		if len(comp_ranges) > 1:
			error_idx = self.declare_error("Unsupported LHS Pattern: Comprehension pattern with multiple comprehension range.")
			for comp_range in comp_ranges:
				self.extend_error(error_idx, comp_range)
		else:
			for comp_range in comp_ranges:
				self.check_int(comp_range)

		inspect = Inspector()
		fact_bases = inspect.get_base_facts( ast_node.facts )
		if len(fact_bases) != 1:
			error_idx = self.declare_error("Unsupported LHS Pattern: Comprehension pattern with multiple fact patterns.")
			for f in ast_node.facts:
				self.extend_error(error_idx, f)
		else:
			self.check_int( ast_node.facts[0] )

		if (len(ast_node.facts) == 1) and (len(comp_ranges) == 1):
			loc = ast_node.facts[0].loc
			if loc.name in map(lambda v: v.name,inspect.free_vars( comp_ranges[0].term_vars )):
				error_idx = self.declare_error("Unsupported LHS Pattern: Multi-location comprehension patterns.")
				self.extend_error(error_idx, loc)
Esempio n. 25
0
	def __init__(self, fact_dir, loc_fact, mem_guard, var_ctxt):
		pred_idx,_ = fact_dir.getFactFromName( loc_fact.fact.name )
		inspect = Inspector()
		fact_vars = inspect.free_var_idxs(loc_fact.loc) | inspect.free_var_idxs(loc_fact.fact) 
		mem_vars  = inspect.free_var_idxs(mem_guard.term1)
		common_vars = fact_vars & mem_vars
		self.degree_join = len(common_vars)
		degree_freedom = 0

		# Existing dependencies from variable context
		self.var_dependencies = (fact_vars | mem_vars) & var_ctxt
		
		has_hash_index = False
		ppred_args = []
		for pred_arg in [loc_fact.loc] + loc_fact.fact.terms:
			if pred_arg.rule_idx in self.var_dependencies:
				ppred_args.append( INPUT )
				has_hash_index = True
			elif pred_arg.rule_idx in common_vars:
				ppred_args.append( pred_arg )
			else:
				ppred_args.append( OUTPUT )
				degree_freedom += 1
		guard_args = []
		for mem_arg in inspect.free_vars(mem_guard.term1):
			if pred_arg.rule_idx in self.var_dependencies:
				ppred_args.append( INPUT )
				has_hash_index = True
			elif mem_arg.rule_idx in common_vars:
				guard_args.append( mem_arg )
			else:
				guard_args.append( OUTPUT )
				degree_freedom += 1
		guard_args.append( INPUT )
		guard_str =  '(%s) in %s' % (','.join( map(lambda _: '%s',range(0,len(guard_args)-1)) ),'%s')
		self.degree_freedom = degree_freedom
		if has_hash_index:
			lk_name = "hash+mem"
		else:
			lk_name = "mem"
		self.initialize(MEM_LK, pred_idx, fact_dir, lk_name, pred_args=ppred_args, guard_args=guard_args, guard_str=guard_str
                               ,assoc_guards=[mem_guard])
Esempio n. 26
0
 def check(self):
     inspect = Inspector()
     for ensem_dec in inspect.filter_decs(self.decs, ensem=True):
         self.check_int(ensem_dec)
Esempio n. 27
0
	def __init__(self, decs):
		self.initialize(decs)
		self.inspect = Inspector()
Esempio n. 28
0
	def __init__(self, decs, source_text):
		self.inspect = Inspector()
		self.initialize(decs, source_text)	
		self.solver = Solver()
		self.infer_goals = {}
Esempio n. 29
0
	def initialize(self, ast_node=None):
		self.id = get_next_matchtask_index()
		if ast_node != None:
			inspect = Inspector()
			self.free_vars = inspect.free_vars( ast_node )
Esempio n. 30
0
	def __init__(self, head, head_idx):
		inspect = Inspector()
		self.head = head
		self.head_idx = head_idx
		self.term_vars  = inspect.free_vars(head.fact.comp_ranges[0].term_vars)
		self.compre_dom = head.fact.comp_ranges[0].term_range
Esempio n. 31
0
	def __init__(self, decs, source_text, builtin_preds=[]):
		self.initialize(decs, source_text, builtin_preds=builtin_preds)
		self.inspect = Inspector()
		self.fact_dict = {}
Esempio n. 32
0
    def int_check(self, ast_node):

        inspect = Inspector()

        decs = ast_node.decs

        simplified_pred_names = {}
        non_local_pred_names = {}
        lhs_compre_pred_names = {}
        prioritized_pred_names = {}

        for rule_dec in inspect.filter_decs(decs, rule=True):
            rule_head_locs = {}
            simp_heads = rule_dec.slhs
            prop_heads = rule_dec.plhs
            rule_body = rule_dec.rhs

            # Scan for simplified predicate names
            for fact in inspect.get_base_facts(simp_heads):
                simplified_pred_names[fact.name] = ()

            # Scan for non local predicate names
            # Annotates non local rule body facts as well.
            loc_var_terms = inspect.free_vars(simp_heads + prop_heads,
                                              args=False)
            loc_vars = map(lambda t: t.name, loc_var_terms)
            if len(set(loc_vars)) > 1:
                # Flag all body predicates as non local
                for fact in inspect.get_base_facts(rule_body):
                    non_local_pred_names[fact.name] = ()
                    fact.local = False
            else:
                loc_var = loc_vars[0]
                (bfs, lfs, lfcs,
                 comps) = inspect.partition_rule_heads(rule_body)
                for lf in lfs:
                    if isinstance(lf.loc, ast.TermVar):
                        if lf.loc.name != loc_var:
                            non_local_pred_names[lf.fact.name] = ()
                            lf.fact.local = False
                    else:
                        # Location is not variable, hence treat as non-local
                        non_local_pred_names[lf.fact.name] = ()
                        lf.fact.local = False
                for lfc in lfcs:
                    if isinstance(lfc.loc, ast.TermVar):
                        if lfc.loc.name != loc_var:
                            for f in lfc.facts:
                                non_local_pred_names[f.name] = ()
                                f.local = False
                    else:
                        # Location is not variable, hence treat as non-local
                        for f in lfc.facts:
                            non_local_pred_names[f.name] = ()
                            f.local = False
                for comp in comps:
                    # Assumes that comprehension fact patterns are solo
                    loc_fact = comp.facts[0]
                    if loc_fact.loc.name != loc_var:
                        non_local_pred_names[loc_fact.loc.name] = ()
                        loc_fact.fact.local = False
                    else:
                        if loc_var in map(
                                lambda tv: tv.name,
                                inspect.free_vars(
                                    comp.comp_ranges[0].term_vars)):
                            non_local_pred_name[loc_fact.loc.name] = ()
                            loc_fact.fact.local = False

            # Scan for LHS comprehension predicate names
            (bfs, lfs, lfcs,
             comps) = inspect.partition_rule_heads(simp_heads + prop_heads)
            for comp in comps:
                loc_fact = comp.facts[0]
                lhs_compre_pred_names[loc_fact.fact.name] = ()

            # Scan for non-unique rule heads
            rule_head_pred_names = {}
            for fact in inspect.get_base_facts(simp_heads + prop_heads):
                if fact.name not in rule_head_pred_names:
                    rule_head_pred_names[fact.name] = [fact]
                else:
                    rule_head_pred_names[fact.name].append(fact)

            self.rule_unique_heads[rule_dec.name] = []
            collision_idx = 0
            for name in rule_head_pred_names:
                facts = rule_head_pred_names[name]
                unique_head = len(facts) == 1
                for fact in facts:
                    fact.unique_head = unique_head
                    fact.collision_idx = collision_idx
                collision_idx += 1
                if unique_head:
                    self.rule_unique_heads[rule_dec.name].append(name)

            # Scan for priorities
            self.rule_priority_body[rule_dec.name] = {}
            (bfs, lfs, lfcs, comps) = inspect.partition_rule_heads(rule_body)
            for bf in bfs:
                if bf.priority != None:
                    prioritized_pred_names[bf.name] = ()
                    self.rule_priority_body[rule_dec.name][bf.name] = ()
            for lf in lfs:
                if lf.priority != None:
                    prioritized_pred_names[lf.fact.name] = ()
                    self.rule_priority_body[rule_dec.name][lf.fact.name] = ()
            for lfc in lfcs:
                if lfc.priority != None:
                    for f in lfc.facts:
                        prioritized_pred_names[f.name] = ()
                        self.rule_priority_body[rule_dec.name][f.name] = ()
            for comp in comps:
                if comp.priority != None:
                    for f in comp.facts:
                        prioritized_pred_names[f.name] = ()
                        self.rule_priority_body[rule_dec.name][f.name] = ()

        # Annotate fact declaration nodes with relevant information
        fact_decs = inspect.filter_decs(decs, fact=True)
        for fact_dec in fact_decs:
            fact_dec.persistent = fact_dec.name not in simplified_pred_names
            fact_dec.local = fact_dec.name not in non_local_pred_names
            fact_dec.monotone = fact_dec.name not in lhs_compre_pred_names
            fact_dec.uses_priority = fact_dec.name in prioritized_pred_names
        self.fact_decs = fact_decs

        # Annotate rule declaration nodes with relevant information
        rule_decs = inspect.filter_decs(decs, rule=True)
        for rule_dec in rule_decs:
            rule_dec.unique_head_names = self.rule_unique_heads[rule_dec.name]
            rule_dec.rule_priority_body_names = self.rule_priority_body[
                rule_dec.name].keys()

        # Annotate RHS constraints with monotonicity information
        for rule_dec in rule_decs:
            rule_body = rule_dec.rhs
            for fact in inspect.get_base_facts(rule_body):
                fact.monotone = fact.name not in lhs_compre_pred_names
Esempio n. 33
0
	def initialize(self, term, guard_type):
		self.type = guard_type
		self.term = term
		self.id   = get_next_guard_index()
		self.inspect = Inspector()
Esempio n. 34
0
class LookupContext:

	def __init__(self, fact_dir):
		self.bs_eq_ctxt = emptyset()
		self.eq_ctxt  = emptyset()
		self.mem_grds = []
		self.ord_grds = []
		self.eq_grds  = []
		self.non_idx_grds = []
		self.inspect  = Inspector()
		self.fact_dir = fact_dir

	def __repr__(self):
		strs  = "======== Lookup Context ========\n"
		strs += "BootStrap Eqs: %s\n" % self.bs_eq_ctxt
		strs += "Eqs: %s\n" % self.eq_ctxt
		all_grds = self.mem_grds + self.ord_grds + self.eq_grds + self.non_idx_grds
		if len(all_grds) > 0:
			strs += "Guards: %s\n" % (','.join(map(lambda g: g.__repr__(), all_grds)))
		strs += "================================"
		return strs

	def varCtxt(self):
		return self.eq_ctxt | self.bs_eq_ctxt

	def removeBootStrapped(self):
		self.bs_eq_ctxt = emptyset()

	def addFactHead(self, head, boot_strap=False):
		if head.is_atom:
			free_vars = self.inspect.free_var_idxs( head.fact )
		else:
			free_vars  =  self.inspect.free_var_idxs( head.fact.facts[0] ) 
			if not boot_strap:
				free_vars |= self.inspect.free_var_idxs( map(lambda cr: cr.term_range, head.fact.comp_ranges) )
		if not boot_strap:
			self.eq_ctxt = self.eq_ctxt | free_vars
		else:
			self.bs_eq_ctxt = self.bs_eq_ctxt | free_vars

	def addVars(self, term_vars, boot_strap=False):
		free_vars = self.inspect.free_var_idxs( term_vars )
		if not boot_strap:
			self.eq_ctxt = self.eq_ctxt | free_vars
		else:
			self.bs_eq_ctxt = self.bs_eq_ctxt | free_vars

	def addGuard(self, guard):
		if guard.indexable():
			if guard.type == MEM_GRD:
				self.mem_grds.append( guard )
			elif guard.type == ORD_GRD:
				self.ord_grds.append( guard )
			elif guard.type == EQ_GRD:
				self.eq_grds.append( guard )
			else:
				self.non_idx_grds.append( guard )
		else:
			self.non_idx_grds.append( guard )

	def scheduleGuards(self):
		sch_grds = []

		new_eq_grds = []
		for eq_grd in self.eq_grds:
			if eq_grd.scheduleAsGuard( self.varCtxt() ):
				sch_grds.append( eq_grd )
			else:
				new_eq_grds.append( eq_grd )
		self.eq_grds = new_eq_grds

		new_ord_grds = []
		for ord_grd in self.ord_grds:
			if ord_grd.scheduleAsGuard( self.varCtxt() ):
				sch_grds.append( ord_grd )
			else:
				new_ord_grds.append( ord_grd )
		self.ord_grds = new_ord_grds

		new_mem_grds = []
		for mem_grd in self.mem_grds:
			if mem_grd.scheduleAsGuard( self.varCtxt() ):
				sch_grds.append( mem_grd )
			else:
				new_mem_grds.append( mem_grd )
		self.mem_grds = new_mem_grds

		new_non_idx_grds = []
		for non_idx_grd in self.non_idx_grds:
			if non_idx_grd.scheduleAsGuard( self.varCtxt() ):
				sch_grds.append( non_idx_grd )
			else:
				new_non_idx_grds.append( non_idx_grd )
		self.non_idx_grds = new_non_idx_grds

		return sch_grds

	def bestLookupOption(self, new_head_info):
		curr_best_head_idx = -1
		curr_best_lookup   = None
		curr_best_cost     = (10000,0,0)
		curr_best_head     = None
		for head_idx,new_head in new_head_info.items():
			lookups = self.lookupOptions(new_head)
			if lookups[0].cost() < curr_best_cost:
				curr_best_head_idx = head_idx
				curr_best_lookup = lookups[0]
				curr_best_cost   = lookups[0].cost()
				curr_best_head   = new_head
	
		return (curr_best_head_idx, curr_best_lookup)

	def remove_guards(self, rm_grds):
		self.eq_grds  = filter(lambda g: g not in rm_grds, self.eq_grds)
		self.mem_grds = filter(lambda g: g not in rm_grds, self.mem_grds)
		self.ord_grds = filter(lambda g: g not in rm_grds, self.ord_grds)
		self.non_idx_grds = filter(lambda g: g not in rm_grds, self.non_idx_grds)

	# Current implementation ignores Eq guards.
	def lookupOptions(self, new_head):
		if new_head.is_atom:
			loc_fact = new_head.fact
			head_eq_grds  = []
			head_mem_grds = []
			head_ord_grds = [] 
		else:
			loc_fact = new_head.fact.facts[0]
			head_eq_grds  = new_head.eq_grds
			head_mem_grds = new_head.mem_grds
			head_ord_grds = new_head.ord_grds
		pred_idx,_ = self.fact_dir.getFactFromName( loc_fact.fact.name )

		lookup_opts = [ LinearLookup(self.fact_dir, loc_fact.fact.name) ]
		free_vars = self.inspect.free_var_idxs( loc_fact )
		join_vars = self.varCtxt() & free_vars
		new_vars  = free_vars - self.varCtxt()
		
		# Add a hash lookup if the new head has overlapping variables with the current
		# variable context.
		if len(join_vars) > 0:
			hash_lookup = HashLookup(self.fact_dir, loc_fact, join_vars)
			lookup_opts.append( hash_lookup )

		# TODO: Mem guard lookup and ord guard lookup omitted for now.
		'''
		# Add member lookup if the head has overlapping variables with a member guard
		for mem_guard in self.mem_grds + head_mem_grds:
			index_info = mem_guard.scheduleAsIndex( self.varCtxt() )
			if index_info != None:
				input_vars,output_vars,_ = index_info
				if len(new_vars & output_vars) > 0:
					mem_lookup = MemLookup(self.fact_dir, loc_fact, mem_guard, self.varCtxt())
					lookup_opts.append( mem_lookup )

		# Add order lookup
		for ord_guard in self.ord_grds + head_ord_grds:
			index_info = ord_guard.scheduleAsIndex( self.varCtxt() )
			if index_info != None:
				input_vars,output_vars,is_left_input = index_info
				if len(new_vars & output_vars) > 0:
					ord_lookup = OrdLookup(self.fact_dir, loc_fact, ord_guard, is_left_input, self.varCtxt())
					lookup_opts.append( ord_lookup )
		'''
	
		return sorted(lookup_opts, key=lambda lk: lk.cost())
Esempio n. 35
0
	def getRuleNames(self):
		inspect = Inspector()
		rules = inspect.filter_decs( self.ensem_dec.decs, rule=True )
		return map(lambda r: r.name, rules)
Esempio n. 36
0
class VarScopeChecker(Checker):
    def __init__(self, decs, source_text):
        self.inspect = Inspector()
        self.initialize(decs, source_text)
        self.curr_out_scopes = new_ctxt()
        self.curr_duplicates = {'vars': {}}
        self.ensems = {}

    # Main checking operation

    def check(self):

        inspect = self.inspect

        ctxt = new_ctxt()

        # Check scoping of extern name declarations
        for extern in inspect.filter_decs(self.decs, extern=True):
            self.check_scope(extern, ctxt)

        for ensem_dec in inspect.filter_decs(self.decs, ensem=True):
            self.check_scope(ensem_dec, ctxt)

        for exec_dec in inspect.filter_decs(self.decs, execute=True):
            self.check_scope(exec_dec, ctxt)

    @visit.on('ast_node')
    def check_scope(self, ast_node, ctxt, lhs=False):
        pass

    @visit.when(ast.EnsemDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        old_ctxt = ctxt
        ctxt = copy_ctxt(ctxt)
        decs = ast_node.decs
        inspect = Inspector()

        dec_preds = {}

        # Check scoping of extern name declarations
        for extern in inspect.filter_decs(decs, extern=True):
            self.check_scope(extern, ctxt)

        # Check scoping of predicate declarations
        for dec in inspect.filter_decs(decs, fact=True):
            local_ctxt = self.check_scope(dec, ctxt)
            for pred in local_ctxt['preds']:
                if pred.name in dec_preds:
                    dec_preds[pred.name].append(pred)
                else:
                    dec_preds[pred.name] = [pred]
            extend_ctxt(ctxt, local_ctxt)

        self.compose_duplicate_error_reports("predicate", dec_preds)

        # Check scoping of rule declarations
        for dec in inspect.filter_decs(decs, rule=True):
            self.check_scope(dec, ctxt)

        ctxt['vars'] = []
        self.ensems[ast_node.name] = ctxt

        return old_ctxt

    @visit.when(ast.ExecDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        if ast_node.name not in self.ensems:
            self.curr_out_scopes['ensem'].append(ast_node)
            self.compose_out_scope_error_report(ctxt)
        else:
            ctxt = self.ensems[ast_node.name]
            for dec in ast_node.decs:
                self.check_scope(dec, ctxt)
            self.compose_duplicate_error_reports("variables",
                                                 self.curr_duplicates['vars'])
            self.curr_duplicates['vars'] = {}
            self.compose_out_scope_error_report(ctxt)

    @visit.when(ast.ExistDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        for tvar in ast_node.exist_vars:
            if not lookup_var(ctxt, tvar):
                ctxt['vars'] += [tvar]
            else:
                self.record_duplicates(tvar, ctxt)

    @visit.when(ast.LocFactDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        for loc_fact in ast_node.loc_facts:
            self.check_scope(loc_fact, ctxt)

    @visit.when(ast.ExternDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        for type_sig in ast_node.type_sigs:
            self.check_scope(type_sig, ctxt)

    @visit.when(ast.ExternTypeSig)
    def check_scope(self, ast_node, ctxt, lhs=False):
        ctxt['cons'].append(ast_node)

    @visit.when(ast.FactDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        ctxt = new_ctxt()
        ctxt['preds'].append(ast_node)
        # TODO: check types
        return ctxt

    @visit.when(ast.RuleDec)
    def check_scope(self, ast_node, ctxt, lhs=False):
        ctxt = copy_ctxt(ctxt)
        heads = ast_node.slhs + ast_node.plhs
        inspect = self.inspect

        # Extend location context with all rule head variables
        '''
		for fact in map(inspect.get_fact, heads):
			terms = inspect.get_atoms( [inspect.get_loc(fact)] + inspect.get_args(fact) )
			ctxt['vars'] += inspect.filter_atoms(terms, var=True)
		'''
        ctxt['vars'] += self.get_rule_scope(heads, compre=False)

        # Check scope of rule heads. This step checks consistency of constant names and
        # scoping of comprehension patterns.
        map(lambda h: self.check_scope(h, ctxt, lhs=True), heads)
        map(lambda g: self.check_scope(g, ctxt, lhs=True), ast_node.grd)

        ctxt['vars'] += self.get_rule_scope(heads, atoms=False)

        # Include exist variables scopes and check for overlaps with existing variables.
        # (We currently disallow them.)
        dup_vars = {}
        for v in ctxt['vars']:
            dup_vars[v.name] = [v]

        for ex_var in ast_node.exists:
            if ex_var.name in dup_vars:
                dup_vars[ex_var.name].append(ex_var)
            else:
                dup_vars[ex_var.name] = [ex_var]

        ctxt['vars'] += ast_node.exists

        # Incremental include where assign statements
        for ass_stmt in ast_node.where:
            self.check_scope(ass_stmt.builtin_exp, ctxt)
            self.compose_out_scope_error_report(ctxt)
            a_vars = inspect.filter_atoms(inspect.get_atoms(ass_stmt.term_pat),
                                          var=True)
            for a_var in a_vars:
                if a_var.name in dup_vars:
                    dup_vars[a_var.name].append(a_var)
                else:
                    dup_vars[a_var.name] = [a_var]
            ctxt['vars'] += a_vars

        self.compose_duplicate_error_reports("variables", dup_vars)

        map(lambda b: self.check_scope(b, ctxt), ast_node.rhs)
        '''
		for fact in map(inspect.get_fact, ast_node.rhs), fact_atoms=True ):
			loc = inspect.get_loc(fact)
			loc_key = loc.compare_value()
			args = inspect.get_args(fact)
			atoms = inspect.get_atoms(args)
			arg_map[loc_key] += map(lambda t: t.name,inspect.filter_atoms(atoms, var=True))
		'''

        self.compose_out_scope_error_report(ctxt)

    '''
	@visit.when(ast.SetComprehension)
	def check_scope(self, ast_node, ctxt):
		inspect = Inspector()
		ctxt = copy_ctxt(ctxt)
		self.check_scope(ast_node.term_subj, ctxt)
		pat_vars = inspect.filter_atoms( inspect.get_atoms(ast_node.term_pat), var=True)
		ctxt['vars'] += pat_vars
		map(lambda fact: self.check_scope(fact, ctxt), ast_node.facts)
		self.compose_out_scope_error_report(ctxt)
		return ctxt
	'''

    @visit.when(ast.FactBase)
    def check_scope(self, ast_node, ctxt, lhs=False):
        ctxt = copy_ctxt(ctxt)
        self.check_pred(ctxt, ast_node)
        # print ast_node
        map(lambda t: self.check_scope(t, ctxt), ast_node.terms)
        return ctxt

    @visit.when(ast.FactLoc)
    def check_scope(self, ast_node, ctxt, lhs=False):
        ctxt = copy_ctxt(ctxt)
        self.check_scope(ast_node.loc, ctxt)
        self.check_scope(ast_node.fact, ctxt)
        return ctxt

    @visit.when(ast.FactLocCluster)
    def check_scope(self, ast_node, ctxt, lhs=False):
        ctxt = copy_ctxt(ctxt)
        self.check_scope(ast_node.loc, ctxt)
        for fact in ast_node.facts:
            self.check_scope(fact, ctxt)
        return ctxt

    @visit.when(ast.FactCompre)
    def check_scope(self, ast_node, old_ctxt, lhs=False):
        ctxt = copy_ctxt(old_ctxt)
        comp_ranges = ast_node.comp_ranges

        # Check scope of comprehension ranges
        if not lhs:
            map(lambda comp_range: self.check_scope(comp_range, ctxt),
                comp_ranges)

        self.compose_out_scope_error_report(ctxt)

        # Extend variable context with comprehension binders
        ctxt['vars'] += self.inspect.free_vars(
            map(lambda cr: cr.term_vars, comp_ranges))

        # With extended variable context, check scopes of the fact pattern and guards
        for fact in ast_node.facts:
            self.check_scope(fact, ctxt)
        for guard in ast_node.guards:
            self.check_scope(guard, ctxt)

        self.compose_out_scope_error_report(ctxt)

        if lhs:
            old_ctxt['vars'] += self.inspect.free_vars(
                map(lambda cr: cr.term_range, comp_ranges))

    @visit.when(ast.CompRange)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_scope(ast_node.term_range, ctxt)

    @visit.when(ast.TermCons)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_cons(ctxt, ast_node)
        return ctxt

    @visit.when(ast.TermVar)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_var(ctxt, ast_node)
        return ctxt

    @visit.when(ast.TermApp)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_scope(ast_node.term1, ctxt)
        self.check_scope(ast_node.term2, ctxt)
        return ctxt

    @visit.when(ast.TermTuple)
    def check_scope(self, ast_node, ctxt, lhs=False):
        map(lambda t: self.check_scope(t, ctxt), ast_node.terms)
        return ctxt

    @visit.when(ast.TermList)
    def check_scope(self, ast_node, ctxt, lhs=False):
        map(lambda t: self.check_scope(t, ctxt), ast_node.terms)
        return ctxt

    @visit.when(ast.TermListCons)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_scope(ast_node.term1, ctxt)
        self.check_scope(ast_node.term2, ctxt)
        return ctxt

    @visit.when(ast.TermMSet)
    def check_scope(self, ast_node, ctxt, lhs=False):
        map(lambda t: self.check_scope(t, ctxt), ast_node.terms)
        return ctxt

    @visit.when(ast.TermBinOp)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_scope(ast_node.term1, ctxt)
        self.check_scope(ast_node.term2, ctxt)
        return ctxt

    @visit.when(ast.TermUnaryOp)
    def check_scope(self, ast_node, ctxt, lhs=False):
        self.check_scope(ast_node.term, ctxt)
        return ctxt

    @visit.when(ast.TermLit)
    def check_scope(self, ast_node, ctxt, lhs=False):
        return ctxt

    @visit.when(ast.TermUnderscore)
    def check_scope(self, ast_node, ctxt, lhs=False):
        return ctxt

    # Error state operations
    def flush_error_ctxt(self):
        self.curr_out_scopes = new_ctxt()
        self.curr_duplicates = {'vars': {}}

    def check_var(self, ctxt, var):
        if not lookup_var(ctxt, var):
            self.curr_out_scopes['vars'].append(var)
            return False
        else:
            return True

    def check_pred(self, ctxt, pred):
        if not lookup_pred(ctxt, pred):
            self.curr_out_scopes['preds'].append(pred)
            return False
        else:
            return True

    def check_cons(self, ctxt, cons):
        if not lookup_cons(ctxt, cons):
            self.curr_out_scopes['cons'].append(cons)
            return False
        else:
            return True

    # Get rule scope

    @visit.on('ast_node')
    def get_rule_scope(self, ast_node, atoms=True, compre=True):
        pass

    @visit.when(list)
    def get_rule_scope(self, ast_node, atoms=True, compre=True):
        this_free_vars = []
        for obj in ast_node:
            this_free_vars += self.get_rule_scope(obj,
                                                  atoms=atoms,
                                                  compre=compre)
        return this_free_vars

    @visit.when(ast.FactBase)
    def get_rule_scope(self, ast_node, atoms=True, compre=True):
        if atoms:
            return self.inspect.free_vars(ast_node)
        else:
            return []

    @visit.when(ast.FactLoc)
    def get_rule_scope(self, ast_node, atoms=True, compre=True):
        if atoms:
            return self.inspect.free_vars(ast_node)
        else:
            return []

    @visit.when(ast.FactLocCluster)
    def get_rule_scope(self, ast_node, atoms=True, compre=True):
        if atoms:
            return self.inspect.free_vars(ast_node)
        else:
            return []

    @visit.when(ast.FactCompre)
    def get_rule_scope(self, ast_node, atoms=True, compre=True):
        if compre:
            comp_ranges = ast_node.comp_ranges
            if len(comp_ranges) == 1:
                return [comp_ranges[0].term_range]
            else:
                return []
        else:
            return []

    # Reporting

    def record_duplicates(self, tvar, ctxt):
        if tvar.name not in self.curr_duplicates['vars']:
            dups = []
            for t in ctxt['vars']:
                if tvar.name == t.name:
                    dups.append(t)
            dups.append(tvar)
            self.curr_duplicates['vars'][tvar.name] = dups
        else:
            self.curr_duplicates['vars'][tvar.name].append(tvar)

    def compose_out_scope_error_report(self, ctxt):
        err = self.curr_out_scopes
        if len(err['vars']) > 0:
            legend = ("%s %s: Scope context variable(s).\n" %
                      (terminal.T_GREEN_BACK, terminal.T_NORM)) + (
                          "%s %s: Out of scope variable(s)." %
                          (terminal.T_RED_BACK, terminal.T_NORM))
            error_idx = self.declare_error(
                "Variable(s) %s not in scope." %
                (','.join(set(map(lambda t: t.name, err['vars'])))), legend)
            map(lambda t: self.extend_error(error_idx, t), err['vars'])
            map(lambda t: self.extend_info(error_idx, t), ctxt['vars'])
        if len(err['preds']) > 0:
            legend = ("%s %s: Scope context predicate(s).\n" %
                      (terminal.T_GREEN_BACK, terminal.T_NORM)) + (
                          "%s %s: Out of scope predicate(s)." %
                          (terminal.T_RED_BACK, terminal.T_NORM))
            error_idx = self.declare_error(
                "Predicate(s) %s not in scope." %
                (','.join(set(map(lambda t: t.name, err['preds'])))), legend)
            map(lambda t: self.extend_error(error_idx, t), err['preds'])
            map(lambda t: self.extend_info(error_idx, t), ctxt['preds'])
        if len(err['cons']) > 0:
            legend = ("%s %s: Scope context name(s).\n" %
                      (terminal.T_GREEN_BACK, terminal.T_NORM)) + (
                          "%s %s: Out of scope name(s)." %
                          (terminal.T_RED_BACK, terminal.T_NORM))
            error_idx = self.declare_error(
                "Name(s) %s not in scope." %
                (','.join(set(map(lambda t: t.name, err['cons'])))), legend)
            map(lambda t: self.extend_error(error_idx, t), err['cons'])
            map(lambda t: self.extend_info(error_idx, t), ctxt['cons'])
        if len(err['ensem']) > 0:
            for exec_node in err['ensem']:
                error_idx = self.declare_error("Ensemble %s not in scope." %
                                               exec_node.name)
                self.extend_error(error_idx, exec_node)
        self.curr_out_scopes = new_ctxt()

    def compose_duplicate_error_reports(self, kind, dups):
        for name in dups:
            elems = dups[name]
            if len(elems) > 1:
                error_idx = self.declare_error(
                    "Duplicated declaration of %s %s." % (kind, name))
                map(lambda p: self.extend_error(error_idx, p), elems)
Esempio n. 37
0
	def __init__(self, decs, source_text):
		self.inspect = Inspector()
		self.initialize(decs, source_text)
Esempio n. 38
0
 def check_int(self, ast_node):
     inspect = Inspector()
     for dec in inspect.filter_decs(ast_node.decs, rule=True):
         self.check_int(dec)
Esempio n. 39
0
 def check_int(self, ast_node):
     inspect = Inspector()
     self.rule_free_vars = inspect.free_vars(ast_node.plhs + ast_node.slhs)
     for fact in ast_node.plhs + ast_node.slhs:
         self.check_int(fact)
     self.rule_free_vars = []
Esempio n. 40
0
 def __init__(self, decs, source_text):
     self.inspect = Inspector()
     self.initialize(decs, source_text)
     self.solver = Solver()
     self.infer_goals = {}
Esempio n. 41
0
class NeighborRestrictChecker(Checker):

	def __init__(self, decs, source_text, builtin_preds=[]):
		self.initialize(decs, source_text, builtin_preds=builtin_preds)
		self.inspect = Inspector()
		self.fact_dict = {}

	def check(self):
		for ensem_dec in self.inspect.filter_decs(self.decs, ensem=True):
			self.checkEnsem(ensem_dec)

	def checkEnsem(self, ensem_dec):
		self.fact_dict = {}

		for fact_dec in self.inspect.filter_decs(ensem_dec.decs, fact=True):
			self.fact_dict[fact_dec.name] = fact_dec

		max_nbr_level = -1
		some_requires_sync = False
		for rule_dec in self.inspect.filter_decs(ensem_dec.decs, rule=True):
			nbr_level,requires_sync = self.checkRule(rule_dec)
			if max_nbr_level < nbr_level:
				max_nbr_level = nbr_level
			if requires_sync:
				some_requires_sync = True

		ensem_dec.max_nbr_level = max_nbr_level 
		ensem_dec.requires_sync = some_requires_sync

	def checkRule(self, rule_dec):
		match_obligations = {}
		for fact in rule_dec.slhs:
			self.checkFact(fact, match_obligations, True)
		for fact in rule_dec.plhs:
			self.checkFact(fact, match_obligations, False)
		if len(match_obligations.keys()) > 1:
			rule_dec.is_system_centric = True
			rule_dec.match_obligations = match_obligations
			# TODO: Check neighbor relation and determine viable primary location
			
			nbr_options = []
			for primary_loc in match_obligations:
				my_facts = retrieveAll( match_obligations[primary_loc] )
				other_locs = []
				other_facts = []
				for loc in match_obligations:
					if primary_loc != loc:
						other_locs  += [ loc ]
						other_facts += retrieveAll( match_obligations[loc] )
				my_args = []
				for fact in my_facts:
					my_args += map(lambda v: v.name, self.inspect.free_vars( fact, loc=False, args=True ))

				if subseteq(other_locs, my_args):
					other_fact_dict = {}
					for loc in other_locs:
						other_fact_dict[loc] = match_obligations[loc]
					has_trigger = False
					for fact in my_facts:
						if ast.TRIGGER_FACT == self.getFactRole( fact ):
							has_trigger = True

					primary_grds = []
					other_grds   = {}
					for loc in other_fact_dict:
						other_grds[loc] = []
					non_iso_grds  = []
					for grd in rule_dec.grd:
						if subseteq(self.inspect.free_vars( grd ), my_args):
							primary_grds.append( grd )
						else:
							grd_added = False
							for loc in other_fact_dict:
								other_facts = other_fact_dict[loc]
								other_args = map(lambda v: v.name, self.inspect.free_vars( other_facts, loc=False, args=True ))
								if subseteq(self.inspect.free_vars( grd ), my_args+other_args):
									other_grds[loc].append( grd )
									grd_added = True
							if not grd_added:
								non_iso_grds.append( grd )

					nbr_option = { 'primary_loc'         : primary_loc
                                                     , 'primary_obligation'  : match_obligations[primary_loc]
                                                     , 'primary_guards'      : primary_grds
                                                     , 'other_obligations'   : other_fact_dict
                                                     , 'other_guards'        : other_grds
                                                     , 'non_iso_guards'      : non_iso_grds
                                                     , 'primary_has_trigger' : has_trigger }
					if len(non_iso_grds) == 0:
						if has_trigger:
							nbr_options = [ nbr_option ] + nbr_options
						else:
							nbr_options.append( nbr_option )
			if len(nbr_options) < 1:
				error_idx = self.declare_error( "System-centric rule is not neighbor-restricted.")
				self.extend_error( error_idx, rule_dec )

			rule_dec.nbr_options = nbr_options
			rule_dec.nbr_level = len( other_locs )

			# Currently always requires sync
			# TODO: Check LHS patterns and programmer pragmas
			rule_dec.requires_sync = True

			return (rule_dec.nbr_level,rule_dec.requires_sync)
		else:
			rule_dec.primary_loc = match_obligations.keys()[0]
			rule_dec.is_system_centric = False
			rule_dec.requires_sync = False
			return (0,False)

	@visit.on( 'fact' )
	def checkFact(self, fact, match_obligations, is_simp):
		pass

	@visit.when( ast.FactLoc )
	def checkFact(self, fact, match_obligations, is_simp):
		vs = self.inspect.free_vars( fact.loc )
		role = self.getFactRole( fact )
		extend(match_obligations, vs[0], fact, is_simp, role)

	@visit.when( ast.FactCompre )
	def checkFact(self, fact, match_obligations, is_simp):
		vs = self.inspect.free_vars( fact.facts[0].loc )
		role = self.getFactRole( fact )
		extend(match_obligations, vs[0], fact, is_simp, role)

	@visit.on( 'fact' )
	def getFactRole(self, fact):
		pass

	@visit.when( ast.FactLoc )
	def getFactRole(self, fact):
		return self.fact_dict[fact.fact.name].fact_role

	@visit.when( ast.FactCompre )
	def getFactRole(self, fact):
		return self.fact_dict[fact.facts[0].fact.name].fact_role
Esempio n. 42
0
class TypeChecker(Checker):
    def __init__(self, decs, source_text):
        self.inspect = Inspector()
        self.initialize(decs, source_text)
        self.solver = Solver()
        self.infer_goals = {}

    def check(self):
        ctxt = new_ctxt()
        for ensem_dec in self.inspect.filter_decs(self.decs, ensem=True):
            self.int_check_dec(ensem_dec, ctxt)
        for exec_dec in self.inspect.filter_decs(self.decs, execute=True):
            self.int_check_dec(exec_dec, ctxt)

    # Check Declarations

    @visit.on('ast_node')
    def int_check_dec(self, ast_node, ctxt):
        pass

    @visit.when(ast.EnsemDec)
    def int_check_dec(self, ast_node, ctxt):
        inspect = self.inspect
        s0 = True
        extern_cons = []
        for extern_dec in inspect.filter_decs(ast_node.decs, extern=True):
            (s, cons) = self.int_check_dec(extern_dec, ctxt)
            s0 = s0 and s
            extern_cons += cons
        fact_dec_cons = []
        for fact_dec in inspect.filter_decs(ast_node.decs, fact=True):
            (s, cons) = self.int_check_dec(fact_dec, ctxt)
            s0 = s0 and s
            fact_dec_cons += cons
        rule_dec_cons = []
        this_s = True
        for rule_dec in inspect.filter_decs(ast_node.decs, rule=True):
            (s, cons) = self.int_check_dec(rule_dec, ctxt)
            s0 = s0 and s
            if self.check_type_sat(cons + extern_cons + fact_dec_cons):
                rule_dec_cons += cons
            else:
                this_s = False
        ctxt['ensem'][ast_node.name] = {
            'succ': s0 and this_s,
            'cons': extern_cons + fact_dec_cons
        }
        return (s0 and this_s, extern_cons + fact_dec_cons + rule_dec_cons)

    @visit.when(ast.ExternDec)
    def int_check_dec(self, ast_node, ctxt):
        s0 = True
        type_sig_cons = []
        for ty_sig in ast_node.type_sigs:
            (s, t, cs) = self.int_check_type(ty_sig.type_sig, ctxt)
            s0 = s0 and s
            t0 = tyVar()
            type_sig_cons += [t0 | Eq | t | just | [ty_sig]] + cs
            ctxt['cons'][ty_sig.name] = t0
        return (s0, type_sig_cons)

    @visit.when(ast.FactDec)
    def int_check_dec(self, ast_node, ctxt):
        local_ctxt = copy_ctxt(ctxt)
        tvar1 = tyVar()
        ctxt['pred'][ast_node.name] = tvar1
        if ast_node.type == None:
            return (True, [tvar1 | Eq | tyUnit | just | [ast_node]])
        else:
            (s, tvar2, cons) = self.int_check_type(ast_node.type, local_ctxt)
            return (s, [tvar1 | Eq | tvar2 | just | [ast_node]] + cons)

    @visit.when(ast.RuleDec)
    def int_check_dec(self, ast_node, ctxt):
        local_ctxt = copy_ctxt(ctxt)
        head_cons = []
        s0 = True
        for lhs in ast_node.slhs + ast_node.plhs:
            (s, cons) = self.int_check_fact(lhs, local_ctxt)
            s0 = s0 and s
            head_cons += cons
        guard_cons = []
        for guard in ast_node.grd:
            (s, tg, cons) = self.int_check_term(guard, local_ctxt)
            s0 = s0 and s
            guard_cons += [tg | Eq | tyBool | just | [ast_node]] + cons
        exist_cons = []
        for exist in ast_node.exists:
            (s, te, cons) = self.int_check_term(exist, local_ctxt)
            s0 = s0 and s
            exist_cons += [(te | Eq | tyLoc | just | [exist]) | Or |
                           (te | Eq | tyDest | just | [exist])] + cons
        where_cons = []
        for wh in ast_node.where:
            (s, cons) = self.int_check_assign(wh, local_ctxt)
            s0 = s0 and s
            where_cons += cons
        body_cons = []
        for rhs in ast_node.rhs:
            (s, cons) = self.int_check_fact(rhs, local_ctxt)
            s0 = s0 and s
            body_cons += cons
        return (s0,
                head_cons + guard_cons + exist_cons + where_cons + body_cons)

    def int_check_assign(self, assign, ctxt):
        (s0, t0, cs0) = self.int_check_term(assign.term_pat, ctxt)
        (s1, t1, cs1) = self.int_check_term(assign.builtin_exp, ctxt)
        return (s0 and s1, [t0 | Eq | t1 | just | [assign]] + cs0 + cs1)

    @visit.when(ast.ExecDec)
    def int_check_dec(self, ast_node, ctxt):
        local_ctxt = copy_ctxt(ctxt)
        ensem_ty_data = ctxt['ensem'][ast_node.name]
        s0 = True
        exec_cons = []
        for dec in ast_node.decs:
            (s, cons) = self.int_check_dec(dec, local_ctxt)
            s0 = s0 and s
            exec_cons += cons
        if not self.check_type_sat(ensem_ty_data['cons'] + exec_cons):
            s0 = False
        return (s0, ensem_ty_data['cons'] + exec_cons)

    @visit.when(ast.ExistDec)
    def int_check_dec(self, ast_node, ctxt):
        s0 = True
        cons = []
        for exist in ast_node.exist_vars:
            (s, t, cs) = self.int_check_term(exist, ctxt)
            s0 = s0 and s
            cons += [(t | Eq | tyLoc | just | [exist]) | Or |
                     (t | Eq | tyDest | just | [exist])] + cs
        return (s0, cons)

    @visit.when(ast.LocFactDec)
    def int_check_dec(self, ast_node, ctxt):
        s0 = True
        cons = []
        for loc_fact in ast_node.loc_facts:
            (s, cs) = self.int_check_fact(loc_fact, ctxt)
            s0 = s0 and s
            cons += cs
        return (s0, cons)

    # Check Facts

    @visit.on('ast_node')
    def int_check_fact(self, ast_node, ctxt):
        pass

    @visit.when(ast.FactBase)
    def int_check_fact(self, ast_node, ctxt):
        t0 = ctxt['pred'][ast_node.name]
        (s1, ts, term_cons) = self.int_check_term(ast_node.terms, ctxt)
        if len(ts) == 1:
            return (s1, [t0 | Eq | ts[0] | just | [ast_node]] + term_cons)
        else:
            curr = len(ts) - 1
            t1 = tyUnit
            while curr >= 0:
                t1 = tyTuple(ts[curr], t1)
                curr -= 1
            return (s1, [t0 | Eq | t1 | just | [ast_node]] + term_cons)

    @visit.when(ast.FactLoc)
    def int_check_fact(self, ast_node, ctxt):
        (s0, tl, loc_cons) = self.int_check_term(ast_node.loc, ctxt)
        (s1, fact_cons) = self.int_check_fact(ast_node.fact, ctxt)
        return (s0 and s1,
                [tl | Eq | tyLoc | just | [ast_node]] + loc_cons + fact_cons)

    @visit.when(ast.FactLocCluster)
    def int_check_fact(self, ast_node, ctxt):
        (s0, tl, loc_cons) = self.int_check_term(ast_node.loc, ctxt)
        fact_cons = []
        for fact in ast_node.facts:
            (s1, c) = self.int_check_fact(fact, ctxt)
            s0 = s0 and s1
            fact_cons += c
        return (s0,
                [tl | Eq | tyLoc | just | [ast_node]] + loc_cons + fact_cons)

    @visit.when(ast.FactCompre)
    def int_check_fact(self, ast_node, ctxt):
        local_ctxt = copy_ctxt(ctxt)
        s0 = True
        all_cons = []
        for comp_range in ast_node.comp_ranges:
            (s, cs) = self.int_check_comp_range(comp_range, ctxt, local_ctxt)
            s0 = s0 and s
            all_cons += cs
        for fact in ast_node.facts:
            (s, cs) = self.int_check_fact(fact, local_ctxt)
            s0 = s0 and s
            all_cons += cs
        for guard in ast_node.guards:
            (s, tg, cs) = self.int_check_term(guard, local_ctxt)
            s0 = s0 and s
            all_cons += [tg | Eq | tyBool | just | [ast_node]] + cs
        return (s0, all_cons)

    def int_check_comp_range(self, comp_range, ctxt, local_ctxt):
        (s0, t0, cs0) = self.int_check_term(comp_range.term_vars, local_ctxt)
        (s1, t1, cs1) = self.int_check_term(comp_range.term_range, ctxt)
        return (s0 and s1,
                [(t1 | Eq | tyMSet(t0) | just | [comp_range]) | Or |
                 (t1 | Eq | tyList(t0) | just | [comp_range])] + cs0 + cs1)

    # Check Terms

    @visit.on('ast_node')
    def int_check_term(self, ast_node, ctxt):
        pass

    @visit.when(list)
    def int_check_term(self, ast_node, ctxt):
        ts = []
        cons = []
        s0 = True
        for node in ast_node:
            (s, t, c) = self.int_check_term(node, ctxt)
            s0 = s0 and s
            ts += [t]
            cons += c
        return (s0, ts, cons)

    @visit.when(ast.TermVar)
    def int_check_term(self, ast_node, ctxt):
        tvar1 = tyVar()
        if ast_node.name in ctxt['vars']:
            tvar2 = ctxt['vars'][ast_node.name]
        else:
            tvar2 = tyVar()
            ctxt['vars'][ast_node.name] = tvar2
        self.mark_for_infer(ast_node, tvar1)
        return (True, tvar1, [tvar1 | Eq | tvar2 | just | [ast_node]])

    @visit.when(ast.TermCons)
    def int_check_term(self, ast_node, ctxt):
        tvar1 = tyVar()
        tvar2 = ctxt['cons'][ast_node.name]
        self.mark_for_infer(ast_node, tvar1)
        return (True, tvar1, [tvar1 | Eq | tvar2 | just | [ast_node]])

    @visit.when(ast.TermLit)
    def int_check_term(self, ast_node, ctxt):
        tvar1 = tyVar()
        (s, tvar2, cons) = self.int_check_type(ast_node.type, ctxt)
        # print map(lambda c: str(c),cons)
        self.mark_for_infer(ast_node, tvar1)
        return (s, tvar1, [tvar1 | Eq | tvar2 | just | [ast_node]] + cons)

    @visit.when(ast.TermApp)
    def int_check_term(self, ast_node, ctxt):
        (s0, t0, cs0) = self.int_check_term(ast_node.term1, ctxt)
        (s1, t1, cs1) = self.int_check_term(ast_node.term2, ctxt)
        t2 = tyVar()
        t3 = tyVar()
        cs2 = [
            t0 | Eq | tyArrow(t2, t3) | just | [ast_node],
            t1 | Eq | t2 | just | [ast_node]
        ]
        self.mark_for_infer(ast_node, t3)
        return (s0 and s1, t3, cs2 + cs0 + cs1)

    @visit.when(ast.TermTuple)
    def int_check_term(self, ast_node, ctxt):
        (s0, ts, cs0) = self.int_check_term(ast_node.terms, ctxt)
        curr = len(ts) - 1
        t1 = tyUnit
        while curr >= 0:
            t1 = tyTuple(ts[curr], t1)
            curr -= 1
        t0 = tyVar()
        self.mark_for_infer(ast_node, t0)
        return (s0, t0, [t0 | Eq | t1 | just | [ast_node]] + cs0)

    @visit.when(ast.TermList)
    def int_check_term(self, ast_node, ctxt):
        (s0, ts, cs0) = self.int_check_term(ast_node.terms, ctxt)
        cs1 = []
        t1 = tyVar()
        for t0 in ts:
            cs1.append(t1 | Eq | tyList(t0) | just | [ast_node])
        self.mark_for_infer(ast_node, t1)
        return (s0, t1, cs1 + cs0)

    @visit.when(ast.TermListCons)
    def int_check_term(self, ast_node, ctxt):
        (s0, t0, cs0) = self.int_check_term(ast_node.term1, ctxt)
        (s1, t1, cs1) = self.int_check_term(ast_node.term2, ctxt)
        t2 = tyVar()
        cs2 = [t1 | Eq | tyList(t0) | just | [ast_node], t2 | Eq | t1]
        self.mark_for_infer(ast_node, t2)
        return (s0 and s1, t2, cs2 + cs0 + cs1)

    @visit.when(ast.TermMSet)
    def int_check_term(self, ast_node, ctxt):
        (s0, ts, cs0) = self.int_check_term(ast_node.terms, ctxt)
        cs1 = []
        t1 = tyVar()
        for t0 in ts:
            cs1.append(t1 | Eq | tyMSet(t0) | just | [ast_node])
        self.mark_for_infer(ast_node, t1)
        return (s0, t1, cs1 + cs0)

    @visit.when(ast.TermCompre)
    def int_check_term(self, ast_node, ctxt):
        local_ctxt = copy_ctxt(ctxt)
        s0 = True
        all_cons = []
        for comp_range in ast_node.comp_ranges:
            (s, cs) = self.int_check_comp_range(comp_range, ctxt, local_ctxt)
            s0 = s0 and s
            all_cons += cs
        (s, t0, cs) = self.int_check_term(ast_node.term, local_ctxt)
        t1 = tyVar()
        s0 = s0 and s
        all_cons += [t1 | Eq | tyMSet(t0) | just | [ast_node]] + cs
        for guard in ast_node.guards:
            (s, tg, cs) = self.int_check_term(guard, local_ctxt)
            s0 = s0 and s
            all_cons += [tg | Eq | tyBool | just | [ast_node]] + cs
        self.mark_for_infer(ast_node, t1)
        return (s0, t1, all_cons)

    @visit.when(ast.TermBinOp)
    def int_check_term(self, ast_node, ctxt):
        (s0, t0, cs0) = self.int_check_term(ast_node.term1, ctxt)
        (s1, t1, cs1) = self.int_check_term(ast_node.term2, ctxt)
        t3 = tyVar()
        if ast_node.op in BOOL_OPS_1:
            cs3 = [
                t0 | Eq | t1 | just | [ast_node],
                t3 | Eq | tyBool | just | [ast_node]
            ]
        elif ast_node.op in BOOL_OPS_2:
            cs3 = [
                tyMSet(t0) | Eq | t1 | just | [ast_node],
                t3 | Eq | tyBool | just | [ast_node]
            ]
        elif ast_node.op in NUM_OPS:
            cs3 = [
                t0 | Eq | t1 | just | [ast_node],
                t0 | Eq | t3 | just | [ast_node]
            ]
        self.mark_for_infer(ast_node, t3)
        return (s0 and s1, t3, cs3 + cs0 + cs1)

    @visit.when(ast.TermUnaryOp)
    def int_check_term(self, ast_node, ctxt):
        (s0, t0, cs0) = self.int_check_term(ast_node.term, ctxt)
        t1 = tyVar()
        self.mark_for_infer(ast_node, t1)
        return (s0, t1, [
            t1 | Eq | tyBool | just | [ast_node],
            t0 | Eq | tyBool | just | [ast_node]
        ] + cs0)

    @visit.when(ast.TermUnderscore)
    def int_check_term(self, ast_node, ctxt):
        t = tyVar()
        self.mark_for_infer(ast_node, t)
        return (True, t, [])

    # Check Types

    @visit.on('ast_node')
    def int_check_type(self, ast_node, ctxt):
        pass

    @visit.when(ast.TypeVar)
    def int_check_type(self, ast_node, ctxt):
        tvar1 = tyVar()
        if ast_node.name in ctxt['type_vars']:
            tvar2 = ctxt['type_vars'][ast_node.name]
        else:
            tvar2 = tyVar()
            ctxt['type_vars'][ast_node.name] = tvar2
        return (True, tvar1, [tvar1 | Eq | tvar2 | just | [ast_node]])

    @visit.when(ast.TypeCons)
    def int_check_type(self, ast_node, ctxt):
        tvar = tyVar()
        if ast_node.name not in BASE_TYPES:
            error_idx = self.declare_error("Unknown data type \'%s\'" %
                                           ast_node.name)
            self.extend_error(error_idx, ast_node)
            return (False, tvar, [])
        else:
            return (True, tvar, [
                tvar | Eq | smt_base_type(ast_node.name) | just | [ast_node]
            ])

    @visit.when(ast.TypeApp)
    def int_check_type(self, ast_node, ctxt):
        (s1, t1, cons1) = self.int_check_type(ast_node.type1, ctxt)
        (s2, t2, cons2) = self.int_check_type(ast_node.type2, ctxt)
        t3 = tyVar()
        t4 = tyVar()
        return (s1 and s2, t4, [
            t1 | Eq | tyArrow(t3, t4) | just | [ast_node],
            t2 | eq | t3 | just | [ast_node]
        ] + cons1 + cons2)

    @visit.when(ast.TypeArrow)
    def int_check_type(self, ast_node, ctxt):
        (s1, t1, cons1) = self.int_check_type(ast_node.type1, ctxt)
        (s2, t2, cons2) = self.int_check_type(ast_node.type2, ctxt)
        t3 = tyVar()
        return (s1
                and s2, t3, [t3 | Eq | tyArrow(t1, t2) | just | [ast_node]] +
                cons1 + cons2)

    @visit.when(ast.TypeTuple)
    def int_check_type(self, ast_node, ctxt):
        types = ast_node.types
        curr = len(types) - 1
        all_cons = []
        t1 = tyUnit
        s0 = True
        while curr >= 0:
            (s, t2, cons) = self.int_check_type(types[curr], ctxt)
            s0 = s0 and s
            all_cons += cons
            t1 = tyTuple(t2, t1)
            curr -= 1
        t0 = tyVar()
        return (s0, t0, [t0 | Eq | t1 | just | [ast_node]] + all_cons)

    @visit.when(ast.TypeMSet)
    def int_check_type(self, ast_node, ctxt):
        (s, t1, cons) = self.int_check_type(ast_node.type, ctxt)
        t0 = tyVar()
        return (s, t0, [t0 | Eq | tyMSet(t1) | just | [ast_node]] + cons)

    @visit.when(ast.TypeList)
    def int_check_type(self, ast_node, ctxt):
        (s, t1, cons) = self.int_check_type(ast_node.type, ctxt)
        t0 = tyVar()
        return (s, t0, [t0 | Eq | tyList(t1) | just | [ast_node]] + cons)

    # Check Type constraint satisfiability

    def check_type_sat(self, cons):
        # print (map(lambda c: str(c),cons))
        mus = min_unsat_subset(self.solver, cons)
        if len(mus) == 0:
            model = self.solver.solve(cons)
            for d in model:
                key = str(d)
                if key in self.infer_goals:
                    smt_type = model[d]
                    self.infer_goals[key].smt_type = smt_type
                    self.infer_goals[key].type = coerce_type(smt_type)
                    print "%s -> %s::%s" % (self.infer_goals[key],
                                            type_to_data_sort(model[d]),
                                            self.infer_goals[key].type)
                    del self.infer_goals[key]
            return True
        else:
            error_idx = self.declare_error("Type Error in the following sites")
            for error_site in foldl(map(lambda c: c.get_just(), mus), []):
                self.extend_error(error_idx, error_site)
            return False

    # Type inference operations

    def mark_for_infer(self, ast_node, ty_var):
        self.infer_goals[str(ty_var)] = ast_node
Esempio n. 43
0
 def __init__(self, decs):
     self.initialize(decs)
     self.inspect = Inspector()
Esempio n. 44
0
	def int_check(self, ast_node):

		inspect = Inspector()

		decs = ast_node.decs

		simplified_pred_names = {}
		non_local_pred_names  = {}
		lhs_compre_pred_names = {}
		prioritized_pred_names = {}

		for rule_dec in inspect.filter_decs(decs, rule=True):
			rule_head_locs = {}
			simp_heads = rule_dec.slhs
			prop_heads = rule_dec.plhs
			rule_body  = rule_dec.rhs

			# Scan for simplified predicate names
			for fact in inspect.get_base_facts( simp_heads ):
				simplified_pred_names[ fact.name ] = ()

			# Scan for non local predicate names
			# Annotates non local rule body facts as well.
			loc_var_terms = inspect.free_vars( simp_heads+prop_heads, args=False )
			loc_vars = map(lambda t: t.name, loc_var_terms)
			if len(set(loc_vars)) > 1:
				# Flag all body predicates as non local
				for fact in inspect.get_base_facts( rule_body ):
					non_local_pred_names[ fact.name ] = ()
					fact.local = False
			else:
				loc_var = loc_vars[0]
				(bfs,lfs,lfcs,comps) = inspect.partition_rule_heads( rule_body )
				for lf in lfs:
					if isinstance(lf.loc, ast.TermVar):
						if lf.loc.name != loc_var:
							non_local_pred_names[ lf.fact.name ] = ()
							lf.fact.local = False
					else:
						# Location is not variable, hence treat as non-local
						non_local_pred_names[ lf.fact.name ] = ()
						lf.fact.local = False
				for lfc in lfcs:
					if isinstance(lfc.loc, ast.TermVar):
						if lfc.loc.name != loc_var:
							for f in lfc.facts:
								non_local_pred_names[ f.name ] = ()
								f.local = False
					else:
						# Location is not variable, hence treat as non-local
						for f in lfc.facts:
							non_local_pred_names[ f.name ] = ()
							f.local = False
				for comp in comps:
					# Assumes that comprehension fact patterns are solo
					loc_fact = comp.facts[0]
					if loc_fact.loc.name != loc_var:
						non_local_pred_names[ loc_fact.loc.name ] = ()
						loc_fact.fact.local = False
					else:
						if loc_var in map(lambda tv: tv.name, inspect.free_vars( comp.comp_ranges[0].term_vars )):
							non_local_pred_name[ loc_fact.loc.name ] = ()
							loc_fact.fact.local = False

			# Scan for LHS comprehension predicate names
			(bfs,lfs,lfcs,comps) = inspect.partition_rule_heads( simp_heads + prop_heads )
			for comp in comps:
				loc_fact = comp.facts[0]
				lhs_compre_pred_names[ loc_fact.fact.name ] = ()

			# Scan for non-unique rule heads
			rule_head_pred_names = {}
			for fact in inspect.get_base_facts( simp_heads + prop_heads ):
				if fact.name not in rule_head_pred_names:
					rule_head_pred_names[fact.name] = [fact]
				else:
					rule_head_pred_names[fact.name].append( fact )

			self.rule_unique_heads[ rule_dec.name ] = []
			collision_idx = 0
			for name in rule_head_pred_names:
				facts = rule_head_pred_names[name]
				unique_head = len(facts) == 1
				for fact in facts:
					fact.unique_head  = unique_head
					fact.collision_idx = collision_idx
				collision_idx += 1
				if unique_head:
					self.rule_unique_heads[rule_dec.name].append( name )

			# Scan for priorities
			self.rule_priority_body[ rule_dec.name ] = {}
			(bfs,lfs,lfcs,comps) = inspect.partition_rule_heads( rule_body )
			for bf in bfs:
				if bf.priority != None:
					prioritized_pred_names[ bf.name ] = ()
					self.rule_priority_body[ rule_dec.name ][ bf.name ] = ()
			for lf in lfs:
				if lf.priority != None:
					prioritized_pred_names[ lf.fact.name ] = ()
					self.rule_priority_body[ rule_dec.name ][ lf.fact.name ] = ()
			for lfc in lfcs:
				if lfc.priority != None:
					for f in lfc.facts:
						prioritized_pred_names[ f.name ] = ()
						self.rule_priority_body[ rule_dec.name ][ f.name ] = ()
			for comp in comps:
				if comp.priority != None:
					for f in comp.facts:
						prioritized_pred_names[ f.name ] = ()
						self.rule_priority_body[ rule_dec.name ][ f.name ] = ()

		# Annotate fact declaration nodes with relevant information	
		fact_decs = inspect.filter_decs(decs, fact=True)
		for fact_dec in fact_decs:
			fact_dec.persistent = fact_dec.name not in simplified_pred_names
			fact_dec.local      = fact_dec.name not in non_local_pred_names
			fact_dec.monotone   = fact_dec.name not in lhs_compre_pred_names
			fact_dec.uses_priority = fact_dec.name in prioritized_pred_names
		self.fact_decs = fact_decs

		# Annotate rule declaration nodes with relevant information
		rule_decs = inspect.filter_decs(decs, rule=True)
		for rule_dec in rule_decs:
			rule_dec.unique_head_names        = self.rule_unique_heads[ rule_dec.name ]
			rule_dec.rule_priority_body_names = self.rule_priority_body[ rule_dec.name ].keys()

		# Annotate RHS constraints with monotonicity information
		for rule_dec in rule_decs:
			rule_body = rule_dec.rhs
			for fact in inspect.get_base_facts( rule_body ):
				fact.monotone = fact.name not in lhs_compre_pred_names
Esempio n. 45
0
class AlphaIndexer(Transformer):
    def __init__(self, decs):
        self.initialize(decs)
        self.inspect = Inspector()

    def transform(self):
        self.int_transform(self.decs)

    @visit.on('ast_node')
    def int_transform(self, ast_node, ctxt=None):
        pass

    @visit.when(list)
    def int_transform(self, ast_node, ctxt=None):
        for node in ast_node:
            self.int_transform(node, ctxt)

    @visit.when(ast.EnsemDec)
    def int_transform(self, ast_node, ctxt=None):
        rules = self.inspect.filter_decs(ast_node.decs, rule=True)
        for rule in rules:
            self.int_transform(rule)

    @visit.when(ast.RuleDec)
    def int_transform(self, ast_node, ctxt=None):
        ctxt = FramedCtxt()
        self.int_transform(ast_node.plhs, ctxt)
        self.int_transform(ast_node.slhs, ctxt)
        self.int_transform(ast_node.grd, ctxt)
        self.int_transform(ast_node.exists, ctxt)
        self.int_transform(ast_node.where, ctxt)
        self.int_transform(ast_node.rhs, ctxt)
        ast_node.next_rule_idx = ctxt.var_idx

    @visit.when(ast.AssignDec)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.term_pat, ctxt)
        self.int_transform(ast_node.builtin_exp, ctxt)

    @visit.when(ast.FactBase)
    def int_transform(self, ast_node, ctxt=None):
        for term in ast_node.terms:
            self.int_transform(term, ctxt)

    @visit.when(ast.FactLoc)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.loc, ctxt)
        self.int_transform(ast_node.fact, ctxt)

    @visit.when(ast.FactLocCluster)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.loc, ctxt)
        for fact in ast_node.facts:
            self.int_transform(fact, ctxt)

    @visit.when(ast.FactCompre)
    def int_transform(self, ast_node, ctxt=None):
        for cr in ast_node.comp_ranges:
            self.int_transform(cr.term_range, ctxt)

        comp_binders = self.inspect.free_vars(
            map(lambda cr: cr.term_vars, ast_node.comp_ranges))
        ctxt.push_frame(keys=set(map(lambda cb: cb.name, comp_binders)))

        self.int_transform(comp_binders, ctxt)
        self.int_transform(ast_node.facts, ctxt)
        self.int_transform(ast_node.guards, ctxt)

        ctxt.pop_frame()

    @visit.when(ast.TermVar)
    def int_transform(self, ast_node, ctxt=None):
        ast_node.rule_idx = ctxt.get_index(ast_node.name)

    @visit.when(ast.TermApp)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.term1, ctxt)
        self.int_transform(ast_node.term2, ctxt)

    @visit.when(ast.TermTuple)
    def int_transform(self, ast_node, ctxt=None):
        for term in ast_node.terms:
            self.int_transform(term, ctxt)

    @visit.when(ast.TermList)
    def int_transform(self, ast_node, ctxt=None):
        for term in ast_node.terms:
            self.int_transform(term, ctxt)

    @visit.when(ast.TermListCons)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.term1, ctxt)
        self.int_transform(ast_node.term2, ctxt)

    @visit.when(ast.TermMSet)
    def int_transform(self, ast_node, ctxt=None):
        for term in ast_node.terms:
            self.int_transform(term, ctxt)

    @visit.when(ast.TermCompre)
    def int_transform(self, ast_node, ctxt=None):
        for cr in ast_node.comp_ranges:
            self.int_transform(cr.term_range, ctxt)

        comp_binders = self.inspect.free_vars(
            map(lambda cr: cr.term_vars, ast_node.comp_ranges))
        ctxt.push_frame(keys=set(map(lambda cb: cb.name, comp_binders)))

        self.int_transform(comp_binders, ctxt)
        self.int_transform(ast_node.term, ctxt)
        self.int_transform(ast_node.guards, ctxt)

        ctxt.pop_frame()

    @visit.when(ast.TermBinOp)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.term1, ctxt)
        self.int_transform(ast_node.term2, ctxt)

    @visit.when(ast.TermUnaryOp)
    def int_transform(self, ast_node, ctxt=None):
        self.int_transform(ast_node.term, ctxt)
Esempio n. 46
0
class AlphaIndexer(Transformer):

	def __init__(self, decs):
		self.initialize(decs)
		self.inspect = Inspector()

	def transform(self):
		self.int_transform( self.decs )

	@visit.on( 'ast_node' )
	def int_transform(self, ast_node, ctxt=None):
		pass

	@visit.when( list )
	def int_transform(self, ast_node, ctxt=None):
		for node in ast_node:
			self.int_transform( node, ctxt )

	@visit.when( ast.EnsemDec )
	def int_transform(self, ast_node, ctxt=None):
		rules = self.inspect.filter_decs( ast_node.decs, rule=True )
		for rule in rules:
			self. int_transform( rule )

	@visit.when( ast.RuleDec )
	def int_transform(self, ast_node, ctxt=None):
		ctxt = FramedCtxt()
		self.int_transform(ast_node.plhs, ctxt)
		self.int_transform(ast_node.slhs, ctxt)
		self.int_transform(ast_node.grd, ctxt)
		self.int_transform(ast_node.exists, ctxt)
		self.int_transform(ast_node.where, ctxt)
		self.int_transform(ast_node.rhs, ctxt)
		ast_node.next_rule_idx = ctxt.var_idx

	@visit.when( ast.AssignDec )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform( ast_node.term_pat, ctxt )
		self.int_transform( ast_node.builtin_exp, ctxt )

	@visit.when( ast.FactBase )
	def int_transform(self, ast_node, ctxt=None):
		for term in ast_node.terms:
			self.int_transform( term, ctxt )

	@visit.when( ast.FactLoc )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.loc, ctxt)
		self.int_transform(ast_node.fact, ctxt)

	@visit.when( ast.FactLocCluster )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.loc, ctxt)
		for fact in ast_node.facts:
			self.int_transform(fact, ctxt)

	@visit.when( ast.FactCompre )
	def int_transform(self, ast_node, ctxt=None):
		for cr in ast_node.comp_ranges:
			self.int_transform(cr.term_range, ctxt)

		comp_binders = self.inspect.free_vars( map(lambda cr: cr.term_vars, ast_node.comp_ranges) )
		ctxt.push_frame( keys = set(map(lambda cb: cb.name, comp_binders)) )

		self.int_transform(comp_binders, ctxt)
		self.int_transform(ast_node.facts, ctxt)
		self.int_transform(ast_node.guards, ctxt)

		ctxt.pop_frame()

	@visit.when( ast.TermVar )
	def int_transform(self, ast_node, ctxt=None):
		ast_node.rule_idx = ctxt.get_index( ast_node.name )

	@visit.when( ast.TermUnderscore )
	def int_transform(self, ast_node, ctxt=None):
		ast_node.rule_idx = ctxt.new_index()
		
	@visit.when( ast.TermApp )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.term1, ctxt) 
		self.int_transform(ast_node.term2, ctxt)

	@visit.when( ast.TermTuple )
	def int_transform(self, ast_node, ctxt=None):
		for term in ast_node.terms:
			self.int_transform(term, ctxt)

	@visit.when( ast.TermList )
	def int_transform(self, ast_node, ctxt=None):
		for term in ast_node.terms:
			self.int_transform(term, ctxt)
	
	@visit.when( ast.TermListCons )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.term1, ctxt) 
		self.int_transform(ast_node.term2, ctxt)

	@visit.when( ast.TermMSet )
	def int_transform(self, ast_node, ctxt=None):
		for term in ast_node.terms:
			self.int_transform(term, ctxt)

	@visit.when( ast.TermEnumMSet )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.texp1, ctxt)
		self.int_transform(ast_node.texp2, ctxt)

	@visit.when( ast.TermCompre )
	def int_transform(self, ast_node, ctxt=None):
		for cr in ast_node.comp_ranges:
			self.int_transform(cr.term_range, ctxt)

		comp_binders = self.inspect.free_vars( map(lambda cr: cr.term_vars, ast_node.comp_ranges) )
		ctxt.push_frame( keys = set(map(lambda cb: cb.name, comp_binders)) )

		self.int_transform(comp_binders, ctxt)
		self.int_transform(ast_node.term, ctxt)
		self.int_transform(ast_node.guards, ctxt)

		ctxt.pop_frame()

	@visit.when( ast.TermBinOp )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.term1, ctxt) 
		self.int_transform(ast_node.term2, ctxt)

	@visit.when( ast.TermUnaryOp )
	def int_transform(self, ast_node, ctxt=None):
		self.int_transform(ast_node.term, ctxt) 
Esempio n. 47
0
    def __init__(self, rule, occ_idx, fact_dir, lookup_tables):

        inspect = Inspector()

        occ_head = rule.occ_heads[occ_idx]
        atom_partner_heads = {}
        compre_partner_heads = {}
        for partner_idx, partner_head in rule.occ_heads.items():
            if partner_idx != occ_idx:
                if partner_head.is_atom:
                    atom_partner_heads[partner_idx] = partner_head
                else:
                    compre_partner_heads[partner_idx] = partner_head

        self.occ_head = occ_head
        self.fact_idx = occ_head.fact_idx
        self.rule = rule
        self.occ_idx = occ_idx
        self.is_active_prop = False
        self.is_propagated = True

        simp_head_idx = []
        prop_head_idx = []

        # Build Active Match task and Initialize guard pool.
        guard_pool = map(lambda a: a, rule.idx_grds + rule.non_idx_grds)
        if occ_head.is_atom:
            active_task = ActiveAtom(occ_head, 0)
            boot_strap = False
            collision_map = defaultdict(list)
            collision_map[occ_head.collision_idx].append((occ_head.is_atom, 0))
            if occ_head.head_type == SIMP_HEAD:
                simp_head_idx.append(0)
                self.is_propagated = False
            else:
                prop_head_idx.append(0)
                self.is_active_prop = True
        else:
            # We bootstrap active comprehensions. Specifically, we execute the comprehension
            # pattern as though it is an atom. This atom is treated as a 'phantom' head.
            # Bootstrapping is completed below at (*)
            active_task = ActiveCompre(occ_head, 0)
            boot_strap = True
            guard_pool += occ_head.eq_grds + occ_head.mem_grds + occ_head.ord_grds + occ_head.non_idx_grds
            collision_map = defaultdict(list)
            if occ_head.head_type == SIMP_HEAD:
                self.is_propagated = False
            else:
                self.is_active_prop = True

        # Initiate Matching Context
        lctxt = LookupContext(fact_dir)
        lctxt.addFactHead(occ_head, boot_strap=boot_strap)
        map(lambda g: lctxt.addGuard(g), guard_pool)

        # Initiate Head Match Tasks
        head_match_tasks = {}
        head_match_tasks[0] = [active_task] + map(lambda g: CheckGuard(g),
                                                  lctxt.scheduleGuards())

        # Build Partner Match task
        head_idx = 1
        while len(atom_partner_heads.items()) > 0:

            # Find best lookup leap to next partner
            (best_idx,
             best_lookup) = lctxt.bestLookupOption(atom_partner_heads)
            partner_head = atom_partner_heads[best_idx]
            del atom_partner_heads[best_idx]

            # Register head type
            if partner_head.head_type == SIMP_HEAD:
                simp_head_idx.append(head_idx)
                self.is_propagated = False
            else:
                prop_head_idx.append(head_idx)

            # Build lookup match task for best lookup partner option
            lookup_task = LookupAtom(partner_head, head_idx, best_lookup, rule)

            lookup_tables.registerLookup(best_lookup)

            # Update lookup context
            # lctxt.addFactHead( partner_head )
            lctxt.addVars(lookup_task.input_vars + lookup_task.output_vars)
            # map(lambda g: lctxt.addGuard(g), lookup_task.dep_grds)
            lctxt.remove_guards(best_lookup.assoc_guards)

            # Append new match tasks
            partner_match_tasks = [lookup_task] + collision_guard_task(
                True, head_idx, collision_map[partner_head.collision_idx])
            partner_match_tasks += map(lambda g: CheckGuard(g),
                                       lctxt.scheduleGuards())
            head_match_tasks[head_idx] = partner_match_tasks

            # Register partner collision index in collision map
            collision_map[partner_head.collision_idx].append((True, head_idx))

            head_idx += 1

        # (*) Complete active comprehension bootstrapping:
        #    i.   Add comprehension guards to end of lookup atom tasks
        #    ii.  Add active comprehension tasks to the head of the lookup compre tasks.
        #    iii. (ii) is followed by adding an 'exclude atom' task
        if occ_head.is_compre:

            if occ_head.head_type == SIMP_HEAD:
                simp_head_idx.append(head_idx)
                self.is_propagated = False
            else:
                prop_head_idx.append(head_idx)

            lctxt.removeBootStrapped()
            lookup = lctxt.lookupOptions(occ_head)[0]
            lookup_task = LookupAll(occ_head, head_idx, lookup, rule)

            lookup_tables.registerLookup(lookup)

            # lctxt.addFactHead( occ_head )
            lctxt.addVars(lookup_task.output_vars + [occ_head.compre_dom])
            # map(lambda g: lctxt.addGuard(g), lookup_task.dep_grds)

            # lctxt.addVars( lookup.outputVars( occ_head ) + [occ_head.compre_dom] )
            lctxt.remove_guards(lookup.assoc_guards)

            compre_guards = []
            for guard in occ_head.eq_grds + occ_head.mem_grds + occ_head.ord_grds + occ_head.non_idx_grds:
                if guard not in lookup.assoc_guards:
                    compre_guards.append(guard)

            boot_match_tasks = [lookup_task] + collision_guard_task(
                False, head_idx, collision_map[occ_head.collision_idx])
            # boot_match_tasks += map(lambda g: FilterGuard(lookup_task.term_vars, lookup_task.compre_dom, head_idx, g), compre_guards)
            if len(compre_guards) > 0:
                boot_match_tasks += [
                    FilterGuard(lookup_task.term_vars, lookup_task.compre_dom,
                                head_idx, compre_guards)
                ]
            boot_match_tasks += [CompreDomain(occ_head, head_idx)]
            boot_match_tasks += map(lambda g: CheckGuard(g),
                                    lctxt.scheduleGuards())
            head_match_tasks[head_idx] = boot_match_tasks

            collision_map[occ_head.collision_idx].append((False, head_idx))

            head_idx += 1

        while len(compre_partner_heads.items()) > 0:

            # Find best lookup leap to next partner
            (best_idx,
             best_lookup) = lctxt.bestLookupOption(compre_partner_heads)
            partner_head = compre_partner_heads[best_idx]
            del compre_partner_heads[best_idx]

            # Register head type
            if partner_head.head_type == SIMP_HEAD:
                simp_head_idx.append(head_idx)
                self.is_propagated = False
            else:
                prop_head_idx.append(head_idx)

            # Build lookup match task for best lookup partner option
            lookup_task = LookupAll(partner_head, head_idx, best_lookup, rule)

            lookup_tables.registerLookup(best_lookup)

            # Update lookup context
            # lctxt.addFactHead( partner_head )
            lctxt.addVars(lookup_task.output_vars + [partner_head.compre_dom])
            # map(lambda g: lctxt.addGuard(g), lookup_task.dep_grds)
            # lctxt.addVars( best_lookup.outputVars( partner_head ) + [partner_head.compre_dom] )
            lctxt.remove_guards(best_lookup.assoc_guards)

            compre_guards = []
            for guard in partner_head.eq_grds + partner_head.mem_grds + partner_head.ord_grds + partner_head.non_idx_grds:
                if guard not in best_lookup.assoc_guards:
                    compre_guards.append(guard)

            # Append new match tasks
            partner_match_tasks = [lookup_task] + collision_guard_task(
                False, head_idx, collision_map[partner_head.collision_idx])
            # partner_match_tasks += map(lambda g: FilterGuard(lookup_task.term_vars, lookup_task.compre_dom, head_idx, g)
            #                           ,compre_guards)
            if len(compre_guards) > 0:
                partner_match_tasks += [
                    FilterGuard(lookup_task.term_vars, lookup_task.compre_dom,
                                head_idx, compre_guards)
                ]
            partner_match_tasks += [CompreDomain(partner_head, head_idx)]
            partner_match_tasks += map(lambda g: CheckGuard(g),
                                       lctxt.scheduleGuards())
            head_match_tasks[head_idx] = partner_match_tasks

            collision_map[partner_head.collision_idx].append((False, head_idx))

            head_idx += 1

        self.head_indices = range(0, head_idx)
        self.head_match_tasks = head_match_tasks

        self.delete_head_tasks = map(lambda hidx: DeleteHead(hidx),
                                     simp_head_idx)

        # Extract exist tasks
        exist_tasks = []
        for exist_dest in rule.exist_dests:
            exist_tasks.append(ExistDest(exist_dest))
        for exist_loc in rule.exist_locs:
            exist_tasks.append(ExistLoc(exist_loc))
        self.exist_tasks = exist_tasks

        # Extract let binding tasks
        letbind_tasks = []
        for where in rule.where:
            letbind_tasks.append(LetBind(where))
        self.letbind_tasks = letbind_tasks

        # Extract rule body tasks
        atom_body_tasks = []
        compre_body_tasks = []
        compre_idx = 0
        body_idx = 0
        for body in rule.rule_body:
            if body.is_atom:
                atom_body_tasks.append(IntroAtom(body, body_idx))
            else:
                compre_body_tasks.append(
                    IntroCompre(body, body_idx, compre_idx))
                compre_idx += 1
            body_idx += 1
        self.atom_body_tasks = atom_body_tasks
        self.compre_body_tasks = compre_body_tasks
Esempio n. 48
0
class TypeChecker(Checker):

	def __init__(self, decs, source_text):
		self.inspect = Inspector()
		self.initialize(decs, source_text)	
		self.solver = Solver()
		self.infer_goals = {}

	def check(self):
		ctxt = new_ctxt()
		for ensem_dec in self.inspect.filter_decs(self.decs, ensem=True):
			self.int_check_dec(ensem_dec, ctxt)
		for exec_dec in self.inspect.filter_decs(self.decs, execute=True):
			self.int_check_dec(exec_dec, ctxt)

	# Check Declarations

	@visit.on( 'ast_node' )
	def int_check_dec(self, ast_node, ctxt):
		pass

	@visit.when( ast.EnsemDec )
	def int_check_dec(self, ast_node, ctxt):
		inspect = self.inspect
		s0 = True
		extern_cons = []
		for extern_dec in inspect.filter_decs(ast_node.decs, extern=True):
			(s,cons) = self.int_check_dec(extern_dec, ctxt)
			s0 = s0 and s
			extern_cons += cons
		fact_dec_cons = []
		for fact_dec in inspect.filter_decs(ast_node.decs, fact=True):
			(s,cons) = self.int_check_dec(fact_dec, ctxt)
			s0 = s0 and s
			fact_dec_cons += cons
		rule_dec_cons = []
		this_s = True
		for rule_dec in inspect.filter_decs(ast_node.decs, rule=True):
			(s,cons) = self.int_check_dec(rule_dec, ctxt)
			s0 = s0 and s
			if self.check_type_sat( cons + extern_cons + fact_dec_cons ):	
				rule_dec_cons += cons
			else:
				this_s = False
		ctxt['ensem'][ast_node.name] = { 'succ':s0 and this_s, 'cons':extern_cons + fact_dec_cons }
		return (s0 and this_s, extern_cons + fact_dec_cons + rule_dec_cons)

	@visit.when( ast.ExternDec )
	def int_check_dec(self, ast_node, ctxt):
		s0 = True
		type_sig_cons = []		
		for ty_sig in ast_node.type_sigs:
			(s,t,cs) = self.int_check_type(ty_sig.type_sig, ctxt)
			s0 = s0 and s
			t0 = tyVar()
			type_sig_cons += [t0 |Eq| t |just| [ty_sig]] + cs
			ctxt['cons'][ty_sig.name] = t0
		return (s0, type_sig_cons)

	@visit.when( ast.FactDec )
	def int_check_dec(self, ast_node, ctxt):
		local_ctxt  = copy_ctxt(ctxt)
		tvar1 = tyVar()
		ctxt['pred'][ast_node.name] = tvar1
		if ast_node.type == None:
			return (True,[tvar1 |Eq| tyUnit |just| [ast_node]])
		else:
			(s,tvar2,cons) = self.int_check_type(ast_node.type, local_ctxt)
			return (s,[tvar1 |Eq| tvar2 |just| [ast_node]] + cons)

	@visit.when( ast.RuleDec )
	def int_check_dec(self, ast_node, ctxt):
		local_ctxt = copy_ctxt(ctxt)
		head_cons = []
		s0 = True
		for lhs in ast_node.slhs + ast_node.plhs:
			(s,cons) = self.int_check_fact(lhs, local_ctxt)
			s0 = s0 and s
			head_cons += cons
		guard_cons = []
		for guard in ast_node.grd:
			(s,tg,cons) = self.int_check_term(guard, local_ctxt)
			s0 = s0 and s
			guard_cons += [tg |Eq| tyBool |just| [ast_node]] + cons
		exist_cons = []
		for exist in ast_node.exists:
			(s,te,cons) = self.int_check_term(exist, local_ctxt)
			s0 = s0 and s
			exist_cons += [(te |Eq| tyLoc |just| [exist])|Or|(te |Eq| tyDest |just| [exist])] + cons 
		where_cons = []
		for wh in ast_node.where:
			(s,cons) = self.int_check_assign(wh, local_ctxt)
			s0 = s0 and s
			where_cons += cons
		body_cons = []
		for rhs in ast_node.rhs:
			(s,cons) = self.int_check_fact(rhs, local_ctxt)
			s0 = s0 and s
			body_cons += cons
		return (s0, head_cons + guard_cons + exist_cons + where_cons + body_cons)

	def int_check_assign(self, assign, ctxt):
		(s0,t0,cs0) = self.int_check_term(assign.term_pat, ctxt)
		(s1,t1,cs1) = self.int_check_term(assign.builtin_exp, ctxt)
		return (s0 and s1,[t0 |Eq| t1 |just| [assign]] + cs0 + cs1)

	@visit.when( ast.ExecDec )
	def int_check_dec(self, ast_node, ctxt):
		local_ctxt = copy_ctxt( ctxt )
		ensem_ty_data = ctxt['ensem'][ast_node.name]
		s0 = True
		exec_cons = []
		for dec in ast_node.decs:
			(s,cons) = self.int_check_dec(dec, local_ctxt)
			s0 = s0 and s
			exec_cons += cons
		if not self.check_type_sat( ensem_ty_data['cons'] + exec_cons ):
			s0 = False
		return (s0, ensem_ty_data['cons'] + exec_cons)

	@visit.when( ast.ExistDec )
	def int_check_dec(self, ast_node, ctxt):
		s0   = True
		cons = []		
		for exist in ast_node.exist_vars:
			(s,t,cs) = self.int_check_term(exist, ctxt)
			s0 = s0 and s
			cons += [(t |Eq| tyLoc |just| [exist])|Or|(t |Eq| tyDest |just| [exist])] + cs
		return (s0, cons)

	@visit.when( ast.LocFactDec )
	def int_check_dec(self, ast_node, ctxt):
		s0   = True
		cons = []		
		for loc_fact in ast_node.loc_facts:
			(s,cs) = self.int_check_fact(loc_fact, ctxt)
			s0 = s0 and s
			cons += cs
		return (s0, cons)

	# Check Facts

	@visit.on( 'ast_node' )
	def int_check_fact(self, ast_node, ctxt):
		pass


	@visit.when( ast.FactBase )
	def int_check_fact(self, ast_node, ctxt):
		t0 = ctxt['pred'][ast_node.name]
		(s1,ts,term_cons) = self.int_check_term(ast_node.terms, ctxt)
		if len(ts) == 1:
			return (s1, [t0 |Eq| ts[0] |just| [ast_node]] + term_cons)
		else:
			curr = len(ts)-1
			t1 = tyUnit
			while curr >= 0:
				t1 = tyTuple(ts[curr],t1)
				curr -= 1
			return (s1, [t0 |Eq| t1 |just| [ast_node]] + term_cons)

	@visit.when( ast.FactLoc )
	def int_check_fact(self, ast_node, ctxt):
		(s0, tl, loc_cons) = self.int_check_term(ast_node.loc, ctxt)
		(s1, fact_cons)    = self.int_check_fact(ast_node.fact, ctxt)
		return (s0 and s1, [tl |Eq| tyLoc |just| [ast_node]] + loc_cons + fact_cons)

	@visit.when( ast.FactLocCluster )
	def int_check_fact(self, ast_node, ctxt):
		(s0, tl, loc_cons) = self.int_check_term(ast_node.loc, ctxt)
		fact_cons = []
		for fact in ast_node.facts:
			(s1, c) = self.int_check_fact(fact, ctxt)
			s0 = s0 and s1
			fact_cons += c
		return (s0, [tl |Eq| tyLoc |just| [ast_node]] + loc_cons + fact_cons)
		
	@visit.when( ast.FactCompre )
	def int_check_fact(self, ast_node, ctxt):
		local_ctxt = copy_ctxt( ctxt )
		s0 = True
		all_cons = []
		for comp_range in ast_node.comp_ranges:
			(s,cs) = self.int_check_comp_range( comp_range, ctxt, local_ctxt )
			s0 = s0 and s
			all_cons += cs
		for fact in ast_node.facts:
			(s,cs) = self.int_check_fact( fact, local_ctxt )
			s0 = s0 and s
			all_cons += cs
		for guard in ast_node.guards:
			(s,tg,cs) = self.int_check_term( guard, local_ctxt )
			s0 = s0 and s
			all_cons += [tg |Eq| tyBool |just| [ast_node]] + cs		
		return (s0, all_cons)

	def int_check_comp_range(self, comp_range, ctxt, local_ctxt):
		(s0,t0,cs0) = self.int_check_term(comp_range.term_vars, local_ctxt)
		(s1,t1,cs1) = self.int_check_term(comp_range.term_range, ctxt)
		return (s0 and s1, [(t1 |Eq| tyMSet(t0) |just| [comp_range])|Or|(t1 |Eq| tyList(t0) |just| [comp_range])] + cs0 + cs1)

	# Check Terms

	@visit.on( 'ast_node' )
	def int_check_term(self, ast_node, ctxt):
		pass

	@visit.when( list )
	def int_check_term(self, ast_node, ctxt):
		ts = []
		cons = []
		s0 = True
		for node in ast_node:
			(s,t,c) = self.int_check_term(node, ctxt)
			s0 = s0 and s
			ts   += [t]
			cons += c
		return (s0,ts,cons)

	@visit.when( ast.TermVar )
	def int_check_term(self, ast_node, ctxt):
		tvar1 = tyVar()
		if ast_node.name in ctxt['vars']:
			tvar2 = ctxt['vars'][ast_node.name]
		else:
			tvar2 = tyVar()
			ctxt['vars'][ast_node.name] = tvar2
		self.mark_for_infer(ast_node, tvar1)
		return (True, tvar1, [tvar1 |Eq| tvar2 |just| [ast_node]])
		
	@visit.when( ast.TermCons )
	def int_check_term(self, ast_node, ctxt):
		tvar1 = tyVar()
		tvar2 = ctxt['cons'][ast_node.name] 
		self.mark_for_infer(ast_node, tvar1)
		return (True, tvar1, [tvar1 |Eq| tvar2 |just| [ast_node]])

	@visit.when( ast.TermLit )
	def int_check_term(self, ast_node, ctxt):
		tvar1 = tyVar()
		(s,tvar2,cons) = self.int_check_type( ast_node.type, ctxt)
		# print map(lambda c: str(c),cons)
		self.mark_for_infer(ast_node, tvar1)
		return (s,tvar1,[tvar1 |Eq| tvar2 |just| [ast_node]] + cons)

	@visit.when( ast.TermApp )
	def int_check_term(self, ast_node, ctxt):
		(s0,t0,cs0) = self.int_check_term(ast_node.term1, ctxt)
		(s1,t1,cs1) = self.int_check_term(ast_node.term2, ctxt)
		t2  = tyVar()
		t3  = tyVar()
		cs2 = [t0 |Eq| tyArrow(t2,t3) |just| [ast_node], t1 |Eq| t2 |just| [ast_node]]
		self.mark_for_infer(ast_node, t3)
		return (s0 and s1, t3, cs2 + cs0 + cs1)

	@visit.when( ast.TermTuple )
	def int_check_term(self, ast_node, ctxt):
		(s0, ts, cs0) = self.int_check_term(ast_node.terms, ctxt)
		curr = len(ts) - 1
		t1   = tyUnit
		while curr >= 0:
			t1 = tyTuple(ts[curr],t1)
			curr -= 1
		t0 = tyVar()
		self.mark_for_infer(ast_node, t0)
		return (s0, t0, [t0 |Eq| t1 |just| [ast_node]] + cs0)

	@visit.when( ast.TermList )
	def int_check_term(self, ast_node, ctxt):
		(s0, ts, cs0) = self.int_check_term(ast_node.terms, ctxt)
		cs1 = []
		t1  = tyVar()
		for t0 in ts:
			cs1.append(t1 |Eq| tyList(t0) |just| [ast_node])
		self.mark_for_infer(ast_node, t1)
		return (s0, t1, cs1 + cs0)

	@visit.when( ast.TermListCons )
	def int_check_term(self, ast_node, ctxt):
		(s0,t0,cs0) = self.int_check_term(ast_node.term1, ctxt)
		(s1,t1,cs1) = self.int_check_term(ast_node.term2, ctxt)
		t2  = tyVar()
		cs2 = [t1 |Eq| tyList(t0) |just| [ast_node], t2 |Eq| t1]
		self.mark_for_infer(ast_node, t2)
		return (s0 and s1, t2, cs2 + cs0 + cs1)

	@visit.when( ast.TermMSet )
	def int_check_term(self, ast_node, ctxt):
		(s0, ts, cs0) = self.int_check_term(ast_node.terms, ctxt)
		cs1 = []
		t1  = tyVar()
		for t0 in ts:
			cs1.append(t1 |Eq| tyMSet(t0) |just| [ast_node])
		self.mark_for_infer(ast_node, t1)
		return (s0, t1, cs1 + cs0)

	@visit.when( ast.TermCompre )
	def int_check_term(self, ast_node, ctxt):
		local_ctxt = copy_ctxt( ctxt )
		s0 = True
		all_cons = []
		for comp_range in ast_node.comp_ranges:
			(s,cs) = self.int_check_comp_range( comp_range, ctxt, local_ctxt )
			s0 = s0 and s
			all_cons += cs
		(s,t0,cs) = self.int_check_term(ast_node.term, local_ctxt)
		t1 = tyVar()
		s0 = s0 and s
		all_cons += [t1 |Eq| tyMSet(t0) |just| [ast_node]] + cs
		for guard in ast_node.guards:
			(s,tg,cs) = self.int_check_term( guard, local_ctxt )
			s0 = s0 and s
			all_cons += [tg |Eq| tyBool |just| [ast_node]] + cs		
		self.mark_for_infer(ast_node, t1)
		return (s0, t1, all_cons)

	@visit.when( ast.TermBinOp )
	def int_check_term(self, ast_node, ctxt):
		(s0,t0,cs0) = self.int_check_term(ast_node.term1, ctxt)
		(s1,t1,cs1) = self.int_check_term(ast_node.term2, ctxt)
		t3 = tyVar()
		if ast_node.op in BOOL_OPS_1:
			cs3 = [t0 |Eq| t1 |just| [ast_node], t3 |Eq| tyBool |just| [ast_node]]
		elif ast_node.op in BOOL_OPS_2:
			cs3 = [tyMSet(t0) |Eq| t1 |just| [ast_node], t3 |Eq| tyBool |just| [ast_node]]
		elif ast_node.op in NUM_OPS:
			cs3 = [t0 |Eq| t1 |just| [ast_node], t0 |Eq| t3 |just| [ast_node]]
		self.mark_for_infer(ast_node, t3)
		return (s0 and s1, t3, cs3 + cs0 + cs1)

	@visit.when( ast.TermUnaryOp )
	def int_check_term(self, ast_node, ctxt):
		(s0,t0,cs0) = self.int_check_term(ast_node.term, ctxt)
		t1 = tyVar()
		self.mark_for_infer(ast_node, t1)
		return (s0, t1, [t1 |Eq| tyBool |just| [ast_node], t0 |Eq| tyBool |just| [ast_node]] + cs0)

	@visit.when( ast.TermUnderscore )
	def int_check_term(self, ast_node, ctxt):
		t = tyVar()
		self.mark_for_infer(ast_node, t)
		return (True, t, [])

	# Check Types

	@visit.on( 'ast_node' )
	def int_check_type(self, ast_node, ctxt):
		pass

	@visit.when( ast.TypeVar )
	def int_check_type(self, ast_node, ctxt):
		tvar1 = tyVar()
		if ast_node.name in ctxt['type_vars']:
			tvar2 = ctxt['type_vars'][ast_node.name]
		else:
			tvar2 = tyVar()
			ctxt['type_vars'][ast_node.name] = tvar2
		return (True, tvar1, [tvar1 |Eq| tvar2 |just| [ast_node]])

	@visit.when(ast.TypeCons)
	def int_check_type(self, ast_node, ctxt):
		tvar = tyVar()
		if ast_node.name not in BASE_TYPES:
			error_idx = self.declare_error("Unknown data type \'%s\'" % ast_node.name)
			self.extend_error(error_idx,ast_node)
			return (False,tvar,[])
		else:
			return (True,tvar,[tvar |Eq| smt_base_type(ast_node.name) |just| [ast_node]])

	@visit.when(ast.TypeApp)
	def int_check_type(self, ast_node, ctxt):
		(s1,t1,cons1) = self.int_check_type(ast_node.type1, ctxt)
		(s2,t2,cons2) = self.int_check_type(ast_node.type2, ctxt)
		t3 = tyVar()
		t4 = tyVar()
		return (s1 and s2, t4, [t1 |Eq| tyArrow(t3,t4) |just| [ast_node], t2 |eq| t3 |just| [ast_node]] + cons1 + cons2)

	@visit.when(ast.TypeArrow)
	def int_check_type(self, ast_node, ctxt):
		(s1,t1,cons1) = self.int_check_type(ast_node.type1, ctxt)
		(s2,t2,cons2) = self.int_check_type(ast_node.type2, ctxt)
		t3 = tyVar()
		return (s1 and s2,t3,[t3 |Eq| tyArrow(t1,t2) |just| [ast_node]] + cons1 + cons2)

	@visit.when(ast.TypeTuple)
	def int_check_type(self, ast_node, ctxt):
		types = ast_node.types
		curr  = len(types)-1
		all_cons = []
		t1 = tyUnit
		s0 = True
		while curr >= 0:
			(s,t2,cons) = self.int_check_type(types[curr], ctxt)
			s0 = s0 and s
			all_cons += cons
			t1 = tyTuple(t2,t1)
			curr -= 1
		t0 = tyVar()
		return (s0,t0,[t0 |Eq| t1 |just| [ast_node]] + all_cons)

	@visit.when(ast.TypeMSet)
	def int_check_type(self, ast_node, ctxt):
		(s,t1,cons) = self.int_check_type(ast_node.type, ctxt)
		t0 = tyVar()
		return (s,t0,[t0 |Eq| tyMSet(t1) |just| [ast_node]] + cons)

	@visit.when(ast.TypeList)
	def int_check_type(self, ast_node, ctxt):
		(s,t1,cons) = self.int_check_type(ast_node.type, ctxt)
		t0 = tyVar()
		return (s,t0,[t0 |Eq| tyList(t1) |just| [ast_node]] + cons)

	# Check Type constraint satisfiability
	
	def check_type_sat(self, cons):
		# print (map(lambda c: str(c),cons))
		mus = min_unsat_subset(self.solver, cons)
		if len(mus) == 0:
			model = self.solver.solve(cons)
			for d in model:
				key = str(d)
				if key in self.infer_goals:
					smt_type = model[d]
					self.infer_goals[key].smt_type = smt_type
					self.infer_goals[key].type = coerce_type( smt_type )
					print "%s -> %s::%s" % (self.infer_goals[key], type_to_data_sort( model[d] ), self.infer_goals[key].type )
					del self.infer_goals[key]
			return True
		else:
			error_idx = self.declare_error("Type Error in the following sites")
			for error_site in foldl( map(lambda c: c.get_just(),mus), []):
				self.extend_error(error_idx, error_site)
			return False

	# Type inference operations

	def mark_for_infer(self, ast_node, ty_var):
		self.infer_goals[str(ty_var)] = ast_node
Esempio n. 49
0
 def initialize(self, ast_node=None):
     self.id = get_next_matchtask_index()
     if ast_node != None:
         inspect = Inspector()
         self.free_vars = inspect.free_vars(ast_node)
Esempio n. 50
0
 def __init__(self, decs, source_text):
     self.inspect = Inspector()
     self.initialize(decs, source_text)
     self.curr_out_scopes = new_ctxt()
     self.curr_duplicates = {'vars': {}}
     self.ensems = {}
Esempio n. 51
0
 def __init__(self, head, head_idx):
     inspect = Inspector()
     self.head = head
     self.head_idx = head_idx
     self.term_vars = inspect.free_vars(head.fact.comp_ranges[0].term_vars)
     self.compre_dom = head.fact.comp_ranges[0].term_range
Esempio n. 52
0
	def __init__(self, decs, source_text):
		self.inspect = Inspector()
		self.initialize(decs, source_text)
		self.curr_out_scopes = new_ctxt()
		self.curr_duplicates = { 'vars':{} }
		self.ensems = {}