def convert_interpreter_value_to_ir( interpreter_value: dslx_value.Value) -> ir_value.Value: """Recursively translates a DSLX Value into an IR Value.""" if interpreter_value.is_bits() or interpreter_value.is_enum(): return ir_value.Value(interpreter_value.get_bits()) elif interpreter_value.is_array(): ir_arr = [] for e in interpreter_value.get_elements(): ir_arr.append(convert_interpreter_value_to_ir(e)) return ir_value.Value.make_array(ir_arr) elif interpreter_value.is_tuple(): ir_tuple = [] for e in interpreter_value.get_elements(): ir_tuple.append(convert_interpreter_value_to_ir(e)) return ir_value.Value.make_tuple(ir_tuple) else: raise UnsupportedJitConversionError( "Can't convert to JIT value: {}".format(interpreter_value))
def compare_values(interpreter_value: dslx_value.Value, jit_value: ir_value.Value) -> None: """Asserts equality between a DSLX Value and an IR Value. Recursively traverses the values (for arrays/tuples) and makes assertions about value and length properties. Args: interpreter_value: Value that resulted from DSL interpretation. jit_value: Value that resulted from JIT-compiled execution. Raises: JitMiscompareError: If the dslx_value and jit_value are not equivalent. UnsupportedJitConversionError: If there is not JIT-supported type equivalent for the interpreter value. """ if interpreter_value.is_bits() or interpreter_value.is_enum(): assert jit_value.is_bits(), f'Expected bits value: {jit_value!r}' jit_bits_value = jit_value.get_bits() assert isinstance(jit_bits_value, ir_bits.Bits), jit_bits_value bit_count = interpreter_value.get_bit_count() if bit_count != jit_bits_value.bit_count(): raise JitMiscompareError(f'Inconsistent bit counts for value -- ' f'interp: {bit_count}, ' f'jit: {jit_bits_value.bit_count()}') interpreter_bits_value = interpreter_value.get_bits() if interpreter_bits_value != jit_bits_value: raise JitMiscompareError( 'Inconsistent bit values in return value -- ' 'interp: {!r}, jit: {!r}'.format(interpreter_bits_value, jit_bits_value)) elif interpreter_value.is_array(): assert jit_value.is_array(), f'Expected array value: {jit_value!r}' interpreter_values = interpreter_value.get_elements() jit_values = jit_value.get_elements() interp_len = len(interpreter_values) jit_len = len(jit_values) if interp_len != jit_len: raise JitMiscompareError( f'Inconsistent array lengths in return value -- ' f'interp: {interp_len}, jit: {jit_len}') for interpreter_element, jit_element in zip(interpreter_values, jit_values): compare_values(interpreter_element, jit_element) elif interpreter_value.is_tuple(): assert jit_value.is_tuple(), 'Expected tuple value: {jit_value!r}' interpreter_values = interpreter_value.get_elements() jit_values = jit_value.get_elements() interp_len = len(interpreter_values) jit_len = len(jit_values) if interp_len != jit_len: raise JitMiscompareError( f'Inconsistent tuple lengths in return value -- ' f'interp: {interp_len}, jit: {jit_len}') for interpreter_element, jit_element in zip(interpreter_values, jit_values): compare_values(interpreter_element, jit_element) else: raise UnsupportedJitConversionError( 'No JIT-supported type equivalent: {}'.format(interpreter_value))