Example #1
0
  def testMaximumSpanningTreeGradient(self):
    """Tests the MST max score gradient."""
    with self.test_session() as session:
      num_nodes = tf.constant([4, 3], tf.int32)
      scores = tf.constant([[[0, 0, 0, 0],
                             [1, 0, 0, 0],
                             [1, 2, 0, 0],
                             [1, 2, 3, 4]],
                            [[4, 3, 2, 9],
                             [0, 0, 2, 9],
                             [0, 0, 0, 9],
                             [9, 9, 9, 9]]], tf.int32)  # pyformat: disable

      mst_ops.maximum_spanning_tree(num_nodes, scores, forest=False, name='MST')
      mst_op = session.graph.get_operation_by_name('MST')

      d_loss_d_max_scores = tf.constant([3, 7], tf.float32)
      d_loss_d_num_nodes, d_loss_d_scores = (
          mst_ops.maximum_spanning_tree_gradient(mst_op, d_loss_d_max_scores))

      # The num_nodes input is non-differentiable.
      self.assertTrue(d_loss_d_num_nodes is None)
      tf.logging.info('\nd_loss_d_scores=\n%s', d_loss_d_scores.eval())

      self.assertAllEqual(d_loss_d_scores.eval(),
                          [[[0, 0, 0, 3],
                            [3, 0, 0, 0],
                            [0, 3, 0, 0],
                            [0, 0, 0, 3]],
                           [[7, 0, 0, 0],
                            [0, 0, 7, 0],
                            [7, 0, 0, 0],
                            [0, 0, 0, 0]]])  # pyformat: disable
Example #2
0
    def testMaximumSpanningTreeGradient(self):
        """Tests the MST max score gradient."""
        with self.test_session() as session:
            num_nodes = tf.constant([4, 3], tf.int32)
            scores = tf.constant(
                [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]],
                 [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]],
                tf.int32)  # pyformat: disable

            mst_ops.maximum_spanning_tree(num_nodes,
                                          scores,
                                          forest=False,
                                          name='MST')
            mst_op = session.graph.get_operation_by_name('MST')

            d_loss_d_max_scores = tf.constant([3, 7], tf.float32)
            d_loss_d_num_nodes, d_loss_d_scores = (
                mst_ops.maximum_spanning_tree_gradient(mst_op,
                                                       d_loss_d_max_scores))

            # The num_nodes input is non-differentiable.
            self.assertTrue(d_loss_d_num_nodes is None)
            tf.logging.info('\nd_loss_d_scores=\n%s', d_loss_d_scores.eval())

            self.assertAllEqual(
                d_loss_d_scores.eval(),
                [[[0, 0, 0, 3], [3, 0, 0, 0], [0, 3, 0, 0], [0, 0, 0, 3]],
                 [[7, 0, 0, 0], [0, 0, 7, 0], [7, 0, 0, 0], [0, 0, 0, 0]]
                 ])  # pyformat: disable
Example #3
0
    def testMaximumSpanningTreeGradientError(self):
        """Numerically validates the max score gradient."""
        with self.test_session():
            # The maximum-spanning-tree-score function, as a max of linear functions,
            # is piecewise-linear (i.e., faceted).  The numerical gradient estimate
            # may be inaccurate if the epsilon ball used for the estimate crosses an
            # edge from one facet to another.  To avoid spurious errors, we manually
            # set the sample point so the epsilon ball fits in a facet.  Or in other
            # words, we set the scores so there is a non-trivial margin between the
            # best and second-best trees.
            scores_raw = [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0],
                           [1, 2, 3, 4]],
                          [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9],
                           [9, 9, 9, 9]]]  # pyformat: disable

            # Use 64-bit floats to reduce numerical error.
            scores = tf.constant(scores_raw, tf.float64)
            init_scores = np.array(scores_raw)

            num_nodes = tf.constant([4, 3], tf.int32)
            max_scores = mst_ops.maximum_spanning_tree(num_nodes,
                                                       scores,
                                                       forest=False)[0]

            gradient_error = tf.test.compute_gradient_error(
                scores, [2, 4, 4], max_scores, [2], init_scores)
            tf.logging.info('gradient_error=%s', gradient_error)
Example #4
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Forwards the lengths and scores."""
        check.NotNone(stride, 'MstSolverNetwork requires stride')

        lengths = network_units.lookup_named_tensor('lengths',
                                                    linked_embeddings)
        lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))

        scores = network_units.lookup_named_tensor('scores', linked_embeddings)
        scores_bnxn = scores.tensor
        max_length = tf.shape(scores_bnxn)[1]
        scores_bxnxn = tf.reshape(scores_bnxn,
                                  [stride, max_length, max_length])

        _, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
            forest=self._attrs['forest'],
            num_nodes=lengths_b,
            scores=scores_bxnxn)
        argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
        arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)

        return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]
Example #5
0
  def testMaximumSpanningTreeGradientError(self):
    """Numerically validates the max score gradient."""
    with self.test_session():
      # The maximum-spanning-tree-score function, as a max of linear functions,
      # is piecewise-linear (i.e., faceted).  The numerical gradient estimate
      # may be inaccurate if the epsilon ball used for the estimate crosses an
      # edge from one facet to another.  To avoid spurious errors, we manually
      # set the sample point so the epsilon ball fits in a facet.  Or in other
      # words, we set the scores so there is a non-trivial margin between the
      # best and second-best trees.
      scores_raw = [[[0, 0, 0, 0],
                     [1, 0, 0, 0],
                     [1, 2, 0, 0],
                     [1, 2, 3, 4]],
                    [[4, 3, 2, 9],
                     [0, 0, 2, 9],
                     [0, 0, 0, 9],
                     [9, 9, 9, 9]]]  # pyformat: disable

      # Use 64-bit floats to reduce numerical error.
      scores = tf.constant(scores_raw, tf.float64)
      init_scores = np.array(scores_raw)

      num_nodes = tf.constant([4, 3], tf.int32)
      max_scores = mst_ops.maximum_spanning_tree(
          num_nodes, scores, forest=False)[0]

      gradient_error = tf.test.compute_gradient_error(
          scores, [2, 4, 4], max_scores, [2], init_scores)
      tf.logging.info('gradient_error=%s', gradient_error)
Example #6
0
  def testMaximumSpanningTree(self):
    """Tests that the MST op can recover a simple tree."""
    with self.test_session() as session:
      # The first batch element prefers 3 as root, then 3->0->1->2, for a total
      # score of 4+2+1=7.  The second batch element is smaller and has reversed
      # scores, so 0 is root and 0->2->1.
      num_nodes = tf.constant([4, 3], tf.int32)
      scores = tf.constant([[[0, 0, 0, 0],
                             [1, 0, 0, 0],
                             [1, 2, 0, 0],
                             [1, 2, 3, 4]],
                            [[4, 3, 2, 9],
                             [0, 0, 2, 9],
                             [0, 0, 0, 9],
                             [9, 9, 9, 9]]], tf.int32)  # pyformat: disable

      mst_outputs = mst_ops.maximum_spanning_tree(
          num_nodes, scores, forest=False)
      max_scores, argmax_sources = session.run(mst_outputs)
      tf.logging.info('\nmax_scores=%s\nargmax_sources=\n%s', max_scores,
                      argmax_sources)

      self.assertAllEqual(max_scores, [7, 6])
      self.assertAllEqual(argmax_sources, [[3, 0, 1, 3],
                                           [0, 2, 0, -1]])  # pyformat: disable
Example #7
0
 def _compute_m3n_loss(self, lengths, scores, gold):
   """Computes the M3N-style structured hinge loss for a batch."""
   # Perform hamming-loss-augmented inference.
   gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2])
   hamming_loss_bxnxn = 1 - gold
   scores_bxnxn = scores + hamming_loss_bxnxn
   max_scores_b, _ = mst_ops.maximum_spanning_tree(
       num_nodes=lengths, scores=scores_bxnxn, forest=self._attrs['forest'])
   return max_scores_b - gold_scores_b
Example #8
0
 def _compute_m3n_loss(self, lengths, scores, gold):
   """Computes the M3N-style structured hinge loss for a batch."""
   # Perform hamming-loss-augmented inference.
   gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2])
   hamming_loss_bxnxn = 1 - gold
   scores_bxnxn = scores + hamming_loss_bxnxn
   max_scores_b, _ = mst_ops.maximum_spanning_tree(
       num_nodes=lengths, scores=scores_bxnxn, forest=self._attrs['forest'])
   return max_scores_b - gold_scores_b
Example #9
0
    def testMaximumSpanningTree(self):
        """Tests that the MST op can recover a simple tree."""
        with self.test_session() as session:
            # The first batch element prefers 3 as root, then 3->0->1->2, for a total
            # score of 4+2+1=7.  The second batch element is smaller and has reversed
            # scores, so 0 is root and 0->2->1.
            num_nodes = tf.constant([4, 3], tf.int32)
            scores = tf.constant(
                [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]],
                 [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]],
                tf.int32)  # pyformat: disable

            mst_outputs = mst_ops.maximum_spanning_tree(num_nodes,
                                                        scores,
                                                        forest=False)
            max_scores, argmax_sources = session.run(mst_outputs)
            tf.logging.info('\nmax_scores=%s\nargmax_sources=\n%s', max_scores,
                            argmax_sources)

            self.assertAllEqual(max_scores, [7, 6])
            self.assertAllEqual(
                argmax_sources,
                [[3, 0, 1, 3], [0, 2, 0, -1]])  # pyformat: disable
Example #10
0
  def create(self,
             fixed_embeddings,
             linked_embeddings,
             context_tensor_arrays,
             attention_tensor,
             during_training,
             stride=None):
    """Forwards the lengths and scores."""
    check.NotNone(stride, 'MstSolverNetwork requires stride')

    lengths = network_units.lookup_named_tensor('lengths', linked_embeddings)
    lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))

    scores = network_units.lookup_named_tensor('scores', linked_embeddings)
    scores_bnxn = scores.tensor
    max_length = tf.shape(scores_bnxn)[1]
    scores_bxnxn = tf.reshape(scores_bnxn, [stride, max_length, max_length])

    _, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
        forest=self._attrs['forest'], num_nodes=lengths_b, scores=scores_bxnxn)
    argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
    arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)

    return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]