コード例 #1
0
ファイル: check_test.py プロジェクト: vincentcheny/models
 def testCheckNotNone(self):
     check.NotNone(1, 'foo')
     check.NotNone([], 'foo')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.NotNone(None, 'bar')
     with self.assertRaisesRegexp(RuntimeError, 'baz'):
         check.NotNone(None, 'baz', RuntimeError)
コード例 #2
0
def RootPotentialsFromTokens(root, tokens, weights):
    r"""Returns root selection potentials computed from tokens and weights.

  For each batch of token activations, computes a scalar potential for each root
  selection as the 3-way product between the activations of the artificial root
  token, the token activations, and the |weights|.  Specifically,

    roots[b,r] = \sum_{i,j} root[i] * weights[i,j] * tokens[b,r,j]

  Args:
    root: [S] vector of activations for the artificial root token.
    tokens: [B,N,T] tensor of batched activations for root tokens.
    weights: [S,T] matrix of weights.

    B,N may be statically-unknown, but S,T must be statically-known.  The dtype
    of all arguments must be compatible.

  Returns:
    [B,N] matrix R of root-selection potentials as defined above.  The dtype of
    R is the same as that of the arguments.
  """
    # All arguments must have statically-known rank.
    check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
    check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
    check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

    # All activation dimensions must be statically-known.
    num_source_activations = weights.get_shape().as_list()[0]
    num_target_activations = weights.get_shape().as_list()[1]
    check.NotNone(num_source_activations,
                  'unknown source activation dimension')
    check.NotNone(num_target_activations,
                  'unknown target activation dimension')
    check.Eq(root.get_shape().as_list()[0], num_source_activations,
             'dimension mismatch between weights and root')
    check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
             'dimension mismatch between weights and tokens')

    # All arguments must share the same type.
    check.Same([
        weights.dtype.base_dtype, root.dtype.base_dtype,
        tokens.dtype.base_dtype
    ], 'dtype mismatch')

    root_1xs = tf.expand_dims(root, 0)

    tokens_shape = tf.shape(tokens)
    batch_size = tokens_shape[0]
    num_tokens = tokens_shape[1]

    # Flatten out the batch dimension so we can use a couple big matmuls.
    tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
    weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True)
    roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)

    # Restore the batch dimension in the output.
    roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
    return roots_bxn
コード例 #3
0
ファイル: component.py プロジェクト: zhoukiller/models
  def get_variable(self, var_name=None, var_params=None):
    """Returns either the original or averaged version of a given variable.

    If the master.read_from_avg flag is set to True, and the
    ExponentialMovingAverage (EMA) object has been attached, then this will ask
    the EMA object for the given variable.

    This is to allow executing inference from the averaged version of
    parameters.

    Arguments:
      var_name: Name of the variable.
      var_params: tf.Variable for which to retrieve an average.

    Only one of |var_name| or |var_params| needs to be provided.  If both are
    provided, |var_params| takes precedence.

    Returns:
      tf.Variable object corresponding to original or averaged version.
    """
    if var_params:
      var_name = var_params.name
    else:
      check.NotNone(var_name, 'specify at least one of var_name or var_params')
      var_params = tf.get_variable(var_name)

    if self.moving_average and self.master.read_from_avg:
      logging.info('Retrieving average for: %s', var_name)
      var_params = self.moving_average.average(var_params)
      assert var_params
    logging.info('Returning: %s', var_params.name)
    return var_params
コード例 #4
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Forwards the lengths and scores."""
        check.NotNone(stride, 'MstSolverNetwork requires stride')

        lengths = network_units.lookup_named_tensor('lengths',
                                                    linked_embeddings)
        lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))

        scores = network_units.lookup_named_tensor('scores', linked_embeddings)
        scores_bnxn = scores.tensor
        max_length = tf.shape(scores_bnxn)[1]
        scores_bxnxn = tf.reshape(scores_bnxn,
                                  [stride, max_length, max_length])

        _, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
            forest=self._attrs['forest'],
            num_nodes=lengths_b,
            scores=scores_bxnxn)
        argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
        arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)

        return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]
コード例 #5
0
ファイル: mst_ops.py プロジェクト: vincentcheny/models
def maximum_spanning_tree_gradient(mst_op, d_loss_d_max_scores, *_):
    """Returns a subgradient of the MaximumSpanningTree op.

  Note that MaximumSpanningTree is only differentiable w.r.t. its |scores| input
  and its |max_scores| output.

  Args:
    mst_op: The MaximumSpanningTree op being differentiated.
    d_loss_d_max_scores: [B] vector where entry b is the gradient of the network
                         loss w.r.t. entry b of the |max_scores| output of the
                         |mst_op|.
    *_: The gradients w.r.t. the other outputs; ignored.

  Returns:
    1. None, since the op is not differentiable w.r.t. its |num_nodes| input.
    2. [B,M,M] tensor where entry b,t,s is a subgradient of the network loss
       w.r.t. entry b,t,s of the |scores| input, with the same dtype as
       |d_loss_d_max_scores|.
  """
    dtype = d_loss_d_max_scores.dtype.base_dtype
    check.NotNone(dtype)

    argmax_sources_bxm = mst_op.outputs[1]
    input_dim = tf.shape(argmax_sources_bxm)[1]  # M in the docstring

    # The one-hot argmax is a subgradient of max.  Convert the batch of maximal
    # spanning trees into 0/1 indicators, then scale them by the relevant output
    # gradients from |d_loss_d_max_scores|.  Note that |d_loss_d_max_scores| must
    # be reshaped in order for it to broadcast across the batch dimension.
    indicators_bxmxm = tf.one_hot(argmax_sources_bxm, input_dim, dtype=dtype)
    d_loss_d_max_scores_bx1 = tf.expand_dims(d_loss_d_max_scores, -1)
    d_loss_d_max_scores_bx1x1 = tf.expand_dims(d_loss_d_max_scores_bx1, -1)
    d_loss_d_scores_bxmxm = indicators_bxmxm * d_loss_d_max_scores_bx1x1
    return None, d_loss_d_scores_bxmxm
コード例 #6
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride, 'BulkBiLSTMNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # Flatten the lengths into a vector.
        lengths = dragnn.lookup_named_tensor('lengths', linked_embeddings)
        lengths_s = tf.squeeze(lengths.tensor, [1])

        # Collect all other inputs into a batched tensor.
        linked_embeddings = [
            named_tensor for named_tensor in linked_embeddings
            if named_tensor.name != 'lengths'
        ]
        inputs_sxnxd = dragnn.get_input_tensor_with_stride(
            fixed_embeddings, linked_embeddings, stride)

        # Since get_input_tensor_with_stride() concatenates the input embeddings, it
        # obscures the static activation dimension, which the RNN library requires.
        # Restore it using set_shape().  Note that set_shape() merges into the known
        # shape, so only specify the activation dimension.
        inputs_sxnxd.set_shape(
            [tf.Dimension(None),
             tf.Dimension(None), self._input_dim])

        initial_states_forward, initial_states_backward = (
            self._create_initial_states(stride))

        if during_training:
            cells_forward = self._train_cells_forward
            cells_backward = self._train_cells_backward
        else:
            cells_forward = self._inference_cells_forward
            cells_backward = self._inference_cells_backward

        def _bilstm_closure(scope):
            """Applies the bi-LSTM to the current inputs."""
            outputs_sxnxd, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
                cells_forward,
                cells_backward,
                inputs_sxnxd,
                initial_states_fw=initial_states_forward,
                initial_states_bw=initial_states_backward,
                sequence_length=lengths_s,
                parallel_iterations=self._attrs['parallel_iterations'],
                scope=scope)
            return outputs_sxnxd

        # Layer outputs are not batched; flatten out the batch dimension.
        outputs_sxnxd = self._apply_with_captured_variables(_bilstm_closure)
        outputs_snxd = tf.reshape(outputs_sxnxd, [-1, self._output_dim])
        return self._append_base_layers([outputs_snxd])
コード例 #7
0
def LabelPotentialsFromTokens(tokens, weights):
    r"""Computes label potentials from tokens and weights.

  For each batch of token activations, computes a scalar potential for each
  label as the product between the activations of the source token and the
  |weights|.  Specifically,

    labels[b,t,l] = \sum_{i} weights[l,i] * tokens[b,t,i]

  Args:
    tokens: [B,N,T] tensor of batched token activations.
    weights: [L,T] matrix of weights.

    B,N may be dynamic, but L,T must be static.  The dtype of all arguments must
    be compatible.

  Returns:
    [B,N,L] tensor of label potentials as defined above, with the same dtype as
    the arguments.
  """
    check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
    check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

    num_labels = weights.get_shape().as_list()[0]
    num_activations = weights.get_shape().as_list()[1]
    check.NotNone(num_labels, 'unknown number of labels')
    check.NotNone(num_activations, 'unknown activation dimension')
    check.Eq(tokens.get_shape().as_list()[2], num_activations,
             'activation mismatch between weights and tokens')
    tokens_shape = tf.shape(tokens)
    batch_size = tokens_shape[0]
    num_tokens = tokens_shape[1]

    check.Same([tokens.dtype.base_dtype, weights.dtype.base_dtype],
               'dtype mismatch')

    # Flatten out the batch dimension so we can use one big matmul().
    tokens_bnxt = tf.reshape(tokens, [-1, num_activations])
    labels_bnxl = tf.matmul(tokens_bnxt, weights, transpose_b=True)

    # Restore the batch dimension in the output.
    labels_bxnxl = tf.reshape(labels_bnxl,
                              [batch_size, num_tokens, num_labels])
    return labels_bxnxl
コード例 #8
0
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride,
            'BiaffineDigraphNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # TODO(googleuser): Add dropout during training.
        del during_training

        # Retrieve (possibly averaged) weights.
        weights_arc = self._component.get_variable('weights_arc')
        weights_source = self._component.get_variable('weights_source')
        root = self._component.get_variable('root')

        # Extract the source and target token activations.  Use |stride| to collapse
        # batch and beam into a single dimension.
        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        source_tokens_bxnxs = tf.reshape(sources.tensor,
                                         [stride, -1, self._source_dim])
        target_tokens_bxnxt = tf.reshape(targets.tensor,
                                         [stride, -1, self._target_dim])
        num_tokens = tf.shape(source_tokens_bxnxs)[1]

        # Compute the arc, source, and root potentials.
        arcs_bxnxn = digraph_ops.ArcPotentialsFromTokens(
            source_tokens_bxnxs, target_tokens_bxnxt, weights_arc)
        sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens(
            source_tokens_bxnxs, weights_source)
        roots_bxn = digraph_ops.RootPotentialsFromTokens(
            root, target_tokens_bxnxt, weights_arc, weights_source)

        # Combine them into a single matrix with the roots on the diagonal.
        adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials(
            arcs_bxnxn + sources_bxnxn, roots_bxn)

        # The adjacency matrix currently has sources on rows and targets on columns,
        # but we want targets on rows so that maximizing within a row corresponds to
        # selecting sources for a given target.
        adjacency_bxnxn = tf.matrix_transpose(adjacency_bxnxn)

        return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])]
コード例 #9
0
def ArcSourcePotentialsFromTokens(tokens, weights):
    r"""Returns arc source potentials computed from tokens and weights.

  For each batch of token activations, computes a scalar potential for each arc
  as the product between the activations of the source token and the |weights|.
  Specifically,

    arc[b,s,:] = \sum_{i} weights[i] * tokens[b,s,i]

  Args:
    tokens: [B,N,S] tensor of batched activations for source tokens.
    weights: [S] vector of weights.

    B,N may be statically-unknown, but S must be statically-known.  The dtype of
    all arguments must be compatible.

  Returns:
    [B,N,N] tensor A of arc potentials as defined above.  The dtype of A is the
    same as that of the arguments.  Note that the diagonal entries (i.e., where
    s==t) represent self-loops and may not be meaningful.
  """
    # All arguments must have statically-known rank.
    check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
    check.Eq(weights.get_shape().ndims, 1, 'weights must be a vector')

    # All activation dimensions must be statically-known.
    num_source_activations = weights.get_shape().as_list()[0]
    check.NotNone(num_source_activations,
                  'unknown source activation dimension')
    check.Eq(tokens.get_shape().as_list()[2], num_source_activations,
             'dimension mismatch between weights and tokens')

    # All arguments must share the same type.
    check.Same([weights.dtype.base_dtype, tokens.dtype.base_dtype],
               'dtype mismatch')

    tokens_shape = tf.shape(tokens)
    batch_size = tokens_shape[0]
    num_tokens = tokens_shape[1]

    # Flatten out the batch dimension so we can use a couple big matmuls.
    tokens_bnxs = tf.reshape(tokens, [-1, num_source_activations])
    weights_sx1 = tf.expand_dims(weights, 1)
    sources_bnx1 = tf.matmul(tokens_bnxs, weights_sx1)
    sources_bnxn = tf.tile(sources_bnx1, [1, num_tokens])

    # Restore the batch dimension in the output.
    sources_bxnxn = tf.reshape(sources_bnxn,
                               [batch_size, num_tokens, num_tokens])
    return sources_bxnxn
コード例 #10
0
ファイル: biaffine_units.py プロジェクト: zhoukiller/models
    def create(self,
               fixed_embeddings,
               linked_embeddings,
               context_tensor_arrays,
               attention_tensor,
               during_training,
               stride=None):
        """Requires |stride|; otherwise see base class."""
        check.NotNone(
            stride,
            'BiaffineLabelNetwork requires "stride" and must be called '
            'in the bulk feature extractor component.')

        # TODO(googleuser): Add dropout during training.
        del during_training

        # Retrieve (possibly averaged) weights.
        weights_pair = self._component.get_variable('weights_pair')
        weights_source = self._component.get_variable('weights_source')
        weights_target = self._component.get_variable('weights_target')
        biases = self._component.get_variable('biases')

        # Extract and shape the source and target token activations.  Use |stride|
        # to collapse batch and beam into a single dimension.
        sources = network_units.lookup_named_tensor('sources',
                                                    linked_embeddings)
        targets = network_units.lookup_named_tensor('targets',
                                                    linked_embeddings)
        sources_bxnxs = tf.reshape(sources.tensor,
                                   [stride, -1, self._source_dim])
        targets_bxnxt = tf.reshape(targets.tensor,
                                   [stride, -1, self._target_dim])

        # Compute the pair, source, and target potentials.
        pairs_bxnxl = digraph_ops.LabelPotentialsFromTokenPairs(
            sources_bxnxs, targets_bxnxt, weights_pair)
        sources_bxnxl = digraph_ops.LabelPotentialsFromTokens(
            sources_bxnxs, weights_source)
        targets_bxnxl = digraph_ops.LabelPotentialsFromTokens(
            targets_bxnxt, weights_target)

        # Combine them with the biases.
        labels_bxnxl = pairs_bxnxl + sources_bxnxl + targets_bxnxl + biases

        # Flatten out the batch dimension.
        return [tf.reshape(labels_bxnxl, [-1, self._num_labels])]
コード例 #11
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    check.NotNone(FLAGS.model_dir, '--model_dir is required')
    check.Ne(
        FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None,
        'Exactly one of --pretrain_steps or --pretrain_epochs is required')
    check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None,
             'Exactly one of --train_steps or --train_epochs is required')

    config_path = os.path.join(FLAGS.model_dir, 'config.txt')
    master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt')
    hyperparameters_path = os.path.join(FLAGS.model_dir,
                                        'hyperparameters.pbtxt')
    targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt')
    checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best')
    tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard')

    with tf.gfile.FastGFile(config_path) as config_file:
        config = collections.defaultdict(bool,
                                         ast.literal_eval(config_file.read()))
    train_corpus_path = config['train_corpus_path']
    tune_corpus_path = config['tune_corpus_path']
    projectivize_train_corpus = config['projectivize_train_corpus']

    master = _read_text_proto(master_path, spec_pb2.MasterSpec)
    hyperparameters = _read_text_proto(hyperparameters_path,
                                       spec_pb2.GridPoint)
    targets = spec_builder.default_targets_from_spec(master)
    if tf.gfile.Exists(targets_path):
        targets = _read_text_proto(targets_path,
                                   spec_pb2.TrainingGridSpec).target

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(hyperparameters.seed)
        builder = graph_builder.MasterBuilder(master, hyperparameters)
        trainers = [
            builder.add_training_from_config(target) for target in targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    train_corpus = sentence_io.ConllSentenceReader(
        train_corpus_path, projectivize=projectivize_train_corpus).corpus()
    tune_corpus = sentence_io.ConllSentenceReader(tune_corpus_path,
                                                  projectivize=False).corpus()
    gold_tune_corpus = tune_corpus

    # Convert to char-based corpora, if requested.
    if config['convert_to_char_corpora']:
        # NB: Do not convert the |gold_tune_corpus|, which should remain word-based
        # for segmentation evaluation purposes.
        train_corpus = _convert_to_char_corpus(train_corpus)
        tune_corpus = _convert_to_char_corpus(tune_corpus)

    pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs,
                                len(train_corpus))
    train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs,
                             len(train_corpus))
    check.Eq(len(targets), len(pretrain_steps),
             'Length mismatch between training targets and --pretrain_steps')
    check.Eq(len(targets), len(train_steps),
             'Length mismatch between training targets and --train_steps')

    # Ready to train!
    tf.logging.info('Training on %d sentences.', len(train_corpus))
    tf.logging.info('Tuning on %d sentences.', len(tune_corpus))

    tf.logging.info('Creating TensorFlow checkpoint dir...')
    summary_writer = trainer_lib.get_summary_writer(tensorboard_dir)

    checkpoint_dir = os.path.dirname(checkpoint_path)
    if tf.gfile.IsDirectory(checkpoint_dir):
        tf.gfile.DeleteRecursively(checkpoint_dir)
    elif tf.gfile.Exists(checkpoint_dir):
        tf.gfile.Remove(checkpoint_dir)
    tf.gfile.MakeDirs(checkpoint_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())
        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, train_corpus, tune_corpus,
                                 gold_tune_corpus, FLAGS.batch_size,
                                 summary_writer, FLAGS.report_every,
                                 builder.saver, checkpoint_path)

    tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
コード例 #12
0
def LabelPotentialsFromTokenPairs(sources, targets, weights):
    r"""Computes label potentials from source and target tokens and weights.

  For each aligned pair of source and target token activations, computes a
  scalar potential for each label on the arc from the source to the target.
  Specifically,

    labels[b,t,l] = \sum_{i,j} sources[b,t,i] * weights[l,i,j] * targets[b,t,j]

  Args:
    sources: [B,N,S] tensor of batched source token activations.
    targets: [B,N,T] tensor of batched target token activations.
    weights: [L,S,T] tensor of weights.

    B,N may be dynamic, but L,S,T must be static.  The dtype of all arguments
    must be compatible.

  Returns:
    [B,N,L] tensor of label potentials as defined above, with the same dtype as
    the arguments.
  """
    check.Eq(sources.get_shape().ndims, 3, 'sources must be rank 3')
    check.Eq(targets.get_shape().ndims, 3, 'targets must be rank 3')
    check.Eq(weights.get_shape().ndims, 3, 'weights must be rank 3')

    num_labels = weights.get_shape().as_list()[0]
    num_source_activations = weights.get_shape().as_list()[1]
    num_target_activations = weights.get_shape().as_list()[2]
    check.NotNone(num_labels, 'unknown number of labels')
    check.NotNone(num_source_activations,
                  'unknown source activation dimension')
    check.NotNone(num_target_activations,
                  'unknown target activation dimension')
    check.Eq(sources.get_shape().as_list()[2], num_source_activations,
             'activation mismatch between weights and source tokens')
    check.Eq(targets.get_shape().as_list()[2], num_target_activations,
             'activation mismatch between weights and target tokens')

    check.Same([
        sources.dtype.base_dtype, targets.dtype.base_dtype,
        weights.dtype.base_dtype
    ], 'dtype mismatch')

    sources_shape = tf.shape(sources)
    targets_shape = tf.shape(targets)
    batch_size = sources_shape[0]
    num_tokens = sources_shape[1]
    with tf.control_dependencies([
            tf.assert_equal(batch_size, targets_shape[0]),
            tf.assert_equal(num_tokens, targets_shape[1])
    ]):
        # For each token, we must compute a vector-3tensor-vector product.  There is
        # no op for this, but we can use reshape() and matmul() to compute it.

        # Reshape |weights| and |targets| so we can use a single matmul().
        weights_lsxt = tf.reshape(
            weights,
            [num_labels * num_source_activations, num_target_activations])
        targets_bnxt = tf.reshape(targets, [-1, num_target_activations])
        weights_targets_bnxls = tf.matmul(targets_bnxt,
                                          weights_lsxt,
                                          transpose_b=True)

        # Restore all dimensions.
        weights_targets_bxnxlxs = tf.reshape(
            weights_targets_bnxls,
            [batch_size, num_tokens, num_labels, num_source_activations])

        # Incorporate the source activations.  In this case, we perform a batched
        # matmul() between the trailing [L,S] matrices of the current result and the
        # trailing [S] vectors of the tokens.
        sources_bxnx1xs = tf.expand_dims(sources, 2)
        labels_bxnxlx1 = tf.matmul(weights_targets_bxnxlxs,
                                   sources_bxnx1xs,
                                   transpose_b=True)
        labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3])
        return labels_bxnxl
コード例 #13
0
def ArcPotentialsFromTokens(source_tokens, target_tokens, weights):
    r"""Returns arc potentials computed from token activations and weights.

  For each batch of source and target token activations, computes a scalar
  potential for each arc as the 3-way product between the activation vectors of
  the source and target of the arc and the |weights|.  Specifically,

    arc[b,s,t] =
        \sum_{i,j} source_tokens[b,s,i] * weights[i,j] * target_tokens[b,t,j]

  Note that the token activations can be extended with bias terms to implement a
  "biaffine" model (Dozat and Manning, 2017).

  Args:
    source_tokens: [B,N,S] tensor of batched activations for the source token in
                   each arc.
    target_tokens: [B,N,T] tensor of batched activations for the target token in
                   each arc.
    weights: [S,T] matrix of weights.

    B,N may be statically-unknown, but S,T must be statically-known.  The dtype
    of all arguments must be compatible.

  Returns:
    [B,N,N] tensor A of arc potentials where A_{b,s,t} is the potential of the
    arc from s to t in batch element b.  The dtype of A is the same as that of
    the arguments.  Note that the diagonal entries (i.e., where s==t) represent
    self-loops and may not be meaningful.
  """
    # All arguments must have statically-known rank.
    check.Eq(source_tokens.get_shape().ndims, 3,
             'source_tokens must be rank 3')
    check.Eq(target_tokens.get_shape().ndims, 3,
             'target_tokens must be rank 3')
    check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

    # All activation dimensions must be statically-known.
    num_source_activations = weights.get_shape().as_list()[0]
    num_target_activations = weights.get_shape().as_list()[1]
    check.NotNone(num_source_activations,
                  'unknown source activation dimension')
    check.NotNone(num_target_activations,
                  'unknown target activation dimension')
    check.Eq(source_tokens.get_shape().as_list()[2], num_source_activations,
             'dimension mismatch between weights and source_tokens')
    check.Eq(target_tokens.get_shape().as_list()[2], num_target_activations,
             'dimension mismatch between weights and target_tokens')

    # All arguments must share the same type.
    check.Same([
        weights.dtype.base_dtype, source_tokens.dtype.base_dtype,
        target_tokens.dtype.base_dtype
    ], 'dtype mismatch')

    source_tokens_shape = tf.shape(source_tokens)
    target_tokens_shape = tf.shape(target_tokens)
    batch_size = source_tokens_shape[0]
    num_tokens = source_tokens_shape[1]
    with tf.control_dependencies([
            tf.assert_equal(batch_size, target_tokens_shape[0]),
            tf.assert_equal(num_tokens, target_tokens_shape[1])
    ]):
        # Flatten out the batch dimension so we can use one big multiplication.
        targets_bnxt = tf.reshape(target_tokens, [-1, num_target_activations])

        # Matrices are row-major, so we arrange for the RHS argument of each matmul
        # to have its transpose flag set.  That way no copying is required to align
        # the rows of the LHS with the columns of the RHS.
        weights_targets_bnxs = tf.matmul(targets_bnxt,
                                         weights,
                                         transpose_b=True)

        # The next computation is over pairs of tokens within each batch element, so
        # restore the batch dimension.
        weights_targets_bxnxs = tf.reshape(
            weights_targets_bnxs,
            [batch_size, num_tokens, num_source_activations])

        # Note that this multiplication is repeated across the batch dimension,
        # instead of being one big multiplication as in the first matmul.  There
        # doesn't seem to be a way to arrange this as a single multiplication given
        # the pairwise nature of this computation.
        arcs_bxnxn = tf.matmul(source_tokens,
                               weights_targets_bxnxs,
                               transpose_b=True)
        return arcs_bxnxn