def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, experimental_named_axis=None, experimental_allow_all_gather=False, name=None): """Computes `log(mean(exp(input_tensor)))`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than `log(reduce_mean(exp(input)))`. It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. Args: input_tensor: The tensor to reduce. Should have numeric type. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keepdims: Boolean. Whether to keep the axis as singleton dimensions. Default value: `False` (i.e., squeeze the reduced dimensions). experimental_named_axis: A `str or list of `str` axis names to additionally reduce over. Providing `None` will not reduce over any axes. experimental_allow_all_gather: Allow using an `all_gather`-based fallback under TensorFlow when computing the distributed maximum. This fallback is only efficient when `axis` reduces away most of the dimensions of `input_tensor`. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., `'reduce_logmeanexp'`). Returns: log_mean_exp: The reduced tensor. """ with tf.name_scope(name or 'reduce_logmeanexp'): named_axes = distribute_lib.canonicalize_named_axis( experimental_named_axis) lse = distribute_lib.reduce_logsumexp( input_tensor, axis=axis, keepdims=keepdims, named_axis=named_axes, allow_all_gather=experimental_allow_all_gather) n = ps.size(input_tensor) // ps.size(lse) for named_axis in named_axes: n = n * distribute_lib.get_axis_size(named_axis) log_n = tf.math.log(tf.cast(n, lse.dtype)) return lse - log_n
def reduce_weighted_logsumexp(logx, w=None, axis=None, keep_dims=False, return_sign=False, experimental_named_axis=None, name=None): """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. If all weights `w` are known to be positive, it is more efficient to directly use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more efficient than `du.reduce_weighted_logsumexp(logx, w)`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than log(sum(w * exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. For example: ```python x = tf.constant([[0., 0, 0], [0, 0, 0]]) w = tf.constant([[-1., 1, 1], [1, 1, 1]]) du.reduce_weighted_logsumexp(x, w) # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) du.reduce_weighted_logsumexp(x, w, axis=0) # ==> [log(-1+1), log(1+1), log(1+1)] du.reduce_weighted_logsumexp(x, w, axis=1) # ==> [log(-1+1+1), log(1+1+1)] du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) # ==> [[log(-1+1+1)], [log(1+1+1)]] du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) # ==> log(-1+5) ``` Args: logx: The tensor to reduce. Should have numeric type. w: The weight tensor. Should have numeric type identical to `logx`. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keep_dims: If true, retains reduced dimensions with length 1. return_sign: If `True`, returns the sign of the result. experimental_named_axis: A `str or list of `str` axis names to additionally reduce over. Providing `None` will not reduce over any axes. name: A name for the operation (optional). Returns: lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. sign: (Optional) The sign of `sum(weight * exp(x))`. """ with tf.name_scope(name or 'reduce_weighted_logsumexp'): logx = tf.convert_to_tensor(logx, name='logx') if w is None: lswe = distribute_lib.reduce_logsumexp( logx, axis=axis, keepdims=keep_dims, named_axis=experimental_named_axis) if return_sign: sgn = tf.ones_like(lswe) return lswe, sgn return lswe w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w') log_absw_x = logx + tf.math.log(tf.abs(w)) max_log_absw_x = distribute_lib.reduce_max( log_absw_x, axis=axis, keepdims=True, named_axis=experimental_named_axis) # If the largest element is `-inf` or `inf` then we don't bother subtracting # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That # this is ok follows from the fact that we're actually free to subtract any # value we like, so long as we add it back after taking the `log(sum(...))`. max_log_absw_x = tf.where(tf.math.is_inf(max_log_absw_x), tf.zeros([], max_log_absw_x.dtype), max_log_absw_x) wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x)) sum_wx_over_max_absw_x = distribute_lib.reduce_sum( wx_over_max_absw_x, axis=axis, keepdims=keep_dims, named_axis=experimental_named_axis) if not keep_dims: max_log_absw_x = tf.squeeze(max_log_absw_x, axis) sgn = tf.sign(sum_wx_over_max_absw_x) lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x) if return_sign: return lswe, sgn return lswe