def testMaximumSpanningTreeGradient(self): """Tests the MST max score gradient.""" with self.test_session() as session: num_nodes = tf.constant([4, 3], tf.int32) scores = tf.constant([[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]], tf.int32) # pyformat: disable mst_ops.maximum_spanning_tree(num_nodes, scores, forest=False, name='MST') mst_op = session.graph.get_operation_by_name('MST') d_loss_d_max_scores = tf.constant([3, 7], tf.float32) d_loss_d_num_nodes, d_loss_d_scores = ( mst_ops.maximum_spanning_tree_gradient(mst_op, d_loss_d_max_scores)) # The num_nodes input is non-differentiable. self.assertTrue(d_loss_d_num_nodes is None) tf.logging.info('\nd_loss_d_scores=\n%s', d_loss_d_scores.eval()) self.assertAllEqual(d_loss_d_scores.eval(), [[[0, 0, 0, 3], [3, 0, 0, 0], [0, 3, 0, 0], [0, 0, 0, 3]], [[7, 0, 0, 0], [0, 0, 7, 0], [7, 0, 0, 0], [0, 0, 0, 0]]]) # pyformat: disable
def testMaximumSpanningTreeGradient(self): """Tests the MST max score gradient.""" with self.test_session() as session: num_nodes = tf.constant([4, 3], tf.int32) scores = tf.constant( [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]], tf.int32) # pyformat: disable mst_ops.maximum_spanning_tree(num_nodes, scores, forest=False, name='MST') mst_op = session.graph.get_operation_by_name('MST') d_loss_d_max_scores = tf.constant([3, 7], tf.float32) d_loss_d_num_nodes, d_loss_d_scores = ( mst_ops.maximum_spanning_tree_gradient(mst_op, d_loss_d_max_scores)) # The num_nodes input is non-differentiable. self.assertTrue(d_loss_d_num_nodes is None) tf.logging.info('\nd_loss_d_scores=\n%s', d_loss_d_scores.eval()) self.assertAllEqual( d_loss_d_scores.eval(), [[[0, 0, 0, 3], [3, 0, 0, 0], [0, 3, 0, 0], [0, 0, 0, 3]], [[7, 0, 0, 0], [0, 0, 7, 0], [7, 0, 0, 0], [0, 0, 0, 0]] ]) # pyformat: disable
def testMaximumSpanningTreeGradientError(self): """Numerically validates the max score gradient.""" with self.test_session(): # The maximum-spanning-tree-score function, as a max of linear functions, # is piecewise-linear (i.e., faceted). The numerical gradient estimate # may be inaccurate if the epsilon ball used for the estimate crosses an # edge from one facet to another. To avoid spurious errors, we manually # set the sample point so the epsilon ball fits in a facet. Or in other # words, we set the scores so there is a non-trivial margin between the # best and second-best trees. scores_raw = [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]] # pyformat: disable # Use 64-bit floats to reduce numerical error. scores = tf.constant(scores_raw, tf.float64) init_scores = np.array(scores_raw) num_nodes = tf.constant([4, 3], tf.int32) max_scores = mst_ops.maximum_spanning_tree(num_nodes, scores, forest=False)[0] gradient_error = tf.test.compute_gradient_error( scores, [2, 4, 4], max_scores, [2], init_scores) tf.logging.info('gradient_error=%s', gradient_error)
def create(self, fixed_embeddings, linked_embeddings, context_tensor_arrays, attention_tensor, during_training, stride=None): """Forwards the lengths and scores.""" check.NotNone(stride, 'MstSolverNetwork requires stride') lengths = network_units.lookup_named_tensor('lengths', linked_embeddings) lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1])) scores = network_units.lookup_named_tensor('scores', linked_embeddings) scores_bnxn = scores.tensor max_length = tf.shape(scores_bnxn)[1] scores_bxnxn = tf.reshape(scores_bnxn, [stride, max_length, max_length]) _, argmax_sources_bxn = mst_ops.maximum_spanning_tree( forest=self._attrs['forest'], num_nodes=lengths_b, scores=scores_bxnxn) argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1]) arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32) return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]
def testMaximumSpanningTreeGradientError(self): """Numerically validates the max score gradient.""" with self.test_session(): # The maximum-spanning-tree-score function, as a max of linear functions, # is piecewise-linear (i.e., faceted). The numerical gradient estimate # may be inaccurate if the epsilon ball used for the estimate crosses an # edge from one facet to another. To avoid spurious errors, we manually # set the sample point so the epsilon ball fits in a facet. Or in other # words, we set the scores so there is a non-trivial margin between the # best and second-best trees. scores_raw = [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]] # pyformat: disable # Use 64-bit floats to reduce numerical error. scores = tf.constant(scores_raw, tf.float64) init_scores = np.array(scores_raw) num_nodes = tf.constant([4, 3], tf.int32) max_scores = mst_ops.maximum_spanning_tree( num_nodes, scores, forest=False)[0] gradient_error = tf.test.compute_gradient_error( scores, [2, 4, 4], max_scores, [2], init_scores) tf.logging.info('gradient_error=%s', gradient_error)
def testMaximumSpanningTree(self): """Tests that the MST op can recover a simple tree.""" with self.test_session() as session: # The first batch element prefers 3 as root, then 3->0->1->2, for a total # score of 4+2+1=7. The second batch element is smaller and has reversed # scores, so 0 is root and 0->2->1. num_nodes = tf.constant([4, 3], tf.int32) scores = tf.constant([[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]], tf.int32) # pyformat: disable mst_outputs = mst_ops.maximum_spanning_tree( num_nodes, scores, forest=False) max_scores, argmax_sources = session.run(mst_outputs) tf.logging.info('\nmax_scores=%s\nargmax_sources=\n%s', max_scores, argmax_sources) self.assertAllEqual(max_scores, [7, 6]) self.assertAllEqual(argmax_sources, [[3, 0, 1, 3], [0, 2, 0, -1]]) # pyformat: disable
def _compute_m3n_loss(self, lengths, scores, gold): """Computes the M3N-style structured hinge loss for a batch.""" # Perform hamming-loss-augmented inference. gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2]) hamming_loss_bxnxn = 1 - gold scores_bxnxn = scores + hamming_loss_bxnxn max_scores_b, _ = mst_ops.maximum_spanning_tree( num_nodes=lengths, scores=scores_bxnxn, forest=self._attrs['forest']) return max_scores_b - gold_scores_b
def testMaximumSpanningTree(self): """Tests that the MST op can recover a simple tree.""" with self.test_session() as session: # The first batch element prefers 3 as root, then 3->0->1->2, for a total # score of 4+2+1=7. The second batch element is smaller and has reversed # scores, so 0 is root and 0->2->1. num_nodes = tf.constant([4, 3], tf.int32) scores = tf.constant( [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]], [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]], tf.int32) # pyformat: disable mst_outputs = mst_ops.maximum_spanning_tree(num_nodes, scores, forest=False) max_scores, argmax_sources = session.run(mst_outputs) tf.logging.info('\nmax_scores=%s\nargmax_sources=\n%s', max_scores, argmax_sources) self.assertAllEqual(max_scores, [7, 6]) self.assertAllEqual( argmax_sources, [[3, 0, 1, 3], [0, 2, 0, -1]]) # pyformat: disable