def get_positional_index(self) -> tf.Tensor: """Gets the positional index for this LeafNodeTensor. The positional index tells us which index of the parent an element is. For example, with the following parent indices: [0, 0, 2] we would have positional index: [ 0, # The 0th element of the 0th parent. 1, # The 1st element of the 0th parent. 0 # The 0th element of the 2nd parent. ]. For more information, view ops/run_length_before_op.cc This is the same for Child NodeTensors. Returns: A tensor of positional indices. """ return struct2tensor_ops.run_length_before(self.parent_index)
def get_positional_index(node: prensor.NodeTensor) -> tf.Tensor: if isinstance(node, (prensor.LeafNodeTensor, prensor.ChildNodeTensor)): return struct2tensor_ops.run_length_before(node.parent_index) # RootNodeTensor return tf.range(node.size)
def test_run_length_before_empty(self): """Breaking down the broadcast.""" a = tf.constant([], dtype=tf.int64) b = struct2tensor_ops.run_length_before(a) self.assertAllEqual(b, [])