def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): weights = constant_op.constant(1, dtype=x.dtype) if extra_out_grads is not None: # We want to assert gradients WRT weights as well as X! extra_out_grads.append(weights) return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims)
def RunWeightedMomentTest(self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False): with self.cached_session() as s: x_numpy = np.random.normal(size=shape).astype(np.float32) weights_numpy = np.absolute( # weights must be positive np.random.normal(size=weights_shape, loc=1.0).astype(np.float32)) # Expand the numpy version to higher precision x_numpy = x_numpy.astype(np.float128) weights_numpy = weights_numpy.astype(np.float128) x_shape = [None] * len(shape) if dynshapes else shape weights_shape = ([None] * len(weights_shape) if dynshapes else weights_shape) x = array_ops.placeholder(dtype, shape=x_shape) weights = array_ops.placeholder(dtype, shape=weights_shape) mean, var = nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims) ax = tuple(axes) def _np_weighted_sum(v): return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims) weight_sum = _np_weighted_sum(np.ones_like(x_numpy)) expected_mean = _np_weighted_sum(x_numpy) / weight_sum expected_mean_squared = np.multiply(expected_mean, expected_mean) expected_x_squared = ( _np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum) expected_variance = expected_x_squared - expected_mean_squared mean_v, var_v = s.run([mean, var], feed_dict={ x: x_numpy, weights: weights_numpy }) self.assertAllCloseAccordingToType(expected_mean, mean_v) self.assertAllCloseAccordingToType(expected_variance, var_v)
def RunWeightedMomentTest(self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False): with self.cached_session() as s: x_numpy = np.random.normal(size=shape).astype(np.float32) weights_numpy = np.absolute( # weights must be positive np.random.normal( size=weights_shape, loc=1.0).astype(np.float32)) # Expand the numpy version to higher precision x_numpy = x_numpy.astype(np.float128) weights_numpy = weights_numpy.astype(np.float128) x_shape = [None] * len(shape) if dynshapes else shape weights_shape = ([None] * len(weights_shape) if dynshapes else weights_shape) x = array_ops.placeholder(dtype, shape=x_shape) weights = array_ops.placeholder(dtype, shape=weights_shape) mean, var = nn_impl.weighted_moments( x, axes, weights, keep_dims=keep_dims) ax = tuple(axes) def _np_weighted_sum(v): return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims) weight_sum = _np_weighted_sum(np.ones_like(x_numpy)) expected_mean = _np_weighted_sum(x_numpy) / weight_sum expected_mean_squared = np.multiply(expected_mean, expected_mean) expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum) expected_variance = expected_x_squared - expected_mean_squared mean_v, var_v = s.run([mean, var], feed_dict={x: x_numpy, weights: weights_numpy}) self.assertAllCloseAccordingToType(expected_mean, mean_v) self.assertAllCloseAccordingToType(expected_variance, var_v)