Example #1
0
def convert_interpreter_value_to_ir(
        interpreter_value: dslx_value.Value) -> ir_value.Value:
    """Recursively translates a DSLX Value into an IR Value."""
    if interpreter_value.is_bits() or interpreter_value.is_enum():
        return ir_value.Value(
            int_to_bits(interpreter_value.get_bits_value_check_sign(),
                        interpreter_value.get_bit_count()))
    elif interpreter_value.is_array():
        ir_arr = []
        for e in interpreter_value.array_payload.elements:
            ir_arr.append(convert_interpreter_value_to_ir(e))
        return ir_value.Value.make_array(ir_arr)
    elif interpreter_value.is_tuple():
        ir_tuple = []
        for e in interpreter_value.tuple_members:
            ir_tuple.append(convert_interpreter_value_to_ir(e))
        return ir_value.Value.make_tuple(ir_tuple)
    else:
        raise UnsupportedJitConversionError(
            "Can't convert to JIT value: {}".format(interpreter_value))
Example #2
0
def _value_compatible_with_type(module: ast.Module, type_: ConcreteType,
                                value: Value) -> bool:
    """Returns whether value is compatible with type_ (recursively)."""
    assert isinstance(value, Value), value

    if isinstance(type_, TupleType) and value.is_tuple():
        return all(
            _value_compatible_with_type(module, ct, m)
            for ct, m in zip(type_.get_unnamed_members(), value.tuple_members))

    if isinstance(type_, ArrayType) and value.is_array():
        et = type_.get_element_type()
        return all(
            _value_compatible_with_type(module, et, m)
            for m in value.array_payload.elements)

    if isinstance(type_, EnumType) and value.tag == Tag.ENUM:
        return type_.get_nominal_type(module) == value.type_

    if isinstance(type_,
                  BitsType) and not type_.signed and value.tag == Tag.UBITS:
        return value.bits_payload.bit_count == type_.get_total_bit_count()

    if isinstance(type_, BitsType) and type_.signed and value.tag == Tag.SBITS:
        return value.bits_payload.bit_count == type_.get_total_bit_count()

    if value.tag == Tag.ENUM and isinstance(type_, BitsType):
        return (value.type_.get_signedness() == type_.signed and
                value.bits_payload.bit_count == type_.get_total_bit_count())

    if value.tag == Tag.ARRAY and is_ubits(type_):
        flat_bit_count = value.array_payload.flatten().bits_payload.bit_count
        return flat_bit_count == type_.get_total_bit_count()

    if isinstance(type_, EnumType) and value.is_bits():
        return (type_.signed == (value.tag == Tag.SBITS)
                and type_.get_total_bit_count() == value.get_bit_count())

    raise NotImplementedError(type_, value)
Example #3
0
def compare_values(interpreter_value: dslx_value.Value,
                   jit_value: ir_value.Value) -> None:
    """Asserts equality between a DSLX Value and an IR Value.

  Recursively traverses the values (for arrays/tuples) and makes assertions
  about value and length properties.

  Args:
    interpreter_value: Value that resulted from DSL interpretation.
    jit_value: Value that resulted from JIT-compiled execution.

  Raises:
    JitMiscompareError: If the dslx_value and jit_value are not equivalent.
    UnsupportedJitConversionError: If there is not JIT-supported type equivalent
      for the interpreter value.
  """
    if interpreter_value.is_bits() or interpreter_value.is_enum():
        assert jit_value.is_bits(), f'Expected bits value: {jit_value!r}'

        jit_value = jit_value.get_bits()
        bit_count = interpreter_value.get_bit_count()
        if bit_count != jit_value.bit_count():
            raise JitMiscompareError(f'Inconsistent bit counts for value -- '
                                     f'interp: {bit_count}, '
                                     f'jit: {jit_value.bit_count()}')

        if interpreter_value.is_ubits():
            interpreter_bits_value = interpreter_value.get_bits_value()
            jit_bits_value = bits_to_int(jit_value, signed=False)
        else:
            interpreter_bits_value = interpreter_value.get_bits_value_signed()
            jit_bits_value = bits_to_int(jit_value, signed=True)

        if interpreter_bits_value != jit_bits_value:
            raise JitMiscompareError(
                'Inconsistent bit values in return value -- '
                'interp: {!r}, jit: {!r}'.format(interpreter_bits_value,
                                                 jit_bits_value))

    elif interpreter_value.is_array():
        assert jit_value.is_array(), f'Expected array value: {jit_value!r}'

        interpreter_values = interpreter_value.array_payload.elements
        jit_values = jit_value.get_elements()
        interp_len = len(interpreter_values)
        jit_len = len(jit_values)
        if interp_len != jit_len:
            raise JitMiscompareError(
                f'Inconsistent array lengths in return value -- '
                f'interp: {interp_len}, jit: {jit_len}')

        for interpreter_element, jit_element in zip(interpreter_values,
                                                    jit_values):
            compare_values(interpreter_element, jit_element)
    elif interpreter_value.is_tuple():
        assert jit_value.is_tuple(), 'Expected tuple value: {jit_value!r}'

        interpreter_values = interpreter_value.tuple_members
        jit_values = jit_value.get_elements()
        interp_len = len(interpreter_values)
        jit_len = len(jit_values)
        if interp_len != jit_len:
            raise JitMiscompareError(
                f'Inconsistent tuple lengths in return value -- '
                f'interp: {interp_len}, jit: {jit_len}')

        for interpreter_element, jit_element in zip(interpreter_values,
                                                    jit_values):
            compare_values(interpreter_element, jit_element)
    else:
        raise UnsupportedJitConversionError(
            'No JIT-supported type equivalent: {}'.format(interpreter_value))