Example #1
0
 def reduce(values, *, dimensions_to_reduce):
     if is_v2:
         return xla.variadic_reduce(
             (values, ),
             (init_val, ),  # pylint: disable=cell-var-from-loop
             dimensions_to_reduce=dimensions_to_reduce,
             reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop
     else:
         return gen_xla_ops.xla_variadic_reduce(
             (values, ),
             (init_val, ),  # pylint: disable=cell-var-from-loop
             dimensions_to_reduce=dimensions_to_reduce,
             reducer=reducer_func)[0]  # pylint: disable=cell-var-from-loop
Example #2
0
                def fn(x):
                    arg = array_ops.zeros([], dtype)  # pylint: disable=cell-var-from-loop
                    reducer = kahan_sum_reducer.get_concrete_function(
                        (arg, arg), (arg, arg))

                    if is_v2:
                        return xla.variadic_reduce(
                            (x, array_ops.zeros_like(x)),
                            init_values=(arg, arg),
                            dimensions_to_reduce=dims,
                            reducer=reducer)[output_idx]
                    else:
                        return gen_xla_ops.xla_variadic_reduce(
                            (x, array_ops.zeros_like(x)),
                            init_value=(arg, arg),
                            dimensions_to_reduce=dims,
                            reducer=reducer)[output_idx]