def testMaximumSpanningTreeGradient(self): """Tests the MST max score gradient.""" with self.test_session() as session: num_nodes = constant_op.constant([4, 3], dtypes.int32) scores = constant_op.constant( [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]], dtypes.int32) # pyformat: disable mst_ops.max_spanning_tree(num_nodes, scores, forest=False, name='MST') mst_op = session.graph.get_operation_by_name('MST') d_loss_d_max_scores = constant_op.constant([3, 7], dtypes.float32) d_loss_d_num_nodes, d_loss_d_scores = ( mst_ops.max_spanning_tree_gradient(mst_op, d_loss_d_max_scores)) # The num_nodes input is non-differentiable. self.assertIs(d_loss_d_num_nodes, None) self.assertAllEqual( d_loss_d_scores.eval(), [[[0, 0, 0, 3], [3, 0, 0, 0], [0, 3, 0, 0], [0, 0, 0, 3]], [[7, 0, 0, 0], [0, 0, 7, 0], [7, 0, 0, 0], [0, 0, 0, 0]] ]) # pyformat: disable
def testMaximumSpanningTreeGradientError(self): """Numerically validates the max score gradient.""" with self.test_session(): # The maximum-spanning-tree-score function, as a max of linear functions, # is piecewise-linear (i.e., faceted). The numerical gradient estimate # may be inaccurate if the epsilon ball used for the estimate crosses an # edge from one facet to another. To avoid spurious errors, we manually # set the sample point so the epsilon ball fits in a facet. Or in other # words, we set the scores so there is a non-trivial margin between the # best and second-best trees. scores_raw = [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]] # pyformat: disable # Use 64-bit floats to reduce numerical error. scores = constant_op.constant(scores_raw, dtypes.float64) init_scores = np.array(scores_raw) num_nodes = constant_op.constant([4, 3], dtypes.int32) max_scores = mst_ops.max_spanning_tree(num_nodes, scores, forest=False)[0] gradient_error = test.compute_gradient_error( scores, [2, 4, 4], max_scores, [2], init_scores) self.assertIsNot(gradient_error, None)
def testMaximumSpanningTree(self): """Tests that the MST op can recover a simple tree.""" # The first batch element prefers 3 as root, then 3->0->1->2, for a total # score of 4+2+1=7. The second batch element is smaller and has reversed # scores, so 0 is root and 0->2->1. num_nodes = constant_op.constant([4, 3], dtypes.int32) scores = constant_op.constant( [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]], dtypes.int32) # pyformat: disable (max_scores, argmax_sources) = mst_ops.max_spanning_tree(num_nodes, scores, forest=False) self.assertAllEqual(max_scores, [7, 6]) self.assertAllEqual(argmax_sources, [[3, 0, 1, 3], [0, 2, 0, -1]]) # pyformat: disable