예제 #1
0
	def test_number_10(self):
		value = random.randint(1, 10000)
		thing = {
			'c': {
				'c1': value,
			}
		}
		rule_text = 'c.c1 == ' + str(value)
		rule1 = engine.Rule(rule_text, context=engine.Context())
		rule2 = engine.Rule(rule_text, context=engine.Context(default_value=None))
		self.assertEqual(rule1.evaluate(thing), rule2.evaluate(thing))
예제 #2
0
 def test_ast_expression_symbol_type(self):
     context = engine.Context(type_resolver=self._type_resolver)
     symbol = ast.SymbolExpression(context, self.sym_name)
     self.assertIs(symbol.result_type, ast.DataType.STRING)
     self.assertEqual(symbol.name, self.sym_name)
     self.assertEqual(symbol.evaluate({self.sym_name: self.sym_value}),
                      self.sym_value)
예제 #3
0
 def test_ast_expression_symbol(self):
     symbol = ast.SymbolExpression(engine.Context(), self.sym_strname)
     self.assertIs(symbol.result_type, ast.DataType.UNDEFINED)
     self.assertEqual(symbol.name, self.sym_strname)
     self.assertEqual(
         symbol.evaluate({self.sym_strname: self.sym_strvalue}),
         self.sym_strvalue)
예제 #4
0
class LiteralExpressionTests(unittest.TestCase):
	context = engine.Context()
	def assertLiteralTests(self, ExpressionClass, false_value, *true_values):
		with self.assertRaises(TypeError):
			ast.StringExpression(self.context, UnknownType())

		expression = ExpressionClass(self.context, false_value)
		self.assertIsInstance(expression, ast.LiteralExpressionBase)
		self.assertFalse(expression.evaluate(None))

		for true_value in true_values:
			expression = ExpressionClass(self.context, true_value)
			self.assertTrue(expression.evaluate(None))

	def test_ast_expression_literal_array(self):
		self.assertLiteralTests(ast.ArrayExpression, tuple(), tuple((ast.NullExpression(context),)))

	def test_ast_expression_literal_boolean(self):
		self.assertLiteralTests(ast.BooleanExpression, False, True)

	def test_ast_expression_literal_float(self):
		trueish_floats = (expression.value for expression in trueish if isinstance(expression, ast.FloatExpression))
		self.assertLiteralTests(ast.FloatExpression, 0.0, float('nan'), *trueish_floats)

	def test_ast_expression_literal_null(self):
		expression = ast.NullExpression(self.context)
		self.assertIsNone(expression.evaluate(None))

	def test_ast_expression_literal_string(self):
		self.assertLiteralTests(ast.StringExpression, '', 'non-empty')
예제 #5
0
 def test_tls_for_regex1(self):
     context = engine.Context()
     rule = engine.Rule('words =~ "(\w+) \w+"', context=context)
     rule.evaluate({'words': 'MainThread Test'})
     self.assertEqual(context._tls.regex_groups, ('MainThread', ))
     RuleThread(rule, {'words': 'AlternateThread Test'}).join()
     self.assertEqual(context._tls.regex_groups, ('MainThread', ))
예제 #6
0
	def test_ast_type_hints(self):
		parser_ = parser.Parser()
		cases = (
			# type,             type_is,             type_is_not
			('symbol << 1',     ast.DataType.FLOAT,  ast.DataType.STRING),
			('symbol + 1',      ast.DataType.FLOAT,  ast.DataType.STRING),
			('symbol > 1',      ast.DataType.FLOAT,  ast.DataType.STRING),
			('symbol =~ "foo"', ast.DataType.STRING, ast.DataType.FLOAT),
		)
		for case, type_is, type_is_not in cases:
			parser_.parse(case, self.context)
			context = engine.Context(type_resolver=engine.type_resolver_from_dict({'symbol': type_is}))
			parser_.parse(case, context)
			context = engine.Context(type_resolver=engine.type_resolver_from_dict({'symbol': type_is_not}))
			with self.assertRaises(errors.EvaluationError):
				parser_.parse(case, context)
예제 #7
0
    def test_engine_builtins(self):
        builtins = engine.Builtins.from_defaults(
            {'test': {
                'one': 1.0,
                'two': 2.0
            }})
        self.assertIsInstance(builtins, engine.Builtins)
        self.assertIsNone(builtins.namespace)
        self.assertRegex(
            repr(builtins),
            r'<Builtins namespace=None keys=\(\'\S+\'(, \'\S+\')*\)')

        self.assertIn('test', builtins)
        test_builtins = builtins['test']
        self.assertIsInstance(test_builtins, engine.Builtins)
        self.assertEqual(test_builtins.namespace, 'test')

        self.assertIn('today', builtins)
        today = builtins['today']
        self.assertIsInstance(today, datetime.date)

        # test that builtins have correct type hints
        builtins = engine.Builtins.from_defaults(
            {'name': 'Alice'}, value_types={'name': ast.DataType.STRING})
        self.assertEqual(builtins.resolve_type('name'), ast.DataType.STRING)
        self.assertEqual(builtins.resolve_type('missing'),
                         ast.DataType.UNDEFINED)
        context = engine.Context()
        context.builtins = builtins
        engine.Rule('$name =~ ""')
        with self.assertRaises(errors.EvaluationError):
            engine.Rule('$name + 1', context=context)
예제 #8
0
    def test_ast_expression_datetime_attributes(self):
        timestamp = datetime.datetime(2019,
                                      9,
                                      11,
                                      20,
                                      46,
                                      57,
                                      506406,
                                      tzinfo=dateutil.tz.UTC)
        symbol = ast.DatetimeExpression(engine.Context(), timestamp)

        attributes = {
            'day': 11,
            'hour': 20,
            'microsecond': 506406,
            'millisecond': 506.406,
            'minute': 46,
            'month': 9,
            'second': 57,
            'weekday': timestamp.strftime('%A'),
            'year': 2019,
            'zone_name': 'UTC',
        }
        for attribute_name, value in attributes.items():
            expression = ast.GetAttributeExpression(context, symbol,
                                                    attribute_name)
            self.assertEqual(expression.evaluate(None), value,
                             "attribute {} failed".format(attribute_name))
예제 #9
0
    def test_engine_builtins_re_groups(self):
        context = engine.Context()
        rule = engine.Rule(
            'words =~ "(\\w+) (\\w+) (\\w+)" and $re_groups[0] == word0',
            context=context)
        self.assertIsNone(context._tls.regex_groups)
        words = (''.join(
            random.choice(string.ascii_letters)
            for _ in range(random.randint(4, 12))), ''.join(
                random.choice(string.ascii_letters)
                for _ in range(random.randint(4, 12))), ''.join(
                    random.choice(string.ascii_letters)
                    for _ in range(random.randint(4, 12))))
        self.assertTrue(
            rule.matches({
                'words': ' '.join(words),
                'word0': words[0]
            }))
        self.assertEqual(context._tls.regex_groups, words)

        self.assertFalse(
            rule.matches({
                'words': ''.join(words),
                'word0': words[0]
            }))
        self.assertIsNone(context._tls.regex_groups)
예제 #10
0
 def test_tls_for_comprehension(self):
     context = engine.Context()
     rule = engine.Rule('[word for word in words][0]', context=context)
     rule.evaluate({'words': ('MainThread', 'Test')})
     # this isn't exactly a thread test since the assignment scope should be cleared after the comprehension is
     # complete
     self.assertEqual(len(context._tls.assignment_scopes), 0)
예제 #11
0
	def test_engine_resolve_item_with_defaults(self):
		thing = {'name': 'Alice'}
		context = engine.Context(resolver=engine.resolve_item, default_value=None)
		self.assertEqual(engine.Rule('name', context=context).evaluate(thing), thing['name'])
		self.assertIsNone(engine.Rule('name.first', context=context).evaluate(thing))
		self.assertIsNone(engine.Rule('address', context=context).evaluate(thing))
		self.assertIsNone(engine.Rule('address.city', context=context).evaluate(thing))
예제 #12
0
	def test_engine_resolve_attribute_with_defaults(self):
		thing = collections.namedtuple('Person', ('name',))(name='alice')
		context = engine.Context(resolver=engine.resolve_attribute, default_value=None)
		self.assertEqual(engine.Rule('name', context=context).evaluate(thing), thing.name)
		self.assertIsNone(engine.Rule('name.first', context=context).evaluate(thing))
		self.assertIsNone(engine.Rule('address', context=context).evaluate(thing))
		self.assertIsNone(engine.Rule('address.city', context=context).evaluate(thing))
예제 #13
0
 def test_number_14(self):
     context = engine.Context(type_resolver=engine.type_resolver_from_dict({
         'TEST_FLOAT':
         ast.DataType.FLOAT,
     }))
     rule = engine.Rule('(TEST_FLOAT == null ? 0 : TEST_FLOAT) < 42',
                        context=context)
     rule.matches({'TEST_FLOAT': None})
예제 #14
0
 def test_number_19(self):
     context = engine.Context(type_resolver=engine.type_resolver_from_dict({
         'facts':
         ast.DataType.MAPPING(key_type=ast.DataType.STRING,
                              value_type=ast.DataType.STRING)
     }))
     rule = engine.Rule('facts.abc == "def"', context=context)
     self.assertTrue(rule.matches({'facts': {'abc': 'def'}}))
예제 #15
0
 def test_ast_expression_symbol_type_errors(self):
     context = engine.Context(type_resolver=self._type_resolver)
     symbol = ast.SymbolExpression(context, self.sym_name)
     self.assertIs(symbol.result_type, ast.DataType.STRING)
     self.assertEqual(symbol.name, self.sym_name)
     with self.assertRaises(errors.SymbolTypeError):
         self.assertEqual(
             symbol.evaluate({self.sym_name: not self.sym_value}),
             self.sym_value)
     self.assertIsNone(symbol.evaluate({self.sym_name: None}))
예제 #16
0
class ParserTestsBase(unittest.TestCase):
	_parser = parser.Parser()
	context = engine.Context()
	def _parse(self, string, context):
		return self._parser.parse(string, self.context)

	def assertStatementType(self, string, ast_expression):
		statement = self._parse(string, self.context)
		self.assertIsInstance(statement, ast.Statement, msg='the parser did not return a statement')
		expression = statement.expression
		self.assertIsInstance(expression, ast_expression, msg='the statement expression is not the correct expression type')
		return statement
예제 #17
0
    def test_ast_expression_string_attributes(self):
        string = 'Rule Engine'
        symbol = ast.StringExpression(engine.Context(), string)

        attributes = {
            'as_lower': string.lower(),
            'as_upper': string.upper(),
            'length': len(string),
        }
        for attribute_name, value in attributes.items():
            expression = ast.GetAttributeExpression(context, symbol,
                                                    attribute_name)
            self.assertEqual(expression.evaluate(None), value,
                             "attribute {} failed".format(attribute_name))
예제 #18
0
 def test_tls_for_regex2(self):
     lock = threading.RLock()
     context = engine.Context(
         resolver=functools.partial(testing_resolver, lock))
     rule = engine.Rule(
         'words =~ "(\w+) \w+" and lock and $re_groups[0] == "MainThread"',
         context=context)
     self.assertTrue(rule.evaluate({'words': 'MainThread Test'}))
     lock.release()
     with lock:
         thread = RuleThread(rule, {'words': 'AlternateThread Test'})
         self.assertTrue(rule.evaluate({'words': 'MainThread Test'}))
         lock.release()
     self.assertFalse(thread.join())
예제 #19
0
    def test_parser_comprehension_expressions_errors(self):
        # test non-iterables raise an exception
        with self.assertRaises(errors.EvaluationError):
            self._parse('[null for something in null]', self.context)

        # test invalid assignments raise an exception
        with self.assertRaises(errors.SyntaxError):
            self._parse('[null for null in something]', self.context)

        # test that data types are propagated...
        typed_context = engine.Context(
            type_resolver=engine.type_resolver_from_dict(
                {'words': types.DataType.ARRAY(types.DataType.STRING)}))
        # ... the result expression
        self._parse('[word =~ ".*" for word in words]', typed_context)
        with self.assertRaises(errors.EvaluationError):
            self._parse('[word % 2 for word in words]', typed_context)
        # ... and the condition expression
        self._parse('[null for word in words if word =~ ".*"]', typed_context)
        with self.assertRaises(errors.EvaluationError):
            self._parse('[null for word in words if word % 2]', typed_context)
예제 #20
0
    def test_ast_expression_symbol_type_errors(self):
        context = engine.Context(type_resolver=self._type_resolver)
        symbol = ast.SymbolExpression(context, self.sym_strname)
        self.assertIs(symbol.result_type, ast.DataType.STRING)
        self.assertEqual(symbol.name, self.sym_strname)
        with self.assertRaises(errors.SymbolTypeError):
            self.assertEqual(
                symbol.evaluate({self.sym_strname: not self.sym_strvalue}),
                self.sym_strvalue)
        self.assertIsNone(symbol.evaluate({self.sym_strname: None}))

        symbol = ast.SymbolExpression(context, self.sym_aryname)
        with self.assertRaises(errors.SymbolTypeError):
            symbol.evaluate({self.sym_aryname: self.sym_aryvalue_nullable})
        try:
            symbol.evaluate({self.sym_aryname: self.sym_aryvalue})
        except errors.SymbolTypeError:
            self.fail('raises SymbolTypeError when it should not')

        symbol = ast.SymbolExpression(context, self.sym_aryname_nontyped)
        try:
            symbol.evaluate({self.sym_aryname_nontyped: self.sym_aryvalue})
            symbol.evaluate(
                {self.sym_aryname_nontyped: self.sym_aryvalue_nontyped})
            symbol.evaluate(
                {self.sym_aryname_nontyped: self.sym_aryvalue_nullable})
        except errors.SymbolTypeError:
            self.fail('raises SymbolTypeError when it should not')

        symbol = ast.SymbolExpression(context, self.sym_aryname_nullable)
        try:
            symbol.evaluate({self.sym_aryname_nullable: self.sym_aryvalue})
            symbol.evaluate(
                {self.sym_aryname_nullable: self.sym_aryvalue_nullable})
        except errors.SymbolTypeError:
            self.fail('raises SymbolTypeError when it should not')
예제 #21
0
    def test_context_default_timezone(self):
        context = engine.Context(default_timezone='Local')
        self.assertEqual(context.default_timezone, dateutil.tz.tzlocal())

        context = engine.Context(default_timezone='UTC')
        self.assertEqual(context.default_timezone, dateutil.tz.tzutc())
예제 #22
0
 def test_context_default_timezone_errors(self):
     with self.assertRaises(ValueError):
         engine.Context(default_timezone='doesnotexist')
     with self.assertRaises(TypeError):
         engine.Context(default_timezone=600)
예제 #23
0
class AstTests(unittest.TestCase):
    context = engine.Context()
    thing = {'age': 21.0, 'name': 'Alice'}

    def test_ast_evaluates_arithmetic_comparisons(self):
        parser_ = parser.Parser()
        statement = parser_.parse('age >= 21', self.context)
        self.assertTrue(statement.evaluate(self.thing))
        statement = parser_.parse('age > 100', self.context)
        self.assertFalse(statement.evaluate(self.thing))

    def test_ast_evaluates_logic(self):
        parser_ = parser.Parser()
        self.assertTrue(
            parser_.parse('true and true', self.context).evaluate(None))
        self.assertFalse(
            parser_.parse('true and false', self.context).evaluate(None))

        self.assertTrue(
            parser_.parse('true or false', self.context).evaluate(None))
        self.assertFalse(
            parser_.parse('false or false', self.context).evaluate(None))

    def test_ast_evaluates_fuzzy_comparisons(self):
        parser_ = parser.Parser()
        statement = parser_.parse('name =~ ".lic."', self.context)
        self.assertTrue(statement.evaluate(self.thing))
        statement = parser_.parse('name =~~ "lic"', self.context)
        self.assertTrue(statement.evaluate(self.thing))

    def test_ast_evaluates_string_comparisons(self):
        parser_ = parser.Parser()
        statement = parser_.parse('name == "Alice"', self.context)
        self.assertTrue(statement.evaluate(self.thing))
        statement = parser_.parse('name == "calie"', self.context)
        self.assertFalse(statement.evaluate(self.thing))

    def test_ast_evaluates_unary_not(self):
        parser_ = parser.Parser()
        statement = parser_.parse('not false', self.context)
        self.assertTrue(statement.evaluate(None))
        statement = parser_.parse('not true', self.context)
        self.assertFalse(statement.evaluate(None))

        statement = parser_.parse('true and not false', self.context)
        self.assertTrue(statement.evaluate(None))
        statement = parser_.parse('false and not true', self.context)
        self.assertFalse(statement.evaluate(None))

    def test_ast_evaluates_unary_uminus(self):
        parser_ = parser.Parser()
        statement = parser_.parse('-(2 * 5)', self.context)
        self.assertEqual(statement.evaluate(None), -10)

    def test_ast_raises_type_mismatch_arithmetic_comparisons(self):
        parser_ = parser.Parser()
        statement = parser_.parse('symbol < 1', self.context)
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': 'string'})
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': True})
        self.assertTrue(statement.evaluate({'symbol': 0.0}))

    def test_ast_raises_type_mismatch_bitwise(self):
        parser_ = parser.Parser()
        statement = parser_.parse('symbol << 1', self.context)
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': 1.1})
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': 'string'})
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': True})
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': inf})
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': nan})
        self.assertEqual(statement.evaluate({'symbol': 1}), 2)

        with self.assertRaises(errors.EvaluationError):
            parser_.parse('symbol << 1.1', self.context)
        with self.assertRaises(errors.EvaluationError):
            parser_.parse('symbol << "string"', self.context)
        with self.assertRaises(errors.EvaluationError):
            parser_.parse('symbol << true', self.context)
        with self.assertRaises(errors.EvaluationError):
            parser_.parse('inf << 1', self.context)
        with self.assertRaises(errors.EvaluationError):
            parser_.parse('nan << 1', self.context)

    def test_ast_raises_type_mismatch_fuzzy_comparisons(self):
        parser_ = parser.Parser()
        statement = parser_.parse('symbol =~ "string"', self.context)
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': 1.1})
        with self.assertRaises(errors.EvaluationError):
            statement.evaluate({'symbol': True})
        self.assertTrue(statement.evaluate({'symbol': 'string'}))

        with self.assertRaises(errors.EvaluationError):
            parser_.parse('"string" =~ 1', self.context)
        with self.assertRaises(errors.EvaluationError):
            parser_.parse('"string" =~ true', self.context)

    def test_ast_reduces_arithmetic(self):
        thing = {'one': 1, 'two': 2}
        parser_ = parser.Parser()
        statement = parser_.parse('1 + 2', self.context)
        self.assertIsInstance(statement.expression, ast.FloatExpression)
        self.assertEqual(statement.evaluate(None), 3)

        statement = parser_.parse('one + 2', self.context)
        self.assertIsInstance(statement.expression, ast.ArithmeticExpression)
        self.assertEqual(statement.evaluate(thing), 3)

        statement = parser_.parse('1 + two', self.context)
        self.assertIsInstance(statement.expression, ast.ArithmeticExpression)
        self.assertEqual(statement.evaluate(thing), 3)

    def test_ast_reduces_array_literals(self):
        parser_ = parser.Parser()
        statement = parser_.parse('[1, 2, 1 + 2]', self.context)
        self.assertIsInstance(statement.expression, ast.ArrayExpression)
        self.assertTrue(statement.expression.is_reduced)
        self.assertEqual(statement.evaluate(None), (1, 2, 3))

        statement = parser_.parse('[foobar]', self.context)
        self.assertIsInstance(statement.expression, ast.ArrayExpression)
        self.assertFalse(statement.expression.is_reduced)

    def test_ast_reduces_attributes(self):
        parser_ = parser.Parser()
        statement = parser_.parse('"foobar".length', self.context)
        self.assertIsInstance(statement.expression, ast.FloatExpression)
        self.assertEqual(statement.evaluate(None), 6)

    def test_ast_reduces_bitwise(self):
        parser_ = parser.Parser()
        statement = parser_.parse('1 << 2', self.context)
        self.assertIsInstance(statement.expression, ast.FloatExpression)
        self.assertEqual(statement.evaluate(None), 4)

    def test_ast_reduces_ternary(self):
        parser_ = parser.Parser()
        statement = parser_.parse('true ? 1 : 0', self.context)
        self.assertIsInstance(statement.expression, ast.FloatExpression)
        self.assertEqual(statement.evaluate(None), 1)

    def test_ast_type_hints(self):
        parser_ = parser.Parser()
        cases = (
            # type,             type_is,             type_is_not
            ('symbol << 1', ast.DataType.FLOAT, ast.DataType.STRING),
            ('symbol + 1', ast.DataType.FLOAT, ast.DataType.STRING),
            ('symbol[1]', ast.DataType.STRING, ast.DataType.FLOAT),
            ('symbol[1]', ast.DataType.ARRAY, ast.DataType.FLOAT),
            ('symbol =~ "foo"', ast.DataType.STRING, ast.DataType.FLOAT),
        )
        for case, type_is, type_is_not in cases:
            parser_.parse(case, self.context)
            context = engine.Context(
                type_resolver=engine.type_resolver_from_dict(
                    {'symbol': type_is}))
            parser_.parse(case, context)
            context = engine.Context(
                type_resolver=engine.type_resolver_from_dict(
                    {'symbol': type_is_not}))
            with self.assertRaises(errors.EvaluationError,
                                   msg='case: ' + case):
                parser_.parse(case, context)
예제 #24
0
#  DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
#  THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import datetime
import decimal
import unittest

import rule_engine.ast as ast
import rule_engine.engine as engine

__all__ = ('LiteralExpressionTests', )

context = engine.Context()
context.builtins = engine.Builtins.from_defaults(
    {'test': {
        'one': 1.0,
        'two': 2.0
    }})
# literal expressions which should evaluate to false
falseish = (ast.ArrayExpression(context, tuple()),
            ast.BooleanExpression(context,
                                  False), ast.FloatExpression(context, 0.0),
            ast.NullExpression(context), ast.StringExpression(context, ''))
# literal expressions which should evaluate to true
trueish = (ast.ArrayExpression(context, tuple(
    (ast.NullExpression(context), ))),
           ast.ArrayExpression(context,
                               tuple((ast.FloatExpression(context, 1.0), ))),