Exemplo n.º 1
0
    def get_positional_index(self) -> tf.Tensor:
        """Gets the positional index for this LeafNodeTensor.

    The positional index tells us which index of the parent an element is.

    For example, with the following parent indices: [0, 0, 2]
    we would have positional index:
    [
      0, # The 0th element of the 0th parent.
      1, # The 1st element of the 0th parent.
      0  # The 0th element of the 2nd parent.
    ].

    For more information, view ops/run_length_before_op.cc

    This is the same for Child NodeTensors.

    Returns:
      A tensor of positional indices.
    """
        return struct2tensor_ops.run_length_before(self.parent_index)
Exemplo n.º 2
0
def get_positional_index(node: prensor.NodeTensor) -> tf.Tensor:
    if isinstance(node, (prensor.LeafNodeTensor, prensor.ChildNodeTensor)):
        return struct2tensor_ops.run_length_before(node.parent_index)
    # RootNodeTensor
    return tf.range(node.size)
Exemplo n.º 3
0
 def test_run_length_before_empty(self):
     """Breaking down the broadcast."""
     a = tf.constant([], dtype=tf.int64)
     b = struct2tensor_ops.run_length_before(a)
     self.assertAllEqual(b, [])