def setitem_computation(dest, source): """ Returns a compiled computation that broadcasts ``source`` to ``dest``, where ``dest`` is a GPU array, and ``source`` is either a GPU array or a scalar. """ if len(source.shape) == 0: trf = transformations.broadcast_param(dest) return PureParallel.from_trf(trf, guiding_array=trf.output) else: source_dt = Type.from_value(source).with_dtype(dest.dtype) trf = transformations.copy(source_dt, dest) comp = PureParallel.from_trf(trf, guiding_array=trf.output) cast_trf = transformations.cast(source, dest.dtype) comp.parameter.input.connect(cast_trf, cast_trf.output, src_input=cast_trf.input) return comp
def test_broadcast_param(some_thr, dtype_to_broadcast): dtype = dtypes.align(dtype_to_broadcast) param = get_test_array(1, dtype)[0] output_ref = numpy.empty((1000,), dtype) output_ref[:] = param output_dev = some_thr.empty_like(output_ref) test = get_test_computation(output_dev) bc = tr.broadcast_param(output_dev) test.parameter.input.connect(bc, bc.output, param=bc.param) testc = test.compile(some_thr) testc(output_dev, param) assert diff_is_negligible(output_dev.get(), output_ref)
def test_broadcast_param(some_thr, dtype_to_broadcast): dtype = dtypes.align(dtype_to_broadcast) param = get_test_array(1, dtype)[0] output_ref = numpy.empty((1000, ), dtype) output_ref[:] = param output_dev = some_thr.empty_like(output_ref) test = get_test_computation(output_dev) bc = tr.broadcast_param(output_dev) test.parameter.input.connect(bc, bc.output, param=bc.param) testc = test.compile(some_thr) testc(output_dev, param) assert diff_is_negligible(output_dev.get(), output_ref)
def setitem_computation(dest, source, is_array): """ Returns a compiled computation that broadcasts ``source`` to ``dest``, where ``dest`` is a GPU array, and ``source`` is either a GPU array or a scalar. """ if is_array: source_dt = Type.from_value(source).with_dtype(dest.dtype) trf = transformations.copy(source_dt, dest) comp = PureParallel.from_trf(trf, guiding_array=trf.output) cast_trf = transformations.cast(source, dest.dtype) comp.parameter.input.connect(cast_trf, cast_trf.output, src_input=cast_trf.input) return comp else: trf = transformations.broadcast_param(dest) return PureParallel.from_trf(trf, guiding_array=trf.output)
def LweNoiselessTrivialConstant(shape_info): comp = LweNoiselessTrivial(shape_info, shape_info.shape) bc = transformations.broadcast_param(comp.parameter.mus) comp.parameter.mus.connect(bc, bc.output, mu=bc.param) return comp