Exemplo n.º 1
0
def convert_interpreter_value_to_ir(
        interpreter_value: dslx_value.Value) -> ir_value.Value:
    """Recursively translates a DSLX Value into an IR Value."""
    if interpreter_value.is_bits() or interpreter_value.is_enum():
        return ir_value.Value(interpreter_value.get_bits())
    elif interpreter_value.is_array():
        ir_arr = []
        for e in interpreter_value.get_elements():
            ir_arr.append(convert_interpreter_value_to_ir(e))
        return ir_value.Value.make_array(ir_arr)
    elif interpreter_value.is_tuple():
        ir_tuple = []
        for e in interpreter_value.get_elements():
            ir_tuple.append(convert_interpreter_value_to_ir(e))
        return ir_value.Value.make_tuple(ir_tuple)
    else:
        raise UnsupportedJitConversionError(
            "Can't convert to JIT value: {}".format(interpreter_value))
Exemplo n.º 2
0
def compare_values(interpreter_value: dslx_value.Value,
                   jit_value: ir_value.Value) -> None:
    """Asserts equality between a DSLX Value and an IR Value.

  Recursively traverses the values (for arrays/tuples) and makes assertions
  about value and length properties.

  Args:
    interpreter_value: Value that resulted from DSL interpretation.
    jit_value: Value that resulted from JIT-compiled execution.

  Raises:
    JitMiscompareError: If the dslx_value and jit_value are not equivalent.
    UnsupportedJitConversionError: If there is not JIT-supported type equivalent
      for the interpreter value.
  """
    if interpreter_value.is_bits() or interpreter_value.is_enum():
        assert jit_value.is_bits(), f'Expected bits value: {jit_value!r}'

        jit_bits_value = jit_value.get_bits()
        assert isinstance(jit_bits_value, ir_bits.Bits), jit_bits_value
        bit_count = interpreter_value.get_bit_count()
        if bit_count != jit_bits_value.bit_count():
            raise JitMiscompareError(f'Inconsistent bit counts for value -- '
                                     f'interp: {bit_count}, '
                                     f'jit: {jit_bits_value.bit_count()}')

        interpreter_bits_value = interpreter_value.get_bits()
        if interpreter_bits_value != jit_bits_value:
            raise JitMiscompareError(
                'Inconsistent bit values in return value -- '
                'interp: {!r}, jit: {!r}'.format(interpreter_bits_value,
                                                 jit_bits_value))

    elif interpreter_value.is_array():
        assert jit_value.is_array(), f'Expected array value: {jit_value!r}'

        interpreter_values = interpreter_value.get_elements()
        jit_values = jit_value.get_elements()
        interp_len = len(interpreter_values)
        jit_len = len(jit_values)
        if interp_len != jit_len:
            raise JitMiscompareError(
                f'Inconsistent array lengths in return value -- '
                f'interp: {interp_len}, jit: {jit_len}')

        for interpreter_element, jit_element in zip(interpreter_values,
                                                    jit_values):
            compare_values(interpreter_element, jit_element)
    elif interpreter_value.is_tuple():
        assert jit_value.is_tuple(), 'Expected tuple value: {jit_value!r}'

        interpreter_values = interpreter_value.get_elements()
        jit_values = jit_value.get_elements()
        interp_len = len(interpreter_values)
        jit_len = len(jit_values)
        if interp_len != jit_len:
            raise JitMiscompareError(
                f'Inconsistent tuple lengths in return value -- '
                f'interp: {interp_len}, jit: {jit_len}')

        for interpreter_element, jit_element in zip(interpreter_values,
                                                    jit_values):
            compare_values(interpreter_element, jit_element)
    else:
        raise UnsupportedJitConversionError(
            'No JIT-supported type equivalent: {}'.format(interpreter_value))