def test_number_10(self): value = random.randint(1, 10000) thing = { 'c': { 'c1': value, } } rule_text = 'c.c1 == ' + str(value) rule1 = engine.Rule(rule_text, context=engine.Context()) rule2 = engine.Rule(rule_text, context=engine.Context(default_value=None)) self.assertEqual(rule1.evaluate(thing), rule2.evaluate(thing))
def test_ast_expression_symbol_type(self): context = engine.Context(type_resolver=self._type_resolver) symbol = ast.SymbolExpression(context, self.sym_name) self.assertIs(symbol.result_type, ast.DataType.STRING) self.assertEqual(symbol.name, self.sym_name) self.assertEqual(symbol.evaluate({self.sym_name: self.sym_value}), self.sym_value)
def test_ast_expression_symbol(self): symbol = ast.SymbolExpression(engine.Context(), self.sym_strname) self.assertIs(symbol.result_type, ast.DataType.UNDEFINED) self.assertEqual(symbol.name, self.sym_strname) self.assertEqual( symbol.evaluate({self.sym_strname: self.sym_strvalue}), self.sym_strvalue)
class LiteralExpressionTests(unittest.TestCase): context = engine.Context() def assertLiteralTests(self, ExpressionClass, false_value, *true_values): with self.assertRaises(TypeError): ast.StringExpression(self.context, UnknownType()) expression = ExpressionClass(self.context, false_value) self.assertIsInstance(expression, ast.LiteralExpressionBase) self.assertFalse(expression.evaluate(None)) for true_value in true_values: expression = ExpressionClass(self.context, true_value) self.assertTrue(expression.evaluate(None)) def test_ast_expression_literal_array(self): self.assertLiteralTests(ast.ArrayExpression, tuple(), tuple((ast.NullExpression(context),))) def test_ast_expression_literal_boolean(self): self.assertLiteralTests(ast.BooleanExpression, False, True) def test_ast_expression_literal_float(self): trueish_floats = (expression.value for expression in trueish if isinstance(expression, ast.FloatExpression)) self.assertLiteralTests(ast.FloatExpression, 0.0, float('nan'), *trueish_floats) def test_ast_expression_literal_null(self): expression = ast.NullExpression(self.context) self.assertIsNone(expression.evaluate(None)) def test_ast_expression_literal_string(self): self.assertLiteralTests(ast.StringExpression, '', 'non-empty')
def test_tls_for_regex1(self): context = engine.Context() rule = engine.Rule('words =~ "(\w+) \w+"', context=context) rule.evaluate({'words': 'MainThread Test'}) self.assertEqual(context._tls.regex_groups, ('MainThread', )) RuleThread(rule, {'words': 'AlternateThread Test'}).join() self.assertEqual(context._tls.regex_groups, ('MainThread', ))
def test_ast_type_hints(self): parser_ = parser.Parser() cases = ( # type, type_is, type_is_not ('symbol << 1', ast.DataType.FLOAT, ast.DataType.STRING), ('symbol + 1', ast.DataType.FLOAT, ast.DataType.STRING), ('symbol > 1', ast.DataType.FLOAT, ast.DataType.STRING), ('symbol =~ "foo"', ast.DataType.STRING, ast.DataType.FLOAT), ) for case, type_is, type_is_not in cases: parser_.parse(case, self.context) context = engine.Context(type_resolver=engine.type_resolver_from_dict({'symbol': type_is})) parser_.parse(case, context) context = engine.Context(type_resolver=engine.type_resolver_from_dict({'symbol': type_is_not})) with self.assertRaises(errors.EvaluationError): parser_.parse(case, context)
def test_engine_builtins(self): builtins = engine.Builtins.from_defaults( {'test': { 'one': 1.0, 'two': 2.0 }}) self.assertIsInstance(builtins, engine.Builtins) self.assertIsNone(builtins.namespace) self.assertRegex( repr(builtins), r'<Builtins namespace=None keys=\(\'\S+\'(, \'\S+\')*\)') self.assertIn('test', builtins) test_builtins = builtins['test'] self.assertIsInstance(test_builtins, engine.Builtins) self.assertEqual(test_builtins.namespace, 'test') self.assertIn('today', builtins) today = builtins['today'] self.assertIsInstance(today, datetime.date) # test that builtins have correct type hints builtins = engine.Builtins.from_defaults( {'name': 'Alice'}, value_types={'name': ast.DataType.STRING}) self.assertEqual(builtins.resolve_type('name'), ast.DataType.STRING) self.assertEqual(builtins.resolve_type('missing'), ast.DataType.UNDEFINED) context = engine.Context() context.builtins = builtins engine.Rule('$name =~ ""') with self.assertRaises(errors.EvaluationError): engine.Rule('$name + 1', context=context)
def test_ast_expression_datetime_attributes(self): timestamp = datetime.datetime(2019, 9, 11, 20, 46, 57, 506406, tzinfo=dateutil.tz.UTC) symbol = ast.DatetimeExpression(engine.Context(), timestamp) attributes = { 'day': 11, 'hour': 20, 'microsecond': 506406, 'millisecond': 506.406, 'minute': 46, 'month': 9, 'second': 57, 'weekday': timestamp.strftime('%A'), 'year': 2019, 'zone_name': 'UTC', } for attribute_name, value in attributes.items(): expression = ast.GetAttributeExpression(context, symbol, attribute_name) self.assertEqual(expression.evaluate(None), value, "attribute {} failed".format(attribute_name))
def test_engine_builtins_re_groups(self): context = engine.Context() rule = engine.Rule( 'words =~ "(\\w+) (\\w+) (\\w+)" and $re_groups[0] == word0', context=context) self.assertIsNone(context._tls.regex_groups) words = (''.join( random.choice(string.ascii_letters) for _ in range(random.randint(4, 12))), ''.join( random.choice(string.ascii_letters) for _ in range(random.randint(4, 12))), ''.join( random.choice(string.ascii_letters) for _ in range(random.randint(4, 12)))) self.assertTrue( rule.matches({ 'words': ' '.join(words), 'word0': words[0] })) self.assertEqual(context._tls.regex_groups, words) self.assertFalse( rule.matches({ 'words': ''.join(words), 'word0': words[0] })) self.assertIsNone(context._tls.regex_groups)
def test_tls_for_comprehension(self): context = engine.Context() rule = engine.Rule('[word for word in words][0]', context=context) rule.evaluate({'words': ('MainThread', 'Test')}) # this isn't exactly a thread test since the assignment scope should be cleared after the comprehension is # complete self.assertEqual(len(context._tls.assignment_scopes), 0)
def test_engine_resolve_item_with_defaults(self): thing = {'name': 'Alice'} context = engine.Context(resolver=engine.resolve_item, default_value=None) self.assertEqual(engine.Rule('name', context=context).evaluate(thing), thing['name']) self.assertIsNone(engine.Rule('name.first', context=context).evaluate(thing)) self.assertIsNone(engine.Rule('address', context=context).evaluate(thing)) self.assertIsNone(engine.Rule('address.city', context=context).evaluate(thing))
def test_engine_resolve_attribute_with_defaults(self): thing = collections.namedtuple('Person', ('name',))(name='alice') context = engine.Context(resolver=engine.resolve_attribute, default_value=None) self.assertEqual(engine.Rule('name', context=context).evaluate(thing), thing.name) self.assertIsNone(engine.Rule('name.first', context=context).evaluate(thing)) self.assertIsNone(engine.Rule('address', context=context).evaluate(thing)) self.assertIsNone(engine.Rule('address.city', context=context).evaluate(thing))
def test_number_14(self): context = engine.Context(type_resolver=engine.type_resolver_from_dict({ 'TEST_FLOAT': ast.DataType.FLOAT, })) rule = engine.Rule('(TEST_FLOAT == null ? 0 : TEST_FLOAT) < 42', context=context) rule.matches({'TEST_FLOAT': None})
def test_number_19(self): context = engine.Context(type_resolver=engine.type_resolver_from_dict({ 'facts': ast.DataType.MAPPING(key_type=ast.DataType.STRING, value_type=ast.DataType.STRING) })) rule = engine.Rule('facts.abc == "def"', context=context) self.assertTrue(rule.matches({'facts': {'abc': 'def'}}))
def test_ast_expression_symbol_type_errors(self): context = engine.Context(type_resolver=self._type_resolver) symbol = ast.SymbolExpression(context, self.sym_name) self.assertIs(symbol.result_type, ast.DataType.STRING) self.assertEqual(symbol.name, self.sym_name) with self.assertRaises(errors.SymbolTypeError): self.assertEqual( symbol.evaluate({self.sym_name: not self.sym_value}), self.sym_value) self.assertIsNone(symbol.evaluate({self.sym_name: None}))
class ParserTestsBase(unittest.TestCase): _parser = parser.Parser() context = engine.Context() def _parse(self, string, context): return self._parser.parse(string, self.context) def assertStatementType(self, string, ast_expression): statement = self._parse(string, self.context) self.assertIsInstance(statement, ast.Statement, msg='the parser did not return a statement') expression = statement.expression self.assertIsInstance(expression, ast_expression, msg='the statement expression is not the correct expression type') return statement
def test_ast_expression_string_attributes(self): string = 'Rule Engine' symbol = ast.StringExpression(engine.Context(), string) attributes = { 'as_lower': string.lower(), 'as_upper': string.upper(), 'length': len(string), } for attribute_name, value in attributes.items(): expression = ast.GetAttributeExpression(context, symbol, attribute_name) self.assertEqual(expression.evaluate(None), value, "attribute {} failed".format(attribute_name))
def test_tls_for_regex2(self): lock = threading.RLock() context = engine.Context( resolver=functools.partial(testing_resolver, lock)) rule = engine.Rule( 'words =~ "(\w+) \w+" and lock and $re_groups[0] == "MainThread"', context=context) self.assertTrue(rule.evaluate({'words': 'MainThread Test'})) lock.release() with lock: thread = RuleThread(rule, {'words': 'AlternateThread Test'}) self.assertTrue(rule.evaluate({'words': 'MainThread Test'})) lock.release() self.assertFalse(thread.join())
def test_parser_comprehension_expressions_errors(self): # test non-iterables raise an exception with self.assertRaises(errors.EvaluationError): self._parse('[null for something in null]', self.context) # test invalid assignments raise an exception with self.assertRaises(errors.SyntaxError): self._parse('[null for null in something]', self.context) # test that data types are propagated... typed_context = engine.Context( type_resolver=engine.type_resolver_from_dict( {'words': types.DataType.ARRAY(types.DataType.STRING)})) # ... the result expression self._parse('[word =~ ".*" for word in words]', typed_context) with self.assertRaises(errors.EvaluationError): self._parse('[word % 2 for word in words]', typed_context) # ... and the condition expression self._parse('[null for word in words if word =~ ".*"]', typed_context) with self.assertRaises(errors.EvaluationError): self._parse('[null for word in words if word % 2]', typed_context)
def test_ast_expression_symbol_type_errors(self): context = engine.Context(type_resolver=self._type_resolver) symbol = ast.SymbolExpression(context, self.sym_strname) self.assertIs(symbol.result_type, ast.DataType.STRING) self.assertEqual(symbol.name, self.sym_strname) with self.assertRaises(errors.SymbolTypeError): self.assertEqual( symbol.evaluate({self.sym_strname: not self.sym_strvalue}), self.sym_strvalue) self.assertIsNone(symbol.evaluate({self.sym_strname: None})) symbol = ast.SymbolExpression(context, self.sym_aryname) with self.assertRaises(errors.SymbolTypeError): symbol.evaluate({self.sym_aryname: self.sym_aryvalue_nullable}) try: symbol.evaluate({self.sym_aryname: self.sym_aryvalue}) except errors.SymbolTypeError: self.fail('raises SymbolTypeError when it should not') symbol = ast.SymbolExpression(context, self.sym_aryname_nontyped) try: symbol.evaluate({self.sym_aryname_nontyped: self.sym_aryvalue}) symbol.evaluate( {self.sym_aryname_nontyped: self.sym_aryvalue_nontyped}) symbol.evaluate( {self.sym_aryname_nontyped: self.sym_aryvalue_nullable}) except errors.SymbolTypeError: self.fail('raises SymbolTypeError when it should not') symbol = ast.SymbolExpression(context, self.sym_aryname_nullable) try: symbol.evaluate({self.sym_aryname_nullable: self.sym_aryvalue}) symbol.evaluate( {self.sym_aryname_nullable: self.sym_aryvalue_nullable}) except errors.SymbolTypeError: self.fail('raises SymbolTypeError when it should not')
def test_context_default_timezone(self): context = engine.Context(default_timezone='Local') self.assertEqual(context.default_timezone, dateutil.tz.tzlocal()) context = engine.Context(default_timezone='UTC') self.assertEqual(context.default_timezone, dateutil.tz.tzutc())
def test_context_default_timezone_errors(self): with self.assertRaises(ValueError): engine.Context(default_timezone='doesnotexist') with self.assertRaises(TypeError): engine.Context(default_timezone=600)
class AstTests(unittest.TestCase): context = engine.Context() thing = {'age': 21.0, 'name': 'Alice'} def test_ast_evaluates_arithmetic_comparisons(self): parser_ = parser.Parser() statement = parser_.parse('age >= 21', self.context) self.assertTrue(statement.evaluate(self.thing)) statement = parser_.parse('age > 100', self.context) self.assertFalse(statement.evaluate(self.thing)) def test_ast_evaluates_logic(self): parser_ = parser.Parser() self.assertTrue( parser_.parse('true and true', self.context).evaluate(None)) self.assertFalse( parser_.parse('true and false', self.context).evaluate(None)) self.assertTrue( parser_.parse('true or false', self.context).evaluate(None)) self.assertFalse( parser_.parse('false or false', self.context).evaluate(None)) def test_ast_evaluates_fuzzy_comparisons(self): parser_ = parser.Parser() statement = parser_.parse('name =~ ".lic."', self.context) self.assertTrue(statement.evaluate(self.thing)) statement = parser_.parse('name =~~ "lic"', self.context) self.assertTrue(statement.evaluate(self.thing)) def test_ast_evaluates_string_comparisons(self): parser_ = parser.Parser() statement = parser_.parse('name == "Alice"', self.context) self.assertTrue(statement.evaluate(self.thing)) statement = parser_.parse('name == "calie"', self.context) self.assertFalse(statement.evaluate(self.thing)) def test_ast_evaluates_unary_not(self): parser_ = parser.Parser() statement = parser_.parse('not false', self.context) self.assertTrue(statement.evaluate(None)) statement = parser_.parse('not true', self.context) self.assertFalse(statement.evaluate(None)) statement = parser_.parse('true and not false', self.context) self.assertTrue(statement.evaluate(None)) statement = parser_.parse('false and not true', self.context) self.assertFalse(statement.evaluate(None)) def test_ast_evaluates_unary_uminus(self): parser_ = parser.Parser() statement = parser_.parse('-(2 * 5)', self.context) self.assertEqual(statement.evaluate(None), -10) def test_ast_raises_type_mismatch_arithmetic_comparisons(self): parser_ = parser.Parser() statement = parser_.parse('symbol < 1', self.context) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': 'string'}) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': True}) self.assertTrue(statement.evaluate({'symbol': 0.0})) def test_ast_raises_type_mismatch_bitwise(self): parser_ = parser.Parser() statement = parser_.parse('symbol << 1', self.context) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': 1.1}) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': 'string'}) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': True}) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': inf}) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': nan}) self.assertEqual(statement.evaluate({'symbol': 1}), 2) with self.assertRaises(errors.EvaluationError): parser_.parse('symbol << 1.1', self.context) with self.assertRaises(errors.EvaluationError): parser_.parse('symbol << "string"', self.context) with self.assertRaises(errors.EvaluationError): parser_.parse('symbol << true', self.context) with self.assertRaises(errors.EvaluationError): parser_.parse('inf << 1', self.context) with self.assertRaises(errors.EvaluationError): parser_.parse('nan << 1', self.context) def test_ast_raises_type_mismatch_fuzzy_comparisons(self): parser_ = parser.Parser() statement = parser_.parse('symbol =~ "string"', self.context) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': 1.1}) with self.assertRaises(errors.EvaluationError): statement.evaluate({'symbol': True}) self.assertTrue(statement.evaluate({'symbol': 'string'})) with self.assertRaises(errors.EvaluationError): parser_.parse('"string" =~ 1', self.context) with self.assertRaises(errors.EvaluationError): parser_.parse('"string" =~ true', self.context) def test_ast_reduces_arithmetic(self): thing = {'one': 1, 'two': 2} parser_ = parser.Parser() statement = parser_.parse('1 + 2', self.context) self.assertIsInstance(statement.expression, ast.FloatExpression) self.assertEqual(statement.evaluate(None), 3) statement = parser_.parse('one + 2', self.context) self.assertIsInstance(statement.expression, ast.ArithmeticExpression) self.assertEqual(statement.evaluate(thing), 3) statement = parser_.parse('1 + two', self.context) self.assertIsInstance(statement.expression, ast.ArithmeticExpression) self.assertEqual(statement.evaluate(thing), 3) def test_ast_reduces_array_literals(self): parser_ = parser.Parser() statement = parser_.parse('[1, 2, 1 + 2]', self.context) self.assertIsInstance(statement.expression, ast.ArrayExpression) self.assertTrue(statement.expression.is_reduced) self.assertEqual(statement.evaluate(None), (1, 2, 3)) statement = parser_.parse('[foobar]', self.context) self.assertIsInstance(statement.expression, ast.ArrayExpression) self.assertFalse(statement.expression.is_reduced) def test_ast_reduces_attributes(self): parser_ = parser.Parser() statement = parser_.parse('"foobar".length', self.context) self.assertIsInstance(statement.expression, ast.FloatExpression) self.assertEqual(statement.evaluate(None), 6) def test_ast_reduces_bitwise(self): parser_ = parser.Parser() statement = parser_.parse('1 << 2', self.context) self.assertIsInstance(statement.expression, ast.FloatExpression) self.assertEqual(statement.evaluate(None), 4) def test_ast_reduces_ternary(self): parser_ = parser.Parser() statement = parser_.parse('true ? 1 : 0', self.context) self.assertIsInstance(statement.expression, ast.FloatExpression) self.assertEqual(statement.evaluate(None), 1) def test_ast_type_hints(self): parser_ = parser.Parser() cases = ( # type, type_is, type_is_not ('symbol << 1', ast.DataType.FLOAT, ast.DataType.STRING), ('symbol + 1', ast.DataType.FLOAT, ast.DataType.STRING), ('symbol[1]', ast.DataType.STRING, ast.DataType.FLOAT), ('symbol[1]', ast.DataType.ARRAY, ast.DataType.FLOAT), ('symbol =~ "foo"', ast.DataType.STRING, ast.DataType.FLOAT), ) for case, type_is, type_is_not in cases: parser_.parse(case, self.context) context = engine.Context( type_resolver=engine.type_resolver_from_dict( {'symbol': type_is})) parser_.parse(case, context) context = engine.Context( type_resolver=engine.type_resolver_from_dict( {'symbol': type_is_not})) with self.assertRaises(errors.EvaluationError, msg='case: ' + case): parser_.parse(case, context)
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # import datetime import decimal import unittest import rule_engine.ast as ast import rule_engine.engine as engine __all__ = ('LiteralExpressionTests', ) context = engine.Context() context.builtins = engine.Builtins.from_defaults( {'test': { 'one': 1.0, 'two': 2.0 }}) # literal expressions which should evaluate to false falseish = (ast.ArrayExpression(context, tuple()), ast.BooleanExpression(context, False), ast.FloatExpression(context, 0.0), ast.NullExpression(context), ast.StringExpression(context, '')) # literal expressions which should evaluate to true trueish = (ast.ArrayExpression(context, tuple( (ast.NullExpression(context), ))), ast.ArrayExpression(context, tuple((ast.FloatExpression(context, 1.0), ))),