Example #1
0
def _compute_pi_tracenorm(left_cov, right_cov):
  """Computes the scalar constant pi for Tikhonov regularization/damping.
  pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
  See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
  Args:
    left_cov: The left Kronecker factor "covariance".
    right_cov: The right Kronecker factor "covariance".
  Returns:
    The computed scalar constant pi for these Kronecker Factors (as a Tensor).
  """
  # Instead of dividing by the dim of the norm, we multiply by the dim of the
  # other norm. This works out the same in the ratio.
  left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0]
  right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0]
  return math_ops.sqrt(left_norm / right_norm)
  def run_test(self, axes, expanded_axes=None):
    expanded_axes = expanded_axes if expanded_axes is not None else axes
    all_axes = {ax: np.random.randint(4, 12)
                for ax in expanded_axes if ax.isalpha()}

    input_vals = []
    input_axes, _, _ = axes.partition('->')

    for idx in input_axes.split(','):
      shape = [all_axes[ax] for ax in idx if ax.isalpha()]
      input_vals.append(np.random.random(shape))

    input_tensors = [constant_op.constant(val) for val in input_vals]
    output_tensor = special_math_ops.einsum(axes, *input_tensors)

    with self.session(use_gpu=True):
      output_value = self.evaluate(output_tensor)

    correct_value = 0
    if axes == 'ijji':
      output = math_ops.trace(*input_tensors)
      correct_value = self.evaluate(output)
    else:
      correct_value = np.einsum(axes, *input_vals)
    err = np.abs(correct_value - output_value).max()
    self.assertLess(err, 1e-8)
Example #3
0
  def run_test(self, axes, expanded_axes=None):
    expanded_axes = expanded_axes if expanded_axes is not None else axes
    all_axes = {ax: np.random.randint(4, 12)
                for ax in expanded_axes if ax.isalpha()}

    input_vals = []
    input_axes, _, _ = axes.partition('->')

    for idx in input_axes.split(','):
      shape = [all_axes[ax] for ax in idx if ax.isalpha()]
      input_vals.append(np.random.random(shape))

    input_tensors = [constant_op.constant(val) for val in input_vals]
    output_tensor = special_math_ops.einsum(axes, *input_tensors)

    with self.session(use_gpu=True):
      output_value = self.evaluate(output_tensor)

    correct_value = 0
    if axes == 'ijji':
      output = math_ops.trace(*input_tensors)
      correct_value = self.evaluate(output)
    else:
      correct_value = np.einsum(axes, *input_vals)
    err = np.abs(correct_value - output_value).max()
    self.assertLess(err, 1e-8)
def trace_sqrt_product(sigma, sigma_v):

    sqrt_sigma = _symmetric_matrix_square_root(sigma)

    sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma,
                                      math_ops.matmul(sigma_v, sqrt_sigma))

    return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
Example #5
0
def _compute_pi_tracenorm(left_cov, right_cov):
  """Computes the scalar constant pi for Tikhonov regularization/damping.

  pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
  See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.

  Args:
    left_cov: The left Kronecker factor "covariance".
    right_cov: The right Kronecker factor "covariance".

  Returns:
    The computed scalar constant pi for these Kronecker Factors (as a Tensor).
  """
  # Instead of dividing by the dim of the norm, we multiply by the dim of the
  # other norm. This works out the same in the ratio.
  left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0]
  right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0]
  return math_ops.sqrt(left_norm / right_norm)
  def compute_kid_block(i):
    """Computes the ith block of the KID estimate."""
    r_s = inds_r[i]
    r_e = inds_r[i + 1]
    r = real_activations[r_s:r_e]
    m = math_ops.cast(r_e - r_s, dtype)

    g_s = inds_g[i]
    g_e = inds_g[i + 1]
    g = generated_activations[g_s:g_e]
    n = math_ops.cast(g_e - g_s, dtype)

    k_rr = (math_ops.matmul(r, r, transpose_b=True) / dim + 1)**3
    k_rg = (math_ops.matmul(r, g, transpose_b=True) / dim + 1)**3
    k_gg = (math_ops.matmul(g, g, transpose_b=True) / dim + 1)**3
    return (-2 * math_ops.reduce_mean(k_rg) +
            (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) +
            (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1)))
Example #7
0
  def compute_kid_block(i):
    """Computes the ith block of the KID estimate."""
    r_s = inds_r[i]
    r_e = inds_r[i + 1]
    r = real_activations[r_s:r_e]
    m = math_ops.cast(r_e - r_s, dtype)

    g_s = inds_g[i]
    g_e = inds_g[i + 1]
    g = generated_activations[g_s:g_e]
    n = math_ops.cast(g_e - g_s, dtype)

    k_rr = (math_ops.matmul(r, r, transpose_b=True) / dim + 1)**3
    k_rg = (math_ops.matmul(r, g, transpose_b=True) / dim + 1)**3
    k_gg = (math_ops.matmul(g, g, transpose_b=True) / dim + 1)**3
    return (-2 * math_ops.reduce_mean(k_rg) +
            (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) +
            (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1)))
Example #8
0
 def _trace(cov):
   if len(cov.shape) == 1:
     # Diagonal matrix.
     return math_ops.reduce_sum(cov)
   elif len(cov.shape) == 2:
     # Full matrix.
     return math_ops.trace(cov)
   else:
     raise ValueError(
         "What's the trace of a Tensor of rank %d?" % len(cov.shape))
 def _trace(cov):
   if len(cov.shape) == 1:
     # Diagonal matrix.
     return math_ops.reduce_sum(cov)
   elif len(cov.shape) == 2:
     # Full matrix.
     return math_ops.trace(cov)
   else:
     raise ValueError(
         "What's the trace of a Tensor of rank %d?" % len(cov.shape))
Example #10
0
def frechet_classifier_distance_from_activations_new(real_activations,
                                                     generated_activations):

    real_activations.shape.assert_has_rank(2)
    generated_activations.shape.assert_has_rank(2)

    activations_dtype = real_activations.dtype
    if activations_dtype != dtypes.float64:
        real_activations = math_ops.to_double(real_activations)
        generated_activations = math_ops.to_double(generated_activations)

    # Compute mean and covariance matrices of activations.
    m = math_ops.reduce_mean(real_activations, 0)
    m_v = math_ops.reduce_mean(generated_activations, 0)
    num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])

    # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
    real_centered = real_activations - m
    sigma = math_ops.matmul(real_centered, real_centered,
                            transpose_a=True) / (num_examples - 1)

    gen_centered = generated_activations - m_v
    sigma_v = math_ops.matmul(gen_centered, gen_centered,
                              transpose_a=True) / (num_examples - 1)

    # Find the Tr(sqrt(sigma sigma_v)) component of FID
    sqrt_trace_component = math_ops.trace(
        math_ops.sqrt(sigma) * sigma_v * math_ops.sqrt(sigma))

    # Compute the two components of FID.

    # First the covariance component.
    # Here, note that trace(A + B) = trace(A) + trace(B)
    trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component

    # Next the distance between means.
    mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
    fid = trace + mean
    if activations_dtype != dtypes.float64:
        fid = math_ops.cast(fid, activations_dtype)

    return fid
 def test_trace(self):
   with self.session(graph=ops.Graph()) as sess:
     sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
     operator, mat = self.operator_and_matrix(
         shapes_info, dtype, use_placeholder=use_placeholder)
     op_trace = operator.trace()
     mat_trace = math_ops.trace(mat)
     if not use_placeholder:
       self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
     op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
     self.assertAC(op_trace_v, mat_trace_v)
Example #12
0
 def test_trace(self):
     with self.session(graph=ops.Graph()) as sess:
         sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
         operator, mat = self.operator_and_matrix(
             shapes_info, dtype, use_placeholder=use_placeholder)
         op_trace = operator.trace()
         mat_trace = math_ops.trace(mat)
         if not use_placeholder:
             self.assertAllEqual(op_trace.shape, mat_trace.shape)
         op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
         self.assertAC(op_trace_v, mat_trace_v)
Example #13
0
 def _get_weights(self, hessian_shape, hessians):
   """Derives weights to be used based on hessians and multiclass strategy."""
   if hessian_shape == tensor_shape.scalar():
     # This is tree per class.
     weights = hessians
   elif len(hessian_shape.dims) == 1:
     # This is diagonal hessian.
     weights = math_ops.reduce_sum(hessians, axis=1)
   else:
     # This is full hessian.
     weights = math_ops.trace(hessians)
   return weights
Example #14
0
 def _get_weights(self, hessian_shape, hessians):
     """Derives weights to be used based on hessians and multiclass strategy."""
     if hessian_shape == tensor_shape.scalar():
         # This is tree per class.
         weights = hessians
     elif len(hessian_shape.dims) == 1:
         # This is diagonal hessian.
         weights = math_ops.reduce_sum(hessians, axis=1)
     else:
         # This is full hessian.
         weights = math_ops.trace(hessians)
     return weights
 def test_trace(self):
   self._skip_if_tests_to_skip_contains("trace")
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       for dtype in self._dtypes_to_test:
         with self.session(graph=ops.Graph()) as sess:
           sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
           operator, mat = self._operator_and_matrix(
               build_info, dtype, use_placeholder=use_placeholder)
           op_trace = operator.trace()
           mat_trace = math_ops.trace(mat)
           if not use_placeholder:
             self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
           op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
           self.assertAC(op_trace_v, mat_trace_v)
 def test_trace(self):
   self._skip_if_tests_to_skip_contains("trace")
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       for dtype in self._dtypes_to_test:
         with self.session(graph=ops.Graph()) as sess:
           sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
           operator, mat = self._operator_and_matrix(
               build_info, dtype, use_placeholder=use_placeholder)
           op_trace = operator.trace()
           mat_trace = math_ops.trace(mat)
           if not use_placeholder:
             self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
           op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
           self.assertAC(op_trace_v, mat_trace_v)
Example #17
0
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
    if dtype:
        dtype = np_utils.result_type(dtype)
    a = np_array_ops.asarray(a, dtype)

    if offset == 0:
        a_shape = a.shape
        if a_shape.rank is not None:
            rank = len(a_shape)
            if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1
                                                       or axis2 == rank - 1):
                return math_ops.trace(a)

    a = np_array_ops.diagonal(a, offset, axis1, axis2)
    return np_array_ops.sum(a, -1, dtype)
 def test_trace(self):
   self._skip_if_tests_to_skip_contains("trace")
   for use_placeholder in False, True:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         with self.test_session(graph=ops.Graph()) as sess:
           sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
           operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
               shape, dtype, use_placeholder=use_placeholder)
           op_trace = operator.trace()
           mat_trace = math_ops.trace(mat)
           if not use_placeholder:
             self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
           op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace],
                                              feed_dict=feed_dict)
           self.assertAC(op_trace_v, mat_trace_v)
Example #19
0
 def test_trace(self):
   self._skip_if_tests_to_skip_contains("trace")
   for use_placeholder in False, True:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         with self.test_session(graph=ops.Graph()) as sess:
           sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
           operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
               shape, dtype, use_placeholder=use_placeholder)
           op_trace = operator.trace()
           mat_trace = math_ops.trace(mat)
           if not use_placeholder:
             self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
           op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace],
                                              feed_dict=feed_dict)
           self.assertAC(op_trace_v, mat_trace_v)
Example #20
0
def einsum(equation, *inputs, **kwargs):
    name = kwargs.pop('name', None)

    if kwargs:
        raise TypeError(
            'invalid keyword arguments for this function: ' +
            ', '.join([format(key) for key in sorted(list(kwargs.keys()))]))

    with ops.name_scope(name, 'einsum', [equation, inputs]):
        inputs = list(inputs)

        input_shapes = [x.get_shape() for x in inputs]
        input_axis_labels, output_axis_labels = special_math_ops._einsum_parse_and_resolve_equation(
            equation, input_shapes)

        axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

        for a in axis_labels:
            for input_labels in input_axis_labels:
                if (len(input_axis_labels) == 1 and input_labels.count(a) == 2
                        and input_labels == input_labels[::-1]
                        and '->' not in equation):
                    return math_ops.trace(inputs[0])
                if input_labels.count(a) > 1:
                    raise ValueError(
                        'Subscript not supported: an axis appears more than once: %s'
                        % input_labels)

        for a in axis_labels:

            input_count = sum(1 for s in input_axis_labels if a in s)

            if input_count > 2 and a not in output_axis_labels:
                tf.logging.warn(
                    'Falling back to exponential-space implementation of einsum()'
                    ' because index "%s" is summed over more than two inputs.',
                    a)
                return special_math_ops._exponential_space_einsum(
                    equation, *inputs)

        equation = ','.join(input_axis_labels) + '->' + output_axis_labels
        if len(inputs) == 1:
            # inputs.append(inputs[0])
            inputs.append(tf.constant([0], dtype=inputs[0].dtype))
        return einsum_lib.einsum_cu_tensor(input_0=inputs[0],
                                           input_1=inputs[1],
                                           equation=equation)
def trace_sqrt_product(sigma, sigma_v):
  """Find the trace of the positive sqrt of product of covariance matrices.

  '_symmetric_matrix_square_root' only works for symmetric matrices, so we
  cannot just take _symmetric_matrix_square_root(sigma * sigma_v).
  ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily).

  Let sigma = A A so A = sqrt(sigma), and sigma_v = B B.
  We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B))
  Note the following properties:
  (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1)
     => eigenvalues(A A B B) = eigenvalues (A B B A)
  (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2))
     => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A))
  (iii) forall M: trace(M) = sum(eigenvalues(M))
     => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v)))
                                   = sum(sqrt(eigenvalues(A B B A)))
                                   = sum(eigenvalues(sqrt(A B B A)))
                                   = trace(sqrt(A B B A))
                                   = trace(sqrt(A sigma_v A))
  A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can**
  use the _symmetric_matrix_square_root function to find the roots of these
  matrices.

  Args:
    sigma: a square, symmetric, real, positive semi-definite covariance matrix
    sigma_v: same as sigma

  Returns:
    The trace of the positive square root of sigma*sigma_v
  """

  # Note sqrt_sigma is called "A" in the proof above
  sqrt_sigma = _symmetric_matrix_square_root(sigma)

  # This is sqrt(A sigma_v A) above
  sqrt_a_sigmav_a = math_ops.matmul(
      sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma))

  return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
def trace_sqrt_product(sigma, sigma_v):
    """Find the trace of the positive sqrt of product of covariance matrices.

  '_symmetric_matrix_square_root' only works for symmetric matrices, so we
  cannot just take _symmetric_matrix_square_root(sigma * sigma_v).
  ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily).

  Let sigma = A A so A = sqrt(sigma), and sigma_v = B B.
  We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B))
  Note the following properties:
  (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1)
     => eigenvalues(A A B B) = eigenvalues (A B B A)
  (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2))
     => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A))
  (iii) forall M: trace(M) = sum(eigenvalues(M))
     => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v)))
                                   = sum(sqrt(eigenvalues(A B B A)))
                                   = sum(eigenvalues(sqrt(A B B A)))
                                   = trace(sqrt(A B B A))
                                   = trace(sqrt(A sigma_v A))
  A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can**
  use the _symmetric_matrix_square_root function to find the roots of these
  matrices.

  Args:
    sigma: a square, symmetric, real, positive semi-definite covariance matrix
    sigma_v: same as sigma

  Returns:
    The trace of the positive square root of sigma*sigma_v
  """

    # Note sqrt_sigma is called "A" in the proof above
    sqrt_sigma = _symmetric_matrix_square_root(sigma)

    # This is sqrt(A sigma_v A) above
    sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma,
                                      math_ops.matmul(sigma_v, sqrt_sigma))

    return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
def frechet_classifier_distance_from_activations(real_activations,
                                                 generated_activations):

    real_activations.shape.assert_has_rank(2)
    generated_activations.shape.assert_has_rank(2)

    activations_dtype = real_activations.dtype
    if activations_dtype != dtypes.float64:
        real_activations = math_ops.to_double(real_activations)
        generated_activations = math_ops.to_double(generated_activations)

    m = math_ops.reduce_mean(real_activations, 0)
    m_w = math_ops.reduce_mean(generated_activations, 0)
    num_examples_real = math_ops.to_double(
        array_ops.shape(real_activations)[0])
    num_examples_generated = math_ops.to_double(
        array_ops.shape(generated_activations)[0])

    real_centered = real_activations - m
    sigma = math_ops.matmul(real_centered, real_centered,
                            transpose_a=True) / (num_examples_real - 1)

    gen_centered = generated_activations - m_w
    sigma_w = math_ops.matmul(gen_centered, gen_centered,
                              transpose_a=True) / (num_examples_generated - 1)

    sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)

    trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component

    mean = math_ops.reduce_sum(math_ops.squared_difference(m, m_w))
    fid = trace + mean
    if activations_dtype != dtypes.float64:
        fid = math_ops.cast(fid, activations_dtype)

    return fid
Example #24
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to calculate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:1138
      # https://github.com/imdone/tensorflow/issues/1139
      # transposing explicitly.
      # TODO (rmlarsen): multi-thread tf.matrix_solve. id:1249
      # https://github.com/imdone/tensorflow/issues/1250
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO (yifanchen): Add special handling for single shard without using id:845
        # https://github.com/imdone/tensorflow/issues/846
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
Example #25
0
def frechet_classifier_distance(real_images,
                                generated_images,
                                classifier_fn,
                                num_batches=1):
    """Classifier distance for evaluating a conditional generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images.

  Args:
    real_images: Real images to use to compute Frechet Inception distance.
    generated_images: Generated images to use to compute Frechet Inception
      distance.
    classifier_fn: A function that takes images and produces activations
      based on a classifier.
    num_batches: Number of batches to split images in to in order to
      efficiently run them through the classifier network.

  Returns:
    The Frechet Inception distance. A floating-point scalar.
  """

    real_images_list = array_ops.split(real_images,
                                       num_or_size_splits=num_batches)
    generated_images_list = array_ops.split(generated_images,
                                            num_or_size_splits=num_batches)

    imgs = array_ops.stack(real_images_list + generated_images_list)

    # Compute the activations using the memory-efficient `map_fn`.
    activations = functional_ops.map_fn(fn=classifier_fn,
                                        elems=imgs,
                                        parallel_iterations=1,
                                        back_prop=False,
                                        swap_memory=True,
                                        name='RunClassifier')

    # Split the activations by the real and generated images.
    real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)

    # Ensure the activations have the right shapes.
    real_a = array_ops.concat(array_ops.unstack(real_a), 0)
    gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
    real_a.shape.assert_has_rank(2)
    gen_a.shape.assert_has_rank(2)

    # Compute mean and covariance matrices of activations.
    m = math_ops.reduce_mean(real_a, 0)
    m_v = math_ops.reduce_mean(gen_a, 0)
    dim = math_ops.to_float(array_ops.shape(m)[0])
    sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
    sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim

    # Take matrix square root of the product of covariance matrices.
    sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))

    # Compute the two components of FID.
    trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc)
    mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
    fid = trace + mean

    return fid
def frechet_classifier_distance_from_activations(
    real_activations, generated_activations):
  """Classifier distance for evaluating a generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images.

  Note that when computed using sample means and sample covariance matrices,
  Frechet distance is biased. It is more biased for small sample sizes. (e.g.
  even if the two distributions are the same, for a small sample size, the
  expected Frechet distance is large). It is important to use the same
  sample size to compute frechet classifier distance when comparing two
  generative models.

  Args:
    real_activations: Real images to use to compute Frechet Inception distance.
    generated_activations: Generated images to use to compute Frechet Inception
      distance.

  Returns:
    The Frechet Inception distance. A floating-point scalar of the same type
    as the output of the activations.
  """
  real_activations.shape.assert_has_rank(2)
  generated_activations.shape.assert_has_rank(2)

  activations_dtype = real_activations.dtype
  if activations_dtype != dtypes.float64:
    real_activations = math_ops.to_double(real_activations)
    generated_activations = math_ops.to_double(generated_activations)

  # Compute mean and covariance matrices of activations.
  m = math_ops.reduce_mean(real_activations, 0)
  m_v = math_ops.reduce_mean(generated_activations, 0)
  num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])

  # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
  real_centered = real_activations - m
  sigma = math_ops.matmul(
      real_centered, real_centered, transpose_a=True) / (num_examples - 1)

  gen_centered = generated_activations - m_v
  sigma_v = math_ops.matmul(
      gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)

  # Find the Tr(sqrt(sigma sigma_v)) component of FID
  sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)

  # Compute the two components of FID.

  # First the covariance component.
  # Here, note that trace(A + B) = trace(A) + trace(B)
  trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component

  # Next the distance between means.
  mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
  fid = trace + mean
  if activations_dtype != dtypes.float64:
    fid = math_ops.cast(fid, activations_dtype)

  return fid
def frechet_classifier_distance(real_images,
                                generated_images,
                                classifier_fn,
                                num_batches=1):
    """Classifier distance for evaluating a generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images.

  Note that when computed using sample means and sample covariance matrices,
  Frechet distance is biased. It is more biased for small sample sizes. (e.g.
  even if the two distributions are the same, for a small sample size, the
  expected Frechet distance is large). It is important to use the same
  sample size to compute frechet classifier distance when comparing two
  generative models.

  Args:
    real_images: Real images to use to compute Frechet Inception distance.
    generated_images: Generated images to use to compute Frechet Inception
      distance.
    classifier_fn: A function that takes images and produces activations
      based on a classifier.
    num_batches: Number of batches to split images in to in order to
      efficiently run them through the classifier network.

  Returns:
    The Frechet Inception distance. A floating-point scalar.
  """

    real_images_list = array_ops.split(real_images,
                                       num_or_size_splits=num_batches)
    generated_images_list = array_ops.split(generated_images,
                                            num_or_size_splits=num_batches)

    imgs = array_ops.stack(real_images_list + generated_images_list)

    # Compute the activations using the memory-efficient `map_fn`.
    activations = functional_ops.map_fn(fn=classifier_fn,
                                        elems=imgs,
                                        parallel_iterations=1,
                                        back_prop=False,
                                        swap_memory=True,
                                        name='RunClassifier')

    # Split the activations by the real and generated images.
    real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)

    # Ensure the activations have the right shapes.
    real_a = array_ops.concat(array_ops.unstack(real_a), 0)
    gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
    real_a.shape.assert_has_rank(2)
    gen_a.shape.assert_has_rank(2)

    # Compute mean and covariance matrices of activations.
    m = math_ops.reduce_mean(real_a, 0)
    m_v = math_ops.reduce_mean(gen_a, 0)
    num_examples = math_ops.to_float(array_ops.shape(real_a)[0])

    # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
    sigma = math_ops.matmul(real_a - m, real_a - m,
                            transpose_a=True) / (num_examples - 1)

    sigma_v = math_ops.matmul(gen_a - m_v, gen_a - m_v,
                              transpose_a=True) / (num_examples - 1)

    # Find the Tr(sqrt(sigma sigma_v)) component of FID
    sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)

    # Compute the two components of FID.

    # First the covariance component.
    # Here, note that trace(A + B) = trace(A) + trace(B)
    trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component

    # Next the distance between means.
    mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
    fid = trace + mean

    return fid
Example #28
0
def _einsum_v1(equation, *inputs, **kwargs):
    """Legacy implementation of einsum without using EinsumOp."""
    name = kwargs.pop('name', None)
    if kwargs:
        raise TypeError(
            'invalid keyword arguments for this function: ' +
            ', '.join([format(key) for key in sorted(list(kwargs.keys()))]))
    with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
        inputs = list(inputs)
        input_shapes = [x.shape for x in inputs]
        input_axis_labels, output_axis_labels = (
            _einsum_v1_parse_and_resolve_equation(equation, input_shapes))

        axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

        for a in axis_labels:
            for input_labels in input_axis_labels:
                if (len(input_axis_labels) == 1 and input_labels.count(a) == 2
                        and input_labels == input_labels[::-1]
                        and '->' not in equation):
                    return math_ops.trace(inputs[0])
                if input_labels.count(a) > 1:
                    raise ValueError(
                        'Subscript not supported: an axis appears more than once: %s'
                        % input_labels)
        for a in axis_labels:
            input_count = sum(1 for s in input_axis_labels if a in s)
            if input_count > 2 and a not in output_axis_labels:
                logging.warn(
                    'Falling back to exponential-space implementation of einsum()'
                    ' because index "%s" is summed over more than two inputs.',
                    a)
                return _exponential_space_einsum_v1(equation, *inputs)

        # Use xla_einsum if executing on TPU and if the operation is a 2 input
        # einsum supported by XlaEinsumOp.
        if _enclosing_tpu_context() is not None and len(inputs) == 2:
            return gen_xla_ops.xla_einsum(
                inputs[0], inputs[1], input_axis_labels[0] + ',' +
                input_axis_labels[1] + '->' + output_axis_labels)
        temp = inputs[0]
        temp_axis_labels = input_axis_labels[0]
        for i in xrange(len(inputs) - 1):
            axes_to_sum = (
                set(temp_axis_labels)
                & set(input_axis_labels[i + 1]) - set(output_axis_labels))
            temp, temp_axis_labels = _einsum_v1_reduction(
                temp, temp_axis_labels, inputs[i + 1],
                input_axis_labels[i + 1], axes_to_sum)

        missing_indices = set(temp_axis_labels) - set(output_axis_labels)
        if missing_indices:
            axis = [
                i for i, a in enumerate(temp_axis_labels)
                if a not in output_axis_labels
            ]
            temp = math_ops.reduce_sum(temp, axis=axis)
            temp_axis_labels = ''.join(a for a in temp_axis_labels
                                       if a in output_axis_labels)
        if sorted(temp_axis_labels) != sorted(output_axis_labels):
            raise ValueError('Invalid equation: %s' % equation)

        perm = [temp_axis_labels.index(a) for a in output_axis_labels]
        return _transpose_if_necessary(temp, perm)
def frechet_classifier_distance(real_images,
                                generated_images,
                                classifier_fn,
                                num_batches=1):
  """Classifier distance for evaluating a generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images.

  Note that when computed using sample means and sample covariance matrices,
  Frechet distance is biased. It is more biased for small sample sizes. (e.g.
  even if the two distributions are the same, for a small sample size, the
  expected Frechet distance is large). It is important to use the same
  sample size to compute frechet classifier distance when comparing two
  generative models.

  Args:
    real_images: Real images to use to compute Frechet Inception distance.
    generated_images: Generated images to use to compute Frechet Inception
      distance.
    classifier_fn: A function that takes images and produces activations
      based on a classifier.
    num_batches: Number of batches to split images in to in order to
      efficiently run them through the classifier network.

  Returns:
    The Frechet Inception distance. A floating-point scalar.
  """

  real_images_list = array_ops.split(
      real_images, num_or_size_splits=num_batches)
  generated_images_list = array_ops.split(
      generated_images, num_or_size_splits=num_batches)

  imgs = array_ops.stack(real_images_list + generated_images_list)

  # Compute the activations using the memory-efficient `map_fn`.
  activations = functional_ops.map_fn(
      fn=classifier_fn,
      elems=imgs,
      parallel_iterations=1,
      back_prop=False,
      swap_memory=True,
      name='RunClassifier')

  # Split the activations by the real and generated images.
  real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)

  # Ensure the activations have the right shapes.
  real_a = array_ops.concat(array_ops.unstack(real_a), 0)
  gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
  real_a.shape.assert_has_rank(2)
  gen_a.shape.assert_has_rank(2)

  # Compute mean and covariance matrices of activations.
  m = math_ops.reduce_mean(real_a, 0)
  m_v = math_ops.reduce_mean(gen_a, 0)
  num_examples = math_ops.to_float(array_ops.shape(real_a)[0])

  # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
  sigma = math_ops.matmul(
      real_a - m, real_a - m, transpose_a=True) / (num_examples - 1)

  sigma_v = math_ops.matmul(
      gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1)

  # Find the Tr(sqrt(sigma sigma_v)) component of FID
  sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)

  # Compute the two components of FID.

  # First the covariance component.
  # Here, note that trace(A + B) = trace(A) + trace(B)
  trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component

  # Next the distance between means.
  mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
  fid = trace + mean

  return fid
 def compare(self, x):
     np_ans = np.trace(x, axis1=-2, axis2=-1)
     with self.test_session(use_gpu=True):
         tf_ans = math_ops.trace(x).eval()
     self.assertAllClose(tf_ans, np_ans)
def frechet_classifier_distance_from_activations(
    real_activations, generated_activations):
  """Classifier distance for evaluating a generative model from activations.

  This methods computes the Frechet classifier distance from activations of
  real images and generated images. This can be used independently of the
  frechet_classifier_distance() method, especially in the case of using large
  batches during evaluation where we would like precompute all of the
  activations before computing the classifier distance.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images.

  Args:
    real_activations: 2D Tensor containing activations of real data. Shape is
      [batch_size, activation_size].
    generated_activations: 2D Tensor containing activations of generated data.
      Shape is [batch_size, activation_size].

  Returns:
   The Frechet Inception distance. A floating-point scalar of the same type
   as the output of the activations.

  """
  real_activations.shape.assert_has_rank(2)
  generated_activations.shape.assert_has_rank(2)

  activations_dtype = real_activations.dtype
  if activations_dtype != dtypes.float64:
    real_activations = math_ops.to_double(real_activations)
    generated_activations = math_ops.to_double(generated_activations)

  # Compute mean and covariance matrices of activations.
  m = math_ops.reduce_mean(real_activations, 0)
  m_v = math_ops.reduce_mean(generated_activations, 0)
  num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])

  # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
  real_centered = real_activations - m
  sigma = math_ops.matmul(
      real_centered, real_centered, transpose_a=True) / (num_examples - 1)

  gen_centered = generated_activations - m_v
  sigma_v = math_ops.matmul(
      gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)

  # Find the Tr(sqrt(sigma sigma_v)) component of FID
  sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)

  # Compute the two components of FID.

  # First the covariance component.
  # Here, note that trace(A + B) = trace(A) + trace(B)
  trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component

  # Next the distance between means.
  mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
  fid = trace + mean
  if activations_dtype != dtypes.float64:
    fid = math_ops.cast(fid, activations_dtype)

  return fid
Example #32
0
def trace_sqrt_product(cov1, cov2):
    sqrt_cov1 = sym_matrix_sqrt(cov1)
    temp = math_ops.matmul(sqrt_cov1, math_ops.matmul(cov2, sqrt_cov1))

    return math_ops.trace(sym_matrix_sqrt(temp))
Example #33
0
  def _process_input_helper(self,
                            update_row_factors,
                            sp_input=None,
                            transpose_input=False,
                            row_weights=None):
    """Creates the graph for processing a sparse slice of input.

    Args:
      update_row_factors: if True, update or project the row_factors, else
        update or project the column factors.
      sp_input: Please refer to comments for update_row_factors,
        update_col_factors, project_row_factors, and project_col_factors for
        restrictions.
      transpose_input: If True, the input is logically transposed and then the
        corresponding rows/columns of the transposed input are updated.
      row_weights: If not None, this is the row/column weights to be used for
        the update or projection. If None, use the corresponding weights from
        the model. Note that the feature (column/row) weights will be
        determined by the model. When not None, it can either be a scalar or
        a rank-1 tensor with the same number of elements as the number of rows
        of columns to be updated/projected.

    Returns:
      A tuple consisting of the following elements:
      new_values: New values for the row/column factors.
      update_op: An op that assigns the newly computed values to the row/column
        factors.
      unregularized_loss: A tensor (scalar) that contains the normalized
        minibatch loss corresponding to sp_input, without the regularization
        term. Add the regularization term below to yield the loss.
      regularization: A tensor (scalar) that contains the normalized
        regularization term for the minibatch loss corresponding to sp_input.
      sum_weights: The sum of the weights corresponding to sp_input. This
        can be used with unregularized loss to caluclate the root weighted
        squared error.
    """
    assert isinstance(sp_input, sparse_tensor.SparseTensor)

    if update_row_factors:
      left = self._row_factors
      right_factors = self._col_factors_cache
      row_wt = self._row_wt_cache
      col_wt = self._col_wt_cache
      total_rows = self._input_rows
      total_cols = self._input_cols
      sharding_func = WALSModel._get_sharding_func(self._input_rows,
                                                   self._num_row_shards)
      gramian = self._col_gramian_cache
    else:
      left = self._col_factors
      right_factors = self._row_factors_cache
      row_wt = self._col_wt_cache
      col_wt = self._row_wt_cache
      total_rows = self._input_cols
      total_cols = self._input_rows
      sharding_func = WALSModel._get_sharding_func(self._input_cols,
                                                   self._num_col_shards)
      gramian = self._row_gramian_cache
      transpose_input = not transpose_input

    # Note that the row indices of sp_input are based on the original full input
    # Here we reindex the rows and give them contiguous ids starting at 0.
    # We use tf.unique to achieve this reindexing. Note that this is done so
    # that the downstream kernel can assume that the input is "dense" along the
    # row dimension.
    row_ids, col_ids = array_ops.split(
        value=sp_input.indices, num_or_size_splits=2, axis=1)
    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)

    if transpose_input:
      update_indices = update_col_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
      ]
      gather_indices = update_row_indices
    else:
      update_indices = update_row_indices
      row_shape = [
          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
      ]
      gather_indices = update_col_indices

    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
    col_shape = [num_rows]
    right = embedding_ops.embedding_lookup(
        right_factors, gather_indices, partition_strategy="div")
    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
                    if transpose_input else
                    array_ops.concat([col_shape, row_shape], 0))
    new_sp_input = sparse_tensor.SparseTensor(
        indices=new_sp_indices,
        values=sp_input.values,
        dense_shape=new_sp_shape)

    # Compute lhs and rhs of the normal equations
    total_lhs = (self._unobserved_weight * gramian)
    if self._regularization_matrix is not None:
      total_lhs += self._regularization_matrix
    if self._row_weights is None:
      # Special case of ALS. Use a much simpler update rule.
      total_rhs = (
          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
              new_sp_input, right, adjoint_a=transpose_input))
      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
      # transposing explicitly.
      # TODO(rmlarsen): multi-thread tf.matrix_solve.
      new_left_values = array_ops.transpose(
          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
    else:
      if row_weights is None:
        # TODO(yifanchen): Add special handling for single shard without using
        # embedding_lookup and perform benchmarks for those cases. Same for
        # col_weights lookup below.
        row_weights_slice = embedding_ops.embedding_lookup(
            row_wt, update_indices, partition_strategy="div")
      else:
        num_indices = array_ops.shape(update_indices)[0]
        with ops.control_dependencies(
            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
          row_weights_slice = control_flow_ops.cond(
              math_ops.equal(array_ops.rank(row_weights), 0),
              lambda: (array_ops.ones([num_indices]) * row_weights),
              lambda: math_ops.cast(row_weights, dtypes.float32))

      col_weights = embedding_ops.embedding_lookup(
          col_wt, gather_indices, partition_strategy="div")
      partial_lhs, total_rhs = (
          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
              right,
              col_weights,
              self._unobserved_weight,
              row_weights_slice,
              new_sp_input.indices,
              new_sp_input.values,
              num_rows,
              transpose_input,
              name="wals_compute_partial_lhs_rhs"))
      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
      total_rhs = array_ops.expand_dims(total_rhs, -1)
      new_left_values = array_ops.squeeze(
          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])

    update_op_name = "row_update" if update_row_factors else "col_update"
    update_op = self.scatter_update(
        left,
        update_indices,
        new_left_values,
        sharding_func,
        name=update_op_name)

    # Create the loss subgraph
    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
                     if transpose_input else new_sp_input)
    # sp_approx is the low rank estimate of the input matrix, formed by
    # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
    sp_approx_vals = gen_factorization_ops.masked_matmul(
        new_left_values,
        right,
        loss_sp_input.indices,
        transpose_a=False,
        transpose_b=True)
    sp_approx = sparse_tensor.SparseTensor(
        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
    sp_approx_sq = math_ops.square(sp_approx)
    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
    sp_residual_sq = math_ops.square(sp_residual)
    row_wt_mat = (constant_op.constant(0.)
                  if self._row_weights is None else array_ops.expand_dims(
                      row_weights_slice, 1))
    col_wt_mat = (constant_op.constant(0.)
                  if self._col_weights is None else array_ops.expand_dims(
                      col_weights, 0))

    # We return the normalized loss
    partial_row_gramian = math_ops.matmul(
        new_left_values, new_left_values, transpose_a=True)
    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)

    unregularized_loss = (
        self._unobserved_weight * (  # pyformat line break
            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
    ) * normalization_factor

    if self._regularization is not None:
      regularization = self._regularization * (
          math_ops.trace(partial_row_gramian) * normalization_factor +
          math_ops.trace(gramian))
    else:
      regularization = constant_op.constant(0.)

    sum_weights = self._unobserved_weight * math_ops.cast(
        total_rows * total_cols, dtypes.float32)
    if self._row_weights is not None and self._col_weights is not None:
      ones = sparse_tensor.SparseTensor(
          indices=loss_sp_input.indices,
          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
          dense_shape=loss_sp_input.dense_shape)
      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
          ones * col_wt_mat)) * normalization_factor

    return (new_left_values, update_op, unregularized_loss, regularization,
            sum_weights)
Example #34
0
 def compare(self, x):
   np_ans = np.trace(x, axis1=-2, axis2=-1)
   with self.test_session(use_gpu=True):
     tf_ans = math_ops.trace(x).eval()
   self.assertAllClose(tf_ans, np_ans)
Example #35
0
def einsum(equation, *inputs, **kwargs):
  """A generalized contraction between tensors of arbitrary dimension.

  This function returns a tensor whose elements are defined by `equation`,
  which is written in a shorthand form inspired by the Einstein summation
  convention.  As an example, consider multiplying two matrices
  A and B to form a matrix C.  The elements of C are given by:

  ```
    C[i,k] = sum_j A[i,j] * B[j,k]
  ```

  The corresponding `equation` is:

  ```
    ij,jk->ik
  ```

  In general, the `equation` is obtained from the more familiar element-wise
  equation by
    1. removing variable names, brackets, and commas,
    2. replacing "*" with ",",
    3. dropping summation signs, and
    4. moving the output to the right, and replacing "=" with "->".

  Many common operations can be expressed in this way.  For example:

  ```python
  # Matrix multiplication
  >>> einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

  # Dot product
  >>> einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

  # Outer product
  >>> einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

  # Transpose
  >>> einsum('ij->ji', m)  # output[j,i] = m[i,j]

  # Trace
  >>> einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

  # Batch matrix multiplication
  >>> einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
  ```

  To enable and control broadcasting, use an ellipsis.  For example, to do
  batch matrix multiplication, you could use:

  ```python
  >>> einsum('...ij,...jk->...ik', u, v)
  ```

  This function behaves like `numpy.einsum`, but does not support:

  * Subscripts where an axis appears more than once for a single input
    (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`).

  Args:
    equation: a `str` describing the contraction, in the same format as
      `numpy.einsum`.
    *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
      be consistent with `equation`.
    name: A name for the operation (optional).

  Returns:
    The contracted `Tensor`, with shape determined by `equation`.

  Raises:
    ValueError: If
      - the format of `equation` is incorrect,
      - the number of inputs implied by `equation` does not match `len(inputs)`,
      - an axis appears in the output subscripts but not in any of the inputs,
      - the number of dimensions of an input differs from the number of
        indices in its subscript, or
      - the input shapes are inconsistent along a particular axis.
  """
  name = kwargs.pop('name', None)
  if kwargs:
    raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
        [format(key) for key in sorted(list(kwargs.keys()))]))
  with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
    inputs = list(inputs)
    input_shapes = [x.get_shape() for x in inputs]
    input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation(
        equation, input_shapes)

    axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

    for a in axis_labels:
      for input_labels in input_axis_labels:
        if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
            input_labels == input_labels[::-1] and '->' not in equation):
          return math_ops.trace(inputs[0])
        if input_labels.count(a) > 1:
          raise ValueError(
              'Subscript not supported: an axis appears more than once: %s' %
              input_labels)
    for a in axis_labels:
      input_count = sum(1 for s in input_axis_labels if a in s)
      if input_count > 2 and a not in output_axis_labels:
        logging.warn(
            'Falling back to exponential-space implementation of einsum()'
            ' because index "%s" is summed over more than two inputs.', a)
        return _exponential_space_einsum(equation, *inputs)

    # Use xla_einsum if executing on TPU and if the operation is a 2 input
    # einsum supported by XlaEinsumOp.
    if _enclosing_tpu_context() is not None and len(inputs) == 2:
      return gen_xla_ops.xla_einsum(
          inputs[0], inputs[1], input_axis_labels[0] + ',' +
          input_axis_labels[1] + '->' + output_axis_labels)
    temp = inputs[0]
    temp_axis_labels = input_axis_labels[0]
    for i in xrange(len(inputs) - 1):
      axes_to_sum = (
          set(temp_axis_labels) &
          set(input_axis_labels[i + 1]) - set(output_axis_labels))
      temp, temp_axis_labels = _einsum_reduction(
          temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1],
          axes_to_sum)


    missing_indices = set(temp_axis_labels) - set(output_axis_labels)
    if missing_indices:
      axis = [
          i for i, a in enumerate(temp_axis_labels)
          if a not in output_axis_labels
      ]
      temp = math_ops.reduce_sum(temp, axis=axis)
      temp_axis_labels = ''.join(
          a for a in temp_axis_labels if a in output_axis_labels)
    if sorted(temp_axis_labels) != sorted(output_axis_labels):
      raise ValueError('Invalid equation: %s' % equation)

    perm = [temp_axis_labels.index(a) for a in output_axis_labels]
    return _transpose_if_necessary(temp, perm)
Example #36
0
def einsum(equation, *inputs, **kwargs):
  """A generalized contraction between tensors of arbitrary dimension.

  This function returns a tensor whose elements are defined by `equation`,
  which is written in a shorthand form inspired by the Einstein summation
  convention.  As an example, consider multiplying two matrices
  A and B to form a matrix C.  The elements of C are given by:

  ```
    C[i,k] = sum_j A[i,j] * B[j,k]
  ```

  The corresponding `equation` is:

  ```
    ij,jk->ik
  ```

  In general, the `equation` is obtained from the more familiar element-wise
  equation by
    1. removing variable names, brackets, and commas,
    2. replacing "*" with ",",
    3. dropping summation signs, and
    4. moving the output to the right, and replacing "=" with "->".

  Many common operations can be expressed in this way.  For example:

  ```python
  # Matrix multiplication
  >>> einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

  # Dot product
  >>> einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

  # Outer product
  >>> einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

  # Transpose
  >>> einsum('ij->ji', m)  # output[j,i] = m[i,j]

  # Trace
  >>> einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

  # Batch matrix multiplication
  >>> einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
  ```

  To enable and control broadcasting, use an ellipsis.  For example, to do
  batch matrix multiplication, you could use:

  ```python
  >>> einsum('...ij,...jk->...ik', u, v)
  ```

  This function behaves like `numpy.einsum`, but does not support:

  * Subscripts where an axis appears more than once for a single input
    (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`).

  Args:
    equation: a `str` describing the contraction, in the same format as
      `numpy.einsum`.
    *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
      be consistent with `equation`.
    name: A name for the operation (optional).

  Returns:
    The contracted `Tensor`, with shape determined by `equation`.

  Raises:
    ValueError: If
      - the format of `equation` is incorrect,
      - the number of inputs implied by `equation` does not match `len(inputs)`,
      - an axis appears in the output subscripts but not in any of the inputs,
      - the number of dimensions of an input differs from the number of
        indices in its subscript, or
      - the input shapes are inconsistent along a particular axis.
  """
  name = kwargs.pop('name', None)
  if kwargs:
    raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
        [format(key) for key in sorted(list(kwargs.keys()))]))
  with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
    inputs = list(inputs)
    input_shapes = [x.get_shape() for x in inputs]
    input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation(
        equation, input_shapes)

    axis_labels = set(''.join(input_axis_labels) + output_axis_labels)

    for a in axis_labels:
      for input_labels in input_axis_labels:
        if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
            input_labels == input_labels[::-1] and '->' not in equation):
          return math_ops.trace(inputs[0])
        if input_labels.count(a) > 1:
          raise ValueError(
              'Subscript not supported: an axis appears more than once: %s' %
              input_labels)
    for a in axis_labels:
      input_count = sum(1 for s in input_axis_labels if a in s)
      if input_count > 2 and a not in output_axis_labels:
        logging.warn(
            'Falling back to exponential-space implementation of einsum()'
            ' because index "%s" is summed over more than two inputs.', a)
        return _exponential_space_einsum(equation, *inputs)

    temp = inputs[0]
    temp_axis_labels = input_axis_labels[0]
    for i in xrange(len(inputs) - 1):
      axes_to_sum = (
          set(temp_axis_labels) &
          set(input_axis_labels[i + 1]) - set(output_axis_labels))
      temp, temp_axis_labels = _einsum_reduction(
          temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1],
          axes_to_sum)


    missing_indices = set(temp_axis_labels) - set(output_axis_labels)
    if missing_indices:
      reduction_indices = [
          i for i, a in enumerate(temp_axis_labels)
          if a not in output_axis_labels
      ]
      temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices)
      temp_axis_labels = ''.join(
          a for a in temp_axis_labels if a in output_axis_labels)
    if sorted(temp_axis_labels) != sorted(output_axis_labels):
      raise ValueError('Invalid equation: %s' % equation)

    perm = [temp_axis_labels.index(a) for a in output_axis_labels]
    return _transpose_if_necessary(temp, perm)
def frechet_classifier_distance(real_images,
                                generated_images,
                                classifier_fn,
                                num_batches=1):
  """Classifier distance for evaluating a conditional generative model.

  This is based on the Frechet Inception distance, but for an arbitrary
  classifier.

  This technique is described in detail in https://arxiv.org/abs/1706.08500.
  Given two Gaussian distribution with means m and m_w and covariance matrices
  C and C_w, this function calcuates

  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

  which captures how different the distributions of real images and generated
  images (or more accurately, their visual features) are. Note that unlike the
  Inception score, this is a true distance and utilizes information about real
  world images.

  Args:
    real_images: Real images to use to compute Frechet Inception distance.
    generated_images: Generated images to use to compute Frechet Inception
      distance.
    classifier_fn: A function that takes images and produces activations
      based on a classifier.
    num_batches: Number of batches to split images in to in order to
      efficiently run them through the classifier network.

  Returns:
    The Frechet Inception distance. A floating-point scalar.
  """

  real_images_list = array_ops.split(
      real_images, num_or_size_splits=num_batches)
  generated_images_list = array_ops.split(
      generated_images, num_or_size_splits=num_batches)

  imgs = array_ops.stack(real_images_list + generated_images_list)

  # Compute the activations using the memory-efficient `map_fn`.
  activations = functional_ops.map_fn(
      fn=classifier_fn,
      elems=imgs,
      parallel_iterations=1,
      back_prop=False,
      swap_memory=True,
      name='RunClassifier')

  # Split the activations by the real and generated images.
  real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)

  # Ensure the activations have the right shapes.
  real_a = array_ops.concat(array_ops.unstack(real_a), 0)
  gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
  real_a.shape.assert_has_rank(2)
  gen_a.shape.assert_has_rank(2)

  # Compute mean and covariance matrices of activations.
  m = math_ops.reduce_mean(real_a, 0)
  m_v = math_ops.reduce_mean(gen_a, 0)
  dim = math_ops.to_float(array_ops.shape(m)[0])
  sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim
  sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim

  # Take matrix square root of the product of covariance matrices.
  sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v))

  # Compute the two components of FID.
  trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc)
  mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
  fid = trace + mean

  return fid
Example #38
0
 def compare(self, x):
     np_ans = np.trace(x, axis1=-2, axis2=-1)
     with self.cached_session():
         tf_ans = math_ops.trace(x).eval()
     self.assertAllClose(tf_ans, np_ans)
def frechet_classifier_distance_from_activations(real_activations,
                                                 generated_activations):
    """Classifier distance for evaluating a generative model.

    This methods computes the Frechet classifier distance from activations of
    real images and generated images. This can be used independently of the
    frechet_classifier_distance() method, especially in the case of using large
    batches during evaluation where we would like precompute all of the
    activations before computing the classifier distance.

    This technique is described in detail in https://arxiv.org/abs/1706.08500.
    Given two Gaussian distribution with means m and m_w and covariance matrices
    C and C_w, this function calculates

                  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))

    which captures how different the distributions of real images and generated
    images (or more accurately, their visual features) are. Note that unlike the
    Inception score, this is a true distance and utilizes information about real
    world images.

    Note that when computed using sample means and sample covariance matrices,
    Frechet distance is biased. It is more biased for small sample sizes. (e.g.
    even if the two distributions are the same, for a small sample size, the
    expected Frechet distance is large). It is important to use the same
    sample size to compute frechet classifier distance when comparing two
    generative models.

    Args:
      real_activations: 2D Tensor containing activations of real data. Shape is
        [batch_size, activation_size].
      generated_activations: 2D Tensor containing activations of generated data.
        Shape is [batch_size, activation_size].

    Returns:
     The Frechet Inception distance. A floating-point scalar of the same type
     as the output of the activations.

    """
    real_activations.shape.assert_has_rank(2)
    generated_activations.shape.assert_has_rank(2)

    activations_dtype = real_activations.dtype
    if activations_dtype != dtypes.float64:
        real_activations = math_ops.to_double(real_activations)
        generated_activations = math_ops.to_double(generated_activations)

    # Compute mean and covariance matrices of activations.
    m = math_ops.reduce_mean(real_activations, 0)
    m_w = math_ops.reduce_mean(generated_activations, 0)
    num_examples_real = math_ops.to_double(
        array_ops.shape(real_activations)[0])
    num_examples_generated = math_ops.to_double(
        array_ops.shape(generated_activations)[0])

    # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
    real_centered = real_activations - m
    sigma = math_ops.matmul(real_centered, real_centered,
                            transpose_a=True) / (num_examples_real - 1)

    gen_centered = generated_activations - m_w
    sigma_w = math_ops.matmul(gen_centered, gen_centered,
                              transpose_a=True) / (num_examples_generated - 1)

    # Find the Tr(sqrt(sigma sigma_w)) component of FID
    sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)

    # Compute the two components of FID.

    # First the covariance component.
    # Here, note that trace(A + B) = trace(A) + trace(B)
    trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component

    # Next the distance between means.
    mean = math_ops.reduce_sum(math_ops.squared_difference(
        m, m_w))  # Equivalent to L2 but more stable.
    fid = trace + mean
    if activations_dtype != dtypes.float64:
        fid = math_ops.cast(fid, activations_dtype)

    return fid
Example #40
0
def _compute_pi_tracenorm(left_cov, right_cov):
    # Instead of dividing by the dim of the norm, we multiply by the dim of the
    # other norm. This works out the same in the ratio.
    left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0]
    right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0]
    return math_ops.sqrt(left_norm / right_norm)