def test_no_contradiction_multi() -> None: structure = [ Link(INITIAL, Field("F0")), Link(Field("F0"), Field("F1"), condition=Equal(Variable("F0"), Number(1))), Link(Field("F0"), Field("F2"), condition=Equal(Variable("F0"), Number(2))), Link(Field("F1"), Field("F3")), Link(Field("F2"), Field("F3")), Link(Field("F3"), Field("F4"), condition=Equal(Variable("F0"), Number(1))), Link(Field("F3"), Field("F5"), condition=Equal(Variable("F0"), Number(2))), Link(Field("F4"), FINAL), Link(Field("F5"), FINAL), ] types = { Field("F0"): RANGE_INTEGER, Field("F1"): RANGE_INTEGER, Field("F2"): RANGE_INTEGER, Field("F3"): RANGE_INTEGER, Field("F4"): RANGE_INTEGER, Field("F5"): RANGE_INTEGER, } Message("P.M", structure, types)
def create_invalid_function() -> UnitPart: specification = FunctionSpecification( "Invalid", "Boolean", [Parameter(["Ctx"], "Context"), Parameter(["Fld"], "Field")]) return UnitPart( [SubprogramDeclaration(specification)], [ ExpressionFunctionDeclaration( specification, Or( Equal( Selected( Indexed(Variable("Ctx.Cursors"), Variable("Fld")), "State"), Variable("S_Invalid"), ), Equal( Selected( Indexed(Variable("Ctx.Cursors"), Variable("Fld")), "State"), Variable("S_Incomplete"), ), ), ) ], )
def test_message_field_condition(self) -> None: self.assertEqual(ETHERNET_FRAME.field_condition(INITIAL), TRUE) self.assertEqual( ETHERNET_FRAME.field_condition(Field("TPID")), Equal(Variable("Type_Length_TPID"), Number(33024, 16)), ) self.assertEqual( ETHERNET_FRAME.field_condition(Field("Type_Length")), Or( NotEqual(Variable("Type_Length_TPID"), Number(33024, 16)), Equal(Variable("Type_Length_TPID"), Number(33024, 16)), ), ) self.assertEqual( ETHERNET_FRAME.field_condition(Field("Payload")), Or( And( Or( NotEqual(Variable("Type_Length_TPID"), Number(33024, 16)), Equal(Variable("Type_Length_TPID"), Number(33024, 16)), ), LessEqual(Variable("Type_Length"), Number(1500)), ), And( Or( NotEqual(Variable("Type_Length_TPID"), Number(33024, 16)), Equal(Variable("Type_Length_TPID"), Number(33024, 16)), ), GreaterEqual(Variable("Type_Length"), Number(1536)), ), ), )
def constraints(self, name: str, proof: bool = False) -> Expr: if proof: return And( And(*[Equal(Variable(l), v) for l, v in self.literals.items()]), Or(*[Equal(Variable(name), Variable(l)) for l in self.literals.keys()]), ) return TRUE
def test_exclusive_with_length_invalid() -> None: f1 = Field(ID("F1", Location((98, 10)))) structure = [ Link(INITIAL, f1, length=Number(32)), Link(f1, FINAL, condition=Equal(Length("F1"), Number(32), Location((10, 2)))), Link(f1, Field("F2"), condition=Equal(Length("F1"), Number(32), Location((12, 4)))), Link(Field("F2"), FINAL), ] types = { Field("F1"): Opaque(), Field("F2"): RANGE_INTEGER, } assert_message_model_error( structure, types, r"^" r'<stdin>:98:10: model: error: conflicting conditions for field "F1"\n' r"<stdin>:10:2: model: info: condition 0 [(]F1 -> Final[)]: F1\'Length = 32\n" r"<stdin>:12:4: model: info: condition 1 [(]F1 -> F2[)]: F1\'Length = 32" r"$", )
def create_tlv_message() -> Message: tag_type = Enumeration("TLV.Tag", { "Msg_Data": Number(1), "Msg_Error": Number(3) }, Number(2), False) length_type = ModularInteger("TLV.Length", Pow(Number(2), Number(14))) structure = [ Link(INITIAL, Field("Tag")), Link(Field("Tag"), Field("Length"), Equal(Variable("Tag"), Variable("Msg_Data"))), Link(Field("Tag"), FINAL, Equal(Variable("Tag"), Variable("Msg_Error"))), Link(Field("Length"), Field("Value"), length=Mul(Variable("Length"), Number(8))), Link(Field("Value"), FINAL), ] types = { Field("Tag"): tag_type, Field("Length"): length_type, Field("Value"): Payload() } return Message("TLV.Message", structure, types)
def bounded_composite_setter_preconditions(message: Message, field: Field) -> Sequence[Expr]: return [ Call( "Field_Condition", [ Variable("Ctx"), NamedAggregate(("Fld", Variable(field.affixed_name))) ] + ([Variable("Length")] if common.length_dependent_condition(message) else []), ), GreaterEqual( Call("Available_Space", [Variable("Ctx"), Variable(field.affixed_name)]), Variable("Length"), ), LessEqual( Add( Call("Field_First", [Variable("Ctx"), Variable(field.affixed_name)]), Variable("Length"), ), Div(Last(const.TYPES_BIT_INDEX), Number(2)), ), Or(*[ And( *[ Call("Valid", [Variable("Ctx"), Variable(field.affixed_name)]) for field in message.fields if Variable(field.name) in l.condition.variables() ], l.condition.substituted( mapping={ Variable(field.name): Call(f"Get_{field.name}", [Variable("Ctx")]) for field in message.fields if Variable(field.name) in l.condition.variables() }), ) for l in message.incoming(field) if Last("Message") in l.length ]), Equal( Mod( Call("Field_First", [Variable("Ctx"), Variable(field.affixed_name)]), Size(const.TYPES_BYTE), ), Number(1), ), Equal( Mod(Variable("Length"), Size(const.TYPES_BYTE)), Number(0), ), ]
def test_tlv_message_with_not_operator_exhausting() -> None: message = Message( "TLV::Message_With_Not_Operator_Exhausting", [ Link(INITIAL, Field("Tag")), Link( Field("Tag"), Field("Length"), Not(Not(Not(NotEqual(Variable("Tag"), Variable("Msg_Data"))))), ), Link( Field("Tag"), FINAL, reduce( lambda acc, f: f(acc), [Not, Not] * 16, Not( Or( Not( Not( Equal(Variable("Tag"), Variable("Msg_Data")))), Not(Equal(Variable("Tag"), Variable("Msg_Error"))), )), ), ), Link(Field("Length"), Field("Value"), size=Mul(Variable("Length"), Number(8))), Link(Field("Value"), FINAL), ], { Field("Tag"): TLV_TAG, Field("Length"): TLV_LENGTH, Field("Value"): OPAQUE }, ) with pytest.raises( FatalError, match=re.escape( "failed to simplify complex expression `not (not (not (not " "(not (not (not (not (not (not (not (not (not (not (not (not " "(not (not (not (not (not (not (not (not (not (not (not (not " "(not (not (not (not (not (not (not (Tag = TLV::Msg_Data))\n" " " "or not (Tag = TLV::Msg_Error))))))))))))))))))))))))))))))))))` " "after `16` iterations, best effort: " "`not (not (not (not (not (not (not (not (not (not (not (not (not " "(not (not (not (not (Tag = TLV::Msg_Data\n" " or Tag /= TLV::Msg_Error)))))))))))))))))`"), ): model = PyRFLX(model=Model([TLV_TAG, TLV_LENGTH, message])) pkg = model.package("TLV") msg = pkg.new_message("Message_With_Not_Operator_Exhausting") test_bytes = b"\x01\x00\x04\x00\x00\x00\x00" msg.parse(test_bytes)
def test_merge_message_recursive() -> None: assert_equal( deepcopy(M_DBL_REF).merged(), UnprovenMessage( "P.Dbl_Ref", [ Link(INITIAL, Field("SR_NR_F1"), length=Number(16)), Link( Field("SR_NR_F3"), Field("NR_F1"), Equal(Variable("SR_NR_F3"), Variable("P.ONE")), length=Number(16), ), Link(Field("SR_NR_F4"), Field("NR_F1"), length=Number(16)), Link(Field("NR_F3"), FINAL, Equal(Variable("NR_F3"), Variable("P.ONE"))), Link(Field("NR_F4"), FINAL), Link(Field("SR_NR_F1"), Field("SR_NR_F2")), Link( Field("SR_NR_F2"), Field("SR_NR_F3"), LessEqual(Variable("SR_NR_F2"), Number(100)), first=First("SR_NR_F2"), ), Link( Field("SR_NR_F2"), Field("SR_NR_F4"), GreaterEqual(Variable("SR_NR_F2"), Number(200)), first=First("SR_NR_F2"), ), Link(Field("NR_F1"), Field("NR_F2")), Link( Field("NR_F2"), Field("NR_F3"), LessEqual(Variable("NR_F2"), Number(100)), first=First("NR_F2"), ), Link( Field("NR_F2"), Field("NR_F4"), GreaterEqual(Variable("NR_F2"), Number(200)), first=First("NR_F2"), ), ], { Field("SR_NR_F1"): Opaque(), Field("SR_NR_F2"): deepcopy(MODULAR_INTEGER), Field("SR_NR_F3"): deepcopy(ENUMERATION), Field("SR_NR_F4"): deepcopy(RANGE_INTEGER), Field("NR_F1"): Opaque(), Field("NR_F2"): deepcopy(MODULAR_INTEGER), Field("NR_F3"): deepcopy(ENUMERATION), Field("NR_F4"): deepcopy(RANGE_INTEGER), }, ), )
def __prove_contradictions(self) -> None: for f in (INITIAL, *self.__fields): for index, c in enumerate(self.outgoing(f)): contradiction = Equal(self.__with_constraints(c.condition), FALSE) result = contradiction.forall() if result == ProofResult.sat: message = str(contradiction).replace("\n", "") raise ModelError( f'contradicting condition {index} from field "{f.name}" to' f' "{c.target.name}" in "{self.full_name}"' f" ({result}: {message})" )
def test_relation_simplified() -> None: assert_equal( Equal(Variable("X"), Add(Number(1), Number(1))).simplified(), Equal(Variable("X"), Number(2)), ) assert_equal( Equal(Add(Number(1), Number(1)), Variable("X")).simplified(), Equal(Number(2), Variable("X")), ) assert_equal( Equal(Add(Number(1), Number(1)), Add(Number(1), Number(1))).simplified(), TRUE, )
def field_accessor_functions(field: Field, package_name: str) -> List[Subprogram]: precondition = Precondition( And(COMMON_PRECONDITION, LogCall(f'Valid_{field.name} (Buffer)'))) functions: List[Subprogram] = [] if isinstance(field.type, Array): for attribute in ['First', 'Last']: functions.append( ExpressionFunction( f'Get_{field.name}_{attribute}', 'Types.Index_Type', [('Buffer', 'Types.Bytes')], IfExpression([( LogCall(f'Valid_{field.name}_{variant_id} (Buffer)'), LogCall( f'Get_{field.name}_{variant_id}_{attribute} (Buffer)' )) for variant_id in field.variants], 'Unreachable_Types_Index_Type'), [precondition])) body: List[Statement] = [ Assignment('First', MathCall(f'Get_{field.name}_First (Buffer)')), Assignment('Last', MathCall(f'Get_{field.name}_Last (Buffer)')) ] postcondition = Postcondition( And( Equal(Value('First'), MathCall(f'Get_{field.name}_First (Buffer)')), Equal(Value('Last'), MathCall(f'Get_{field.name}_Last (Buffer)')))) if 'Payload' not in field.type.name: predicate = f'{package_name}.{field.type.name}.Is_Contained (Buffer (First .. Last))' body.append(PragmaStatement('Assume', [predicate])) postcondition.expr = And(postcondition.expr, LogCall(predicate)) functions.append( Procedure(f'Get_{field.name}', [('Buffer', 'Types.Bytes'), ('First', 'out Types.Index_Type'), ('Last', 'out Types.Index_Type')], [], body, [precondition, postcondition])) else: functions.append( ExpressionFunction( f'Get_{field.name}', field.type.name, [('Buffer', 'Types.Bytes')], IfExpression( [(LogCall(f'Valid_{field.name}_{variant_id} (Buffer)'), MathCall(f'Get_{field.name}_{variant_id} (Buffer)')) for variant_id in field.variants], f'Unreachable_{field.type.name}'), [precondition])) return functions
def __link_expression(self, link: Link) -> Expr: name = link.target.name return And( *[ Equal(First(name), self.__target_first(link)), Equal(Length(name), self.__target_length(link)), Equal(Last(name), self.__target_last(link)), GreaterEqual(First("Message"), Number(0)), GreaterEqual(Last("Message"), Last(name)), GreaterEqual(Last("Message"), First("Message")), Equal(Length("Message"), Add(Sub(Last("Message"), First("Message")), Number(1))), link.condition, ] )
def create_structural_valid_function() -> UnitPart: specification = FunctionSpecification( "Structural_Valid", "Boolean", [Parameter(["Ctx"], "Context"), Parameter(["Fld"], "Field")], ) return UnitPart( [SubprogramDeclaration(specification)], [ ExpressionFunctionDeclaration( specification, And( Or(*[ Equal( Selected( Indexed(Variable("Ctx.Cursors"), Variable("Fld")), "State"), Variable(s), ) for s in ("S_Valid", "S_Structural_Valid") ])), ) ], )
def create_verify_message_procedure( message: Message, context_invariant: Sequence[Expr]) -> UnitPart: specification = ProcedureSpecification( "Verify_Message", [InOutParameter(["Ctx"], "Context")]) return UnitPart( [ SubprogramDeclaration( specification, [ Postcondition( And( Equal( Call("Has_Buffer", [Variable("Ctx")]), Old(Call("Has_Buffer", [Variable("Ctx")])), ), *context_invariant, )), ], ) ], [ SubprogramBody( specification, [], [ CallStatement( "Verify", [Variable("Ctx"), Variable(f.affixed_name)]) for f in message.fields ], ) ], )
def test_opaque_not_byte_aligned_dynamic() -> None: with pytest.raises( RecordFluxError, match= r'^<stdin>:44:3: model: error: opaque field "O2" not aligned to' r" 8 bit boundary [(]L1 -> O1 -> L2 -> O2[)]", ): o2 = Field(ID("O2", location=Location((44, 3)))) Message( "P.M", [ Link(INITIAL, Field("L1")), Link( Field("L1"), Field("O1"), length=Variable("L1"), condition=Equal(Mod(Variable("L1"), Number(8)), Number(0)), ), Link(Field("O1"), Field("L2")), Link(Field("L2"), o2, length=Number(128)), Link(o2, FINAL), ], { Field("L1"): MODULAR_INTEGER, Field("L2"): ModularInteger("P.T", Number(4)), Field("O1"): Opaque(), o2: Opaque(), }, )
def test_array_aggregate_invalid_element_type() -> None: inner = Message( "P.I", [Link(INITIAL, Field("F")), Link(Field("F"), FINAL)], {Field("F"): MODULAR_INTEGER}, ) array_type = Array("P.Array", inner) f = Field("F") with pytest.raises( RecordFluxError, match=r"^<stdin>:90:10: model: error: invalid array element type" ' "P.I" for aggregate comparison$', ): Message( "P.M", [ Link(INITIAL, f, length=Number(18)), Link( f, FINAL, condition=Equal( Variable("F"), Aggregate(Number(1), Number(2), Number(64)), Location((90, 10)), ), ), ], {Field("F"): array_type}, )
def test_array_aggregate_out_of_range() -> None: array_type = Array("P.Array", ModularInteger("P.Element", Number(64))) f = Field("F") with pytest.raises( RecordFluxError, match= r"^<stdin>:44:3: model: error: aggregate element out of range 0 .. 63", ): Message( "P.M", [ Link(INITIAL, f, length=Number(18)), Link( f, FINAL, condition=Equal( Variable("F"), Aggregate(Number(1), Number(2), Number(64, location=Location((44, 3)))), ), ), ], {Field("F"): array_type}, )
def test_merge_message_simple_derived() -> None: assert_equal( deepcopy(M_SMPL_REF_DERI).merged(), UnprovenDerivedMessage( "P.Smpl_Ref_Deri", M_SMPL_REF, [ Link(INITIAL, Field("NR_F1"), length=Number(16)), Link(Field("NR_F3"), FINAL, Equal(Variable("NR_F3"), Variable("P.ONE"))), Link(Field("NR_F4"), FINAL), Link(Field("NR_F1"), Field("NR_F2")), Link( Field("NR_F2"), Field("NR_F3"), LessEqual(Variable("NR_F2"), Number(100)), first=First("NR_F2"), ), Link( Field("NR_F2"), Field("NR_F4"), GreaterEqual(Variable("NR_F2"), Number(200)), first=First("NR_F2"), ), ], { Field("NR_F1"): Opaque(), Field("NR_F2"): deepcopy(MODULAR_INTEGER), Field("NR_F3"): deepcopy(ENUMERATION), Field("NR_F4"): deepcopy(RANGE_INTEGER), }, ), )
def test_ass_expr_findall() -> None: assert_equal( And(Equal(Variable("X"), Number(1)), Variable("Y"), Number(2)).findall( lambda x: isinstance(x, Number) ), [Number(1), Number(2)], )
def test_conditionally_unreachable_field_enum_last() -> None: structure = [ Link(INITIAL, Field("F1")), Link(Field("F1"), Field("F2"), Equal(Last("F1"), Last("Message"))), Link(Field("F2"), FINAL), ] types = { Field("F1"): ENUMERATION, Field("F2"): ENUMERATION, } assert_message_model_error( structure, types, r"^" r'model: error: unreachable field "F2" in "P.M"\n' r"model: info: path 0 [(]F1 -> F2[)]:\n" r'model: info: unsatisfied "F2\'Last = [(][(][(]F1\'Last [+] 1[)] [+] 8[)][)] - 1"\n' r'model: info: unsatisfied "Message\'Last >= F2\'Last"\n' r'model: info: unsatisfied "F1\'Last = Message\'Last"\n' r'model: error: unreachable field "Final" in "P.M"\n' r"model: info: path 0 [(]F1 -> F2 -> Final[)]:\n" r'model: info: unsatisfied "F2\'Last = [(][(][(]F1\'Last [+] 1[)] [+] 8[)][)] - 1"\n' r'model: info: unsatisfied "Message\'Last >= F2\'Last"\n' r'model: info: unsatisfied "F1\'Last = Message\'Last"', )
def test_exclusive_enum_valid() -> None: structure = [ Link(INITIAL, Field("F1")), Link(Field("F1"), FINAL, condition=Equal(Variable("F1"), Variable("ONE"))), Link(Field("F1"), Field("F2"), condition=Equal(Variable("F1"), Variable("TWO"))), Link(Field("F2"), FINAL), ] types = { Field("F1"): ENUMERATION, Field("F2"): MODULAR_INTEGER, } Message("P.M", structure, types)
def valid_path_to_next_field_condition( self, message: Message, field: Field, field_type: Type ) -> Sequence[Expr]: return [ If( [ ( l.condition, And( Equal( Call("Predecessor", [Name("Ctx"), Name(l.target.affixed_name)],), Name(field.affixed_name), ), Call("Valid_Next", [Name("Ctx"), Name(l.target.affixed_name)]) if l.target != FINAL else TRUE, ), ) ] ).simplified( { **{ Variable(field.name): Call("Convert", [Name("Value")]) if isinstance(field_type, Enumeration) and field_type.always_valid else Name("Value") }, **self.public_substitution(message), } ) for l in message.outgoing(field) if l.target != FINAL ]
def setter_postconditions( self, message: Message, field: Field, field_type: Type ) -> Sequence[Expr]: return [ *[ Call("Invalid", [Name("Ctx"), Name(p.affixed_name)]) for p in message.successors(field) if p != FINAL ], *self.common.valid_path_to_next_field_condition(message, field, field_type), *[ Equal(e, Old(e)) for e in [ Selected("Ctx", "Buffer_First"), Selected("Ctx", "Buffer_Last"), Selected("Ctx", "First"), Call("Predecessor", [Name("Ctx"), Name(field.affixed_name)]), Call("Valid_Next", [Name("Ctx"), Name(field.affixed_name)]), ] + [ Call(f"Get_{p.name}", [Name("Ctx")]) for p in message.definite_predecessors(field) if isinstance(message.types[p], Scalar) ] ], ]
def enumeration_functions(enum: Enumeration) -> List[Subprogram]: common_precondition = And( Less(Value('Offset'), Number(8)), Equal( Length('Buffer'), Add( Div(Add(Size(enum.base_name), Value('Offset'), Number(-1)), Number(8)), Number(1)))) control_expression = LogCall( f'Convert_To_{enum.base_name} (Buffer, Offset)') validation_expression: Expr if enum.always_valid: validation_expression = Value('True') else: validation_cases: List[Tuple[Expr, Expr]] = [] validation_cases.extend( (value, Value('True')) for value in enum.literals.values()) validation_cases.append((Value('others'), Value('False'))) validation_expression = CaseExpression(control_expression, validation_cases) validation_function = ExpressionFunction( f'Valid_{enum.name}', 'Boolean', [('Buffer', 'Types.Bytes'), ('Offset', 'Natural')], validation_expression, [Precondition(common_precondition)]) function_name = f'Convert_To_{enum.name}' parameters = [('Buffer', 'Types.Bytes'), ('Offset', 'Natural')] precondition = Precondition( And(common_precondition, LogCall(f'Valid_{enum.name} (Buffer, Offset)'))) conversion_cases: List[Tuple[Expr, Expr]] = [] conversion_function: Subprogram if enum.always_valid: conversion_cases.extend((value, Aggregate(Value('True'), Value(key))) for key, value in enum.literals.items()) conversion_cases.append( (Value('others'), Aggregate(Value('False'), Value('Raw')))) conversion_function = Function( function_name, enum.name, parameters, [Declaration('Raw', enum.base_name, control_expression)], [ReturnStatement(CaseExpression(Value('Raw'), conversion_cases))], [precondition]) else: conversion_cases.extend( (value, Value(key)) for key, value in enum.literals.items()) conversion_cases.append( (Value('others'), LogCall(f'Unreachable_{enum.name}'))) conversion_function = ExpressionFunction( function_name, enum.name, parameters, CaseExpression(control_expression, conversion_cases), [precondition]) return [validation_function, conversion_function]
def test_type_refinement_spec() -> None: spec = { "Message_Type": Specification( ContextSpec([]), PackageSpec( "Message_Type", [ ModularInteger("__PACKAGE__.T", Number(256)), MessageSpec( "__PACKAGE__.PDU", [ Component( "Foo", "T", [ Then( "Bar", UNDEFINED, UNDEFINED, LessEqual(Variable("Foo"), Number(30, 16)), ), Then( "Baz", UNDEFINED, UNDEFINED, Greater(Variable("Foo"), Number(30, 16)), ), ], ), Component("Bar", "T"), Component("Baz", "T"), ], ), MessageSpec( "__PACKAGE__.Simple_PDU", [Component("Bar", "T"), Component("Baz", "T")], ), MessageSpec("__PACKAGE__.Empty_PDU", []), ], ), ), "Type_Refinement": Specification( ContextSpec(["Message_Type"]), PackageSpec( "Type_Refinement", [ RefinementSpec( "Message_Type.Simple_PDU", "Bar", "Message_Type.PDU", Equal(Variable("Baz"), Number(42)), ), RefinementSpec("Message_Type.PDU", "Bar", "Message_Type.Simple_PDU"), ], ), ), } assert_specifications_files( [f"{TESTDIR}/message_type.rflx", f"{TESTDIR}/type_refinement.rflx"], spec )
def test_tlv_message_with_not_operator() -> None: message = Message( "TLV::Message_With_Not_Operator", [ Link(INITIAL, Field("Tag")), Link( Field("Tag"), Field("Length"), Not(Not(Not(NotEqual(Variable("Tag"), Variable("Msg_Data"))))), ), Link( Field("Tag"), FINAL, Not( Not( Not( Or( Not( Not( Equal(Variable("Tag"), Variable("Msg_Data")))), Not( Equal(Variable("Tag"), Variable("Msg_Error"))), )))), ), Link(Field("Length"), Field("Value"), size=Mul(Variable("Length"), Number(8))), Link(Field("Value"), FINAL), ], { Field("Tag"): TLV_TAG, Field("Length"): TLV_LENGTH, Field("Value"): OPAQUE }, ) model = PyRFLX(model=Model([TLV_TAG, TLV_LENGTH, message])) pkg = model.package("TLV") msg = pkg.new_message("Message_With_Not_Operator") test_bytes = b"\x01\x00\x04\x00\x00\x00\x00" msg.parse(test_bytes) assert msg.valid_message assert msg.bytestring == test_bytes
def create_expression_message() -> Message: structure = [ Link(INITIAL, Field("Payload"), length=Number(16)), Link(Field("Payload"), FINAL, Equal(Variable("Payload"), Aggregate(Number(1), Number(2)))), ] types = {Field("Payload"): Payload()} return Message("Expression.Message", structure, types)
def test_if_expr_substituted() -> None: assert_equal( If( [ (Equal(Variable("X"), Number(42)), Number(21)), (Variable("Y"), Number(42)), (Number(42), Variable("Z")), ] ).substituted(lambda x: Variable(f"P_{x}") if isinstance(x, Variable) else x), If( [ (Equal(Variable("P_X"), Number(42)), Number(21)), (Variable("P_Y"), Number(42)), (Number(42), Variable("P_Z")), ] ), ) assert_equal( If( [ (Equal(Variable("X"), Number(42)), Number(21)), (Variable("Y"), Number(42)), (Number(42), Variable("Z")), ] ).substituted( lambda x: Variable(f"P_{x}") if isinstance(x, Variable) else ( If([*x.condition_expressions, (Variable("Z"), Number(1))], x.else_expression) if isinstance(x, If) else x ) ), If( [ (Equal(Variable("P_X"), Number(42)), Number(21)), (Variable("P_Y"), Number(42)), (Number(42), Variable("P_Z")), (Variable("P_Z"), Number(1)), ] ), )
def test_if_expr_findall() -> None: assert_equal( If( [ (Equal(Variable("X"), Number(42)), Number(21)), (Variable("Y"), Number(42)), (Number(42), Variable("Z")), ] ).findall(lambda x: isinstance(x, Number)), [Number(42), Number(21), Number(42), Number(42)], )