def doOutputTest(self, input_shape, moments_axes, tol=1e-4, check_gradients=False): for mu in [0.0, 1.0, 1e3]: for sigma in [1.0, 0.1]: for keep_dims in [True, False]: input_values = np.random.rand(*input_shape) * sigma + mu expected_mean = np.mean( input_values, axis=moments_axes, keepdims=keep_dims) expected_var = np.var( input_values, axis=moments_axes, keepdims=keep_dims) with ops.Graph().as_default() as g: with self.session(graph=g) as sess: inputs = constant_op.constant( input_values, shape=input_shape, dtype=dtypes.float32) mean, variance = nn_impl.moments_v2( inputs, moments_axes, keepdims=keep_dims) if check_gradients: err = gradient_checker.compute_gradient_error( inputs, input_shape, mean, mean.shape.as_list()) self.assertLess(err, 1e-3) err = gradient_checker.compute_gradient_error( inputs, input_shape, variance, variance.shape.as_list()) self.assertLess(err, 1e-3) # Evaluate. [mean, variance] = self.evaluate([mean, variance]) # Make sure that there are no NaNs self.assertFalse(np.isnan(mean).any()) self.assertFalse(np.isnan(variance).any()) self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol) self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
def update_state(self, data): if not self.built: raise RuntimeError('`build` must be called before `update_state`.') data = self._standardize_inputs(data) batch_mean, batch_variance = nn_impl.moments_v2( data, axes=self._reduce_axis) batch_shape = array_ops.shape(data, out_type=self.count.dtype) batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis) batch_count = math_ops.reduce_prod(batch_reduce_shape) total_count = batch_count + self.count batch_weight = ( math_ops.cast(batch_count, dtype=self.dtype) / math_ops.cast(total_count, dtype=self.dtype)) existing_weight = 1. - batch_weight total_mean = self.mean * existing_weight + batch_mean * batch_weight # The variance is computed using the lack-of-fit sum of squares # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). total_variance = ((self.variance + (self.mean - total_mean)**2) * existing_weight + (batch_variance + (batch_mean - total_mean)**2) * batch_weight) self.mean.assign(total_mean) self.variance.assign(total_variance) self.count.assign(total_count)
def update_state(self, data): if self.input_mean is not None: raise ValueError( 'Cannot `adapt` a Normalization layer that is initialized with ' 'static `mean` and `variance`, you passed mean {} and variance {}.' .format(self.input_mean, self.input_variance)) if not self.built: raise RuntimeError('`build` must be called before `update_state`.') data = self._standardize_inputs(data) data = math_ops.cast(data, self.adapt_mean.dtype) batch_mean, batch_variance = nn_impl.moments_v2(data, axes=self._reduce_axis) batch_shape = array_ops.shape(data, out_type=self.count.dtype) batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis) batch_count = math_ops.reduce_prod(batch_reduce_shape) total_count = batch_count + self.count batch_weight = (math_ops.cast(batch_count, dtype=self.dtype) / math_ops.cast(total_count, dtype=self.dtype)) existing_weight = 1. - batch_weight total_mean = self.adapt_mean * existing_weight + batch_mean * batch_weight # The variance is computed using the lack-of-fit sum of squares # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). total_variance = ( (self.adapt_variance + (self.adapt_mean - total_mean)**2) * existing_weight + (batch_variance + (batch_mean - total_mean)**2) * batch_weight) self.adapt_mean.assign(total_mean) self.adapt_variance.assign(total_variance) self.count.assign(total_count)