Example #1
0
  def FProp(self, theta, inputs, paddings, state0=None, labels=None):
    """Computes xent loss given the language model input activations.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: Input ids. An int32 tensor of shape [time, batch].
      paddings: A 0/1 tensor of shape [time, batch].
      state0: Not used for Transformer.
      labels: If not None, a `.NestedMap` containing the following fields:  -
        class_weights, a tensor with shape [time, batch] containing the weights
        for each target word. - class_ids, a tensor with shape [time, batch] of
        int32 dtype containing the target class labels. - class_probabilities, a
        tensor with shape [time, batch, vocab_size] of float values indicating
        class-membership probabilities.

    Returns:
      If `labels` is not None, returns (xent_output, state1), where
      `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return
      value and `state1` is the next recurrent state. Otherwise,
      `xent_output` only contains the softmax logits.
    """
    p = self.params
    ids = py_utils.HasRank(inputs, 2)
    paddings = py_utils.HasShape(paddings, tf.shape(ids))
    per_example_xent, logits = self.stack.FProp(
        theta.stack, ids, paddings, None, None, None, None,
        tf.cast(labels.class_ids, py_utils.FPropDtype(p)), labels.class_weights)
    per_example_argmax = py_utils.ArgMax(logits)
    total_xent = tf.reduce_sum(per_example_xent * labels.class_weights)
    total_weights = tf.reduce_sum(labels.class_weights)
    xent_output = py_utils.NestedMap(
        total_weight=total_weights,
        per_example_xent=per_example_xent,
        logits=logits,
        per_example_argmax=per_example_argmax,
        avg_xent=total_xent / total_weights,
        total_xent=total_xent)
    return xent_output, {}
Example #2
0
 def Compute(x):
   with self.session(graph=tf.Graph()) as sess:
     x = tf.constant(x)
     y = py_utils.ArgMax(x)
     return sess.run([x, y])