def test_required_arg(): "A simple collection" s = ast_lambda("e.Jets_req()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) with pytest.raises(ValueError) as e: remap_by_types(objs, "e", Event, s) assert "bank_required" in str(e)
def test_function_with_missing_arg(): "Define a function we can use" @func_adl_callable() def MySqrt(my_x: float) -> float: ... s = ast_lambda("MySqrt()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) with pytest.raises(ValueError) as e: remap_by_types(objs, "e", Event, s) assert "my_x" in str(e)
def test_index_callback_prop_not_dec(): "Indexed callback - make sure arg is passed correctly" class TEvent: @property def info(self): ... s = ast_lambda("e.info['fork'](55)") objs = ObjectStream[TEvent](ast.Name(id="e", ctx=ast.Load())) with pytest.raises(ValueError) as e: remap_by_types(objs, "e", TEvent, s) assert "info" in str(e) assert "TEvent" in str(e)
def test_index_callback_modify_ast_nested(): "Indexed callback - make ast can be correctly modified when nested in a Select" def my_callback(s: ObjectStream[T], a: ast.Call, param_1: str) -> Tuple[ObjectStream[T], ast.Call, Type]: new_a = copy.copy(a) assert isinstance(a.func, ast.Attribute) new_a.func = ast.Attribute(value=a.func.value, attr="dude", ctx=a.func.ctx) return (s, new_a, float) class MyJet: @func_adl_parameterized_call(my_callback) @property def info(self): ... class TEvent: def Jets(self) -> Iterable[MyJet]: ... s = ast_lambda("e.Jets().Select(lambda j: j.info['fork'](55))") objs = ObjectStream[TEvent](ast.Name(id="e", ctx=ast.Load())) _, new_s, _ = remap_by_types(objs, "e", TEvent, s) assert ast.dump(new_s) == ast.dump( ast_lambda("e.Jets().Select(lambda j: j.dude(55))"))
def test_index_callback_1arg(): "Indexed callback - make sure arg is passed correctly" param_1_capture = None def my_callback(s: ObjectStream[T], a: ast.Call, param_1: str) -> Tuple[ObjectStream[T], ast.Call, Type]: nonlocal param_1_capture param_1_capture = param_1 return (s.MetaData({"k": "stuff"}), a, float) class TEvent: @func_adl_parameterized_call(my_callback) @property def info(self): ... s = ast_lambda("e.info['fork'](55)") objs = ObjectStream[TEvent](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", TEvent, s) assert ast.dump(new_s) == ast.dump(ast_lambda("e.info(55)")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'k': 'stuff'})")) assert expr_type == float assert param_1_capture == "fork"
def test_method_on_collection_bool(): "Call a method that requires some special stuff on a returend object" s = ast_lambda("e.MET().isGood()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) _, _, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == bool
def test_math_method(caplog): "A simple collection" caplog.set_level(logging.WARNING) s = ast_lambda("abs(e.MET.pxy())") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert len(caplog.text) == 0
def test_bogus_method(): "A method that is not typed" s = ast_lambda("e.Jetsss('default')") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump(ast_lambda("e.Jetsss('default')")) assert ast.dump(new_objs.query_ast) == ast.dump(ast_lambda("e")) assert expr_type == Any
def test_plain_object_method(): "A method that is not typed" s = ast_lambda("j.pt()") objs = ObjectStream[Jet](ast.Name(id="j", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "j", Jet, s) assert ast.dump(new_s) == ast.dump(ast_lambda("j.pt()")) assert ast.dump(new_objs.query_ast) == ast.dump(ast_lambda("j")) assert expr_type == float
def test_method_modify_ast(): "Call a method that requires some special stuff on a returend object" s = ast_lambda("e.EventNumber()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump(ast_lambda("e.EventNumber(20)")) assert ast.dump(new_objs.query_ast) == ast.dump(ast_lambda("e")) assert expr_type == int
def test_collection_with_default(): "A simple collection" s = ast_lambda("e.Jets()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump(ast_lambda("e.Jets('default')")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'j': 'stuff'})")) assert expr_type == Iterable[Jet]
def test_method_on_collection(): "Call a method that requires some special stuff on a returend object" s = ast_lambda("e.MET().pxy()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump(ast_lambda("e.MET().pxy()")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'j': 'pxy stuff'})")) assert expr_type == float
def test_collection_First(caplog): "A simple collection" caplog.set_level(logging.WARNING) s = ast_lambda("e.Jets().First()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) _, _, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == Jet assert len(caplog.text) == 0
def test_collection_CustomIterable_fallback(caplog): "A simple collection from an iterable with its own defined terminals" caplog.set_level(logging.WARNING) s = ast_lambda("e.JetsIterSub().Where(lambda j: j.pt() > 10)") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == Iterable[Jet] assert len(caplog.text) == 0
def test_method_with_no_inital_type(caplog): "A simple collection" caplog.set_level(logging.WARNING) s = ast_lambda("e.MET_bogus().pxy()") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Any, s) assert ast.dump(new_s) == ast.dump(ast_lambda("e.MET_bogus().pxy()")) assert ast.dump(new_objs.query_ast) == ast.dump(ast_lambda("e")) assert expr_type == Any assert len(caplog.text) == 0
def test_collection_Select(caplog): "A simple collection" caplog.set_level(logging.WARNING) s = ast_lambda("e.Jets().Select(lambda j: j.pt())") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == Iterable[float] assert len(caplog.text) == 0
def test_collection_lambda_not_followed(caplog): "Warn if a lambda is not tracked" caplog.set_level(logging.WARNING) s = ast_lambda("e.MyLambdaCallback(lambda f: True)") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == int assert "lambda" in caplog.text.lower() assert "MyLambdaCallback" in caplog.text
def test_index_callback_2arg(): "Indexed callback - make sure 2 args are passed correctly" param_1_capture = None def my_callback(s: ObjectStream[T], a: ast.Call, param_1: str) -> Tuple[ObjectStream[T], ast.Call, Type]: nonlocal param_1_capture param_1_capture = param_1 return (s.MetaData({"k": "stuff"}), a, float) class TEvent: @func_adl_parameterized_call(my_callback) @property def info(self): ... s = ast_lambda("e.info['fork', 22](55)") objs = ObjectStream[TEvent](ast.Name(id="e", ctx=ast.Load())) remap_by_types(objs, "e", TEvent, s) assert param_1_capture == ("fork", 22)
def test_collection_Select_meta(caplog): "A simple collection" caplog.set_level(logging.WARNING) s = ast_lambda("e.TrackStuffs().Select(lambda t: t.pt())") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == Iterable[float] assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'t': 'track stuff'})")) assert len(caplog.text) == 0
def test_shortcut_nested_callback(): """When there is a simple return, like Where, make sure that lambdas inside the method are called""" s = ast_lambda("e.TrackStuffs().Where(lambda t: abs(t.pt()) > 10)") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump( ast_lambda("e.TrackStuffs().Where(lambda t: abs(t.pt()) > 10)")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'t': 'track stuff'})")) assert expr_type == Iterable[TrackStuff]
def test_function_with_keyword(): "Define a function we can use" @func_adl_callable() def MySqrt(x: float = 20) -> float: ... s = ast_lambda("MySqrt(x=15)") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump(ast_lambda("MySqrt(15)")) assert ast.dump(new_objs.query_ast) == ast.dump(ast_lambda("e")) assert expr_type == float
def test_index_callback_bad_prop(): "Indexed callback - make sure arg is passed correctly" param_1_capture = None def my_callback(s: ObjectStream[T], a: ast.Call, param_1: str) -> Tuple[ObjectStream[T], ast.Call, Type]: nonlocal param_1_capture param_1_capture = param_1 return (s.MetaData({"k": "stuff"}), a, float) class TEvent: @func_adl_parameterized_call(my_callback) @property def info(self): ... s = ast_lambda("e.infoo['fork'](55)") objs = ObjectStream[TEvent](ast.Name(id="e", ctx=ast.Load())) with pytest.raises(AttributeError) as e: remap_by_types(objs, "e", TEvent, s) assert "infoo" in str(e) assert "TEvent" in str(e)
def test_function_with_default_inside(): "A function with a default arg that is inside a select" @func_adl_callable() def MySqrt(x: float = 20) -> float: ... s = ast_lambda("e.Jets().Select(lambda j: MySqrt())") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump( ast_lambda("e.Jets('default').Select(lambda j: MySqrt(20))")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'j': 'stuff'})")) assert expr_type == Iterable[float]
def test_shortcut_2nested_callback(): """When there is a simple return, like Where, make sure that lambdas inside the method are called, but double inside""" s = ast_lambda( "ds.Select(lambda e: e.TrackStuffs()).Select(lambda ts: ts.Where(lambda t: t.pt() > 10))" ) objs = ObjectStream[Iterable[Event]](ast.Name(id="ds", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "ds", Iterable[Event], s) assert ast.dump(new_s) == ast.dump( ast_lambda("ds.Select(lambda e: e.TrackStuffs())" ".Select(lambda ts: ts.Where(lambda t: t.pt() > 10))")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(ds, {'t': 'track stuff'})")) assert expr_type == Iterable[Iterable[TrackStuff]]
def test_collection_Custom_Method_multiple_args(caplog): "A custom collection method not pre-given" caplog.set_level(logging.WARNING) M = TypeVar("M") @register_func_adl_os_collection class CustomCollection(ObjectStream[M]): def __init__(self, a: ast.AST, item_type=Any): super().__init__(a, item_type) def MyFirst(self, arg1: int, arg2: int) -> int: ... s = ast_lambda("e.Jets().MyFirst(1,3)") objs = CustomCollection[Event](ast.Name(id="e", ctx=ast.Load())) _, _, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == int assert len(caplog.text) == 0
def test_function_with_processor(): "Define a function we can use" def MySqrtProcessor(s: ObjectStream[T], a: ast.Call) -> Tuple[ObjectStream[T], ast.Call]: new_s = s.MetaData({"j": "func_stuff"}) return new_s, a @func_adl_callable(MySqrtProcessor) def MySqrt(x: float) -> float: ... s = ast_lambda("MySqrt(2)") objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()), item_type=Event) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert ast.dump(new_s) == ast.dump(ast_lambda("MySqrt(2)")) assert ast.dump(new_objs.query_ast) == ast.dump( ast_lambda("MetaData(e, {'j': 'func_stuff'})")) assert new_objs.item_type == Event assert expr_type == float
def test_collection_Custom_Method_Jet(caplog): "A custom collection method not pre-given" caplog.set_level(logging.WARNING) M = TypeVar("M") class CustomCollection_Jet(ObjectStream[M]): def __init__(self, a: ast.AST, item_type): super().__init__(a, item_type) def MyFirst(self) -> M: ... register_func_adl_os_collection(CustomCollection_Jet) s = ast_lambda("e.Jets().MyFirst()") objs = CustomCollection_Jet[Event](ast.Name(id="e", ctx=ast.Load()), Event) new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == Jet assert len(caplog.text) == 0
def test_shortcut_nested_with_iterable_subclass(): """When there is a simple return, like Where, make sure that lambdas inside the method are called""" class MyEvent: def MyTracks(self) -> _itsb_FADLStream[_itsb_MyTrack]: ... s = ast_lambda( "ds.Select(lambda e: e.MyTracks()).Select(lambda ts: ts.Select(lambda t: t.pt()))" ) objs = ObjectStream[Iterable[MyEvent]](ast.Name(id="ds", ctx=ast.Load())) new_objs, new_s, expr_type = remap_by_types(objs, "ds", Iterable[MyEvent], s) assert ast.dump(new_s) == ast.dump( ast_lambda( "ds.Select(lambda e: e.MyTracks()).Select(lambda ts: ts.Select(lambda t: t.pt()))" )) # assert ast.dump(new_objs.query_ast) == ast.dump( # ast_lambda("MetaData(e, {'t': 'track stuff'})") # ) assert expr_type == Iterable[Iterable[float]]
def test_collection_Custom_Method_default(caplog): "A custom collection method not pre-given" caplog.set_level(logging.WARNING) M = TypeVar("M") @register_func_adl_os_collection class CustomCollection_default(ObjectStream[M]): def __init__(self, a: ast.AST, item_type): super().__init__(a, item_type) def Take(self, n: int = 5) -> ObjectStream[M]: ... s = ast_lambda("e.Jets().Take()") objs = CustomCollection_default[Event](ast.Name(id="e", ctx=ast.Load()), Event) _, new_s, expr_type = remap_by_types(objs, "e", Event, s) assert expr_type == ObjectStream[Jet] assert ast.dump(new_s) == ast.dump(ast_lambda("e.Jets('default').Take(5)")) assert len(caplog.text) == 0
def return_type_test(expr: str, arg_type: type, expected_type: type): s = ast_lambda(expr) objs = ObjectStream(ast.Name(id="e", ctx=ast.Load()), arg_type) _, _, expr_type = remap_by_types(objs, "e", arg_type, s) assert expr_type == expected_type