Пример #1
0
def _get_leaf_node_path(p: path.Path, t: prensor.Prensor) -> _LeafNodePath:
    """Creates a _LeafNodePath to p."""
    leaf_node = t.get_descendant_or_error(p).node
    if not isinstance(leaf_node, prensor.LeafNodeTensor):
        raise ValueError("Expected Leaf Node at {} in {}".format(
            str(p), str(t)))
    if not p:
        raise ValueError("Leaf should not be at the root")
    # If there is a leaf at the root, this will return a ValueError.
    root_node = _as_root_node_tensor(t.node)

    # Not the root, not p.
    strict_ancestor_paths = [p.prefix(i) for i in range(1, len(p))]

    child_node_pairs = [(t.get_descendant_or_error(ancestor).node, ancestor)
                        for ancestor in strict_ancestor_paths]
    bad_struct_paths = [
        ancestor for node, ancestor in child_node_pairs
        if not isinstance(node, prensor.ChildNodeTensor)
    ]
    if bad_struct_paths:
        raise ValueError("Expected ChildNodeTensor at {} in {}".format(
            " ".join([str(x) for x in bad_struct_paths]), str(t)))
    # This should select all elements: the isinstance is for type-checking.
    child_nodes = [
        node for node, ancestor in child_node_pairs
        if isinstance(node, prensor.ChildNodeTensor)
    ]
    assert len(child_nodes) == len(child_node_pairs)
    return _LeafNodePath(root_node, child_nodes, leaf_node)
Пример #2
0
def create_expression_from_prensor(
        t: prensor.Prensor) -> expression.Expression:
    """Gets an expression representing the prensor.

  Args:
    t: The prensor to represent.

  Returns:
    An expression representing the prensor.
  """
    node_tensor = t.node
    children = {
        k: create_expression_from_prensor(v)
        for k, v in t.get_children().items()
    }
    if isinstance(node_tensor, prensor.RootNodeTensor):
        return _DirectExpression(True, None, node_tensor, children)
    elif isinstance(node_tensor, prensor.ChildNodeTensor):
        return _DirectExpression(node_tensor.is_repeated, None, node_tensor,
                                 children)
    else:
        # isinstance(node_tensor, LeafNodeTensor)
        return _DirectExpression(node_tensor.is_repeated,
                                 node_tensor.values.dtype, node_tensor,
                                 children)
Пример #3
0
def _get_leaf_node_paths(
        t: prensor.Prensor) -> Mapping[path.Path, _LeafNodePath]:
    """Gets a map of paths to leaf nodes in the expression."""
    return {
        k: _get_leaf_node_path(k, t)
        for k, v in t.get_descendants().items()
        if isinstance(v.node, prensor.LeafNodeTensor)
    }
Пример #4
0
def _prensor_to_structured_tensor_helper(
    p: prensor.Prensor, nrows: tf.Tensor
) -> Union[tf.RaggedTensor, structured_tensor.StructuredTensor]:
    """Convert a prensor to a structured tensor with a certain number of rows."""
    node = p.node
    if isinstance(node, prensor.LeafNodeTensor):
        return _leaf_node_to_ragged_tensor(node, nrows)
    assert isinstance(node, prensor.ChildNodeTensor)
    return _child_node_to_structured_tensor(
        node, _prensor_to_field_map(p.get_children(), node.size), nrows)
Пример #5
0
 def new_op(tree: prensor.Prensor,
            options: calculate_options.Options) -> prensor.LeafNodeTensor:
     """Apply operation to tree."""
     ragged_tensor_map = tree.get_ragged_tensors(options)
     ragged_tensors = [ragged_tensor_map[p] for p in paths]
     result_as_tensor = operation(*ragged_tensors)
     result = _ragged_as_leaf_node(result_as_tensor, is_repeated,
                                   ragged_tensors[0], options)
     if result.values.dtype != dtype:
         raise ValueError(
             "Type unmatched: actual ({})!= expected ({})".format(
                 str(result.values.dtype), str(dtype)))
     return result
Пример #6
0
 def new_op(pren: prensor.Prensor,
            options: calculate_options.Options) -> prensor.LeafNodeTensor:
     """Op for mapping prensor using the operation."""
     sparse_tensor_map = pren.get_sparse_tensors(options)
     sparse_tensors = [sparse_tensor_map[p] for p in paths]
     result_as_tensor = operation(*sparse_tensors)
     result = _as_leaf_node(result_as_tensor, is_repeated,
                            sparse_tensors[0].dense_shape[0], options)
     if result.values.dtype != dtype:
         raise ValueError(
             "Type unmatched: actual ({})!= expected ({})".format(
                 str(result.values.dtype), str(dtype)))
     return result
Пример #7
0
def prensor_to_structured_tensor(
        p: prensor.Prensor) -> structured_tensor.StructuredTensor:
    """Creates a structured tensor from a prensor.

  All information about optional and repeated fields is dropped.
  If the field names in the proto do not meet the specifications for
  StructuredTensor, the behavior is undefined.

  Args:
    p: the prensor to convert.

  Returns:
    An equivalent StructuredTensor.

  Raises:
    ValueError: if the root of the prensor is not a RootNodeTensor.
  """
    node = p.node
    if isinstance(node, prensor.RootNodeTensor):
        return _root_node_to_structured_tensor(
            _prensor_to_field_map(p.get_children(), node.size))
    raise ValueError("Must be a root prensor")