Example #1
0
def is_pure_symbol(symbol, clauses):
    ### Get the instances of the symbol 

    #print symbol
    #print clauses

    instances = []
    for c in clauses:
        if type(c) == sp.Or:
            symbols_in_clause = list(c.args)
            for s in symbols_in_clause:
                if s == symbol or sp.Not(s) == symbol:
                    instances.append(s)
        else:
            if c == symbol or sp.Not(c) == symbol:
                instances.append(c)

    #print instances

    ### Determine if symbol appears with same sign everywhere
    instances = set(instances)

    #print instances

    if len(instances) == 1:
        return (True, symbol)
    else:
        return (False, symbol)
Example #2
0
    def claripy_ast_to_sympy_expr(ast, memo=None):
        if ast.op == "And":
            return sympy.And(
                *(ConditionProcessor.claripy_ast_to_sympy_expr(arg, memo=memo)
                  for arg in ast.args))
        if ast.op == "Or":
            return sympy.Or(
                *(ConditionProcessor.claripy_ast_to_sympy_expr(arg, memo=memo)
                  for arg in ast.args))
        if ast.op == "Not":
            return sympy.Not(
                ConditionProcessor.claripy_ast_to_sympy_expr(ast.args[0],
                                                             memo=memo))

        if ast.op in _UNIFIABLE_COMPARISONS:
            # unify comparisons to enable more simplification opportunities without going "deep" in sympy
            inverse_op = getattr(ast.args[0],
                                 claripy.operations.inverse_operations[ast.op])
            return sympy.Not(
                ConditionProcessor.claripy_ast_to_sympy_expr(inverse_op(
                    ast.args[1]),
                                                             memo=memo))

        if memo is not None and ast in memo:
            return memo[ast]
        symbol = sympy.Symbol(str(hash(ast)))
        if memo is not None:
            memo[symbol] = ast
        return symbol
Example #3
0
 def sanitized_nand(*args):
     """
     Replaces sympy nand with Or of Nots (because Nand introduces problems with other replacements)
     :param args:
     :return:
     """
     return sympy.Or(*(sympy.Not(x) for x in args))
Example #4
0
def negate_expr(node):
    """ Negates an AST expression by adding a `Not` AST node in front of it. 
    """

    # Negation support for SymPy expressions
    if isinstance(node, sympy.Basic):
        return sympy.Not(node)
    # Support for numerical constants
    if isinstance(node, numbers.Number):
        return str(not node)
    # Negation support for strings (most likely dace.Data.Scalar names)
    if isinstance(node, str):
        return "not ({})".format(node)

    from dace.properties import CodeBlock  # Avoid import loop
    if isinstance(node, CodeBlock):
        node = node.code
    if hasattr(node, "__len__"):
        if len(node) > 1:
            raise ValueError("negate_expr only expects "
                             "single expressions, got: {}".format(node))
        expr = node[0]
    else:
        expr = node
    if isinstance(expr, ast.Expr):
        expr = expr.value

    newexpr = ast.Expr(value=ast.UnaryOp(op=ast.Not(), operand=expr))
    newexpr = ast.copy_location(newexpr, expr)
    return ast.fix_missing_locations(newexpr)
Example #5
0
    def as_cpp(self, defined_vars, symbols) -> str:
        expr = ''
        for i, elem in enumerate(self.elements):
            expr += elem.as_cpp(defined_vars, symbols)
            # In a general block, emit transitions and assignments after each
            # individual state
            if isinstance(elem, SingleState):
                sdfg = elem.state.parent
                out_edges = sdfg.out_edges(elem.state)
                for j, e in enumerate(out_edges):
                    if e not in self.edges_to_ignore:
                        # If this is the last generated edge and it leads
                        # to the next state, skip emitting goto
                        successor = None
                        if (j == (len(out_edges) - 1)
                                and (i + 1) < len(self.elements)):
                            successor = self.elements[i + 1].first_state

                        expr += elem.generate_transition(sdfg, e, successor)
                # Add exit goto as necessary
                if elem.last_state:
                    continue
                # Two negating conditions
                if (len(out_edges) == 2
                        and out_edges[0].data.condition_sympy() == sp.Not(
                            out_edges[1].data.condition_sympy())):
                    continue
                # One unconditional edge
                if (len(out_edges) == 1
                        and out_edges[0].data.is_unconditional()):
                    continue
                expr += f'goto __state_exit_{sdfg.sdfg_id};\n'

        return expr
Example #6
0
def backward_chaining(kb, q):
    clauses = list(kb.args)

    ### Construct list of symbols known to be true
    known_true_symbols = []
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            known_true_symbols.append(c)

    negation = False;
    ### check if query is a negation
    if type(q) == sp.Not and q.args[0] not in known_true_symbols:
        negation = True
        q = q.args[0]
    
    ### Construct tables of premises and conclusions keyed by clauses
    premises = {}
    conclusions = {}
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            premises[c] = None
            conclusions[c] = None
        else:
            symbols_in_clause = c.args
            premise_list = []
            for s in symbols_in_clause:
                if type(s) == sp.Not:
                    premise_list.append(sp.Not(s))
                else:
                    conclusion = s
            premises[c] = tuple(premise_list)
            conclusions[c] = conclusion
    
    return backward_chaining_helper(kb, clauses, known_true_symbols, premises, conclusions, q, negation)
Example #7
0
def to_automaton(f) -> SymbolicDFA:  # noqa: C901
    """Translate to automaton."""
    f = f.to_nnf()
    initial_state = frozenset({frozenset({PLAtomic(f)})})
    states = {initial_state}
    final_states = set()
    transition_function = {}  # type: Dict

    all_labels = f.find_labels()
    alphabet = powerset(all_labels)

    if f.delta({}, epsilon=True) == PLTrue():
        final_states.add(initial_state)

    visited = set()  # type: Set
    to_be_visited = {initial_state}

    while len(to_be_visited) != 0:

        for q in list(to_be_visited):
            to_be_visited.remove(q)
            for actions_set in alphabet:
                new_state = _make_transition(
                    q, {label: True
                        for label in actions_set})
                if new_state not in states:
                    states.add(new_state)
                    to_be_visited.add(new_state)

                transition_function.setdefault(q, {})[actions_set] = new_state

                if new_state not in visited:
                    visited.add(new_state)
                    if _is_true(new_state):
                        final_states.add(new_state)

    automaton = SymbolicAutomaton()
    state2idx = {}
    for state in states:
        state_idx = automaton.create_state()
        state2idx[state] = state_idx
        if state == initial_state:
            automaton.set_initial_state(state_idx)
        if state in final_states:
            automaton.set_accepting_state(state_idx, True)

    for source in transition_function:
        for symbol, destination in transition_function[source].items():
            source_idx = state2idx[source]
            dest_idx = state2idx[destination]
            pos_expr = sympy.And(*map(sympy.Symbol, symbol))
            neg_expr = sympy.And(*map(lambda x: sympy.Not(sympy.Symbol(x)),
                                      all_labels.difference(symbol)))
            automaton.add_transition(
                (source_idx, sympy.And(pos_expr, neg_expr), dest_idx))

    determinized = automaton.determinize()
    minimized = determinized.minimize()
    return minimized
Example #8
0
 def check(cond__):
     false_cond = simplify_and(sympy.And(sympy.Not(cond__),
                                         extra_condition))
     if false_cond == sympy.sympify(False):
         return True
     true_cond = simplify_and(sympy.And(cond__, extra_condition))
     if true_cond == sympy.sympify(False):
         return False
     return None
Example #9
0
 def on_lblfact (self, match, sub=None, neg=None, atom=None) :
     if neg is not None :
         return sympy.Not(neg)
     elif isinstance(atom, (int, sympy.Integer)) :
         return sympy.Symbol("CONST:%s" % atom)
     elif isinstance(atom, str) :
         return sympy.Symbol(atom)
     else :
         return sub
Example #10
0
def test_logic():
    x = true
    y = false
    x1 = sympy.true
    y1 = sympy.false

    assert And(x, y) == And(x1, y1)
    assert And(x1, y) == And(x1, y1)
    assert And(x, y)._sympy_() == sympy.And(x1, y1)
    assert sympify(sympy.And(x1, y1)) == And(x, y)

    assert Or(x, y) == Or(x1, y1)
    assert Or(x1, y) == Or(x1, y1)
    assert Or(x, y)._sympy_() == sympy.Or(x1, y1)
    assert sympify(sympy.Or(x1, y1)) == Or(x, y)

    assert Not(x) == Not(x1)
    assert Not(x1) == Not(x1)
    assert Not(x)._sympy_() == sympy.Not(x1)
    assert sympify(sympy.Not(x1)) == Not(x)

    assert Xor(x, y) == Xor(x1, y1)
    assert Xor(x1, y) == Xor(x1, y1)
    assert Xor(x, y)._sympy_() == sympy.Xor(x1, y1)
    assert sympify(sympy.Xor(x1, y1)) == Xor(x, y)

    x = Symbol("x")
    x1 = sympy.Symbol("x")

    assert Piecewise((x, x < 1), (0, True)) == Piecewise((x1, x1 < 1),
                                                         (0, True))
    assert Piecewise((x, x1 < 1), (0, True)) == Piecewise((x1, x1 < 1),
                                                          (0, True))
    assert Piecewise((x, x < 1), (0, True))._sympy_() == sympy.Piecewise(
        (x1, x1 < 1), (0, True))
    assert sympify(sympy.Piecewise((x1, x1 < 1), (0, True))) == Piecewise(
        (x, x < 1), (0, True))

    assert Contains(x, Interval(1, 1)) == Contains(x1, Interval(1, 1))
    assert Contains(x, Interval(1, 1))._sympy_() == sympy.Contains(
        x1, Interval(1, 1))
    assert sympify(sympy.Contains(x1,
                                  Interval(1,
                                           1))) == Contains(x, Interval(1, 1))
Example #11
0
 def print_stmt(self, i, name_to_id, statement, indent=1, end_of_line=""):
     if isinstance(statement, ast_inter.Statement):
         for c in statement.children:
             self.print_stmt(i, name_to_id, c, indent + 4, ";")
     elif isinstance(statement, ast_inter.Receive):  # Action
         #TODO update to restrict by sender!
         self.write_indent(indent, "do")
         for a in statement.actions:
             self.write_indent(
                 indent, ":: channel_" + str(i) + "?" +
                 self.as_msg(a.str_msg_type) + " ->")
             self.print_stmt(i,
                             name_to_id,
                             a.program,
                             indent + 1,
                             end_of_line=";")
             self.write_indent(indent + 1, "break")
         if statement.motion != None:
             self.write_indent(indent, ":: timeout ->")
             self.print_motion(i, statement.get_label(), statement.motion,
                               indent + 1)
         self.write_indent(indent, "od" + end_of_line)
     elif isinstance(statement, ast_inter.If):  # IfComponent
         self.write_indent(indent, "if")
         for c in statement.if_list:
             self.write_indent(
                 indent,
                 ":: " + self.condition_as_string(c.condition) + " ->")
             self.print_stmt(i, name_to_id, c.program, indent + 1)
         self.write_indent(indent, "fi" + end_of_line)
     elif isinstance(statement, ast_inter.While):
         self.write_indent(indent, "do")
         self.write_indent(
             indent,
             "::" + self.condition_as_string(statement.condition) + " ->")
         self.print_stmt(i, name_to_id, statement.program, indent + 1)
         self.write_indent(
             indent, "::" +
             self.condition_as_string(sp.Not(statement.condition)) + " ->")
         self.write_indent(indent + 1, "break")
         self.write_indent(indent, "od" + end_of_line)
     elif isinstance(statement, ast_inter.Send):
         self.write_indent(
             indent, "channel_" + str(name_to_id[statement.comp]) + "!" +
             self.as_msg(statement.msg_type) + end_of_line)
     elif isinstance(statement, ast_inter.Motion):
         self.print_motion(i, statement.get_label(), statement, indent,
                           end_of_line)
     elif isinstance(statement, ast_inter.Print):
         self.write_indent(indent, "skip" + end_of_line)
     elif isinstance(statement, ast_inter.Skip):
         self.write_indent(indent, "skip" + end_of_line)
     elif isinstance(statement, ast_inter.Exit):
         self.write_indent(indent, "goto LEXIT_" + str(i) + end_of_line)
     else:
         raise Exception("!??! " + str(statement))
Example #12
0
 def render_UnaryOp(self, node):
     op_name = node.op.__class__.__name__
     if op_name == 'UAdd':
         return self.render_node(node.operand)
     elif op_name == 'USub':
         return -self.render_node(node.operand)
     elif op_name == 'Not':
         return sympy.Not(self.render_node(node.operand))
     else:
         raise ValueError('Unknown unary operator: ' + op_name)
Example #13
0
def TEST_resolve():
    A = sp.sympify("A")
    B = sp.sympify("B")
    C = sp.sympify("C")

    ### Test a case that produces the empty clause
    clause_i = A
    clause_j = sp.Not(A)
    print resolve(clause_i, clause_j)

    ### Test a case that produces a non-empty clause
    clause_i = A
    clause_j = sp.Or(C, sp.Not(A), sp.Not(B))
    print resolve(clause_i, clause_j)

    ### Test a case that produces a non-empty clause with a repeated literal
    clause_i = A
    clause_j = sp.Or(C, sp.Not(A), sp.Not(B), A)
    print resolve(clause_i, clause_j)
Example #14
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # A for-loop guard only has two incoming edges (init and increment)
        guard_inedges = graph.in_edges(guard)
        if len(guard_inedges) != 2:
            return False
        # A for-loop guard only has two outgoing edges (loop and exit-loop)
        guard_outedges = graph.out_edges(guard)
        if len(guard_outedges) != 2:
            return False

        # Both incoming edges to guard must set exactly one variable and
        # the same one
        if (len(guard_inedges[0].data.assignments) != 1
                or len(guard_inedges[1].data.assignments) != 1):
            return False
        itervar = list(guard_inedges[0].data.assignments.keys())[0]
        if itervar not in guard_inedges[1].data.assignments:
            return False

        # Outgoing edges must not have assignments and be a negation of each
        # other
        if any(len(e.data.assignments) > 0 for e in guard_outedges):
            return False
        if guard_outedges[0].data.condition_sympy() != (sp.Not(
                guard_outedges[1].data.condition_sympy())):
            return False

        # All nodes inside loop must be dominated by loop guard
        dominators = nx.dominance.immediate_dominators(sdfg.nx,
                                                       sdfg.start_state)
        loop_nodes = nxutil.dfs_topological_sort(
            sdfg, sources=[begin], condition=lambda _, child: child != guard)
        backedge_found = False
        for node in loop_nodes:
            if any(e.dst == guard for e in graph.out_edges(node)):
                backedge_found = True

            # Traverse the dominator tree upwards, if we reached the guard,
            # the node is in the loop. If we reach the starting state
            # without passing through the guard, fail.
            dom = node
            while dom != dominators[dom]:
                if dom == guard:
                    break
                dom = dominators[dom]
            else:
                return False

        if not backedge_found:
            return False

        return True
Example #15
0
    def _parse(self, expr, trivialize=True):
        def add_sym(expr, trivialize=trivialize):
            return self._add_symbol_if_nontrivial(ExprSymbol(expr), trivialize)

        if expr.__class__ is not tuple:
            if expr.__class__ is kconfiglib.Symbol:
                if expr.is_constant:
                    return sympy.true if tri_to_bool(expr) else sympy.false
                elif expr.type in [kconfiglib.BOOL, kconfiglib.TRISTATE]:
                    return add_sym(expr)
                else:
                    # Ignore unknown symbol types
                    return self.expr_ignore()
            elif expr.__class__ is kconfiglib.Choice:
                return self.expr_ignore()
            else:
                raise ValueError("Unexpected expression type '{}'".format(
                    expr.__class__.__name__))
        else:
            # If the expression is an operator, resolve the operator.
            if expr[0] is kconfiglib.AND:
                return sympy.And(self._parse(expr[1]), self._parse(expr[2]))
            elif expr[0] is kconfiglib.OR:
                return sympy.Or(self._parse(expr[1]), self._parse(expr[2]))
            elif expr[0] is kconfiglib.NOT:
                return sympy.Not(self._parse(expr[1], trivialize=False))
            elif expr[0] is kconfiglib.EQUAL and expr[2].is_constant:
                if tri_to_bool(expr[2]):
                    return add_sym(expr[1], trivialize=False)
                else:
                    return sympy.Not(ExprSymbol(expr[1]))
            elif expr[0] in [
                    kconfiglib.UNEQUAL, kconfiglib.LESS, kconfiglib.LESS_EQUAL,
                    kconfiglib.GREATER, kconfiglib.GREATER_EQUAL
            ]:
                if expr[1].__class__ is tuple or expr[2].__class__ is tuple:
                    raise ValueError("Cannot compare expressions")
                return self._add_symbol_if_nontrivial(
                    ExprCompare(expr[0], expr[1], expr[2]), trivialize)
            else:
                raise ValueError("Unknown expression type: '{}'".format(
                    expr[0]))
Example #16
0
def resolve(i, j):
    clause_i_literals = []
    clause_j_literals = []

    ### Determine if either clause i or j is simply a single symbol
    ### Then get the literals out of the clauses
    if type(i) == sp.Symbol:
        clause_i_literals.append(i)
    else:
        if len(i.args) == 1:
            clause_i_literals.append(sp.Not(i.args[0]))
        else:
            clause_i_literals = list(i.args)
    if type(j) == sp.Symbol:
        clause_j_literals.append(j)
    else:
        if len(j.args) == 1:
            clause_j_literals.append(sp.Not(j.args[0]))
        else:
            clause_j_literals = list(j.args)

    ### First makes a list of all unique literals
    ### Then prunes complementary literals
    new_literals = list(set(clause_i_literals + clause_j_literals))
    for x in new_literals:
        for y in new_literals:
            if x == ~y:
                new_literals.remove(x)
                new_literals.remove(y)
            else:
                pass

    ### Construct new clause from literals or the empty clause
    if len(new_literals) > 0:
        new_clause = new_literals[0]
        for l in new_literals[1:]:
            new_clause = new_clause | l

        return new_clause
    else:
        return None
Example #17
0
def dpll_satisfiable(kb, q):
    sentence = kb & ~q
    clauses = list(sentence.args)
    symbols = []
    for c in clauses:
        args = list(c.args)
        new_symbols = []
        for a in args:
            if type(a) == sp.Symbol:
                new_symbols.append(a)
            else:
                new_symbols.append(sp.Not(a))
        symbols = symbols + new_symbols
    symbols = list(set(symbols))
    model = {}
    return dpll(clauses, symbols, model)
Example #18
0
    def visit_node(node, mask):
        if isinstance(node, ast.Conditional):
            cond = node.condition_expr
            skip = (loop_body.loop_counter_symbol not in cond.atoms(sp.Symbol)) or cond.func in (vec_all, vec_any)
            cond = True if skip else cond

            true_mask = sp.And(cond, mask)
            visit_node(node.true_block, true_mask)
            if node.false_block:
                false_mask = sp.And(sp.Not(node.condition_expr), mask)
                visit_node(node, false_mask)
            if not skip:
                node.condition_expr = vec_any(node.condition_expr)
        elif isinstance(node, ast.SympyAssignment):
            if mask is not True:
                s = {ma: vector_memory_access(*ma.args[0:4], sp.And(mask, ma.args[4]), *ma.args[5:])
                     for ma in node.atoms(vector_memory_access)}
                node.subs(s)
        else:
            for arg in node.args:
                visit_node(arg, mask)
Example #19
0
def to_sympy(
        formula: Formula,
        replace: Optional[Dict[AtomSymbol, sympy.Symbol]] = None) -> Boolean:
    """
    Translate a PLFormula object into a SymPy expression.

    :param formula: the formula to translate.
    :param replace: an optional mapping from symbols to replace to other replacement symbols.
    :return: the SymPy formula object equivalent to the formula.
    """
    if replace is None:
        replace = {}

    if isinstance(formula, PLTrue):
        return BooleanTrue()
    elif isinstance(formula, PLFalse):
        return BooleanFalse()
    elif isinstance(formula, PLAtomic):
        symbol = replace.get(formula.s, formula.s)
        return sympy.Symbol(symbol)
    elif isinstance(formula, PLNot):
        return sympy.Not(to_sympy(formula.f, replace=replace))
    elif isinstance(formula, PLOr):
        return sympy.simplify(
            sympy.Or(*[to_sympy(f, replace=replace)
                       for f in formula.formulas]))
    elif isinstance(formula, PLAnd):
        return sympy.simplify(
            sympy.And(
                *[to_sympy(f, replace=replace) for f in formula.formulas]))
    elif isinstance(formula, PLImplies):
        return sympy.simplify(
            sympy.Implies(
                *[to_sympy(f, replace=replace) for f in formula.formulas]))
    elif isinstance(formula, PLEquivalence):
        return sympy.simplify(
            sympy.Equivalent(
                *[to_sympy(f, replace=replace) for f in formula.formulas]))
    else:
        raise ValueError("Formula is not valid.")
Example #20
0
 def test_negation(self):
     expr = Expression(sympy.Not(self.a))
     self.assertEqual(expr.rhs_cstr, '!a')
Example #21
0
 def test_triple_negation(self):
     expr = Expression('!!!a')
     self.assertEqual(expr.rhs, sympy.Not(self.a))
Example #22
0
def _structured_control_flow_traversal(
        sdfg: SDFG,
        start: SDFGState,
        ptree: Dict[SDFGState, SDFGState],
        branch_merges: Dict[SDFGState, SDFGState],
        back_edges: List[Edge[InterstateEdge]],
        dispatch_state: Callable[[SDFGState], str],
        parent_block: GeneralBlock,
        stop: SDFGState = None,
        generate_children_of: SDFGState = None) -> Set[SDFGState]:
    """ 
    Helper function for ``structured_control_flow_tree``. 
    :param sdfg: SDFG.
    :param start: Starting state for traversal.
    :param ptree: State parent tree (computed from ``state_parent_tree``).
    :param branch_merges: Dictionary mapping from branch state to its merge
                          state.
    :param dispatch_state: A function that dispatches code generation for a 
                           single state.
    :param parent_block: The block to append children to.
    :param stop: Stopping state to not traverse through (merge state of a 
                 branch or guard state of a loop).
    :return: Generator that yields states in state-order from ``start`` to 
             ``stop``.
    """
    # Traverse states in custom order
    visited = set()
    if stop is not None:
        visited.add(stop)
    stack = [start]
    while stack:
        node = stack.pop()
        if (generate_children_of is not None
                and not _child_of(node, generate_children_of, ptree)):
            continue
        if node in visited:
            continue
        visited.add(node)
        stateblock = SingleState(dispatch_state, node)

        oe = sdfg.out_edges(node)
        if len(oe) == 0:  # End state
            # If there are no remaining nodes, this is the last state and it can
            # be marked as such
            if len(stack) == 0:
                stateblock.last_state = True
            parent_block.elements.append(stateblock)
            continue
        elif len(oe) == 1:  # No traversal change
            stack.append(oe[0].dst)
            parent_block.elements.append(stateblock)
            continue

        # Potential branch or loop
        if node in branch_merges:
            mergestate = branch_merges[node]

            # Add branching node and ignore outgoing edges
            parent_block.elements.append(stateblock)
            parent_block.edges_to_ignore.extend(oe)
            stateblock.last_state = True

            # Parse all outgoing edges recursively first
            cblocks: Dict[Edge[InterstateEdge], GeneralBlock] = {}
            for branch in oe:
                cblocks[branch] = GeneralBlock(dispatch_state, [], [])
                visited |= _structured_control_flow_traversal(
                    sdfg,
                    branch.dst,
                    ptree,
                    branch_merges,
                    back_edges,
                    dispatch_state,
                    cblocks[branch],
                    stop=mergestate,
                    generate_children_of=node)

            # Classify branch type:
            branch_block = None
            # If there are 2 out edges, one negation of the other:
            #   * if/else in case both branches are not merge state
            #   * if without else in case one branch is merge state
            if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(
                    oe[1].data.condition_sympy())):
                # If without else
                if oe[0].dst is mergestate:
                    branch_block = IfScope(dispatch_state, sdfg, node,
                                           oe[1].data.condition,
                                           cblocks[oe[1]])
                elif oe[1].dst is mergestate:
                    branch_block = IfScope(dispatch_state, sdfg, node,
                                           oe[0].data.condition,
                                           cblocks[oe[0]])
                else:
                    branch_block = IfScope(dispatch_state, sdfg, node,
                                           oe[0].data.condition,
                                           cblocks[oe[0]], cblocks[oe[1]])
            else:
                # If there are 2 or more edges (one is not the negation of the
                # other):
                switch = _cases_from_branches(oe, cblocks)
                if switch:
                    # If all edges are of form "x == y" for a single x and
                    # integer y, it is a switch/case
                    branch_block = SwitchCaseScope(dispatch_state, sdfg, node,
                                                   switch[0], switch[1])
                else:
                    # Otherwise, create if/else if/.../else goto exit chain
                    branch_block = IfElseChain(dispatch_state, sdfg, node,
                                               [(e.data.condition, cblocks[e])
                                                for e in oe])
            # End of branch classification
            parent_block.elements.append(branch_block)
            if mergestate != stop:
                stack.append(mergestate)

        elif len(oe) == 2:  # Potential loop
            # TODO(later): Recognize do/while loops
            # If loop, traverse body, then exit
            body_start = None
            loop_exit = None
            scope = None
            if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node:
                scope = _loop_from_structure(sdfg, node, oe[0], oe[1],
                                             back_edges, dispatch_state)
                body_start = oe[0].dst
                loop_exit = oe[1].dst
            elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node:
                scope = _loop_from_structure(sdfg, node, oe[1], oe[0],
                                             back_edges, dispatch_state)
                body_start = oe[1].dst
                loop_exit = oe[0].dst

            if scope:
                visited |= _structured_control_flow_traversal(
                    sdfg,
                    body_start,
                    ptree,
                    branch_merges,
                    back_edges,
                    dispatch_state,
                    scope.body,
                    stop=node,
                    generate_children_of=node)

                # Add branching node and ignore outgoing edges
                parent_block.elements.append(stateblock)
                parent_block.edges_to_ignore.extend(oe)

                parent_block.elements.append(scope)

                # If for loop, ignore certain edges
                if isinstance(scope, ForScope):
                    # Mark init edge(s) to ignore in parent_block and all children
                    _ignore_recursive([
                        e for e in sdfg.in_edges(node) if e not in back_edges
                    ], parent_block)
                    # Mark back edge for ignoring in all children of loop body
                    _ignore_recursive(
                        [e for e in sdfg.in_edges(node) if e in back_edges],
                        scope.body)

                stack.append(loop_exit)
                continue

            # No proper loop detected: Unstructured control flow
            parent_block.elements.append(stateblock)
            stack.extend([e.dst for e in oe])
        else:  # No merge state: Unstructured control flow
            parent_block.elements.append(stateblock)
            stack.extend([e.dst for e in oe])

    return visited - {stop}
Example #23
0
def _loop_from_structure(
        sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge],
        leave_edge: Edge[InterstateEdge],
        back_edges: List[Edge[InterstateEdge]],
        dispatch_state: Callable[[SDFGState],
                                 str]) -> Union[ForScope, WhileScope]:
    """ 
    Helper method that constructs the correct structured loop construct from a
    set of states. Can construct for or while loops.
    """

    body = GeneralBlock(dispatch_state, [], [])

    guard_inedges = sdfg.in_edges(guard)
    increment_edges = [e for e in guard_inedges if e in back_edges]
    init_edges = [e for e in guard_inedges if e not in back_edges]

    # If no back edge found (or more than one, indicating a "continue"
    # statement), disregard
    if len(increment_edges) > 1 or len(increment_edges) == 0:
        return None
    increment_edge = increment_edges[0]

    # Mark increment edge to be ignored in body
    body.edges_to_ignore.append(increment_edge)

    # Outgoing edges must be a negation of each other
    if enter_edge.data.condition_sympy() != (sp.Not(
            leave_edge.data.condition_sympy())):
        return None

    # Body of guard state must be empty
    if not guard.is_empty():
        return None

    if not increment_edge.data.is_unconditional():
        return None
    if len(enter_edge.data.assignments) > 0:
        return None

    condition = enter_edge.data.condition

    # Detect whether this loop is a for loop:
    # All incoming edges to the guard must set the same variable
    itvars = None
    for iedge in guard_inedges:
        if itvars is None:
            itvars = set(iedge.data.assignments.keys())
        else:
            itvars &= iedge.data.assignments.keys()
    if itvars and len(itvars) == 1:
        itvar = next(iter(itvars))
        init = init_edges[0].data.assignments[itvar]

        # Check that all init edges are the same and that increment edge only
        # increments
        if (all(e.data.assignments[itvar] == init for e in init_edges)
                and len(increment_edge.data.assignments) == 1):
            update = increment_edge.data.assignments[itvar]
            return ForScope(dispatch_state, itvar, guard, init, condition,
                            update, body)

    # Otherwise, it is a while loop
    return WhileScope(dispatch_state, guard, condition, body)
Example #24
0
def create_staggered_kernel(staggered_field,
                            expressions,
                            subexpressions=(),
                            target='cpu',
                            gpu_exclusive_conditions=False,
                            **kwargs):
    """Kernel that updates a staggered field.

    .. image:: /img/staggered_grid.svg

    Args:
        staggered_field: field where the first index coordinate defines the location of the staggered value
                can have 1 or 2 index coordinates, in case of two index coordinates at every staggered location
                a vector is stored, expressions parameter has to be a sequence of sequences then
                where e.g. ``f[0,0](0)`` is interpreted as value at the left cell boundary, ``f[1,0](0)`` the right cell
                boundary and ``f[0,0](1)`` the southern cell boundary etc.
        expressions: sequence of expressions of length dim, defining how the west, southern, (bottom) cell boundary
                     should be updated.
        subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
        target: 'cpu' or 'gpu'
        gpu_exclusive_conditions: if/else construct to have only one code block for each of 2**dim code paths
        kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed

    Returns:
        AST, see `create_kernel`
    """
    assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
    assert staggered_field.index_dimensions in (
        1, 2), 'Staggered field must have one or two index dimensions'
    dim = staggered_field.spatial_dimensions

    counters = [
        LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)
    ]
    conditions = [
        counters[i] < staggered_field.shape[i] - 1 for i in range(dim)
    ]
    assert len(expressions) == dim
    if staggered_field.index_dimensions == 2:
        assert all(len(sublist) == len(expressions[0]) for sublist in expressions), \
            "If staggered field has two index dimensions expressions has to be a sequence of sequences of all the " \
            "same length."

    final_assignments = []
    last_conditional = None

    def add(condition, dimensions, as_else_block=False):
        nonlocal last_conditional
        if staggered_field.index_dimensions == 1:
            assignments = [
                Assignment(staggered_field(d), expressions[d])
                for d in dimensions
            ]
            a_coll = AssignmentCollection(assignments, list(subexpressions))
            a_coll = a_coll.new_filtered(
                [staggered_field(d) for d in dimensions])
        elif staggered_field.index_dimensions == 2:
            assert staggered_field.has_fixed_index_shape
            assignments = [
                Assignment(staggered_field(d, i), expr) for d in dimensions
                for i, expr in enumerate(expressions[d])
            ]
            a_coll = AssignmentCollection(assignments, list(subexpressions))
            a_coll = a_coll.new_filtered([
                staggered_field(d, i)
                for i in range(staggered_field.index_shape[1])
                for d in dimensions
            ])
        sp_assignments = [
            SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments
        ]
        if as_else_block and last_conditional:
            new_cond = Conditional(condition, Block(sp_assignments))
            last_conditional.false_block = Block([new_cond])
            last_conditional = new_cond
        else:
            last_conditional = Conditional(condition, Block(sp_assignments))
            final_assignments.append(last_conditional)

    if target == 'cpu' or not gpu_exclusive_conditions:
        for d in range(dim):
            cond = sp.And(*[conditions[i] for i in range(dim) if d != i])
            add(cond, [d])
    elif target == 'gpu':
        full_conditions = [
            sp.And(*[conditions[i] for i in range(dim) if d != i])
            for d in range(dim)
        ]
        for include in itertools.product(*[[1, 0]] * dim):
            case_conditions = sp.And(*[
                c if value else sp.Not(c)
                for c, value in zip(full_conditions, include)
            ])
            dimensions_to_include = [i for i in range(dim) if include[i]]
            if dimensions_to_include:
                add(case_conditions, dimensions_to_include, True)

    ghost_layers = [(1, 0)] * dim

    blocking = kwargs.get('cpu_blocking', None)
    if blocking:
        del kwargs['cpu_blocking']

    cpu_vectorize_info = kwargs.get('cpu_vectorize_info', None)
    if cpu_vectorize_info:
        del kwargs['cpu_vectorize_info']
    openmp = kwargs.get('cpu_openmp', None)
    if openmp:
        del kwargs['cpu_openmp']

    ast = create_kernel(final_assignments,
                        ghost_layers=ghost_layers,
                        target=target,
                        **kwargs)

    if target == 'cpu':
        remove_conditionals_in_staggered_kernel(ast)
        move_constants_before_loop(ast)
        omp_collapse = None
        if blocking:
            omp_collapse = loop_blocking(ast, blocking)
        if openmp:
            from pystencils.cpu import add_openmp
            add_openmp(ast,
                       num_threads=openmp,
                       collapse=omp_collapse,
                       assume_single_outer_loop=False)
        if cpu_vectorize_info is True:
            vectorize(ast)
        elif isinstance(cpu_vectorize_info, dict):
            vectorize(ast, **cpu_vectorize_info)
    return ast
Example #25
0
 def gen_not(self, ident, val):
     return sym.Not(val.exp)
Example #26
0
def simplify_and(
        x: sympy.Basic,
        gen: typing.Optional[sympy.Symbol] = None,
        extra_conditions: typing.Optional[sympy.Basic] = True) -> sympy.Basic:
    """
  Some rules, because SymPy currently does not automatically simplify them...
  """
    assert isinstance(x, sympy.Basic), "type x: %r" % type(x)
    from sympy.solvers.inequalities import reduce_rational_inequalities
    from sympy.core.relational import Relational

    syms = []
    if gen is not None:
        syms.append(gen)

    w1 = sympy.Wild("w1")
    w2 = sympy.Wild("w2")
    for sub_expr in x.find(sympy.Eq(w1, w2)):
        m = sub_expr.match(sympy.Eq(w1, w2))
        ws_ = m[w1], m[w2]
        for w_ in ws_:
            if isinstance(w_, sympy.Symbol) and w_ not in syms:
                syms.append(w_)
    for w_ in x.free_symbols:
        if w_ not in syms:
            syms.append(w_)

    if len(syms) >= 1:
        _c = syms[0]
        if len(syms) >= 2:
            n = syms[1]
        else:
            n = sympy.Wild("n")
    else:
        return x

    x = x.replace(((_c - 2 * n >= -1) & (_c - 2 * n <= -1)),
                  sympy.Eq(_c, 2 * n - 1))  # probably not needed anymore...
    apply_rules = True
    while apply_rules:
        apply_rules = False
        for and_expr in x.find(sympy.And):
            assert isinstance(and_expr, sympy.And)

            and_expr_ = reduce_rational_inequalities([and_expr.args], _c)
            # print(and_expr, "->", and_expr_)
            if and_expr_ != and_expr:
                x = x.replace(and_expr, and_expr_)
                and_expr = and_expr_
                if and_expr == sympy.sympify(False):
                    continue
                if isinstance(and_expr, sympy.Rel):
                    continue
                assert isinstance(and_expr, sympy.And)

            and_expr_args = list(and_expr.args)
            # for i, part in enumerate(and_expr_args):
            #  and_expr_args[i] = part.simplify()
            if all([
                    isinstance(part, Relational) and _c in part.free_symbols
                    for part in and_expr_args
            ]):
                # No equality, as that should have been resolved above.
                rel_ops = ["==", ">=", "<="]
                if not (_c.is_Integer or _c.assumptions0["integer"]):
                    rel_ops.extend(["<", ">"])
                rhs_by_c = {op: [] for op in rel_ops}
                for part in and_expr_args:
                    assert isinstance(part, Relational)
                    part = _solve_inequality(part, _c)
                    assert isinstance(part, Relational)
                    assert part.lhs == _c
                    rel_op, rhs = part.rel_op, part.rhs
                    if _c.is_Integer or _c.assumptions0["integer"]:
                        if rel_op == "<":
                            rhs = rhs - 1
                            rel_op = "<="
                        elif rel_op == ">":
                            rhs = rhs + 1
                            rel_op = ">="
                    assert rel_op in rhs_by_c, "x: %r, _c: %r, and expr: %r, part %r" % (
                        x, _c, and_expr, part)
                    other_rhs = rhs_by_c[rel_op]
                    assert isinstance(other_rhs, list)
                    need_to_add = True
                    for rhs_ in other_rhs:
                        cmp = Relational.ValidRelationOperator[rel_op](rhs,
                                                                       rhs_)
                        if simplify_and(
                                sympy.And(sympy.Not(cmp),
                                          extra_conditions)) == sympy.sympify(
                                              False):  # checks True...
                            other_rhs.remove(rhs_)
                            break
                        elif simplify_and(sympy.And(
                                cmp,
                                extra_conditions)) == sympy.sympify(False):
                            need_to_add = False
                            break
                        # else:
                        #  raise NotImplementedError("cannot compare %r in %r; extra cond %r" % (cmp, and_expr, extra_conditions))
                    if need_to_add:
                        other_rhs.append(rhs)
                if rhs_by_c[">="] and rhs_by_c["<="]:
                    all_false = False
                    for lhs in rhs_by_c[">="]:
                        for rhs in rhs_by_c["<="]:
                            if sympy.Lt(lhs, rhs) == sympy.sympify(False):
                                all_false = True
                            if sympy.Eq(lhs, rhs) == sympy.sympify(True):
                                rhs_by_c["=="].append(lhs)
                    if all_false:
                        x = x.replace(and_expr, False)
                        continue
                if rhs_by_c["=="]:
                    all_false = False
                    while len(rhs_by_c["=="]) >= 2:
                        lhs, rhs = rhs_by_c["=="][:2]
                        if sympy.Eq(lhs, rhs) == sympy.sympify(False):
                            all_false = True
                            break
                        elif sympy.Eq(lhs, rhs) == sympy.sympify(True):
                            rhs_by_c["=="].pop(1)
                        else:
                            raise NotImplementedError(
                                "cannot cmp %r == %r. rhs_by_c %r" %
                                (lhs, rhs, rhs_by_c))
                    if all_false:
                        x = x.replace(and_expr, False)
                        continue
                    new_parts = [sympy.Eq(_c, rhs_by_c["=="][0])]
                    for op in rel_ops:
                        for part in rhs_by_c[op]:
                            new_parts.append(
                                Relational.ValidRelationOperator[op](
                                    rhs_by_c["=="][0], part).simplify())
                else:  # no "=="
                    new_parts = []
                    for op in rel_ops:
                        for part in rhs_by_c[op]:
                            new_parts.append(
                                Relational.ValidRelationOperator[op](_c, part))
                    assert new_parts
                and_expr_ = sympy.And(*new_parts)
                # print(and_expr, "--->", and_expr_)
                x = x.replace(and_expr, and_expr_)
                and_expr = and_expr_

            # Probably all the remaining hard-coded rules are not needed anymore with the more generic code above...
            if sympy.Eq(_c, 2 * n) in and_expr.args:
                if (_c - 2 * n <= -1) in and_expr.args:
                    x = x.replace(and_expr, False)
                    continue
                if sympy.Eq(_c - 2 * n, -1) in and_expr.args:
                    x = x.replace(and_expr, False)
                    continue
                if (_c - n <= -1) in and_expr.args:
                    x = x.replace(and_expr, False)
                    continue
            if (_c >= n) in and_expr.args and (_c - n <= -1) in and_expr.args:
                x = x.replace(and_expr, False)
                continue
            if sympy.Eq(_c - 2 * n, -1) in and_expr.args:  # assume n>=1
                if (_c >= n) in and_expr.args:
                    x = x.replace(
                        and_expr,
                        sympy.And(
                            *
                            [arg for arg in and_expr.args
                             if arg != (_c >= n)]))
                    apply_rules = True
                    break
                if (_c - n >= -1) in and_expr.args:
                    x = x.replace(
                        and_expr,
                        sympy.And(*[
                            arg for arg in and_expr.args
                            if arg != (_c - n >= -1)
                        ]))
                    apply_rules = True
                    break
            if (_c >= n) in and_expr.args:
                if (_c - n >= -1) in and_expr.args:
                    x = x.replace(
                        and_expr,
                        sympy.And(*[
                            arg for arg in and_expr.args
                            if arg != (_c - n >= -1)
                        ]))
                    apply_rules = True
                    break
            if (_c - n >= -1) in and_expr.args and (_c - n <=
                                                    -1) in and_expr.args:
                args = list(and_expr.args)
                args.remove((_c - n >= -1))
                args.remove((_c - n <= -1))
                args.append(sympy.Eq(_c - n, -1))
                if (_c - 2 * n <= -1) in args:
                    args.remove((_c - 2 * n <= -1))
                x = x.replace(and_expr, sympy.And(*args))
                apply_rules = True
                break
    return x
Example #27
0
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # A for-loop guard only has two incoming edges (init and increment)
        guard_inedges = graph.in_edges(guard)
        if len(guard_inedges) < 2:
            return False
        # A for-loop guard only has two outgoing edges (loop and exit-loop)
        guard_outedges = graph.out_edges(guard)
        if len(guard_outedges) != 2:
            return False

        # All incoming edges to the guard must set the same variable
        itvar = None
        for iedge in guard_inedges:
            if itvar is None:
                itvar = set(iedge.data.assignments.keys())
            else:
                itvar &= iedge.data.assignments.keys()
        if itvar is None:
            return False

        # Outgoing edges must be a negation of each other
        if guard_outedges[0].data.condition_sympy() != (sp.Not(
                guard_outedges[1].data.condition_sympy())):
            return False

        # All nodes inside loop must be dominated by loop guard
        dominators = nx.dominance.immediate_dominators(sdfg.nx,
                                                       sdfg.start_state)
        loop_nodes = sdutil.dfs_conditional(
            sdfg, sources=[begin], condition=lambda _, child: child != guard)
        backedge = None
        for node in loop_nodes:
            for e in graph.out_edges(node):
                if e.dst == guard:
                    backedge = e
                    break

            # Traverse the dominator tree upwards, if we reached the guard,
            # the node is in the loop. If we reach the starting state
            # without passing through the guard, fail.
            dom = node
            while dom != dominators[dom]:
                if dom == guard:
                    break
                dom = dominators[dom]
            else:
                return False

        if backedge is None:
            return False

        # The backedge must assignment the iteration variable
        itvar &= backedge.data.assignments.keys()
        if len(itvar) != 1:
            # Either no consistent iteration variable found, or too many
            # consistent iteration variables found
            return False

        return True
Example #28
0
    def _assign_statements(self):
        s = []
        self.nodes = []
        for statement in self.root.all('statement'):
            for node in statement.children:
                if node.rule == 'assignment':
                    name = str(node.variable).upper()
                    expr = ExpressionInterpreter().visit(node.expression)
                    ass = Assignment(name, expr)
                    s.append(ass)
                    self.nodes.append(statement)
                elif node.rule == 'logical_if':
                    logic_expr = ExpressionInterpreter().visit(
                        node.logical_expression)
                    try:
                        assignment = node.assignment
                    except NoSuchRuleException:
                        pass
                    else:
                        name = str(assignment.variable).upper()
                        expr = ExpressionInterpreter().visit(
                            assignment.expression)
                        # Check if symbol was previously declared
                        else_val = sympy.Integer(0)
                        for prevass in s:
                            if prevass.symbol.name == name:
                                else_val = sympy.Symbol(name)
                                break
                        pw = sympy.Piecewise((expr, logic_expr),
                                             (else_val, True))
                        ass = Assignment(name, pw)
                        s.append(ass)
                    self.nodes.append(statement)
                elif node.rule == 'block_if':
                    interpreter = ExpressionInterpreter()
                    blocks = []  # [(logic, [(symb1, expr1), ...]), ...]
                    symbols = OrderedSet()

                    first_logic = interpreter.visit(
                        node.block_if_start.logical_expression)
                    first_block = node.block_if_start
                    first_symb_exprs = []
                    for ifstat in first_block.all('statement'):
                        for assign_node in ifstat.all('assignment'):
                            name = str(assign_node.variable).upper()
                            first_symb_exprs.append(
                                (name,
                                 interpreter.visit(assign_node.expression)))
                            symbols.add(name)
                    blocks.append((first_logic, first_symb_exprs))

                    else_if_blocks = node.all('block_if_elseif')
                    for elseif in else_if_blocks:
                        logic = interpreter.visit(elseif.logical_expression)
                        elseif_symb_exprs = []
                        for elseifstat in elseif.all('statement'):
                            for assign_node in elseifstat.all('assignment'):
                                name = str(assign_node.variable).upper()
                                elseif_symb_exprs.append(
                                    (name,
                                     interpreter.visit(
                                         assign_node.expression)))
                                symbols.add(name)
                        blocks.append((logic, elseif_symb_exprs))

                    else_block = node.find('block_if_else')
                    if else_block:
                        else_symb_exprs = []
                        for elsestat in else_block.all('statement'):
                            for assign_node in elsestat.all('assignment'):
                                name = str(assign_node.variable).upper()
                                else_symb_exprs.append(
                                    (name,
                                     interpreter.visit(
                                         assign_node.expression)))
                                symbols.add(name)
                        piecewise_logic = True
                        if len(blocks[0][1]) == 0 and not else_if_blocks:
                            # Special case for empty if
                            piecewise_logic = sympy.Not(blocks[0][0])
                        blocks.append((piecewise_logic, else_symb_exprs))

                    for symbol in symbols:
                        pairs = []
                        for block in blocks:
                            logic = block[0]
                            for cursymb, expr in block[1]:
                                if cursymb == symbol:
                                    pairs.append((expr, logic))
                        pw = sympy.Piecewise(*pairs)
                        ass = Assignment(symbol, pw)
                        s.append(ass)
                    self.nodes.append(statement)

        statements = ModelStatements(s)
        return statements
Example #29
0
def iterative_backward_chaining(kb, q):
    clauses = list(kb.args)

    ### Construct list of symbols known to be true
    known_true_symbols = []
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            known_true_symbols.append(c)

    negation = False;
    ### check if query is a negation
    if type(q) == sp.Not and q.args[0] not in known_true_symbols:
        negation = True
        q = q.args[0]

    ### Construct tables of premises and conclusions keyed by clauses
    premises = {}
    conclusions = {}
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            premises[c] = None
            conclusions[c] = None
        else:
            symbols_in_clause = c.args
            premise_list = []
            for s in symbols_in_clause:
                if type(s) == sp.Not:
                    premise_list.append(sp.Not(s))
                else:
                    conclusion = s
            premises[c] = tuple(premise_list)
            conclusions[c] = conclusion

    ### Check if query is a known true symbol
    if q in known_true_symbols:
        return True

    ### Check if query can possibly be entailed
    if q not in conclusions.values():
        if negation == True:
            return True
        return False


    ### Determine the clauses that can entail the query
    candidates = []
    for c in clauses:
        if conclusions[c] == q:
            candidates.append(c)

    ### Loop over the candidates
    tried_to_prove = [q]
    old_known_true_symbols = list(known_true_symbols)
    for c in candidates:
        things_to_prove = list(premises[c])
        under_consideration = []

        while things_to_prove:
            ttps1 = set(things_to_prove)
            ttps2 = set(tried_to_prove)
            if ttps1 < ttps2:
                if negation == True:
                    return True
                return False

            if old_known_true_symbols != known_true_symbols:
                for r in tried_to_prove:
                    if r in conclusions.values():
                        possible_premises = []
                        for c in clauses:
                            if conclusions[c] == r:
                                possible_premises.append(premises[c])
                        for pp in possible_premises:
                            s1 = set(known_true_symbols)
                            s2 = set(pp)
                            if s2 < s1:
                                known_true_symbols.append(r)
                                tried_to_prove.remove(r)

            t = things_to_prove.pop()

            ### If known to be true, do nothing
            if t in known_true_symbols:
                pass
            ### Can it proved by anything?
            elif t not in conclusions.values():
                tried_to_prove.append(t) 
                tried_to_prove = remove_duplicates_maintain_order(tried_to_prove)
                things_to_prove.insert(0,t)
                continue

            ### See if t can be proved using nothing but known true symbols. 
            ### If so, add it as a known true symbol. 
            ### If not, add it's premises to the stack.
            else:
                ### Did we already try to prove this and fail? 
                ### Has the known symbol true symbol list changed?
                if t in tried_to_prove and old_known_true_symbols == known_true_symbols:
                    break
                else:
                    for c in clauses:
                        if conclusions[c] == t:
                            things_that_prove_t = list(premises[c])

                            ### Check if t can be proved with nothing but known true symbols
                            if contains_sublist(known_true_symbols, things_that_prove_t):
                                old_known_true_symbols = list(known_true_symbols)
                                known_true_symbols.append(t)
                                if t in tried_to_prove:
                                    tried_to_prove.remove(t)
                                break
                            else:
                                things_to_prove = things_to_prove + things_that_prove_t
                                tried_to_prove.append(t)
                                tried_to_prove = remove_duplicates_maintain_order(tried_to_prove)
           
            things_to_prove = remove_duplicates_maintain_order(things_to_prove)
            
            if len(things_to_prove) == 0:
                return True

    if negation == True:
        return True
    return False
Example #30
0
def forward_chaining(kb, q):
    ### Extract all unique symbols from the kb
    clauses = list(kb.args)
    symbols = []

    ### Construct agenda queue
    ### Initially the symbols known to be true in the kb
    agenda = []
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            agenda.append(c)

    negation = False;
    ### check if query is a negation
    if type(q) == sp.Not and q.args[0] not in agenda:
        negation = True
        q = q.args[0]

    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            symbols.append(c)
        else:
            symbols_in_clause = c.args
            for s in symbols_in_clause:
                symbols.append(s)

    ### Construct inferred table 
    ### Initially false for all symbols 
    inferred = dict((k,False) for k in symbols)

    ### Construct count table
    ### Where count[c] = number of symbols in c's premise
    counts = {}
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            counts[c] = 1
        else:
            counts[c] = len(list(c.args)) - 1

    ### Auxiliary tables of premises and conclusions
    premises = {}
    conclusions = {}
    for c in clauses:
        if type(c) == sp.Symbol or type(c) == sp.Not:
            premises[c] = None
            conclusions[c] = None
        else:
            symbols_in_clause = c.args
            premise_list = []
            for s in symbols_in_clause:
                if type(s) == sp.Not:
                    premise_list.append(sp.Not(s))
                else:
                    conclusion = s
            premises[c] = tuple(premise_list)
            conclusions[c] = conclusion

    ### Forward chaining algorithm    
    checked_agenda = []
    while agenda:

        p = agenda.pop()
        checked_agenda.append(p)

        if p == q:
            return True
        if inferred[p] == False:
            inferred[p] = True
            for c in clauses:
                if premises[c] and p in premises[c]:
                    counts[c] = counts[c] - 1
                if counts[c] == 0 and conclusions[c] not in agenda and conclusions[c] not in checked_agenda:
                    agenda.insert(0, conclusions[c])

    if negation == True:
        return True
    return False