def convert_interpreter_value_to_ir( interpreter_value: dslx_value.Value) -> ir_value.Value: """Recursively translates a DSLX Value into an IR Value.""" if interpreter_value.is_bits() or interpreter_value.is_enum(): return ir_value.Value( int_to_bits(interpreter_value.get_bits_value_check_sign(), interpreter_value.get_bit_count())) elif interpreter_value.is_array(): ir_arr = [] for e in interpreter_value.array_payload.elements: ir_arr.append(convert_interpreter_value_to_ir(e)) return ir_value.Value.make_array(ir_arr) elif interpreter_value.is_tuple(): ir_tuple = [] for e in interpreter_value.tuple_members: ir_tuple.append(convert_interpreter_value_to_ir(e)) return ir_value.Value.make_tuple(ir_tuple) else: raise UnsupportedJitConversionError( "Can't convert to JIT value: {}".format(interpreter_value))
def _value_compatible_with_type(module: ast.Module, type_: ConcreteType, value: Value) -> bool: """Returns whether value is compatible with type_ (recursively).""" assert isinstance(value, Value), value if isinstance(type_, TupleType) and value.is_tuple(): return all( _value_compatible_with_type(module, ct, m) for ct, m in zip(type_.get_unnamed_members(), value.tuple_members)) if isinstance(type_, ArrayType) and value.is_array(): et = type_.get_element_type() return all( _value_compatible_with_type(module, et, m) for m in value.array_payload.elements) if isinstance(type_, EnumType) and value.tag == Tag.ENUM: return type_.get_nominal_type(module) == value.type_ if isinstance(type_, BitsType) and not type_.signed and value.tag == Tag.UBITS: return value.bits_payload.bit_count == type_.get_total_bit_count() if isinstance(type_, BitsType) and type_.signed and value.tag == Tag.SBITS: return value.bits_payload.bit_count == type_.get_total_bit_count() if value.tag == Tag.ENUM and isinstance(type_, BitsType): return (value.type_.get_signedness() == type_.signed and value.bits_payload.bit_count == type_.get_total_bit_count()) if value.tag == Tag.ARRAY and is_ubits(type_): flat_bit_count = value.array_payload.flatten().bits_payload.bit_count return flat_bit_count == type_.get_total_bit_count() if isinstance(type_, EnumType) and value.is_bits(): return (type_.signed == (value.tag == Tag.SBITS) and type_.get_total_bit_count() == value.get_bit_count()) raise NotImplementedError(type_, value)
def compare_values(interpreter_value: dslx_value.Value, jit_value: ir_value.Value) -> None: """Asserts equality between a DSLX Value and an IR Value. Recursively traverses the values (for arrays/tuples) and makes assertions about value and length properties. Args: interpreter_value: Value that resulted from DSL interpretation. jit_value: Value that resulted from JIT-compiled execution. Raises: JitMiscompareError: If the dslx_value and jit_value are not equivalent. UnsupportedJitConversionError: If there is not JIT-supported type equivalent for the interpreter value. """ if interpreter_value.is_bits() or interpreter_value.is_enum(): assert jit_value.is_bits(), f'Expected bits value: {jit_value!r}' jit_value = jit_value.get_bits() bit_count = interpreter_value.get_bit_count() if bit_count != jit_value.bit_count(): raise JitMiscompareError(f'Inconsistent bit counts for value -- ' f'interp: {bit_count}, ' f'jit: {jit_value.bit_count()}') if interpreter_value.is_ubits(): interpreter_bits_value = interpreter_value.get_bits_value() jit_bits_value = bits_to_int(jit_value, signed=False) else: interpreter_bits_value = interpreter_value.get_bits_value_signed() jit_bits_value = bits_to_int(jit_value, signed=True) if interpreter_bits_value != jit_bits_value: raise JitMiscompareError( 'Inconsistent bit values in return value -- ' 'interp: {!r}, jit: {!r}'.format(interpreter_bits_value, jit_bits_value)) elif interpreter_value.is_array(): assert jit_value.is_array(), f'Expected array value: {jit_value!r}' interpreter_values = interpreter_value.array_payload.elements jit_values = jit_value.get_elements() interp_len = len(interpreter_values) jit_len = len(jit_values) if interp_len != jit_len: raise JitMiscompareError( f'Inconsistent array lengths in return value -- ' f'interp: {interp_len}, jit: {jit_len}') for interpreter_element, jit_element in zip(interpreter_values, jit_values): compare_values(interpreter_element, jit_element) elif interpreter_value.is_tuple(): assert jit_value.is_tuple(), 'Expected tuple value: {jit_value!r}' interpreter_values = interpreter_value.tuple_members jit_values = jit_value.get_elements() interp_len = len(interpreter_values) jit_len = len(jit_values) if interp_len != jit_len: raise JitMiscompareError( f'Inconsistent tuple lengths in return value -- ' f'interp: {interp_len}, jit: {jit_len}') for interpreter_element, jit_element in zip(interpreter_values, jit_values): compare_values(interpreter_element, jit_element) else: raise UnsupportedJitConversionError( 'No JIT-supported type equivalent: {}'.format(interpreter_value))