def fn(x):
                    arg = array_ops.zeros([], dtype)  # pylint: disable=cell-var-from-loop
                    reducer = kahan_sum_reducer.get_concrete_function(
                        (arg, arg), (arg, arg))

                    return xla.variadic_reduce((x, array_ops.zeros_like(x)),
                                               init_value=(arg, arg),
                                               dimensions_to_reduce=dims,
                                               reducer=reducer)[output_idx]
 def reduce(*values, dimensions_to_reduce):
     return xla.variadic_reduce(
         values,
         (
             init_val_1,  # pylint: disable=cell-var-from-loop
             init_val_2,  # pylint: disable=cell-var-from-loop
         ),
         dimensions_to_reduce=dimensions_to_reduce,
         reducer=reducer_func)  # pylint: disable=cell-var-from-loop
 def reduce(values, *, dimensions_to_reduce):
     if is_v2:
         return xla.variadic_reduce(
             (values, ),
             (init_val, ),  # pylint: disable=cell-var-from-loop
             dimensions_to_reduce=dimensions_to_reduce,
             reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop
     else:
         return gen_xla_ops.xla_variadic_reduce(
             (values, ),
             (init_val, ),  # pylint: disable=cell-var-from-loop
             dimensions_to_reduce=dimensions_to_reduce,
             reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop
Exemple #4
0
 def _xla_reduce(operands, inits, axis):
     """JIT-ed wrapper for TF `xla.variadic_reduce(..., reducer)`."""
     from tensorflow.compiler.tf2xla.python import xla  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
     result = xla.variadic_reduce(
         operands,
         init_values=inits,
         dimensions_to_reduce=axis,
         reducer=tf.function(reducer).get_concrete_function(inits, inits))
     # Graph mode: variadic reduce doesn't specify output shapes. Patch that.
     shp = operands[0].shape
     for arg in operands:
         shp = tensorshape_util.merge_with(shp, arg.shape)
     for part in result:
         tensorshape_util.set_shape(
             part, tuple(dim for i, dim in enumerate(shp) if i not in axis))
     return result
        def reduce_with_shapes(shape1,
                               shape2,
                               shape3,
                               dimensions_to_reduce=(1, )):
            inputs = (array_ops.placeholder(np.float32, shape=shape1),
                      array_ops.placeholder(np.int32, shape=shape2),
                      array_ops.placeholder(np.int32, shape=shape3))
            init_values = (array_ops.placeholder(np.float32, shape=()),
                           array_ops.placeholder(np.int32, shape=()),
                           array_ops.placeholder(np.int32, shape=()))

            return xla.variadic_reduce(
                inputs,
                init_values,
                dimensions_to_reduce=dimensions_to_reduce,
                reducer=reducer_func)
    def testVariadicReduceV2SingleArg(self):
        @def_function.function
        def reducer_add(op_element, acc_val):
            return (op_element + acc_val, )

        dtype = np.float32
        arg_spec = array_ops.zeros([], dtype)  # pylint: disable=cell-var-from-loop
        reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec)

        res = xla.variadic_reduce(
            (array_ops.placeholder(np.float32, shape=(3, 4, 5)), ),
            (array_ops.placeholder(np.float32, shape=()), ),
            dimensions_to_reduce=(1, ),
            reducer=reducer_func)
        self.assertLen(res, 1)
        self.assertEqual(res[0].shape, tensor_shape.TensorShape([3, 5]))