Пример #1
0
 def test_rollback_on_mapping(self):
     mem = Memory()
     m: 'TransactionalMapping[str, int]' = TransactionalMapping(mem)
     m["apple"] = 10
     m["tomato"] = 20
     t1 = mem.begin_transaction()
     m["apple"] = 30
     m["potato"] = 40
     t2 = mem.begin_transaction()
     m["tomato"] = 50
     m["potato"] = 60
     self.assertEqual(3, len(m))
     self.assertEqual(30, m["apple"])
     self.assertEqual(50, m["tomato"])
     self.assertEqual(60, m["potato"])
     self.assertEqual(["apple", "potato", "tomato"], sorted(m))
     t2.rollback()
     self.assertEqual(3, len(m))
     self.assertEqual(30, m["apple"])
     self.assertEqual(20, m["tomato"])
     self.assertEqual(40, m["potato"])
     self.assertEqual(["apple", "potato", "tomato"], sorted(m))
     t1.rollback()
     self.assertEqual(2, len(m))
     self.assertEqual(10, m["apple"])
     self.assertEqual(20, m["tomato"])
     self.assertEqual(["apple", "tomato"], sorted(m))
Пример #2
0
    def test_rollback_on_set(self):
        mem = Memory()
        s: 'TransactionalSet[str]' = TransactionalSet(mem)
        self.assertEqual(0, len(s))
        s.add("a")
        t1 = mem.begin_transaction()
        s.add("b")
        mem.begin_transaction()
        s.add("c")
        t3 = mem.begin_transaction()
        s.add("a")

        self.assertEqual(3, len(s))
        self.assertIn("a", s)
        self.assertIn("b", s)
        self.assertIn("c", s)
        self.assertNotIn("d", s)
        self.assertEqual(["a", "b", "c"], sorted(s))

        t3.rollback()
        self.assertEqual(3, len(s))
        self.assertIn("a", s)
        self.assertIn("b", s)
        self.assertIn("c", s)
        self.assertNotIn("d", s)
        self.assertEqual(["a", "b", "c"], sorted(s))

        t1.rollback()
        self.assertEqual(1, len(s))
        self.assertIn("a", s)
        self.assertNotIn("b", s)
        self.assertNotIn("c", s)
        self.assertNotIn("d", s)
        self.assertEqual(["a"], sorted(s))
Пример #3
0
    def test_transactions_on_value(self):
        mem = Memory()
        x: 'Transactional[Optional[str]]' = Transactional(mem, None)
        self.assertIsNone(x.value)

        t1 = mem.begin_transaction()
        x.value = "a"
        self.assertEqual("a", x.value)

        t2 = mem.begin_transaction()
        self.assertEqual("a", x.value)
        x.value = "b"
        self.assertEqual("b", x.value)

        t2.commit()
        self.assertEqual("b", x.value)

        t1.rollback()
        self.assertIsNone(x.value)
Пример #4
0
    def test_discard_and_rollback_on_set(self):
        mem = Memory()
        s: 'TransactionalSet[str]' = TransactionalSet(mem)
        self.assertEqual(0, len(s))

        s.add("a")
        s.add("b")
        s.add("c")
        self.assertEqual(["a", "b", "c"], sorted(s))

        t1 = mem.begin_transaction()
        s.add("d")
        s.discard("a")
        mem.begin_transaction()
        s.discard("b")
        self.assertEqual(["c", "d"], sorted(s))

        t1.rollback()
        self.assertEqual(["a", "b", "c"], sorted(s))
Пример #5
0
    def test_delete_and_commit_on_mapping_v2(self):
        mem = Memory()
        m: 'TransactionalMapping[str, int]' = TransactionalMapping(mem)

        m["a"] = 1
        mem.begin_transaction()
        m["b"] = 20
        del m["a"]
        t1 = mem.begin_transaction()
        m["a"] = 10
        self.assertEqual(2, len(m))
        self.assertEqual(10, m["a"])
        self.assertEqual(20, m["b"])
        self.assertEqual(["a", "b"], sorted(m))

        t1.commit()
        self.assertEqual(2, len(m))
        self.assertEqual(10, m["a"])
        self.assertEqual(20, m["b"])
        self.assertEqual(["a", "b"], sorted(m))
Пример #6
0
 def test_commit_on_set(self):
     mem = Memory()
     s: 'TransactionalSet[str]' = TransactionalSet(mem)
     self.assertEqual(0, len(s))
     s.add("a")
     t1 = mem.begin_transaction()
     s.add("b")
     t2 = mem.begin_transaction()
     s.add("c")
     t3 = mem.begin_transaction()
     s.add("a")
     t1.commit()
     t2.commit()
     t3.commit()
     self.assertEqual(3, len(s))
     self.assertIn("a", s)
     self.assertIn("b", s)
     self.assertIn("c", s)
     self.assertNotIn("d", s)
     self.assertEqual(["a", "b", "c"], sorted(s))
Пример #7
0
    def test_delete_on_mapping_with_transactions(self):
        mem = Memory()
        m: 'TransactionalMapping[int, str]' = TransactionalMapping(mem)
        self.assertEqual(0, len(m))

        m[1] = "a"
        mem.begin_transaction()
        m[2] = "b"
        mem.begin_transaction()
        m[3] = "c"
        mem.begin_transaction()
        del m[2]
        self.assertEqual(2, len(m))
        self.assertIn(1, m)
        self.assertNotIn(2, m)
        self.assertIn(3, m)
        self.assertEqual("a", m[1])
        self.assertEqual("c", m[3])
        self.assertEqual([1, 3], sorted(m))

        m[2] = "d"
        self.assertEqual(3, len(m))
        self.assertIn(1, m)
        self.assertIn(2, m)
        self.assertIn(3, m)
        self.assertEqual("a", m[1])
        self.assertEqual("d", m[2])
        self.assertEqual("c", m[3])
        self.assertEqual([1, 2, 3], sorted(m))
Пример #8
0
    def test_delete_and_rollback_on_mapping(self):
        mem = Memory()
        m: 'TransactionalMapping[str, int]' = TransactionalMapping(mem)

        m["a"] = 1
        m["b"] = 20
        mem.begin_transaction()
        m["c"] = 30
        del m["a"]
        t1 = mem.begin_transaction()
        m["a"] = 10
        self.assertEqual(3, len(m))
        self.assertEqual(10, m["a"])
        self.assertEqual(20, m["b"])
        self.assertEqual(30, m["c"])
        self.assertEqual(["a", "b", "c"], sorted(m))

        t1.rollback()
        self.assertEqual(2, len(m))
        self.assertEqual(20, m["b"])
        self.assertEqual(30, m["c"])
        self.assertEqual(["b", "c"], sorted(m))
Пример #9
0
    def test_discard_and_commit_on_set(self):
        mem = Memory()
        s: 'TransactionalSet[str]' = TransactionalSet(mem)
        self.assertEqual(0, len(s))

        s.add("a")
        s.add("b")
        t1 = mem.begin_transaction()
        s.add("c")
        s.add("d")
        s.discard("c")
        s.discard("b")
        t2 = mem.begin_transaction()
        s.add("e")
        s.discard("a")
        self.assertEqual(["d", "e"], sorted(s))

        t2.commit()
        self.assertEqual(["d", "e"], sorted(s))

        t1.commit()
        self.assertEqual(["d", "e"], sorted(s))
Пример #10
0
    def test_operations_on_set_with_transactions(self):
        mem = Memory()
        s: 'TransactionalSet[str]' = TransactionalSet(mem)
        self.assertEqual(0, len(s))

        s.add("a")
        mem.begin_transaction()
        s.add("b")
        mem.begin_transaction()
        s.add("c")
        mem.begin_transaction()
        s.add("a")
        mem.begin_transaction()
        self.assertEqual(3, len(s))
        self.assertIn("a", s)
        self.assertIn("b", s)
        self.assertIn("c", s)
        self.assertNotIn("d", s)
        self.assertEqual(["a", "b", "c"], sorted(s))
Пример #11
0
    def test_delete_and_commit_on_mapping_v1(self):
        mem = Memory()
        m: 'TransactionalMapping[str, int]' = TransactionalMapping(mem)

        m["a"] = 10
        m["b"] = 2
        m["c"] = 30
        t1 = mem.begin_transaction()
        m["b"] *= 10
        m["d"] = 40
        del m["a"]
        self.assertEqual(3, len(m))
        self.assertEqual(20, m["b"])
        self.assertEqual(30, m["c"])
        self.assertEqual(40, m["d"])
        self.assertEqual(["b", "c", "d"], sorted(m))

        t1.commit()
        self.assertEqual(3, len(m))
        self.assertEqual(20, m["b"])
        self.assertEqual(30, m["c"])
        self.assertEqual(40, m["d"])
        self.assertEqual(["b", "c", "d"], sorted(m))
Пример #12
0
 def test_operations_on_mapping_with_transactions(self):
     mem = Memory()
     m: 'TransactionalMapping[int, str]' = TransactionalMapping(mem)
     self.assertEqual(0, len(m))
     m[1] = "d"
     mem.begin_transaction()
     m[2] = "b"
     mem.begin_transaction()
     mem.begin_transaction()
     m[3] = "c"
     mem.begin_transaction()
     m[1] = "a"
     self.assertEqual(3, len(m))
     self.assertIn(1, m)
     self.assertIn(2, m)
     self.assertIn(3, m)
     self.assertNotIn(4, m)
     self.assertEqual("a", m[1])
     self.assertEqual("b", m[2])
     self.assertEqual("c", m[3])
     self.assertIsNone(m.get(5))
     self.assertEqual([1, 2, 3], sorted(m))
Пример #13
0
class Smtlib(VoidVisitor[Tag]):
    def __init__(self, ms: 'MessageSet') -> 'None':
        self.__ms = ms
        self.__mem = Memory()
        self.__symbols = SymbolTable(self.__mem)
        self.__assertions: 'MutableSet[Expr]' = TransactionalSet(self.__mem)
        self.__stack = _ExprStack()
        self.__model: 'Optional[Model]' = None

    @property
    def symbols(self) -> 'SymbolTable':
        return self.__symbols

    @property
    def assertion(self) -> 'Expr':
        if len(self.__assertions) == 0:
            return boolean(True)
        return boolean_and(*self.__assertions)

    def execute(self, pos: 'Position') -> 'None':
        self._visit(CommandListNode(Scanner(pos, self.__ms)))

    def _visit_command_list_node(self, node: 'CommandListNode') -> 'None':
        for command in node.seq:
            self._visit(command)

    def _visit_command_node(self, node: 'CommandNode') -> 'None':
        self._visit(node.content)

    def _visit_assert_node(self, node: 'AssertNode') -> 'None':
        self._visit(node.term)
        _, expr = self.__stack.pop()
        if expr.symbol.sort != Sort.UNKNOWN and expr.symbol.sort != Sort.BOOL:
            self.__ms.add(
                Message(node.term.start,
                        "invalid assert command, term is not Bool"))
        if not expr.has_wrappers:
            self.__assertions.add(expr)

    def _visit_check_sat_node(self, _: 'CheckSatNode') -> 'None':
        self.__model = Model(self.assertion)
        self.__model.solve()
        print(self.__model.status)

    def _visit_declare_const_node(self, node: 'DeclareConstNode') -> 'None':
        sort = node.sort.value if node.sort.is_consistent else Sort.UNKNOWN
        self.__declare_symbol(node.ident, VariableSymbol(sort))

    def _visit_declare_fun_node(self, node: 'DeclareFunNode') -> 'None':
        arg_sorts = tuple(π.value if π.is_consistent else Sort.UNKNOWN
                          for π in node.args)
        sort = node.sort.value if node.sort.is_consistent else Sort.UNKNOWN
        symbol = FunctionSymbol(sort, arg_sorts)
        self.__declare_symbol(node.ident, symbol)

    def _visit_define_fun_node(self, node: 'DefineFunNode') -> 'None':
        tr = self.__mem.begin_transaction()
        formal_args: 'List[VariableSymbol]' = []
        for π in node.args:
            var_sort = π.sort.value if π.sort.is_consistent else Sort.UNKNOWN
            var = VariableSymbol(var_sort)
            formal_args.append(var)
            self.__declare_symbol(π.ident, var)
        self._visit(node.term)
        _, body = self.__stack.pop()
        tr.rollback()

        sort = body.symbol.sort
        if node.sort.is_consistent:
            sort = node.sort.value
        if isinstance(body.symbol, ValencySymbol) and sort != body.symbol.sort:
            self.__ms.add(
                Message(node.sort.start,
                        "invalid function definition, sort mismatch"))
        symbol = MacroSymbol(sort, tuple(formal_args), body)
        self.__declare_symbol(node.ident, symbol)

    def _visit_get_model_node(self, node: 'CheckSatNode') -> 'None':
        if (self.__model is None) or (self.__model.status is Status.UNSAT):
            self.__ms.add(Message(node.start, "model not available"))
        else:
            for name in self.__symbols:
                sym = self.__symbols.get_symbol(name, ())
                assert sym is not None
                if isinstance(sym, VariableSymbol):
                    expr = self.__model.eval(sym.apply())
                    if expr is not None:
                        s = self.__symbols.serialize_expr(expr)
                        print(f"{name}: {s}")

    def _visit_simplify_node(self, node: 'SimplifyNode') -> 'None':
        self._visit(node.term)
        _, expr = self.__stack.pop()
        print(self.__symbols.serialize_expr(to_cnf(expr)))

    def _visit_term_node(self, node: 'TermNode') -> 'None':
        self._visit(node.content)

    def _visit_inconsistent_term_node(self, node: 'TermNode') -> 'None':
        self.__stack.push(node, WrapperSymbol().apply())

    def _visit_call_expr_node(self, node: 'CallExprNode') -> 'None':
        for term in node.args:
            self._visit(term)
        if node.ident.is_consistent:
            self.__apply_symbol(node.ident, len(node.args))
        else:
            self.__stack.push(node, WrapperSymbol().apply())

    def _visit_inconsistent_call_expr_node(self, node: 'IdentNode') -> 'None':
        self.__stack.push(node, WrapperSymbol().apply())

    def _visit_let_expr_node(self, node: 'LetExprNode') -> 'None':
        symbols: 'List[IdentNode]' = []
        es: 'List[Expr]' = []
        for binding in node.bindings:
            if binding.is_consistent:
                self._visit(binding.term)
                _, e = self.__stack.pop()
                symbols.append(binding.ident)
                es.append(e)

        tr = self.__mem.begin_transaction()
        table: 'MutableMapping[Expr, Expr]' = {}
        for i in range(len(symbols)):
            var = VariableSymbol(es[i].symbol.sort)
            table[var.apply()] = es[i]
            self.__declare_symbol(symbols[i], var)
        self._visit(node.term)
        _, expr = self.__stack.pop()
        self.__stack.push(node, expr.substitute(table))
        tr.rollback()

    def _visit_inconsistent_let_expr_node(self, node: 'IdentNode') -> 'None':
        self.__stack.push(node, WrapperSymbol().apply())

    def _visit_ident_node(self, node: 'IdentNode') -> 'None':
        self.__apply_symbol(node, 0)

    def _visit_inconsistent_ident_node(self, node: 'IdentNode') -> 'None':
        self.__stack.push(node, WrapperSymbol().apply())

    def _visit_number_node(self, node: 'NumberNode') -> 'None':
        self.__stack.push(node, integer(node.value))

    def _visit_inconsistent_number_node(self, node: 'IdentNode') -> 'None':
        self.__stack.push(node, WrapperSymbol().apply())

    def __declare_symbol(self, node: 'IdentNode', symbol: 'Symbol') -> 'None':
        if node.is_consistent:
            name = node.name
            if not self.__symbols.declare(name, symbol):
                if SymbolTable.is_standard_symbol(name):
                    self.__ms.add(
                        Message(
                            node.start,
                            f"invalid declaration, builtin symbol '{name}'"))
                else:
                    self.__ms.add(
                        Message(
                            node.start,
                            f"invalid declaration, symbol '{name}' already declared"
                        ))

    def __apply_symbol(self, node: 'IdentNode', args_num: 'int') -> 'None':
        assert node.is_consistent
        nodes, es = self.__stack.take_off(args_num)
        arg_sorts = tuple(ε.symbol.sort for ε in es)
        name = node.name
        symbol = self.__symbols.get_symbol(name, arg_sorts)
        if symbol is None:
            self.__ms.add(Message(node.start, f"symbol '{name}' not declared"))
            symbol = WrapperSymbol()
            self.__symbols.declare(name, symbol)

        expr = symbol.apply(*es)
        self.__stack.push(node, expr)
        if isinstance(symbol, ValencySymbol) and isinstance(
                expr.symbol, WrapperSymbol):
            for i in range(args_num):
                pos = nodes[i].start
                actual_sort, formal_sort = arg_sorts[i], symbol.get_arg_sort(
                    i, True)
                if formal_sort is None:
                    self.__ms.add(
                        Message(pos,
                                f"extra argument passed to function '{name}'"))
                    break
                if actual_sort != Sort.UNKNOWN and actual_sort != formal_sort:
                    self.__ms.add(
                        Message(
                            pos,
                            f"sort mismatch at argument #{i+1} for function '{name}'"
                        ))
            if symbol.get_arg_sort(args_num, False) is not None:
                pos = nodes[-1].follow if args_num > 0 else node.follow
                self.__ms.add(
                    Message(
                        pos,
                        f"not enough arguments ({args_num}) passed to function '{name}'"
                    ))