예제 #1
0
    def testArcSourcePotentialsFromTokens(self):
        with self.test_session():
            tokens = tf.constant([[[4, 5, 6], [5, 6, 7], [6, 7, 8]],
                                  [[6, 7, 8], [5, 6, 7], [4, 5, 6]]],
                                 tf.float32)
            weights = tf.constant([2, 3, 5], tf.float32)

            arcs = digraph_ops.ArcSourcePotentialsFromTokens(tokens, weights)

            self.assertAllEqual(arcs.eval(),
                                [[[53, 53, 53], [63, 63, 63], [73, 73, 73]],
                                 [[73, 73, 73], [63, 63, 63], [53, 53, 53]]])
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride,
            'BiaffineDigraphNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # TODO(googleuser): Add dropout during training.
        del during_training

        # Retrieve (possibly averaged) weights.
        weights_arc = self._component.get_variable('weights_arc')
        weights_source = self._component.get_variable('weights_source')
        root = self._component.get_variable('root')

        # Extract the source and target token activations.  Use |stride| to collapse
        # batch and beam into a single dimension.
        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        source_tokens_bxnxs = tf.reshape(sources.tensor,
                                         [stride, -1, self._source_dim])
        target_tokens_bxnxt = tf.reshape(targets.tensor,
                                         [stride, -1, self._target_dim])
        num_tokens = tf.shape(source_tokens_bxnxs)[1]

        # Compute the arc, source, and root potentials.
        arcs_bxnxn = digraph_ops.ArcPotentialsFromTokens(
            source_tokens_bxnxs, target_tokens_bxnxt, weights_arc)
        sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens(
            source_tokens_bxnxs, weights_source)
        roots_bxn = digraph_ops.RootPotentialsFromTokens(
            root, target_tokens_bxnxt, weights_arc, weights_source)

        # Combine them into a single matrix with the roots on the diagonal.
        adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials(
            arcs_bxnxn + sources_bxnxn, roots_bxn)

        # The adjacency matrix currently has sources on rows and targets on columns,
        # but we want targets on rows so that maximizing within a row corresponds to
        # selecting sources for a given target.
        adjacency_bxnxn = tf.matrix_transpose(adjacency_bxnxn)

        return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])]