def FProp(self, theta, inputs, paddings, state0=None, labels=None): """Computes xent loss given the language model input activations. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: Input ids. An int32 tensor of shape [time, batch]. paddings: A 0/1 tensor of shape [time, batch]. state0: Not used for Transformer. labels: If not None, a `.NestedMap` containing the following fields: - class_weights, a tensor with shape [time, batch] containing the weights for each target word. - class_ids, a tensor with shape [time, batch] of int32 dtype containing the target class labels. - class_probabilities, a tensor with shape [time, batch, vocab_size] of float values indicating class-membership probabilities. Returns: If `labels` is not None, returns (xent_output, state1), where `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return value and `state1` is the next recurrent state. Otherwise, `xent_output` only contains the softmax logits. """ p = self.params ids = py_utils.HasRank(inputs, 2) paddings = py_utils.HasShape(paddings, tf.shape(ids)) per_example_xent, logits = self.stack.FProp( theta.stack, ids, paddings, None, None, None, None, tf.cast(labels.class_ids, py_utils.FPropDtype(p)), labels.class_weights) per_example_argmax = py_utils.ArgMax(logits) total_xent = tf.reduce_sum(per_example_xent * labels.class_weights) total_weights = tf.reduce_sum(labels.class_weights) xent_output = py_utils.NestedMap( total_weight=total_weights, per_example_xent=per_example_xent, logits=logits, per_example_argmax=per_example_argmax, avg_xent=total_xent / total_weights, total_xent=total_xent) return xent_output, {}
def Compute(x): with self.session(graph=tf.Graph()) as sess: x = tf.constant(x) y = py_utils.ArgMax(x) return sess.run([x, y])