def testAlphaDropout(self):
     x_dim, y_dim = 40, 30
     for keep_prob in [0.1, 0.5, 0.8]:
         with self.test_session():
             t = random_ops.random_normal([x_dim, y_dim])
             output = alpha_dropout(t, keep_prob)
             self.assertEqual([x_dim, y_dim], output.get_shape())
             t_mean, t_std = nn_impl.moments(t, axes=[0, 1])
             output_mean, output_std = nn_impl.moments(output, axes=[0, 1])
             self.assertLess(abs(t_mean.eval() - output_mean.eval()), 0.1)
             self.assertLess(abs(t_std.eval() - output_std.eval()), 0.1)
 def testAlphaDropout(self):
   x_dim, y_dim = 40, 30
   for keep_prob in [0.1, 0.5, 0.8]:
     with self.test_session():
       t = random_ops.random_normal([x_dim, y_dim])
       output = alpha_dropout(t, keep_prob)
       self.assertEqual([x_dim, y_dim], output.get_shape())
       t_mean, t_std = nn_impl.moments(t, axes=[0, 1])
       output_mean, output_std = nn_impl.moments(output, axes=[0, 1])
       self.assertLess(abs(t_mean.eval() - output_mean.eval()), 0.1)
       self.assertLess(abs(t_std.eval() - output_std.eval()), 0.1)
Example #3
0
    def call(self, inputs):
        output_size = self.kernel.get_shape().as_list()[1]
        g = vs.get_variable('weight_norm', [output_size],
                            initializer=init_ops.constant_initializer(1.0),
                            dtype=self.kernel.dtype)  # trainable

        self.kernel = nn_impl.l2_normalize(self.kernel, dim=0) * g

        outputs = self._convolution_op(inputs, self.kernel)

        if self.use_bias:
            if self.data_format == 'channels_first':
                bias = array_ops.reshape(self.bias, (1, self.filters, 1))
                outputs += bias
            else:
                outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')

        #  data-dependent initialization
        if self.wn_init:
            mean, variance = nn_impl.moments(outputs, axes=[0, 1, 2])
            scale_init = 1. / math_ops.sqrt(variance + 1e-10)
            with ops.control_dependencies([
                    g.assign(g * scale_init),
                    self.bias.assign_add(-mean * scale_init)
            ]):
                outputs = array_ops.identity(outputs)
            tf.assign(self.wn_init, False)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
    def _training_ref(self, x, scale, offset, old_mean, old_var,
                      exponential_avg_factor, epsilon, data_format):
        if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
            raise ValueError(
                'data_format must be NCHW or NHWC for 4D tensors or'
                'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
        use_4d_tensor = (x.shape.ndims == 4)
        if data_format == 'NCHW':
            x = array_ops.transpose(x, [0, 2, 3, 1])
        elif data_format == 'NCDHW':
            x = array_ops.transpose(x, [0, 2, 3, 4, 1])

        mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
        batch_mean, batch_var = nn_impl.moments(math_ops.cast(x, scale.dtype),
                                                mean_axis,
                                                keep_dims=False)

        y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
        if data_format == 'NCHW':
            y = array_ops.transpose(y, [0, 3, 1, 2])
        elif data_format == 'NCDHW':
            y = array_ops.transpose(y, [0, 4, 1, 2, 3])

        # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
        # the denominator in the formula to calculate variance, while
        # tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
        sample_size = math_ops.cast(
            array_ops.size(x) / array_ops.size(scale), scale.dtype)
        batch_var_corrected = batch_var * sample_size / (math_ops.maximum(
            sample_size - 1.0, 1.0))

        mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
        var = self._running_mean(old_var, batch_var_corrected,
                                 exponential_avg_factor)
        return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
Example #5
0
  def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
                   check_gradients=False):
    for mu in [0.0, 1.0, 1e3]:
      for sigma in [1.0, 0.1]:
        for keep_dims in [True, False]:
          input_values = np.random.rand(*input_shape) * sigma + mu
          expected_mean = np.mean(
              input_values, axis=moments_axes, keepdims=keep_dims)
          expected_var = np.var(
              input_values, axis=moments_axes, keepdims=keep_dims)
          with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as sess:
              inputs = constant_op.constant(
                  input_values, shape=input_shape, dtype=dtypes.float32)
              mean, variance = nn_impl.moments(
                  inputs, moments_axes, keep_dims=keep_dims)

              if check_gradients:
                err = gradient_checker.compute_gradient_error(
                    inputs, input_shape, mean, mean.shape.as_list())
                self.assertLess(err, 1e-3)
                err = gradient_checker.compute_gradient_error(
                    inputs, input_shape, variance, variance.shape.as_list())
                self.assertLess(err, 1e-3)

              # Evaluate.
              [mean, variance] = sess.run([mean, variance])
              # Make sure that there are no NaNs
              self.assertFalse(np.isnan(mean).any())
              self.assertFalse(np.isnan(variance).any())
              self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
              self.assertAllClose(variance, expected_var, rtol=tol, atol=tol)
Example #6
0
    def doOutputTest(self, input_shape, moments_axes, tol=1e-4):
        for mu in [0.0, 1.0, 1e3]:
            for sigma in [1.0, 0.1]:
                for keep_dims in [True, False]:
                    input_values = np.random.rand(*input_shape) * sigma + mu
                    expected_mean = np.mean(input_values,
                                            axis=moments_axes,
                                            keepdims=keep_dims)
                    expected_var = np.var(input_values,
                                          axis=moments_axes,
                                          keepdims=keep_dims)
                    with ops.Graph().as_default() as g:
                        with self.test_session(graph=g) as sess:
                            inputs = constant_op.constant(input_values,
                                                          shape=input_shape,
                                                          dtype=dtypes.float32)
                            mean, variance = nn_impl.moments(
                                inputs, moments_axes, keep_dims=keep_dims)

                            [mean, variance] = sess.run([mean, variance])
                            # Make sure that there are no NaNs
                            self.assertFalse(np.isnan(mean).any())
                            self.assertFalse(np.isnan(variance).any())
                            self.assertAllClose(mean,
                                                expected_mean,
                                                rtol=tol,
                                                atol=tol)
                            self.assertAllClose(variance,
                                                expected_var,
                                                rtol=tol,
                                                atol=tol)
  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
    # Method to compute moments of `x` wrt `axes`.
    #
    # This is exposed so WeightedMomentsTest can inherit the tests and
    # assertions from MomentsTest; the extra_out_grads argument allows
    # its inherited gradient tests to assert gradients against the
    # weights as well as the input values.

    return nn_impl.moments(x, axes, keep_dims=keep_dims)
  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
    # Method to compute moments of `x` wrt `axes`.
    #
    # This is exposed so WeightedMomentsTest can inherit the tests and
    # assertions from MomentsTest; the extra_out_grads argument allows
    # its inherited gradient tests to assert gradients against the
    # weights as well as the input values.

    return nn_impl.moments(x, axes, keep_dims=keep_dims)
 def mean_var_with_update():
     mean, variance = nn_impl.moments(input_x, [0, 1, 2],
                                      name='moments')
     with ops.control_dependencies([\
         moving_averages.assign_moving_average(\
                       moving_mean, mean, 0.9),
         moving_averages.assign_moving_average(\
                       moving_variance, variance, 0.9)]):
         return array_ops.identity(mean), array_ops.identity(
             variance)
 def _training_ref(self, x, scale, offset, epsilon, data_format):
   if data_format not in ['NHWC', 'NCHW']:
     raise ValueError('data_format must be NCHW or NHWC, '
                      'got %s.' % data_format)
   if data_format == 'NCHW':
     x = array_ops.transpose(x, [0, 2, 3, 1])
   mean, var = nn_impl.moments(x, [0, 1, 2], keep_dims=False)
   y = nn_impl.batch_normalization(x, mean, var, offset, scale, epsilon)
   if data_format == 'NCHW':
     y = array_ops.transpose(y, [0, 3, 1, 2])
   return y.eval(), mean.eval(), var.eval()
Example #11
0
 def _training_ref(self, x, scale, offset, epsilon, data_format):
     if data_format not in ['NHWC', 'NCHW']:
         raise ValueError('data_format must be NCHW or NHWC, '
                          'got %s.' % data_format)
     if data_format == 'NCHW':
         x = array_ops.transpose(x, [0, 2, 3, 1])
     mean, var = nn_impl.moments(x, [0, 1, 2], keep_dims=False)
     y = nn_impl.batch_normalization(x, mean, var, offset, scale, epsilon)
     if data_format == 'NCHW':
         y = array_ops.transpose(y, [0, 3, 1, 2])
     return y.eval(), mean.eval(), var.eval()
Example #12
0
def build_graph(device, input_shape, axes, num_layers, mode, scale, train):
    """Build a graph containing a sequence of batch normalizations.

  Args:
    device: string, the device to run on.
    input_shape: shape of the input tensor.
    axes: axes that are to be normalized across.
    num_layers: number of batch normalization layers in the graph.
    mode: "op", "py" or "slow" depending on the implementation.
    scale: scale after normalization.
    train: if true, also run backprop.

  Returns:
    An array of tensors to run()
  """
    moment_shape = []
    keep_dims = mode == "py" or mode == "slow"
    if keep_dims:
        for axis in range(len(input_shape)):
            if axis in axes:
                moment_shape.append(1)
            else:
                moment_shape.append(input_shape[axis])
    else:
        for axis in range(len(input_shape)):
            if axis not in axes:
                moment_shape.append(input_shape[axis])
    with ops.device("/%s:0" % device):
        tensor = variables.Variable(random_ops.truncated_normal(input_shape))
        for _ in range(num_layers):
            if train:
                mean, variance = nn_impl.moments(tensor,
                                                 axes,
                                                 keep_dims=keep_dims)
            else:
                mean = array_ops.zeros(moment_shape)
                variance = array_ops.ones(moment_shape)
            beta = variables.Variable(array_ops.zeros(moment_shape))
            gamma = variables.Variable(
                constant_op.constant(1.0, shape=moment_shape))
            if mode == "py":
                tensor = batch_norm_py(tensor, mean, variance, beta, gamma,
                                       scale)
            elif mode == "op":
                tensor = batch_norm_op(tensor, mean, variance, beta, gamma,
                                       scale)
            elif mode == "slow":
                tensor = batch_norm_slow(tensor, mean, variance, beta, gamma,
                                         scale)
        if train:
            return gradients_impl.gradients([tensor],
                                            variables.trainable_variables())
        else:
            return [tensor]
Example #13
0
 def _training_ref(self, x, scale, offset, epsilon, data_format):
   if data_format not in ['NHWC', 'NCHW']:
     raise ValueError('data_format must be NCHW or NHWC, '
                      'got %s.' % data_format)
   if data_format == 'NCHW':
     x = array_ops.transpose(x, [0, 2, 3, 1])
   mean, var = nn_impl.moments(
       math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
   y = self._batch_norm(x, mean, var, offset, scale, epsilon)
   if data_format == 'NCHW':
     y = array_ops.transpose(y, [0, 3, 1, 2])
   return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
 def _training_ref(self, x, scale, offset, epsilon, data_format):
   if data_format not in ['NHWC', 'NCHW']:
     raise ValueError('data_format must be NCHW or NHWC, '
                      'got %s.' % data_format)
   if data_format == 'NCHW':
     x = array_ops.transpose(x, [0, 2, 3, 1])
   mean, var = nn_impl.moments(
       math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
   y = self._batch_norm(x, mean, var, offset, scale, epsilon)
   if data_format == 'NCHW':
     y = array_ops.transpose(y, [0, 3, 1, 2])
   return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
Example #15
0
 def _layer_normalization(self, inputs, scope=None):
     """
     :param inputs: (batch, shape)
     :param scope:
     :return : layer normalized inputs (batch, shape)
     """
     shape = inputs.get_shape()[-1:]
     with vs.variable_scope(scope or "layer_norm"):
         # Initialize beta and gamma for use by layer_norm.
         g = vs.get_variable("gain", shape=shape, initializer=init_ops.constant_initializer(self._g))  # (shape,)
         s = vs.get_variable("shift", shape=shape, initializer=init_ops.constant_initializer(self._b))  # (shape,)
     m, v = nn_impl.moments(inputs, [1], keep_dims=True)  # (batch,)
     normalized_input = (inputs - m) / math_ops.sqrt(v + _EPSILON)  # (batch, shape)
     return normalized_input * g + s
def build_graph(device, input_shape, axes, num_layers, mode, scale, train):
  """Build a graph containing a sequence of batch normalizations.

  Args:
    device: string, the device to run on.
    input_shape: shape of the input tensor.
    axes: axes that are to be normalized across.
    num_layers: number of batch normalization layers in the graph.
    mode: "op", "py" or "slow" depending on the implementation.
    scale: scale after normalization.
    train: if true, also run backprop.

  Returns:
    An array of tensors to run()
  """
  moment_shape = []
  keep_dims = mode == "py" or mode == "slow"
  if keep_dims:
    for axis in range(len(input_shape)):
      if axis in axes:
        moment_shape.append(1)
      else:
        moment_shape.append(input_shape[axis])
  else:
    for axis in range(len(input_shape)):
      if axis not in axes:
        moment_shape.append(input_shape[axis])
  with ops.device("/%s:0" % device):
    tensor = variables.Variable(random_ops.truncated_normal(input_shape))
    for _ in range(num_layers):
      if train:
        mean, variance = nn_impl.moments(tensor, axes, keep_dims=keep_dims)
      else:
        mean = array_ops.zeros(moment_shape)
        variance = array_ops.ones(moment_shape)
      beta = variables.Variable(array_ops.zeros(moment_shape))
      gamma = variables.Variable(constant_op.constant(1.0, shape=moment_shape))
      if mode == "py":
        tensor = batch_norm_py(tensor, mean, variance, beta, gamma, scale)
      elif mode == "op":
        tensor = batch_norm_op(tensor, mean, variance, beta, gamma, scale)
      elif mode == "slow":
        tensor = batch_norm_slow(tensor, mean, variance, beta, gamma, scale)
    if train:
      return gradients_impl.gradients([tensor], variables.trainable_variables())
    else:
      return [tensor]
Example #17
0
    def call(self, inputs, **kwargs):
        input_shape = K.int_shape(inputs)
        tensor_input_shape = K.shape(inputs)

        # Prepare broadcasting shape.
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
        broadcast_shape.insert(1, self.groups)

        reshape_group_shape = K.shape(inputs)
        group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
        group_axes[self.axis] = input_shape[self.axis] // self.groups
        group_axes.insert(1, self.groups)

        # reshape inputs to new group shape
        group_shape = [group_axes[0], self.groups] + group_axes[2:]
        group_shape = K.stack(group_shape)
        inputs = K.reshape(inputs, group_shape)

        group_reduction_axes = list(range(len(group_axes)))
        mean, variance = nn_impl.moments(inputs,
                                         group_reduction_axes[2:],
                                         shift=None,
                                         keep_dims=True)
        inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))

        # prepare broadcast shape
        inputs = K.reshape(inputs, group_shape)

        outputs = inputs

        # In this case we must explicitly broadcast all parameters.
        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            outputs = outputs * broadcast_gamma

        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            outputs = outputs + broadcast_beta

        # finally we reshape the output back to the input shape
        outputs = K.reshape(outputs, tensor_input_shape)

        return outputs
  def testUnstableOutputShiftNone(self):
    input_shape = (10, 300)
    moments_axes = (0, 1)
    mu, sigma = 1e3, 0.1
    tol = 1e-3
    input_values = np.random.rand(*input_shape) * sigma + mu
    expected_mean = np.mean(input_values, axis=moments_axes)
    expected_var = np.var(input_values, axis=moments_axes)

    with self.test_session() as sess:
      inputs = constant_op.constant(
          input_values, shape=input_shape, dtype=dtypes.float32)
      mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0)

      [mean, variance] = sess.run([mean, variance])
      # Make sure that there are no NaNs
      self.assertFalse(np.isnan(mean).any())
      self.assertFalse(np.isnan(variance).any())
      self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
      # The variance is unstable
      self.assertGreater(np.abs(variance - expected_var), 0.1)
Example #19
0
  def testUnstableOutputShiftNone(self):
    input_shape = (10, 300)
    moments_axes = (0, 1)
    mu, sigma = 1e3, 0.1
    tol = 1e-3
    input_values = np.random.rand(*input_shape) * sigma + mu
    expected_mean = np.mean(input_values, axis=moments_axes)
    expected_var = np.var(input_values, axis=moments_axes)

    with self.test_session() as sess:
      inputs = constant_op.constant(input_values, shape=input_shape,
                                    dtype=dtypes.float32)
      mean, variance = nn_impl.moments(inputs, moments_axes, shift=0.0)

      [mean, variance] = sess.run([mean, variance])
      # Make sure that there are no NaNs
      self.assertFalse(np.isnan(mean).any())
      self.assertFalse(np.isnan(variance).any())
      self.assertAllClose(mean, expected_mean, rtol=tol, atol=tol)
      # The variance is unstable
      self.assertGreater(np.abs(variance - expected_var), 0.1)
    def _training_ref(self, x, scale, offset, epsilon, data_format):
        if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
            raise ValueError(
                'data_format must be NCHW or NHWC for 4D tensors or'
                'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
        use_4d_tensor = (x.shape.ndims == 4)
        if data_format == 'NCHW':
            x = array_ops.transpose(x, [0, 2, 3, 1])
        elif data_format == 'NCDHW':
            x = array_ops.transpose(x, [0, 2, 3, 4, 1])

        mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
        mean, var = nn_impl.moments(math_ops.cast(x, scale.dtype),
                                    mean_axis,
                                    keep_dims=False)
        y = self._batch_norm(x, mean, var, offset, scale, epsilon)
        if data_format == 'NCHW':
            y = array_ops.transpose(y, [0, 3, 1, 2])
        elif data_format == 'NCDHW':
            y = array_ops.transpose(y, [0, 4, 1, 2, 3])
        return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
def diagonal_only_frechet_classifier_distance_from_activations(
    real_activations, generated_activations):
  """Classifier distance for evaluating a generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

          |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images. In this variant, we compute diagonal-only covariance matrices.
  As a result, instead of computing an expensive matrix square root, we can do
  something much simpler, and has O(n) vs O(n^2) space complexity.

  Note that when computed using sample means and sample covariance matrices,
  Frechet distance is biased. It is more biased for small sample sizes. (e.g.
  even if the two distributions are the same, for a small sample size, the
  expected Frechet distance is large). It is important to use the same
  sample size to compute frechet classifier distance when comparing two
  generative models.

  Args:
    real_activations: Real images to use to compute Frechet Inception distance.
    generated_activations: Generated images to use to compute Frechet Inception
      distance.

  Returns:
    The diagonal-only Frechet Inception distance. A floating-point scalar of
    the same type as the output of the activations.

  Raises:
    ValueError: If the shape of the variance and mean vectors are not equal.
  """
  real_activations.shape.assert_has_rank(2)
  generated_activations.shape.assert_has_rank(2)

  activations_dtype = real_activations.dtype
  if activations_dtype != dtypes.float64:
    real_activations = math_ops.cast(real_activations, dtypes.float64)
    generated_activations = math_ops.cast(generated_activations, dtypes.float64)

  # Compute mean and covariance matrices of activations.
  m, var = nn_impl.moments(real_activations, axes=[0])
  m_w, var_w = nn_impl.moments(generated_activations, axes=[0])

  actual_shape = var.get_shape()
  expected_shape = m.get_shape()

  if actual_shape != expected_shape:
    raise ValueError('shape: {} must match expected shape: {}'.format(
        actual_shape, expected_shape))

  # Compute the two components of FID.

  # First the covariance component.
  # Here, note that trace(A + B) = trace(A) + trace(B)
  trace = math_ops.reduce_sum(
      (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w)))

  # Next the distance between means.
  mean = math_ops.reduce_sum(
      math_ops.squared_difference(m, m_w))  # Equivalent to L2 but more stable.
  dofid = trace + mean
  if activations_dtype != dtypes.float64:
    dofid = math_ops.cast(dofid, activations_dtype)

  return dofid
Example #22
0
def diagonal_only_frechet_classifier_distance_from_activations(
        real_activations, generated_activations):
    """Classifier distance for evaluating a generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

          |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images. In this variant, we compute diagonal-only covariance matrices.
  As a result, instead of computing an expensive matrix square root, we can do
  something much simpler, and has O(n) vs O(n^2) space complexity.

  Note that when computed using sample means and sample covariance matrices,
  Frechet distance is biased. It is more biased for small sample sizes. (e.g.
  even if the two distributions are the same, for a small sample size, the
  expected Frechet distance is large). It is important to use the same
  sample size to compute frechet classifier distance when comparing two
  generative models.

  Args:
    real_activations: Real images to use to compute Frechet Inception distance.
    generated_activations: Generated images to use to compute Frechet Inception
      distance.

  Returns:
    The diagonal-only Frechet Inception distance. A floating-point scalar of
    the same type as the output of the activations.

  Raises:
    ValueError: If the shape of the variance and mean vectors are not equal.
  """
    real_activations.shape.assert_has_rank(2)
    generated_activations.shape.assert_has_rank(2)

    activations_dtype = real_activations.dtype
    if activations_dtype != dtypes.float64:
        real_activations = math_ops.cast(real_activations, dtypes.float64)
        generated_activations = math_ops.cast(generated_activations,
                                              dtypes.float64)

    # Compute mean and covariance matrices of activations.
    m, var = nn_impl.moments(real_activations, axes=[0])
    m_w, var_w = nn_impl.moments(generated_activations, axes=[0])

    actual_shape = var.get_shape()
    expected_shape = m.get_shape()

    if actual_shape != expected_shape:
        raise ValueError('shape: {} must match expected shape: {}'.format(
            actual_shape, expected_shape))

    # Compute the two components of FID.

    # First the covariance component.
    # Here, note that trace(A + B) = trace(A) + trace(B)
    trace = math_ops.reduce_sum((var + var_w) - 2.0 *
                                math_ops.sqrt(math_ops.multiply(var, var_w)))

    # Next the distance between means.
    mean = math_ops.reduce_sum(math_ops.squared_difference(
        m, m_w))  # Equivalent to L2 but more stable.
    dofid = trace + mean
    if activations_dtype != dtypes.float64:
        dofid = math_ops.cast(dofid, activations_dtype)

    return dofid
Example #23
0
def _batch_norm(_input,
                trnFlag,
                eps=1e-3,
                name="batch_norm",
                ema_decay=0.5,
                dtype=dtypes.float32):
    """
    A wrapped BN operation used for 2D or 3D convolution as described in:

        https://gist.github.com/tomokishii/0ce3bdac1588b5cca9fa5fbdf6e1c412
        https://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow?answertab=votes#tab-top

    :param _input: tensor, always convolution result before Relu
    :param eps: scalar,
    :param trnFlag: bool, whether training or not
    :param ema_decay: scalar, moving average used of BN's beta and gamma
    :param dtype: tf.dtype, data type
    :return:
        tensor, BN reuslt which has the same shape as _input
    """
    shape = _input.get_shape().as_list()
    with variable_scope.variable_scope(name) as scope:
        beta = variable_scope.get_variable(
            "beta", [shape[-1]],
            dtype=dtype,
            initializer=init_ops.constant_initializer(0., dtype=dtype),
            trainable=True)
        gamma = variable_scope.get_variable(
            "gamma", [shape[-1]],
            dtype=dtype,
            initializer=init_ops.random_normal_initializer(1.,
                                                           0.01,
                                                           dtype=dtype,
                                                           seed=20170705),
            trainable=True)

        if shape.__len__() == 2:  # fc, [batch_size, num_dim]
            batch_mean, batch_var = nn_impl.moments(_input, [0],
                                                    name="moments")
        elif shape.__len__(
        ) == 4:  # conv, [batch_size, width, heigth, channel]
            batch_mean, batch_var = nn_impl.moments(_input, [0, 1, 2],
                                                    name="moments")
        elif shape.__len__(
        ) == 5:  # conv, [batch_size, depth, width, heigth, channel]
            batch_mean, batch_var = nn_impl.moments(_input, [0, 1, 2, 3],
                                                    name="moments")
        else:
            raise 'wrong _input shape, it must have dim of 2 or 4 or 5'

        ema = training.ExponentialMovingAverage(decay=ema_decay)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean, batch_var])
            with ops.control_dependencies([ema_apply_op]):
                return array_ops.identity(batch_mean), array_ops.identity(
                    batch_var)

        mean, var = control_flow_ops.cond(
            trnFlag, mean_var_with_update, lambda:
            (ema.average(batch_mean), ema.average(batch_var)))

        bn_out = nn_impl.batch_normalization(_input, mean, var, beta, gamma,
                                             eps)
    return bn_out