def testLabelPotentialsFromTokenPairs(self):
    with self.test_session():
      sources = tf.constant([[[1, 2],
                              [3, 4],
                              [5, 6]],
                             [[6, 5],
                              [4, 3],
                              [2, 1]]], tf.float32)
      targets = tf.constant([[[3, 4],
                              [5, 6],
                              [7, 8]],
                             [[8, 7],
                              [6, 5],
                              [4, 3]]], tf.float32)


      weights = tf.constant([[[ 2,  3],
                              [ 5,  7]],
                             [[11, 13],
                              [17, 19]],
                             [[23, 29],
                              [31, 37]]], tf.float32)

      labels = digraph_ops.LabelPotentialsFromTokenPairs(sources, targets,
                                                         weights)

      self.assertAllEqual(labels.eval(),

                          [[[ 104,  339,  667],
                            [ 352, 1195, 2375],
                            [ 736, 2531, 5043]],
                           [[ 667, 2419, 4857],
                            [ 303, 1115, 2245],
                            [  75,  291,  593]]])
Esempio n. 2
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride,
            'BiaffineLabelNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # TODO(googleuser): Add dropout during training.
        del during_training

        # Retrieve (possibly averaged) weights.
        weights_pair = self._component.get_variable('weights_pair')
        weights_source = self._component.get_variable('weights_source')
        weights_target = self._component.get_variable('weights_target')
        biases = self._component.get_variable('biases')

        # Extract and shape the source and target token activations.  Use |stride|
        # to collapse batch and beam into a single dimension.
        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        sources_bxnxs = tf.reshape(sources.tensor,
                                   [stride, -1, self._source_dim])
        targets_bxnxt = tf.reshape(targets.tensor,
                                   [stride, -1, self._target_dim])

        # Compute the pair, source, and target potentials.
        pairs_bxnxl = digraph_ops.LabelPotentialsFromTokenPairs(
            sources_bxnxs, targets_bxnxt, weights_pair)
        sources_bxnxl = digraph_ops.LabelPotentialsFromTokens(
            sources_bxnxs, weights_source)
        targets_bxnxl = digraph_ops.LabelPotentialsFromTokens(
            targets_bxnxt, weights_target)

        # Combine them with the biases.
        labels_bxnxl = pairs_bxnxl + sources_bxnxl + targets_bxnxl + biases

        # Flatten out the batch dimension.
        return [tf.reshape(labels_bxnxl, [-1, self._num_labels])]