Ejemplo n.º 1
0
def setitem_computation(dest, source):
    """
    Returns a compiled computation that broadcasts ``source`` to ``dest``,
    where ``dest`` is a GPU array, and ``source`` is either a GPU array or a scalar.
    """
    if len(source.shape) == 0:
        trf = transformations.broadcast_param(dest)
        return PureParallel.from_trf(trf, guiding_array=trf.output)
    else:
        source_dt = Type.from_value(source).with_dtype(dest.dtype)
        trf = transformations.copy(source_dt, dest)
        comp = PureParallel.from_trf(trf, guiding_array=trf.output)
        cast_trf = transformations.cast(source, dest.dtype)
        comp.parameter.input.connect(cast_trf, cast_trf.output, src_input=cast_trf.input)
        return comp
Ejemplo n.º 2
0
def test_broadcast_param(some_thr, dtype_to_broadcast):

    dtype = dtypes.align(dtype_to_broadcast)
    param = get_test_array(1, dtype)[0]

    output_ref = numpy.empty((1000,), dtype)
    output_ref[:] = param

    output_dev = some_thr.empty_like(output_ref)

    test = get_test_computation(output_dev)
    bc = tr.broadcast_param(output_dev)
    test.parameter.input.connect(bc, bc.output, param=bc.param)
    testc = test.compile(some_thr)

    testc(output_dev, param)
    assert diff_is_negligible(output_dev.get(), output_ref)
Ejemplo n.º 3
0
def test_broadcast_param(some_thr, dtype_to_broadcast):

    dtype = dtypes.align(dtype_to_broadcast)
    param = get_test_array(1, dtype)[0]

    output_ref = numpy.empty((1000, ), dtype)
    output_ref[:] = param

    output_dev = some_thr.empty_like(output_ref)

    test = get_test_computation(output_dev)
    bc = tr.broadcast_param(output_dev)
    test.parameter.input.connect(bc, bc.output, param=bc.param)
    testc = test.compile(some_thr)

    testc(output_dev, param)
    assert diff_is_negligible(output_dev.get(), output_ref)
Ejemplo n.º 4
0
def setitem_computation(dest, source, is_array):
    """
    Returns a compiled computation that broadcasts ``source`` to ``dest``,
    where ``dest`` is a GPU array, and ``source`` is either a GPU array or a scalar.
    """
    if is_array:
        source_dt = Type.from_value(source).with_dtype(dest.dtype)
        trf = transformations.copy(source_dt, dest)
        comp = PureParallel.from_trf(trf, guiding_array=trf.output)
        cast_trf = transformations.cast(source, dest.dtype)
        comp.parameter.input.connect(cast_trf,
                                     cast_trf.output,
                                     src_input=cast_trf.input)
        return comp
    else:
        trf = transformations.broadcast_param(dest)
        return PureParallel.from_trf(trf, guiding_array=trf.output)
Ejemplo n.º 5
0
def LweNoiselessTrivialConstant(shape_info):
    comp = LweNoiselessTrivial(shape_info, shape_info.shape)
    bc = transformations.broadcast_param(comp.parameter.mus)
    comp.parameter.mus.connect(bc, bc.output, mu=bc.param)
    return comp