def fn(x): arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop reducer = kahan_sum_reducer.get_concrete_function( (arg, arg), (arg, arg)) return xla.variadic_reduce((x, array_ops.zeros_like(x)), init_value=(arg, arg), dimensions_to_reduce=dims, reducer=reducer)[output_idx]
def reduce(*values, dimensions_to_reduce): return xla.variadic_reduce( values, ( init_val_1, # pylint: disable=cell-var-from-loop init_val_2, # pylint: disable=cell-var-from-loop ), dimensions_to_reduce=dimensions_to_reduce, reducer=reducer_func) # pylint: disable=cell-var-from-loop
def reduce(values, *, dimensions_to_reduce): if is_v2: return xla.variadic_reduce( (values, ), (init_val, ), # pylint: disable=cell-var-from-loop dimensions_to_reduce=dimensions_to_reduce, reducer=reducer_func)[0] # pylint: disable=cell-var-from-loop else: return gen_xla_ops.xla_variadic_reduce( (values, ), (init_val, ), # pylint: disable=cell-var-from-loop dimensions_to_reduce=dimensions_to_reduce, reducer=reducer_func)[0] # pylint: disable=cell-var-from-loop
def _xla_reduce(operands, inits, axis): """JIT-ed wrapper for TF `xla.variadic_reduce(..., reducer)`.""" from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top result = xla.variadic_reduce( operands, init_values=inits, dimensions_to_reduce=axis, reducer=tf.function(reducer).get_concrete_function(inits, inits)) # Graph mode: variadic reduce doesn't specify output shapes. Patch that. shp = operands[0].shape for arg in operands: shp = tensorshape_util.merge_with(shp, arg.shape) for part in result: tensorshape_util.set_shape( part, tuple(dim for i, dim in enumerate(shp) if i not in axis)) return result
def reduce_with_shapes(shape1, shape2, shape3, dimensions_to_reduce=(1, )): inputs = (array_ops.placeholder(np.float32, shape=shape1), array_ops.placeholder(np.int32, shape=shape2), array_ops.placeholder(np.int32, shape=shape3)) init_values = (array_ops.placeholder(np.float32, shape=()), array_ops.placeholder(np.int32, shape=()), array_ops.placeholder(np.int32, shape=())) return xla.variadic_reduce( inputs, init_values, dimensions_to_reduce=dimensions_to_reduce, reducer=reducer_func)
def testVariadicReduceV2SingleArg(self): @def_function.function def reducer_add(op_element, acc_val): return (op_element + acc_val, ) dtype = np.float32 arg_spec = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec) res = xla.variadic_reduce( (array_ops.placeholder(np.float32, shape=(3, 4, 5)), ), (array_ops.placeholder(np.float32, shape=()), ), dimensions_to_reduce=(1, ), reducer=reducer_func) self.assertLen(res, 1) self.assertEqual(res[0].shape, tensor_shape.TensorShape([3, 5]))