コード例 #1
0
def pow(dtype, power_dtype=None):
    """
    Returns a :py:class:`~reikna.cluda.Module` with a function of two arguments
    that raises the first argument of type ``dtype`` (must be a real or complex data type)
    to the power of the second argument (a corresponding real data type or an integer).
    """
    if dtypes.is_complex(power_dtype):
        raise NotImplementedError("pow() with a complex power is not supported")

    if power_dtype is None:
        if dtypes.is_integer(dtype):
            raise ValueError("Power dtype must be specified for an integer argument")
        elif dtypes.is_real(dtype):
            power_dtype = dtype
        else:
            power_dtype = dtypes.real_for(dtype)

    if dtypes.is_complex(dtype):
        r_dtype = dtypes.real_for(dtype)
    elif dtypes.is_real(dtype):
        r_dtype = dtype
    elif dtypes.is_real(power_dtype):
        r_dtype = power_dtype
    else:
        r_dtype = numpy.float32

    if dtypes.is_integer(dtype) and dtypes.is_real(power_dtype):
        dtype = power_dtype

    return Module(
        TEMPLATE.get_def('pow'),
        render_kwds=dict(
            dtype=dtype, power_dtype=power_dtype,
            mul_=mul(dtype, dtype), div_=div(dtype, dtype),
            polar_=polar(r_dtype)))
コード例 #2
0
ファイル: helpers.py プロジェクト: ringw/reikna
def get_test_array(shape, dtype, strides=None, no_zeros=False, high=None):
    shape = wrap_in_tuple(shape)
    dtype = dtypes.normalize_type(dtype)

    if dtype.names is not None:
        result = numpy.empty(shape, dtype)
        for name in dtype.names:
            result[name] = get_test_array(shape,
                                          dtype[name],
                                          no_zeros=no_zeros,
                                          high=high)
    else:
        if dtypes.is_integer(dtype):
            low = 1 if no_zeros else 0
            if high is None:
                high = 100  # will work even with signed chars
            get_arr = lambda: numpy.random.randint(low, high, shape).astype(
                dtype)
        else:
            low = 0.01 if no_zeros else 0
            if high is None:
                high = 1.0
            get_arr = lambda: numpy.random.uniform(low, high, shape).astype(
                dtype)

        if dtypes.is_complex(dtype):
            result = get_arr() + 1j * get_arr()
        else:
            result = get_arr()

    if strides is not None:
        result = as_strided(result, result.shape, strides)

    return result
コード例 #3
0
ファイル: helpers.py プロジェクト: xexo7C8/reikna
def diff_is_negligible(m, m_ref, atol=None, rtol=None, verbose=True):

    if m.dtype.names is not None:
        return all(
            diff_is_negligible(m[name], m_ref[name]) for name in m.dtype.names)

    assert m.dtype == m_ref.dtype

    if dtypes.is_integer(m.dtype):
        close = (m == m_ref)
    else:
        if atol is None:
            atol = DOUBLE_ATOL if dtypes.is_double(m.dtype) else SINGLE_ATOL
        if rtol is None:
            rtol = DOUBLE_RTOL if dtypes.is_double(m.dtype) else SINGLE_RTOL

        close = numpy.isclose(m, m_ref, atol=atol, rtol=rtol)

    if close.all():
        return True

    if verbose:
        far_idxs = numpy.vstack(numpy.where(close == False)).T
        print(("diff_is_negligible() with atol={atol} and rtol={rtol} " +
               "found {diffs} differences, first ones are:").format(
                   atol=atol, rtol=rtol, diffs=str(far_idxs.shape[0])))
        for idx, _ in zip(far_idxs, range(10)):
            idx = tuple(idx)
            print("idx: {idx}, test: {test}, ref: {ref}".format(
                idx=idx, test=m[idx], ref=m_ref[idx]))

    return False
コード例 #4
0
ファイル: helpers.py プロジェクト: mgolub2/reikna
def diff_is_negligible(m, m_ref, atol=None, rtol=None):

    if m.dtype.names is not None:
        return all(diff_is_negligible(m[name], m_ref[name]) for name in m.dtype.names)

    assert m.dtype == m_ref.dtype

    if dtypes.is_integer(m.dtype):
        close = (m == m_ref)
    else:
        if atol is None:
            atol = DOUBLE_ATOL if dtypes.is_double(m.dtype) else SINGLE_ATOL
        if rtol is None:
            rtol = DOUBLE_RTOL if dtypes.is_double(m.dtype) else SINGLE_RTOL

        close = numpy.isclose(m, m_ref, atol=atol, rtol=rtol)

    if close.all():
        return True

    far_idxs = numpy.vstack(numpy.where(close == False)).T
    print(
        ("diff_is_negligible() with atol={atol} and rtol={rtol} " +
        "found {diffs} differences, first ones are:").format(
        atol=atol, rtol=rtol, diffs=str(far_idxs.shape[0])))
    for idx, _ in zip(far_idxs, range(10)):
        idx = tuple(idx)
        print("idx: {idx}, test: {test}, ref: {ref}".format(
            idx=idx, test=m[idx], ref=m_ref[idx]))

    return False
コード例 #5
0
ファイル: helpers.py プロジェクト: mgolub2/reikna
def get_test_array(shape, dtype, strides=None, no_zeros=False, high=None):
    shape = wrap_in_tuple(shape)
    dtype = dtypes.normalize_type(dtype)

    if dtype.names is not None:
        result = numpy.empty(shape, dtype)
        for name in dtype.names:
            result[name] = get_test_array(shape, dtype[name], no_zeros=no_zeros, high=high)
    else:
        if dtypes.is_integer(dtype):
            low = 1 if no_zeros else 0
            if high is None:
                high = 100 # will work even with signed chars
            get_arr = lambda: numpy.random.randint(low, high, shape).astype(dtype)
        else:
            low = 0.01 if no_zeros else 0
            if high is None:
                high = 1.0
            get_arr = lambda: numpy.random.uniform(low, high, shape).astype(dtype)

        if dtypes.is_complex(dtype):
            result = get_arr() + 1j * get_arr()
        else:
            result = get_arr()

    if strides is not None:
        result = as_strided(result, result.shape, strides)

    return result
コード例 #6
0
ファイル: helpers.py プロジェクト: SyamGadde/reikna
def diff_is_negligible(m, m_ref):

    if m.dtype.names is not None:
        return all(diff_is_negligible(m[name], m_ref[name]) for name in m.dtype.names)

    assert m.dtype == m_ref.dtype

    if dtypes.is_integer(m.dtype):
        return ((m - m_ref) == 0).all()

    diff = float_diff(m, m_ref)
    if dtypes.is_double(m.dtype):
        return diff < DOUBLE_EPS
    else:
        return diff < SINGLE_EPS
コード例 #7
0
ファイル: helpers.py プロジェクト: ringw/reikna
def diff_is_negligible(m, m_ref):

    if m.dtype.names is not None:
        return all(
            diff_is_negligible(m[name], m_ref[name]) for name in m.dtype.names)

    assert m.dtype == m_ref.dtype

    if dtypes.is_integer(m.dtype):
        return ((m - m_ref) == 0).all()

    diff = float_diff(m, m_ref)
    if dtypes.is_double(m.dtype):
        return diff < DOUBLE_EPS
    else:
        return diff < SINGLE_EPS
コード例 #8
0
def exp(dtype):
    """
    Returns a :py:class:`~reikna.cluda.Module` with a function of one argument
    that exponentiates the value of type ``dtype``
    (must be a real or complex data type).
    """
    if dtypes.is_integer(dtype):
        raise NotImplementedError("exp() of " + str(dtype) + " is not supported")

    if dtypes.is_real(dtype):
        polar_unit_ = None
    else:
        polar_unit_ = polar_unit(dtypes.real_for(dtype))
    return Module(
        TEMPLATE.get_def('exp'),
        render_kwds=dict(dtype=dtype, polar_unit_=polar_unit_))
コード例 #9
0
ファイル: functions.py プロジェクト: xexo7C8/reikna
def pow(dtype, exponent_dtype=None, output_dtype=None):
    """
    Returns a :py:class:`~reikna.cluda.Module` with a function of two arguments
    that raises the first argument of type ``dtype``
    to the power of the second argument of type ``exponent_dtype``
    (an integer or real data type).
    If ``exponent_dtype`` or ``output_dtype`` are not given, they default to ``dtype``.
    If ``dtype`` is not the same as ``output_dtype``,
    the input is cast to ``output_dtype`` *before* exponentiation.
    If ``exponent_dtype`` is real, but both ``dtype`` and ``output_dtype`` are integer,
    a ``ValueError`` is raised.
    """
    if exponent_dtype is None:
        exponent_dtype = dtype

    if output_dtype is None:
        output_dtype = dtype

    if dtypes.is_complex(exponent_dtype):
        raise NotImplementedError("pow() with a complex exponent is not supported")

    if dtypes.is_real(exponent_dtype):
        if dtypes.is_complex(output_dtype):
            exponent_dtype = dtypes.real_for(output_dtype)
        elif dtypes.is_real(output_dtype):
            exponent_dtype = output_dtype
        else:
            raise ValueError("pow(integer, float): integer is not supported")

    kwds = dict(
        dtype=dtype, exponent_dtype=exponent_dtype, output_dtype=output_dtype,
        div_=None, mul_=None, cast_=None, polar_=None)
    if output_dtype != dtype:
        kwds['cast_'] = cast(output_dtype, dtype)
    if dtypes.is_integer(exponent_dtype) and not dtypes.is_real(output_dtype):
        kwds['mul_'] = mul(output_dtype, output_dtype)
        kwds['div_'] = div(output_dtype, output_dtype)
    if dtypes.is_complex(output_dtype):
        kwds['polar_'] = polar(dtypes.real_for(output_dtype))

    return Module(TEMPLATE.get_def('pow'), render_kwds=kwds)