def test_merge_message_error_name_conflict() -> None: m2_f2 = Field(ID("F2", Location((10, 5)))) m2 = UnprovenMessage( "P.M2", [Link(INITIAL, m2_f2), Link(m2_f2, FINAL)], {Field("F2"): MODULAR_INTEGER}, Location((15, 3)), ) m1_f1 = Field(ID("F1", Location((20, 8)))) m1_f1_f2 = Field(ID("F1_F2", Location((30, 5)))) m1 = UnprovenMessage( "P.M1", [Link(INITIAL, m1_f1), Link(m1_f1, m1_f1_f2), Link(m1_f1_f2, FINAL)], { Field("F1"): m2, Field("F1_F2"): MODULAR_INTEGER }, Location((2, 9)), ) assert_type_error( m1.merged(), r"^" r'<stdin>:30:5: model: error: name conflict for "F1_F2" in "P.M1"\n' r'<stdin>:15:3: model: info: when merging message "P.M2"\n' r'<stdin>:20:8: model: info: into field "F1"$', )
def test_prefixed_message() -> None: assert_equal( UnprovenMessage( "P.M", [ Link(INITIAL, Field("F1")), Link( Field("F1"), Field("F2"), LessEqual(Variable("F1"), Number(100)), first=First("F1"), ), Link( Field("F1"), Field("F3"), GreaterEqual(Variable("F1"), Number(200)), first=First("F1"), ), Link(Field("F2"), FINAL), Link(Field("F3"), Field("F4"), length=Variable("F3")), Link(Field("F4"), FINAL), ], { Field("F1"): deepcopy(MODULAR_INTEGER), Field("F2"): deepcopy(MODULAR_INTEGER), Field("F3"): deepcopy(RANGE_INTEGER), Field("F4"): Opaque(), }, ).prefixed("X_"), UnprovenMessage( "P.M", [ Link(INITIAL, Field("X_F1")), Link( Field("X_F1"), Field("X_F2"), LessEqual(Variable("X_F1"), Number(100)), first=First("X_F1"), ), Link( Field("X_F1"), Field("X_F3"), GreaterEqual(Variable("X_F1"), Number(200)), first=First("X_F1"), ), Link(Field("X_F2"), FINAL), Link(Field("X_F3"), Field("X_F4"), length=Variable("X_F3")), Link(Field("X_F4"), FINAL), ], { Field("X_F1"): deepcopy(MODULAR_INTEGER), Field("X_F2"): deepcopy(MODULAR_INTEGER), Field("X_F3"): deepcopy(RANGE_INTEGER), Field("X_F4"): Opaque(), }, ), )
def test_merge_message_simple() -> None: assert_equal( deepcopy(M_SMPL_REF).merged(), UnprovenMessage( "P.Smpl_Ref", [ Link(INITIAL, Field("NR_F1"), length=Number(16)), Link(Field("NR_F3"), FINAL, Equal(Variable("NR_F3"), Variable("P.ONE"))), Link(Field("NR_F4"), FINAL), Link(Field("NR_F1"), Field("NR_F2")), Link( Field("NR_F2"), Field("NR_F3"), LessEqual(Variable("NR_F2"), Number(100)), first=First("NR_F2"), ), Link( Field("NR_F2"), Field("NR_F4"), GreaterEqual(Variable("NR_F2"), Number(200)), first=First("NR_F2"), ), ], { Field("NR_F1"): Opaque(), Field("NR_F2"): deepcopy(MODULAR_INTEGER), Field("NR_F3"): deepcopy(ENUMERATION), Field("NR_F4"): deepcopy(RANGE_INTEGER), }, ), )
def create_message(message: MessageSpec, types: Mapping[ID, Type]) -> Message: components = list(message.components) if components and components[0].name: components.insert(0, Component()) field_types: Dict[Field, Type] = {} error = RecordFluxError() for component in components: if component.name and component.type_name: type_name = qualified_type_name(component.type_name, message.package) if type_name not in types: continue field_types[Field(component.name)] = types[type_name] structure: List[Link] = [] for i, component in enumerate(components): if not component.name: error.extend([( "invalid first expression", Subsystem.PARSER, Severity.ERROR, then.first.location, ) for then in component.thens if then.first != UNDEFINED]) source_node = Field(component.name) if component.name else INITIAL if not component.thens: name = components[i + 1].name if i + 1 < len(components) else None target_node = Field(name) if name else FINAL structure.append(Link(source_node, target_node)) for then in component.thens: target_node = Field(then.name) if then.name else FINAL if then.name and target_node not in field_types.keys(): error.append( f'undefined field "{then.name}"', Subsystem.PARSER, Severity.ERROR, then.name.location if then.name else None, ) continue structure.append( Link(source_node, target_node, then.condition, then.length, then.first, then.location)) return (UnprovenMessage(message.identifier, structure, field_types, message.location, error).merged().proven())
def test_field_locations() -> None: f2 = Field(ID("F2", Location((2, 2)))) f3 = Field(ID("F3", Location((3, 2)))) message = UnprovenMessage( "P.M", [Link(INITIAL, f2), Link(f2, f3), Link(f3, FINAL)], { Field("F2"): MODULAR_INTEGER, Field("F3"): MODULAR_INTEGER }, Location((17, 9)), ) assert message.fields == (f2, f3)
[ Link(INITIAL, Field("Payload"), size=Number(16)), Link(Field("Payload"), FINAL, Equal(Variable("Payload"), Aggregate(Number(1), Number(2)))), ], {Field("Payload"): OPAQUE}, skip_proof=True, ) EXPRESSION_MODEL = Model([EXPRESSION_MESSAGE]) DERIVATION_MESSAGE = DerivedMessage("Derivation::Message", TLV_MESSAGE) DERIVATION_MODEL = Model([DERIVATION_MESSAGE]) VALID_MESSAGE = UnprovenMessage( "P::M", [ Link(INITIAL, Field("F"), size=Number(16)), Link(Field("F"), FINAL), ], {Field("F"): OPAQUE}, ) INVALID_MESSAGE = UnprovenMessage( "P::M", [ Link(INITIAL, Field("F")), Link(Field("X"), FINAL), ], {Field("F"): OPAQUE}, ) MODULAR_INTEGER = ModularInteger("P::Modular", Number(256)) RANGE_INTEGER = RangeInteger("P::Range_Integer", Number(1), Number(100), Number(8))
def messages( draw: Draw, unique_identifiers: ty.Generator[ID, None, None], not_null: bool = False, ) -> Message: # pylint: disable=too-many-locals, too-many-statements @dataclass class FieldPair: source: Field target: Field source_type: ty.Optional[Type] target_type: ty.Optional[Type] def size(pair: FieldPair) -> expr.Expr: max_size = 2**29 - 1 if isinstance(pair.target_type, (Opaque, Sequence)): if isinstance(pair.source_type, Integer): if pair.source_type.last.value <= max_size: return expr.Mul(expr.Variable(pair.source.name), expr.Number(8)) return expr.Number( draw(st.integers(min_value=1, max_value=max_size).map(lambda x: x * 8)) ) return expr.UNDEFINED def condition(pair: FieldPair) -> expr.Expr: if isinstance(pair.source_type, Integer): first = pair.source_type.first.value last = pair.source_type.last.value if last - first > 0: return expr.Equal( expr.Variable(pair.source.name), expr.Number(draw(st.integers(min_value=first, max_value=last))), ) elif isinstance(pair.source_type, Enumeration) and len(pair.source_type.literals) > 1: return expr.Equal( expr.Variable(pair.source.name), expr.Variable( list(pair.source_type.literals.keys())[ draw(st.integers(min_value=0, max_value=len(pair.source_type.literals) - 1)) ] ), ) return expr.TRUE @st.composite def fields(_: ty.Callable[[], object]) -> Field: return Field(next(unique_identifiers).name) structure: ty.List[Link] = [] def outgoing(field: Field) -> ty.Sequence[Link]: return [l for l in structure if l.source == field] alignment = 0 alignments = {} types_ = {} for i in range(draw(st.integers(min_value=1 if not_null else 0, max_value=4))): f = draw(fields()) t = draw( st.one_of(scalars(unique_identifiers), composites(unique_identifiers)) if alignment == 0 else scalars(unique_identifiers, align_to_8=alignment) ) types_[f] = t alignments[f] = alignment if isinstance(t, Scalar): alignment = (alignment + int(t.size)) % 8 if types_: fields_ = list(types_.keys()) for i, target in enumerate(fields_): source = fields_[i - 1] if i > 0 else INITIAL pair = FieldPair( source, target, types_[source] if source != INITIAL else None, types_[target], ) structure.append(Link(source, target, condition=condition(pair), size=size(pair))) for i, source in enumerate(fields_): out = outgoing(source) if fields_[i + 1 :] and len(out) == 1 and out[0].condition != expr.TRUE: source_type = types_[source] field_size = int(source_type.size) if isinstance(source_type, Scalar) else 0 target_alignment = (alignments[source] + field_size) % 8 potential_targets = [ f for f in fields_[i + 1 :] if alignments[f] == target_alignment ] if target_alignment == 0: potential_targets.append(FINAL) target = draw(st.sampled_from(potential_targets)) pair = FieldPair( source, target, types_[source], types_[target] if target != FINAL else None ) structure.append( Link( source, target, condition=expr.Not(out[0].condition).simplified(), size=size(pair), ) ) loose_ends = [f for f in fields_ if all(l.source != f for l in structure)] for field in loose_ends: field_type = types_[f] field_size = int(field_type.size) if isinstance(field_type, Scalar) else 0 padding = (alignments[field] + field_size) % 8 if padding == 0: structure.append(Link(field, FINAL)) else: f = draw(fields()) t = draw(scalars(unique_identifiers, align_to_8=padding)) types_[f] = t structure.append(Link(field, f)) structure.append(Link(f, FINAL)) message = UnprovenMessage(next(unique_identifiers), structure, types_) try: return message.proven() except error.RecordFluxError as e: e.extend( [ ( f"incorrectly generated message:\n {message!r}", error.Subsystem.MODEL, error.Severity.INFO, None, ) ], ) raise e
from tests.models import ENUMERATION, ETHERNET_FRAME, MODULAR_INTEGER, RANGE_INTEGER from tests.utils import assert_equal, assert_message_model_error M_NO_REF = UnprovenMessage( "P.No_Ref", [ Link(INITIAL, Field("F1"), length=Number(16)), Link(Field("F1"), Field("F2")), Link(Field("F2"), Field("F3"), LessEqual(Variable("F2"), Number(100)), first=First("F2")), Link( Field("F2"), Field("F4"), GreaterEqual(Variable("F2"), Number(200)), first=First("F2"), ), Link(Field("F3"), FINAL, Equal(Variable("F3"), Variable("ONE"))), Link(Field("F4"), FINAL), ], { Field("F1"): Opaque(), Field("F2"): MODULAR_INTEGER, Field("F3"): ENUMERATION, Field("F4"): RANGE_INTEGER, }, ) M_SMPL_REF = UnprovenMessage( "P.Smpl_Ref",