예제 #1
0
def average(a, axis=None, weights=None, returned=False):  # pylint: disable=missing-docstring
    if axis is not None and not isinstance(axis, six.integer_types):
        # TODO(wangpeng): Support tuple of ints as `axis`
        raise ValueError('`axis` must be an integer. Tuple of ints is not '
                         'supported yet. Got type: %s' % type(axis))
    a = array_ops.array(a)
    if weights is None:  # Treat all weights as 1
        if not np.issubdtype(a.dtype, np.inexact):
            a = a.astype(
                utils.result_type(a.dtype, dtypes.default_float_type()))
        avg = tf.reduce_mean(a.data, axis=axis)
        if returned:
            if axis is None:
                weights_sum = tf.size(a.data)
            else:
                weights_sum = tf.shape(a.data)[axis]
            weights_sum = tf.cast(weights_sum, a.data.dtype)
    else:
        if np.issubdtype(a.dtype, np.inexact):
            out_dtype = utils.result_type(a.dtype, weights)
        else:
            out_dtype = utils.result_type(a.dtype, weights,
                                          dtypes.default_float_type())
        a = array_ops.array(a, out_dtype).data
        weights = array_ops.array(weights, out_dtype).data

        def rank_equal_case():
            tf.debugging.Assert(
                tf.reduce_all(tf.shape(a) == tf.shape(weights)),
                [tf.shape(a), tf.shape(weights)])
            weights_sum = tf.reduce_sum(weights, axis=axis)
            avg = tf.reduce_sum(a * weights, axis=axis) / weights_sum
            return avg, weights_sum

        if axis is None:
            avg, weights_sum = rank_equal_case()
        else:

            def rank_not_equal_case():
                tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)])
                weights_sum = tf.reduce_sum(weights)
                axes = tf.convert_to_tensor([[axis], [0]])
                avg = tf.tensordot(a, weights, axes) / weights_sum
                return avg, weights_sum

            # We condition on rank rather than shape equality, because if we do the
            # latter, when the shapes are partially unknown but the ranks are known
            # and different, utils.cond will run shape checking on the true branch,
            # which will raise a shape-checking error.
            avg, weights_sum = utils.cond(
                tf.rank(a) == tf.rank(weights), rank_equal_case,
                rank_not_equal_case)

    avg = array_ops.array(avg)
    if returned:
        weights_sum = array_ops.broadcast_to(weights_sum, tf.shape(avg.data))
        return avg, weights_sum
    return avg
예제 #2
0
 def run_test(arr, shape):
     for fn in self.array_transforms:
         arg1 = fn(arr)
         self.match(array_ops.broadcast_to(arg1, shape),
                    np.broadcast_to(arg1, shape))
예제 #3
0
 def replicate(x, num_devices=2):
   return array_ops.broadcast_to(x, (num_devices,) + x.shape)