def _compute_pi_tracenorm(left_cov, right_cov): """Computes the scalar constant pi for Tikhonov regularization/damping. pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: left_cov: The left Kronecker factor "covariance". right_cov: The right Kronecker factor "covariance". Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] return math_ops.sqrt(left_norm / right_norm)
def run_test(self, axes, expanded_axes=None): expanded_axes = expanded_axes if expanded_axes is not None else axes all_axes = {ax: np.random.randint(4, 12) for ax in expanded_axes if ax.isalpha()} input_vals = [] input_axes, _, _ = axes.partition('->') for idx in input_axes.split(','): shape = [all_axes[ax] for ax in idx if ax.isalpha()] input_vals.append(np.random.random(shape)) input_tensors = [constant_op.constant(val) for val in input_vals] output_tensor = special_math_ops.einsum(axes, *input_tensors) with self.session(use_gpu=True): output_value = self.evaluate(output_tensor) correct_value = 0 if axes == 'ijji': output = math_ops.trace(*input_tensors) correct_value = self.evaluate(output) else: correct_value = np.einsum(axes, *input_vals) err = np.abs(correct_value - output_value).max() self.assertLess(err, 1e-8)
def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
def compute_kid_block(i): """Computes the ith block of the KID estimate.""" r_s = inds_r[i] r_e = inds_r[i + 1] r = real_activations[r_s:r_e] m = math_ops.cast(r_e - r_s, dtype) g_s = inds_g[i] g_e = inds_g[i + 1] g = generated_activations[g_s:g_e] n = math_ops.cast(g_e - g_s, dtype) k_rr = (math_ops.matmul(r, r, transpose_b=True) / dim + 1)**3 k_rg = (math_ops.matmul(r, g, transpose_b=True) / dim + 1)**3 k_gg = (math_ops.matmul(g, g, transpose_b=True) / dim + 1)**3 return (-2 * math_ops.reduce_mean(k_rg) + (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) + (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1)))
def _trace(cov): if len(cov.shape) == 1: # Diagonal matrix. return math_ops.reduce_sum(cov) elif len(cov.shape) == 2: # Full matrix. return math_ops.trace(cov) else: raise ValueError( "What's the trace of a Tensor of rank %d?" % len(cov.shape))
def frechet_classifier_distance_from_activations_new(real_activations, generated_activations): real_activations.shape.assert_has_rank(2) generated_activations.shape.assert_has_rank(2) activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: real_activations = math_ops.to_double(real_activations) generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_v = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul(real_centered, real_centered, transpose_a=True) / (num_examples - 1) gen_centered = generated_activations - m_v sigma_v = math_ops.matmul(gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = math_ops.trace( math_ops.sqrt(sigma) * sigma_v * math_ops.sqrt(sigma)) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid
def test_trace(self): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace]) self.assertAC(op_trace_v, mat_trace_v)
def test_trace(self): with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self.operator_and_matrix( shapes_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.shape, mat_trace.shape) op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace]) self.assertAC(op_trace_v, mat_trace_v)
def _get_weights(self, hessian_shape, hessians): """Derives weights to be used based on hessians and multiclass strategy.""" if hessian_shape == tensor_shape.scalar(): # This is tree per class. weights = hessians elif len(hessian_shape.dims) == 1: # This is diagonal hessian. weights = math_ops.reduce_sum(hessians, axis=1) else: # This is full hessian. weights = math_ops.trace(hessians) return weights
def test_trace(self): self._skip_if_tests_to_skip_contains("trace") for use_placeholder in self._use_placeholder_options: for build_info in self._operator_build_infos: for dtype in self._dtypes_to_test: with self.session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace]) self.assertAC(op_trace_v, mat_trace_v)
def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring if dtype: dtype = np_utils.result_type(dtype) a = np_array_ops.asarray(a, dtype) if offset == 0: a_shape = a.shape if a_shape.rank is not None: rank = len(a_shape) if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or axis2 == rank - 1): return math_ops.trace(a) a = np_array_ops.diagonal(a, offset, axis1, axis2) return np_array_ops.sum(a, -1, dtype)
def test_trace(self): self._skip_if_tests_to_skip_contains("trace") for use_placeholder in False, True: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( shape, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace], feed_dict=feed_dict) self.assertAC(op_trace_v, mat_trace_v)
def einsum(equation, *inputs, **kwargs): name = kwargs.pop('name', None) if kwargs: raise TypeError( 'invalid keyword arguments for this function: ' + ', '.join([format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]): inputs = list(inputs) input_shapes = [x.get_shape() for x in inputs] input_axis_labels, output_axis_labels = special_math_ops._einsum_parse_and_resolve_equation( equation, input_shapes) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: tf.logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return special_math_ops._exponential_space_einsum( equation, *inputs) equation = ','.join(input_axis_labels) + '->' + output_axis_labels if len(inputs) == 1: # inputs.append(inputs[0]) inputs.append(tf.constant([0], dtype=inputs[0].dtype)) return einsum_lib.einsum_cu_tensor(input_0=inputs[0], input_1=inputs[1], equation=equation)
def trace_sqrt_product(sigma, sigma_v): """Find the trace of the positive sqrt of product of covariance matrices. '_symmetric_matrix_square_root' only works for symmetric matrices, so we cannot just take _symmetric_matrix_square_root(sigma * sigma_v). ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) Note the following properties: (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) => eigenvalues(A A B B) = eigenvalues (A B B A) (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) (iii) forall M: trace(M) = sum(eigenvalues(M)) => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) = sum(sqrt(eigenvalues(A B B A))) = sum(eigenvalues(sqrt(A B B A))) = trace(sqrt(A B B A)) = trace(sqrt(A sigma_v A)) A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** use the _symmetric_matrix_square_root function to find the roots of these matrices. Args: sigma: a square, symmetric, real, positive semi-definite covariance matrix sigma_v: same as sigma Returns: The trace of the positive square root of sigma*sigma_v """ # Note sqrt_sigma is called "A" in the proof above sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above sqrt_a_sigmav_a = math_ops.matmul( sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
def trace_sqrt_product(sigma, sigma_v): """Find the trace of the positive sqrt of product of covariance matrices. '_symmetric_matrix_square_root' only works for symmetric matrices, so we cannot just take _symmetric_matrix_square_root(sigma * sigma_v). ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) Note the following properties: (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) => eigenvalues(A A B B) = eigenvalues (A B B A) (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) (iii) forall M: trace(M) = sum(eigenvalues(M)) => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) = sum(sqrt(eigenvalues(A B B A))) = sum(eigenvalues(sqrt(A B B A))) = trace(sqrt(A B B A)) = trace(sqrt(A sigma_v A)) A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** use the _symmetric_matrix_square_root function to find the roots of these matrices. Args: sigma: a square, symmetric, real, positive semi-definite covariance matrix sigma_v: same as sigma Returns: The trace of the positive square root of sigma*sigma_v """ # Note sqrt_sigma is called "A" in the proof above sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
def frechet_classifier_distance_from_activations(real_activations, generated_activations): real_activations.shape.assert_has_rank(2) generated_activations.shape.assert_has_rank(2) activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: real_activations = math_ops.to_double(real_activations) generated_activations = math_ops.to_double(generated_activations) m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) num_examples_real = math_ops.to_double( array_ops.shape(real_activations)[0]) num_examples_generated = math_ops.to_double( array_ops.shape(generated_activations)[0]) real_centered = real_activations - m sigma = math_ops.matmul(real_centered, real_centered, transpose_a=True) / (num_examples_real - 1) gen_centered = generated_activations - m_w sigma_w = math_ops.matmul(gen_centered, gen_centered, transpose_a=True) / (num_examples_generated - 1) sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component mean = math_ops.reduce_sum(math_ops.squared_difference(m, m_w)) fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization term. Add the regularization term below to yield the loss. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. sum_weights: The sum of the weights corresponding to sp_input. This can be used with unregularized loss to calculate the root weighted squared error. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache total_rows = self._input_rows total_cols = self._input_cols sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache total_rows = self._input_cols total_cols = self._input_rows sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split( value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup( right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if transpose_input else array_ops.concat([col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor( indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization_matrix is not None: total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = ( self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO (rmlarsen): handle transposing in tf.matrix_solve instead of id:1138 # https://github.com/imdone/tensorflow/issues/1139 # transposing explicitly. # TODO (rmlarsen): multi-thread tf.matrix_solve. id:1249 # https://github.com/imdone/tensorflow/issues/1250 new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO (yifanchen): Add special handling for single shard without using id:845 # https://github.com/imdone/tensorflow/issues/846 # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = ( gen_factorization_ops.wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) update_op_name = "row_update" if update_row_factors else "col_update" update_op = self.scatter_update( left, update_indices, new_left_values, sharding_func, name=update_op_name) # Create the loss subgraph loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, loss_sp_input.indices, transpose_a=False, transpose_b=True) sp_approx = sparse_tensor.SparseTensor( loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) sp_approx_sq = math_ops.square(sp_approx) sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) sp_residual_sq = math_ops.square(sp_residual) row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else array_ops.expand_dims( row_weights_slice, 1)) col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else array_ops.expand_dims( col_weights, 0)) # We return the normalized loss partial_row_gramian = math_ops.matmul( new_left_values, new_left_values, transpose_a=True) normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) unregularized_loss = ( self._unobserved_weight * ( # pyformat line break sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) ) * normalization_factor if self._regularization is not None: regularization = self._regularization * ( math_ops.trace(partial_row_gramian) * normalization_factor + math_ops.trace(gramian)) else: regularization = constant_op.constant(0.) sum_weights = self._unobserved_weight * math_ops.cast( total_rows * total_cols, dtypes.float32) if self._row_weights is not None and self._col_weights is not None: ones = sparse_tensor.SparseTensor( indices=loss_sp_input.indices, values=array_ops.ones(array_ops.shape(loss_sp_input.values)), dense_shape=loss_sp_input.dense_shape) sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * ( ones * col_wt_mat)) * normalization_factor return (new_left_values, update_op, unregularized_loss, regularization, sum_weights)
def frechet_classifier_distance(real_images, generated_images, classifier_fn, num_batches=1): """Classifier distance for evaluating a conditional generative model. This is based on the Frechet Inception distance, but for an arbitrary classifier. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception distance. classifier_fn: A function that takes images and produces activations based on a classifier. num_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. Returns: The Frechet Inception distance. A floating-point scalar. """ real_images_list = array_ops.split(real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split(generated_images, num_or_size_splits=num_batches) imgs = array_ops.stack(real_images_list + generated_images_list) # Compute the activations using the memory-efficient `map_fn`. activations = functional_ops.map_fn(fn=classifier_fn, elems=imgs, parallel_iterations=1, back_prop=False, swap_memory=True, name='RunClassifier') # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) real_a.shape.assert_has_rank(2) gen_a.shape.assert_has_rank(2) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_a, 0) m_v = math_ops.reduce_mean(gen_a, 0) dim = math_ops.to_float(array_ops.shape(m)[0]) sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim # Take matrix square root of the product of covariance matrices. sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v)) # Compute the two components of FID. trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc) mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean return fid
def frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model. This is based on the Frechet Inception distance, but for an arbitrary classifier. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Note that when computed using sample means and sample covariance matrices, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same sample size to compute frechet classifier distance when comparing two generative models. Args: real_activations: Real images to use to compute Frechet Inception distance. generated_activations: Generated images to use to compute Frechet Inception distance. Returns: The Frechet Inception distance. A floating-point scalar of the same type as the output of the activations. """ real_activations.shape.assert_has_rank(2) generated_activations.shape.assert_has_rank(2) activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: real_activations = math_ops.to_double(real_activations) generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_v = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( real_centered, real_centered, transpose_a=True) / (num_examples - 1) gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid
def frechet_classifier_distance(real_images, generated_images, classifier_fn, num_batches=1): """Classifier distance for evaluating a generative model. This is based on the Frechet Inception distance, but for an arbitrary classifier. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Note that when computed using sample means and sample covariance matrices, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same sample size to compute frechet classifier distance when comparing two generative models. Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception distance. classifier_fn: A function that takes images and produces activations based on a classifier. num_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. Returns: The Frechet Inception distance. A floating-point scalar. """ real_images_list = array_ops.split(real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split(generated_images, num_or_size_splits=num_batches) imgs = array_ops.stack(real_images_list + generated_images_list) # Compute the activations using the memory-efficient `map_fn`. activations = functional_ops.map_fn(fn=classifier_fn, elems=imgs, parallel_iterations=1, back_prop=False, swap_memory=True, name='RunClassifier') # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) real_a.shape.assert_has_rank(2) gen_a.shape.assert_has_rank(2) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_a, 0) m_v = math_ops.reduce_mean(gen_a, 0) num_examples = math_ops.to_float(array_ops.shape(real_a)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T sigma = math_ops.matmul(real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) sigma_v = math_ops.matmul(gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean return fid
def _einsum_v1(equation, *inputs, **kwargs): """Legacy implementation of einsum without using EinsumOp.""" name = kwargs.pop('name', None) if kwargs: raise TypeError( 'invalid keyword arguments for this function: ' + ', '.join([format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [x.shape for x in inputs] input_axis_labels, output_axis_labels = ( _einsum_v1_parse_and_resolve_equation(equation, input_shapes)) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum_v1(equation, *inputs) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. if _enclosing_tpu_context() is not None and len(inputs) == 2: return gen_xla_ops.xla_einsum( inputs[0], inputs[1], input_axis_labels[0] + ',' + input_axis_labels[1] + '->' + output_axis_labels) temp = inputs[0] temp_axis_labels = input_axis_labels[0] for i in xrange(len(inputs) - 1): axes_to_sum = ( set(temp_axis_labels) & set(input_axis_labels[i + 1]) - set(output_axis_labels)) temp, temp_axis_labels = _einsum_v1_reduction( temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: axis = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join(a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) perm = [temp_axis_labels.index(a) for a in output_axis_labels] return _transpose_if_necessary(temp, perm)
def frechet_classifier_distance(real_images, generated_images, classifier_fn, num_batches=1): """Classifier distance for evaluating a generative model. This is based on the Frechet Inception distance, but for an arbitrary classifier. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Note that when computed using sample means and sample covariance matrices, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same sample size to compute frechet classifier distance when comparing two generative models. Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception distance. classifier_fn: A function that takes images and produces activations based on a classifier. num_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. Returns: The Frechet Inception distance. A floating-point scalar. """ real_images_list = array_ops.split( real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split( generated_images, num_or_size_splits=num_batches) imgs = array_ops.stack(real_images_list + generated_images_list) # Compute the activations using the memory-efficient `map_fn`. activations = functional_ops.map_fn( fn=classifier_fn, elems=imgs, parallel_iterations=1, back_prop=False, swap_memory=True, name='RunClassifier') # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) real_a.shape.assert_has_rank(2) gen_a.shape.assert_has_rank(2) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_a, 0) m_v = math_ops.reduce_mean(gen_a, 0) num_examples = math_ops.to_float(array_ops.shape(real_a)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T sigma = math_ops.matmul( real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) sigma_v = math_ops.matmul( gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean return fid
def compare(self, x): np_ans = np.trace(x, axis1=-2, axis2=-1) with self.test_session(use_gpu=True): tf_ans = math_ops.trace(x).eval() self.assertAllClose(tf_ans, np_ans)
def frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model from activations. This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large batches during evaluation where we would like precompute all of the activations before computing the classifier distance. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. generated_activations: 2D Tensor containing activations of generated data. Shape is [batch_size, activation_size]. Returns: The Frechet Inception distance. A floating-point scalar of the same type as the output of the activations. """ real_activations.shape.assert_has_rank(2) generated_activations.shape.assert_has_rank(2) activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: real_activations = math_ops.to_double(real_activations) generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_v = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( real_centered, real_centered, transpose_a=True) / (num_examples - 1) gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid
def trace_sqrt_product(cov1, cov2): sqrt_cov1 = sym_matrix_sqrt(cov1) temp = math_ops.matmul(sqrt_cov1, math_ops.matmul(cov2, sqrt_cov1)) return math_ops.trace(sym_matrix_sqrt(temp))
def _process_input_helper(self, update_row_factors, sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. Args: update_row_factors: if True, update or project the row_factors, else update or project the column factors. sp_input: Please refer to comments for update_row_factors, update_col_factors, project_row_factors, and project_col_factors for restrictions. transpose_input: If True, the input is logically transposed and then the corresponding rows/columns of the transposed input are updated. row_weights: If not None, this is the row/column weights to be used for the update or projection. If None, use the corresponding weights from the model. Note that the feature (column/row) weights will be determined by the model. When not None, it can either be a scalar or a rank-1 tensor with the same number of elements as the number of rows of columns to be updated/projected. Returns: A tuple consisting of the following elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization term. Add the regularization term below to yield the loss. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. sum_weights: The sum of the weights corresponding to sp_input. This can be used with unregularized loss to caluclate the root weighted squared error. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) if update_row_factors: left = self._row_factors right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache total_rows = self._input_rows total_cols = self._input_cols sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache else: left = self._col_factors right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache total_rows = self._input_cols total_cols = self._input_rows sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache transpose_input = not transpose_input # Note that the row indices of sp_input are based on the original full input # Here we reindex the rows and give them contiguous ids starting at 0. # We use tf.unique to achieve this reindexing. Note that this is done so # that the downstream kernel can assume that the input is "dense" along the # row dimension. row_ids, col_ids = array_ops.split( value=sp_input.indices, num_or_size_splits=2, axis=1) update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) if transpose_input: update_indices = update_col_indices row_shape = [ math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) ] gather_indices = update_row_indices else: update_indices = update_row_indices row_shape = [ math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) ] gather_indices = update_col_indices num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) col_shape = [num_rows] right = embedding_ops.embedding_lookup( right_factors, gather_indices, partition_strategy="div") new_sp_indices = array_ops.concat([row_ids, col_ids], 1) new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) if transpose_input else array_ops.concat([col_shape, row_shape], 0)) new_sp_input = sparse_tensor.SparseTensor( indices=new_sp_indices, values=sp_input.values, dense_shape=new_sp_shape) # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) if self._regularization_matrix is not None: total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = ( self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( new_sp_input, right, adjoint_a=transpose_input)) # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of # transposing explicitly. # TODO(rmlarsen): multi-thread tf.matrix_solve. new_left_values = array_ops.transpose( linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) else: if row_weights is None: # TODO(yifanchen): Add special handling for single shard without using # embedding_lookup and perform benchmarks for those cases. Same for # col_weights lookup below. row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") partial_lhs, total_rhs = ( gen_factorization_ops.wals_compute_partial_lhs_and_rhs( right, col_weights, self._unobserved_weight, row_weights_slice, new_sp_input.indices, new_sp_input.values, num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) update_op_name = "row_update" if update_row_factors else "col_update" update_op = self.scatter_update( left, update_indices, new_left_values, sharding_func, name=update_op_name) # Create the loss subgraph loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, loss_sp_input.indices, transpose_a=False, transpose_b=True) sp_approx = sparse_tensor.SparseTensor( loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) sp_approx_sq = math_ops.square(sp_approx) sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) sp_residual_sq = math_ops.square(sp_residual) row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else array_ops.expand_dims( row_weights_slice, 1)) col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else array_ops.expand_dims( col_weights, 0)) # We return the normalized loss partial_row_gramian = math_ops.matmul( new_left_values, new_left_values, transpose_a=True) normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) unregularized_loss = ( self._unobserved_weight * ( # pyformat line break sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) ) * normalization_factor if self._regularization is not None: regularization = self._regularization * ( math_ops.trace(partial_row_gramian) * normalization_factor + math_ops.trace(gramian)) else: regularization = constant_op.constant(0.) sum_weights = self._unobserved_weight * math_ops.cast( total_rows * total_cols, dtypes.float32) if self._row_weights is not None and self._col_weights is not None: ones = sparse_tensor.SparseTensor( indices=loss_sp_input.indices, values=array_ops.ones(array_ops.shape(loss_sp_input.values)), dense_shape=loss_sp_input.dense_shape) sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * ( ones * col_wt_mat)) * normalization_factor return (new_left_values, update_op, unregularized_loss, regularization, sum_weights)
def einsum(equation, *inputs, **kwargs): """A generalized contraction between tensors of arbitrary dimension. This function returns a tensor whose elements are defined by `equation`, which is written in a shorthand form inspired by the Einstein summation convention. As an example, consider multiplying two matrices A and B to form a matrix C. The elements of C are given by: ``` C[i,k] = sum_j A[i,j] * B[j,k] ``` The corresponding `equation` is: ``` ij,jk->ik ``` In general, the `equation` is obtained from the more familiar element-wise equation by 1. removing variable names, brackets, and commas, 2. replacing "*" with ",", 3. dropping summation signs, and 4. moving the output to the right, and replacing "=" with "->". Many common operations can be expressed in this way. For example: ```python # Matrix multiplication >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] # Dot product >>> einsum('i,i->', u, v) # output = sum_i u[i]*v[i] # Outer product >>> einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] # Transpose >>> einsum('ij->ji', m) # output[j,i] = m[i,j] # Trace >>> einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] # Batch matrix multiplication >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] ``` To enable and control broadcasting, use an ellipsis. For example, to do batch matrix multiplication, you could use: ```python >>> einsum('...ij,...jk->...ik', u, v) ``` This function behaves like `numpy.einsum`, but does not support: * Subscripts where an axis appears more than once for a single input (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`). Args: equation: a `str` describing the contraction, in the same format as `numpy.einsum`. *inputs: the inputs to contract (each one a `Tensor`), whose shapes should be consistent with `equation`. name: A name for the operation (optional). Returns: The contracted `Tensor`, with shape determined by `equation`. Raises: ValueError: If - the format of `equation` is incorrect, - the number of inputs implied by `equation` does not match `len(inputs)`, - an axis appears in the output subscripts but not in any of the inputs, - the number of dimensions of an input differs from the number of indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( [format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [x.get_shape() for x in inputs] input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation( equation, input_shapes) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum(equation, *inputs) # Use xla_einsum if executing on TPU and if the operation is a 2 input # einsum supported by XlaEinsumOp. if _enclosing_tpu_context() is not None and len(inputs) == 2: return gen_xla_ops.xla_einsum( inputs[0], inputs[1], input_axis_labels[0] + ',' + input_axis_labels[1] + '->' + output_axis_labels) temp = inputs[0] temp_axis_labels = input_axis_labels[0] for i in xrange(len(inputs) - 1): axes_to_sum = ( set(temp_axis_labels) & set(input_axis_labels[i + 1]) - set(output_axis_labels)) temp, temp_axis_labels = _einsum_reduction( temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: axis = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] temp = math_ops.reduce_sum(temp, axis=axis) temp_axis_labels = ''.join( a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) perm = [temp_axis_labels.index(a) for a in output_axis_labels] return _transpose_if_necessary(temp, perm)
def einsum(equation, *inputs, **kwargs): """A generalized contraction between tensors of arbitrary dimension. This function returns a tensor whose elements are defined by `equation`, which is written in a shorthand form inspired by the Einstein summation convention. As an example, consider multiplying two matrices A and B to form a matrix C. The elements of C are given by: ``` C[i,k] = sum_j A[i,j] * B[j,k] ``` The corresponding `equation` is: ``` ij,jk->ik ``` In general, the `equation` is obtained from the more familiar element-wise equation by 1. removing variable names, brackets, and commas, 2. replacing "*" with ",", 3. dropping summation signs, and 4. moving the output to the right, and replacing "=" with "->". Many common operations can be expressed in this way. For example: ```python # Matrix multiplication >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] # Dot product >>> einsum('i,i->', u, v) # output = sum_i u[i]*v[i] # Outer product >>> einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] # Transpose >>> einsum('ij->ji', m) # output[j,i] = m[i,j] # Trace >>> einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] # Batch matrix multiplication >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] ``` To enable and control broadcasting, use an ellipsis. For example, to do batch matrix multiplication, you could use: ```python >>> einsum('...ij,...jk->...ik', u, v) ``` This function behaves like `numpy.einsum`, but does not support: * Subscripts where an axis appears more than once for a single input (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`). Args: equation: a `str` describing the contraction, in the same format as `numpy.einsum`. *inputs: the inputs to contract (each one a `Tensor`), whose shapes should be consistent with `equation`. name: A name for the operation (optional). Returns: The contracted `Tensor`, with shape determined by `equation`. Raises: ValueError: If - the format of `equation` is incorrect, - the number of inputs implied by `equation` does not match `len(inputs)`, - an axis appears in the output subscripts but not in any of the inputs, - the number of dimensions of an input differs from the number of indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( [format(key) for key in sorted(list(kwargs.keys()))])) with ops.name_scope(name, 'einsum', [equation, inputs]) as name: inputs = list(inputs) input_shapes = [x.get_shape() for x in inputs] input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation( equation, input_shapes) axis_labels = set(''.join(input_axis_labels) + output_axis_labels) for a in axis_labels: for input_labels in input_axis_labels: if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and input_labels == input_labels[::-1] and '->' not in equation): return math_ops.trace(inputs[0]) if input_labels.count(a) > 1: raise ValueError( 'Subscript not supported: an axis appears more than once: %s' % input_labels) for a in axis_labels: input_count = sum(1 for s in input_axis_labels if a in s) if input_count > 2 and a not in output_axis_labels: logging.warn( 'Falling back to exponential-space implementation of einsum()' ' because index "%s" is summed over more than two inputs.', a) return _exponential_space_einsum(equation, *inputs) temp = inputs[0] temp_axis_labels = input_axis_labels[0] for i in xrange(len(inputs) - 1): axes_to_sum = ( set(temp_axis_labels) & set(input_axis_labels[i + 1]) - set(output_axis_labels)) temp, temp_axis_labels = _einsum_reduction( temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], axes_to_sum) missing_indices = set(temp_axis_labels) - set(output_axis_labels) if missing_indices: reduction_indices = [ i for i, a in enumerate(temp_axis_labels) if a not in output_axis_labels ] temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices) temp_axis_labels = ''.join( a for a in temp_axis_labels if a in output_axis_labels) if sorted(temp_axis_labels) != sorted(output_axis_labels): raise ValueError('Invalid equation: %s' % equation) perm = [temp_axis_labels.index(a) for a in output_axis_labels] return _transpose_if_necessary(temp, perm)
def frechet_classifier_distance(real_images, generated_images, classifier_fn, num_batches=1): """Classifier distance for evaluating a conditional generative model. This is based on the Frechet Inception distance, but for an arbitrary classifier. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception distance. classifier_fn: A function that takes images and produces activations based on a classifier. num_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. Returns: The Frechet Inception distance. A floating-point scalar. """ real_images_list = array_ops.split( real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split( generated_images, num_or_size_splits=num_batches) imgs = array_ops.stack(real_images_list + generated_images_list) # Compute the activations using the memory-efficient `map_fn`. activations = functional_ops.map_fn( fn=classifier_fn, elems=imgs, parallel_iterations=1, back_prop=False, swap_memory=True, name='RunClassifier') # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) real_a.shape.assert_has_rank(2) gen_a.shape.assert_has_rank(2) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_a, 0) m_v = math_ops.reduce_mean(gen_a, 0) dim = math_ops.to_float(array_ops.shape(m)[0]) sigma = math_ops.matmul(real_a - m, real_a - m, transpose_b=True) / dim sigma_v = math_ops.matmul(gen_a - m, gen_a - m, transpose_b=True) / dim # Take matrix square root of the product of covariance matrices. sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v)) # Compute the two components of FID. trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc) mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean return fid
def compare(self, x): np_ans = np.trace(x, axis1=-2, axis2=-1) with self.cached_session(): tf_ans = math_ops.trace(x).eval() self.assertAllClose(tf_ans, np_ans)
def frechet_classifier_distance_from_activations(real_activations, generated_activations): """Classifier distance for evaluating a generative model. This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large batches during evaluation where we would like precompute all of the activations before computing the classifier distance. This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calculates |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. Note that when computed using sample means and sample covariance matrices, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same sample size to compute frechet classifier distance when comparing two generative models. Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. generated_activations: 2D Tensor containing activations of generated data. Shape is [batch_size, activation_size]. Returns: The Frechet Inception distance. A floating-point scalar of the same type as the output of the activations. """ real_activations.shape.assert_has_rank(2) generated_activations.shape.assert_has_rank(2) activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: real_activations = math_ops.to_double(real_activations) generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) num_examples_real = math_ops.to_double( array_ops.shape(real_activations)[0]) num_examples_generated = math_ops.to_double( array_ops.shape(generated_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul(real_centered, real_centered, transpose_a=True) / (num_examples_real - 1) gen_centered = generated_activations - m_w sigma_w = math_ops.matmul(gen_centered, gen_centered, transpose_a=True) / (num_examples_generated - 1) # Find the Tr(sqrt(sigma sigma_w)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. mean = math_ops.reduce_sum(math_ops.squared_difference( m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid
def _compute_pi_tracenorm(left_cov, right_cov): # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] return math_ops.sqrt(left_norm / right_norm)