Exemplo n.º 1
0
def average(a, axis=None, weights=None, returned=False):  # pylint: disable=missing-docstring
  if axis is not None and not isinstance(axis, six.integer_types):
    # TODO(wangpeng): Support tuple of ints as `axis`
    raise ValueError('`axis` must be an integer. Tuple of ints is not '
                     'supported yet. Got type: %s' % type(axis))
  a = np_array_ops.array(a)
  if weights is None:  # Treat all weights as 1
    if not np.issubdtype(a.dtype, np.inexact):
      a = a.astype(
          np_utils.result_type(a.dtype, np_dtypes.default_float_type()))
    avg = math_ops.reduce_mean(a.data, axis=axis)
    if returned:
      if axis is None:
        weights_sum = array_ops.size(a.data)
      else:
        weights_sum = array_ops.shape(a.data)[axis]
      weights_sum = math_ops.cast(weights_sum, a.data.dtype)
  else:
    if np.issubdtype(a.dtype, np.inexact):
      out_dtype = np_utils.result_type(a.dtype, weights)
    else:
      out_dtype = np_utils.result_type(a.dtype, weights,
                                       np_dtypes.default_float_type())
    a = np_array_ops.array(a, out_dtype).data
    weights = np_array_ops.array(weights, out_dtype).data

    def rank_equal_case():
      control_flow_ops.Assert(
          math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)),
          [array_ops.shape(a), array_ops.shape(weights)])
      weights_sum = math_ops.reduce_sum(weights, axis=axis)
      avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum
      return avg, weights_sum

    if axis is None:
      avg, weights_sum = rank_equal_case()
    else:

      def rank_not_equal_case():
        control_flow_ops.Assert(
            array_ops.rank(weights) == 1, [array_ops.rank(weights)])
        weights_sum = math_ops.reduce_sum(weights)
        axes = ops.convert_to_tensor([[axis], [0]])
        avg = math_ops.tensordot(a, weights, axes) / weights_sum
        return avg, weights_sum

      # We condition on rank rather than shape equality, because if we do the
      # latter, when the shapes are partially unknown but the ranks are known
      # and different, np_utils.cond will run shape checking on the true branch,
      # which will raise a shape-checking error.
      avg, weights_sum = np_utils.cond(
          math_ops.equal(array_ops.rank(a), array_ops.rank(weights)),
          rank_equal_case, rank_not_equal_case)

  avg = np_array_ops.array(avg)
  if returned:
    weights_sum = np_array_ops.broadcast_to(weights_sum,
                                            array_ops.shape(avg.data))
    return avg, weights_sum
  return avg
Exemplo n.º 2
0
 def run_test(arr, shape):
   for fn in self.array_transforms:
     arg1 = fn(arr)
     self.match(
         np_array_ops.broadcast_to(arg1, shape),
         np.broadcast_to(arg1, shape))