예제 #1
0
 def reduce(*values, dimensions_to_reduce):
     return xla.variadic_reduce_v2(
         values,
         (
             init_val_1,
             init_val_2,
         ),  # pylint: disable=cell-var-from-loop
         dimensions_to_reduce=dimensions_to_reduce,
         reducer=reducer_func)  # pylint: disable=cell-var-from-loop
예제 #2
0
        def reduce_with_shapes(shape1,
                               shape2,
                               shape3,
                               dimensions_to_reduce=(1, )):
            inputs = (array_ops.placeholder(np.float32, shape=shape1),
                      array_ops.placeholder(np.int32, shape=shape2),
                      array_ops.placeholder(np.int32, shape=shape3))
            init_values = (array_ops.placeholder(np.float32, shape=()),
                           array_ops.placeholder(np.int32, shape=()),
                           array_ops.placeholder(np.int32, shape=()))

            return xla.variadic_reduce_v2(
                inputs,
                init_values,
                dimensions_to_reduce=dimensions_to_reduce,
                reducer=reducer_func)
예제 #3
0
    def testVariadicReduceV2SingleArg(self):
        @def_function.function
        def reducer_add(op_element, acc_val):
            return (op_element + acc_val, )

        dtype = np.float32
        arg_spec = array_ops.zeros([], dtype)  # pylint: disable=cell-var-from-loop
        reducer_func = reducer_add.get_concrete_function(arg_spec, arg_spec)

        res = xla.variadic_reduce_v2(
            (array_ops.placeholder(np.float32, shape=(3, 4, 5)), ),
            (array_ops.placeholder(np.float32, shape=()), ),
            dimensions_to_reduce=(1, ),
            reducer=reducer_func)
        self.assertLen(res, 1)
        self.assertEqual(res[0].shape, tensor_shape.TensorShape([3, 5]))
예제 #4
0
                def fn(x):
                    arg = array_ops.zeros([], dtype)  # pylint: disable=cell-var-from-loop
                    reducer = kahan_sum_reducer.get_concrete_function(
                        (arg, arg), (arg, arg))

                    if use_v2:
                        return xla.variadic_reduce_v2(
                            (x, array_ops.zeros_like(x)),
                            init_values=(arg, arg),
                            dimensions_to_reduce=dims,
                            reducer=reducer)[output_idx]
                    else:
                        return xla.variadic_reduce(
                            (x, array_ops.zeros_like(x)),
                            init_value=(arg, arg),
                            dimensions_to_reduce=dims,
                            reducer=reducer)[output_idx]
예제 #5
0
 def reduce(values, *, dimensions_to_reduce):
     return xla.variadic_reduce_v2(
         (values, ),
         (init_val, ),  # pylint: disable=cell-var-from-loop
         dimensions_to_reduce=dimensions_to_reduce,
         reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop