Exemplo n.º 1
0
  def doOutputTest(self,
                   input_shape,
                   moments_axes,
                   tol=1e-4,
                   check_gradients=False):
    for mu in [0.0, 1.0, 1e3]:
      for sigma in [1.0, 0.1]:
        for keep_dims in [True, False]:
          input_values = np.random.rand(*input_shape) * sigma + mu
          expected_mean = np.mean(
              input_values, axis=moments_axes, keepdims=keep_dims)
          expected_var = np.var(
              input_values, axis=moments_axes, keepdims=keep_dims)
          with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
              inputs = constant_op.constant(
                  input_values, shape=input_shape, dtype=dtypes.float32)
              mean, variance = nn_impl.moments_v2(
                  inputs, moments_axes, keepdims=keep_dims)

              if check_gradients:
                err = gradient_checker.compute_gradient_error(
                    inputs, input_shape, mean, mean.shape.as_list())
                self.assertLess(err, 1e-3)
                err = gradient_checker.compute_gradient_error(
                    inputs, input_shape, variance, variance.shape.as_list())
                self.assertLess(err, 1e-3)

              # Evaluate.
              [mean, variance] = self.evaluate([mean, variance])
              # Make sure that there are no NaNs
              self.assertFalse(np.isnan(mean).any())
              self.assertFalse(np.isnan(variance).any())
              self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
              self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
Exemplo n.º 2
0
  def update_state(self, data):
    if not self.built:
      raise RuntimeError('`build` must be called before `update_state`.')

    data = self._standardize_inputs(data)
    batch_mean, batch_variance = nn_impl.moments_v2(
        data, axes=self._reduce_axis)
    batch_shape = array_ops.shape(data, out_type=self.count.dtype)
    batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis)
    batch_count = math_ops.reduce_prod(batch_reduce_shape)

    total_count = batch_count + self.count
    batch_weight = (
        math_ops.cast(batch_count, dtype=self.dtype) /
        math_ops.cast(total_count, dtype=self.dtype))
    existing_weight = 1. - batch_weight

    total_mean = self.mean * existing_weight + batch_mean * batch_weight
    # The variance is computed using the lack-of-fit sum of squares
    # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
    total_variance = ((self.variance +
                       (self.mean - total_mean)**2) * existing_weight +
                      (batch_variance +
                       (batch_mean - total_mean)**2) * batch_weight)
    self.mean.assign(total_mean)
    self.variance.assign(total_variance)
    self.count.assign(total_count)
Exemplo n.º 3
0
  def doOutputTest(self,
                   input_shape,
                   moments_axes,
                   tol=1e-4,
                   check_gradients=False):
    for mu in [0.0, 1.0, 1e3]:
      for sigma in [1.0, 0.1]:
        for keep_dims in [True, False]:
          input_values = np.random.rand(*input_shape) * sigma + mu
          expected_mean = np.mean(
              input_values, axis=moments_axes, keepdims=keep_dims)
          expected_var = np.var(
              input_values, axis=moments_axes, keepdims=keep_dims)
          with ops.Graph().as_default() as g:
            with self.session(graph=g) as sess:
              inputs = constant_op.constant(
                  input_values, shape=input_shape, dtype=dtypes.float32)
              mean, variance = nn_impl.moments_v2(
                  inputs, moments_axes, keepdims=keep_dims)

              if check_gradients:
                err = gradient_checker.compute_gradient_error(
                    inputs, input_shape, mean, mean.shape.as_list())
                self.assertLess(err, 1e-3)
                err = gradient_checker.compute_gradient_error(
                    inputs, input_shape, variance, variance.shape.as_list())
                self.assertLess(err, 1e-3)

              # Evaluate.
              [mean, variance] = self.evaluate([mean, variance])
              # Make sure that there are no NaNs
              self.assertFalse(np.isnan(mean).any())
              self.assertFalse(np.isnan(variance).any())
              self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
              self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
Exemplo n.º 4
0
    def update_state(self, data):
        if self.input_mean is not None:
            raise ValueError(
                'Cannot `adapt` a Normalization layer that is initialized with '
                'static `mean` and `variance`, you passed mean {} and variance {}.'
                .format(self.input_mean, self.input_variance))

        if not self.built:
            raise RuntimeError('`build` must be called before `update_state`.')

        data = self._standardize_inputs(data)
        data = math_ops.cast(data, self.adapt_mean.dtype)
        batch_mean, batch_variance = nn_impl.moments_v2(data,
                                                        axes=self._reduce_axis)
        batch_shape = array_ops.shape(data, out_type=self.count.dtype)
        batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis)
        batch_count = math_ops.reduce_prod(batch_reduce_shape)

        total_count = batch_count + self.count
        batch_weight = (math_ops.cast(batch_count, dtype=self.dtype) /
                        math_ops.cast(total_count, dtype=self.dtype))
        existing_weight = 1. - batch_weight

        total_mean = self.adapt_mean * existing_weight + batch_mean * batch_weight
        # The variance is computed using the lack-of-fit sum of squares
        # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
        total_variance = (
            (self.adapt_variance +
             (self.adapt_mean - total_mean)**2) * existing_weight +
            (batch_variance + (batch_mean - total_mean)**2) * batch_weight)
        self.adapt_mean.assign(total_mean)
        self.adapt_variance.assign(total_variance)
        self.count.assign(total_count)