def test_fn(operand, source): return xla.select_and_scatter(operand, window_dimensions=[2, 3, 1, 1], window_strides=[2, 2, 1, 1], padding=[[0, 0]] * 4, source=source, init_value=0, select=ge_select, scatter=add_scatter)
def test_fn(operand, source): return xla.select_and_scatter( operand, window_dimensions=[2, 3, 1, 1], window_strides=[2, 2, 1, 1], padding=[[0, 0]] * 4, source=source, init_value=0, select=ge_select, scatter=add_scatter)
def _select_and_scatter_add(operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): del select_jaxpr, select_consts, scatter_jaxpr, scatter_consts # TODO(phawkins): handle the select and scatter jaxprs correctly. a = tf.constant(0, operand.dtype) select_fn = _ge_fn.get_concrete_function(a, a) scatter_fn = _add_fn.get_concrete_function(a, a) return tfxla.select_and_scatter(operand, window_dimensions, window_strides, padding, source, init_value, select_fn, scatter_fn)