示例#1
0
文件: jax2tf.py 项目: ekelsen/jax
def _reduce_window(jax_f,
                   reducer,
                   init_val,
                   operand,
                   window_dimensions,
                   window_strides,
                   padding,
                   input_shape=None):
    """TensorFlow implementation of reduce_window_{sum,min,max}."""
    del input_shape
    # TODO(tomhennigan): tf2xla should have a shape inference function.
    out_shape = _reduce_window_shape(jax_f, operand, window_dimensions,
                                     window_strides, padding)
    padding = lax.padtype_to_pads(_get_shape_from_tensor_or_array(operand),
                                  window_dimensions, window_strides, padding)
    a = tf.constant(0, operand.dtype)
    reducer_fn = reducer.get_concrete_function(a, a)
    out = tfxla.reduce_window(operand,
                              tf.constant(init_val, operand.dtype),
                              reducer_fn,
                              window_dimensions,
                              window_strides,
                              padding=padding)
    out.set_shape(out_shape)
    return out
示例#2
0
 def _reduce_window(self, operand, init, reducer, **kwargs):
     with self.test_session():
         placeholder = array_ops.placeholder(operand.dtype)
         with self.test_scope():
             output = xla.reduce_window(placeholder, init, reducer,
                                        **kwargs)
         return output.eval(feed_dict={placeholder: operand})
示例#3
0
 def _reduce_window(self, operand, init, reducer, **kwargs):
   with self.cached_session():
     placeholder = array_ops.placeholder(operand.dtype)
     with self.test_scope():
       output = xla.reduce_window(placeholder, init, reducer, **kwargs)
     return output.eval(feed_dict={placeholder: operand})