コード例 #1
0
ファイル: test_jax.py プロジェクト: brandonwillard/aesara
def test_jax_CAReduce():
    a_aet = vector("a")
    a_aet.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)

    x = aet_sum(a_aet, axis=None)
    x_fg = FunctionGraph([a_aet], [x])

    compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])

    a_aet = matrix("a")
    a_aet.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)

    x = aet_sum(a_aet, axis=0)
    x_fg = FunctionGraph([a_aet], [x])

    compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

    x = aet_sum(a_aet, axis=1)
    x_fg = FunctionGraph([a_aet], [x])

    compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

    a_aet = matrix("a")
    a_aet.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)

    x = prod(a_aet, axis=0)
    x_fg = FunctionGraph([a_aet], [x])

    compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])

    x = aet_all(a_aet)
    x_fg = FunctionGraph([a_aet], [x])

    compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
コード例 #2
0
def test_jax_logp():

    mu = vector("mu")
    mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX)
    tau = vector("tau")
    tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX)
    sigma = vector("sigma")
    sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX)
    value = vector("value")
    value.tag.test_value = np.r_[0.1, -10].astype(config.floatX)

    logp = (-tau * (value - mu)**2 + log(tau / np.pi / 2.0)) / 2.0
    conditions = [sigma > 0]
    alltrue = aet_all([aet_all(1 * val) for val in conditions])
    normal_logp = aet.switch(alltrue, logp, -np.inf)

    fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp])

    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
コード例 #3
0
def bincount(x, weights=None, minlength=None, assert_nonneg=False):
    """Count number of occurrences of each value in array of ints.

    The number of bins (of size 1) is one larger than the largest
    value in x. If minlength is specified, there will be at least
    this number of bins in the output array (though it will be longer
    if necessary, depending on the contents of x). Each bin gives the
    number of occurrences of its index value in x. If weights is
    specified the input array is weighted by it, i.e. if a value n
    is found at position i, out[n] += weight[i] instead of out[n] += 1.

    Parameters
    ----------
    x : 1 dimension, nonnegative ints
    weights : array of the same shape as x with corresponding weights.
        Optional.
    minlength : A minimum number of bins for the output array.
        Optional.
    assert_nonneg : A flag that inserts an assert_op to check if
        every input x is nonnegative.
        Optional.


    .. versionadded:: 0.6

    """
    if x.ndim != 1:
        raise TypeError("Inputs must be of dimension 1.")

    if assert_nonneg:
        assert_op = Assert("Input to bincount has negative values!")
        x = assert_op(x, aet_all(x >= 0))

    max_value = aet.cast(x.max() + 1, "int64")

    if minlength is not None:
        max_value = maximum(max_value, minlength)

    # Note: we do not use inc_subtensor(out[x], ...) in the following lines,
    # since out[x] raises an exception if the indices (x) are int8.
    if weights is None:
        out = aet.zeros([max_value], dtype=x.dtype)
        out = advanced_inc_subtensor1(out, 1, x)
    else:
        out = aet.zeros([max_value], dtype=weights.dtype)
        out = advanced_inc_subtensor1(out, weights, x)
    return out
コード例 #4
0
def broadcast_shape_iter(arrays, **kwargs):
    """Compute the shape resulting from broadcasting arrays.

    Parameters
    ----------
    arrays: Iterable[TensorVariable] or Iterable[Tuple[Variable]]
        An iterable of tensors, or a tuple of shapes (as tuples),
        for which the broadcast shape is computed.
        XXX: Do not call this with a generator/iterator; this function will not
        make copies!
    arrays_are_shapes: bool (Optional)
        Indicates whether or not the `arrays` contains shape tuples.
        If you use this approach, make sure that the broadcastable dimensions
        are (scalar) constants with the value `1` or `1` exactly.

    """
    one = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)

    arrays_are_shapes = kwargs.pop("arrays_are_shapes", False)
    if arrays_are_shapes:
        max_dims = max(len(a) for a in arrays)

        array_shapes = [(one, ) * (max_dims - len(a)) +
                        tuple(one if getattr(sh, "value", sh) == 1 else sh
                              for sh in a) for a in arrays]
    else:
        max_dims = max(a.ndim for a in arrays)

        array_shapes = [(one, ) * (max_dims - a.ndim) +
                        tuple(one if bcast else sh
                              for sh, bcast in zip(a.shape, a.broadcastable))
                        for a in arrays]

    result_dims = []

    for dim_shapes in zip(*array_shapes):
        non_bcast_shapes = [shape for shape in dim_shapes if shape != one]

        if len(non_bcast_shapes) > 0:
            # Either there's only one non-broadcastable dimensions--and that's
            # what determines the dimension size, or there are multiple
            # non-broadcastable dimensions that must be equal
            i_dim = non_bcast_shapes.pop()

            potentially_unequal_dims = [
                dim for dim in non_bcast_shapes
                # TODO FIXME: This is a largely deficient means of comparing graphs
                # (and especially shapes)
                if not equal_computations([i_dim], [dim])
            ]

            if potentially_unequal_dims:
                # In this case, we can't tell whether or not the dimensions are
                # equal, so we'll need to assert their equality and move the error
                # handling to evaluation time.
                assert_dim = Assert("Could not broadcast dimensions")
                eq_condition = aet_all([
                    or_(eq(dim, one), eq(i_dim, dim))
                    for dim in potentially_unequal_dims
                ])
                eq_condition = or_(eq(i_dim, one), eq_condition)
                result_dims.append(assert_dim(i_dim, eq_condition))
            else:
                result_dims.append(i_dim)
        else:
            # Every array was broadcastable in this dimension
            result_dims.append(one)

    return tuple(result_dims)