def _get_leaf_node_path(p: path.Path, t: prensor.Prensor) -> _LeafNodePath: """Creates a _LeafNodePath to p.""" leaf_node = t.get_descendant_or_error(p).node if not isinstance(leaf_node, prensor.LeafNodeTensor): raise ValueError("Expected Leaf Node at {} in {}".format( str(p), str(t))) if not p: raise ValueError("Leaf should not be at the root") # If there is a leaf at the root, this will return a ValueError. root_node = _as_root_node_tensor(t.node) # Not the root, not p. strict_ancestor_paths = [p.prefix(i) for i in range(1, len(p))] child_node_pairs = [(t.get_descendant_or_error(ancestor).node, ancestor) for ancestor in strict_ancestor_paths] bad_struct_paths = [ ancestor for node, ancestor in child_node_pairs if not isinstance(node, prensor.ChildNodeTensor) ] if bad_struct_paths: raise ValueError("Expected ChildNodeTensor at {} in {}".format( " ".join([str(x) for x in bad_struct_paths]), str(t))) # This should select all elements: the isinstance is for type-checking. child_nodes = [ node for node, ancestor in child_node_pairs if isinstance(node, prensor.ChildNodeTensor) ] assert len(child_nodes) == len(child_node_pairs) return _LeafNodePath(root_node, child_nodes, leaf_node)
def create_expression_from_prensor( t: prensor.Prensor) -> expression.Expression: """Gets an expression representing the prensor. Args: t: The prensor to represent. Returns: An expression representing the prensor. """ node_tensor = t.node children = { k: create_expression_from_prensor(v) for k, v in t.get_children().items() } if isinstance(node_tensor, prensor.RootNodeTensor): return _DirectExpression(True, None, node_tensor, children) elif isinstance(node_tensor, prensor.ChildNodeTensor): return _DirectExpression(node_tensor.is_repeated, None, node_tensor, children) else: # isinstance(node_tensor, LeafNodeTensor) return _DirectExpression(node_tensor.is_repeated, node_tensor.values.dtype, node_tensor, children)
def _get_leaf_node_paths( t: prensor.Prensor) -> Mapping[path.Path, _LeafNodePath]: """Gets a map of paths to leaf nodes in the expression.""" return { k: _get_leaf_node_path(k, t) for k, v in t.get_descendants().items() if isinstance(v.node, prensor.LeafNodeTensor) }
def _prensor_to_structured_tensor_helper( p: prensor.Prensor, nrows: tf.Tensor ) -> Union[tf.RaggedTensor, structured_tensor.StructuredTensor]: """Convert a prensor to a structured tensor with a certain number of rows.""" node = p.node if isinstance(node, prensor.LeafNodeTensor): return _leaf_node_to_ragged_tensor(node, nrows) assert isinstance(node, prensor.ChildNodeTensor) return _child_node_to_structured_tensor( node, _prensor_to_field_map(p.get_children(), node.size), nrows)
def new_op(tree: prensor.Prensor, options: calculate_options.Options) -> prensor.LeafNodeTensor: """Apply operation to tree.""" ragged_tensor_map = tree.get_ragged_tensors(options) ragged_tensors = [ragged_tensor_map[p] for p in paths] result_as_tensor = operation(*ragged_tensors) result = _ragged_as_leaf_node(result_as_tensor, is_repeated, ragged_tensors[0], options) if result.values.dtype != dtype: raise ValueError( "Type unmatched: actual ({})!= expected ({})".format( str(result.values.dtype), str(dtype))) return result
def new_op(pren: prensor.Prensor, options: calculate_options.Options) -> prensor.LeafNodeTensor: """Op for mapping prensor using the operation.""" sparse_tensor_map = pren.get_sparse_tensors(options) sparse_tensors = [sparse_tensor_map[p] for p in paths] result_as_tensor = operation(*sparse_tensors) result = _as_leaf_node(result_as_tensor, is_repeated, sparse_tensors[0].dense_shape[0], options) if result.values.dtype != dtype: raise ValueError( "Type unmatched: actual ({})!= expected ({})".format( str(result.values.dtype), str(dtype))) return result
def prensor_to_structured_tensor( p: prensor.Prensor) -> structured_tensor.StructuredTensor: """Creates a structured tensor from a prensor. All information about optional and repeated fields is dropped. If the field names in the proto do not meet the specifications for StructuredTensor, the behavior is undefined. Args: p: the prensor to convert. Returns: An equivalent StructuredTensor. Raises: ValueError: if the root of the prensor is not a RootNodeTensor. """ node = p.node if isinstance(node, prensor.RootNodeTensor): return _root_node_to_structured_tensor( _prensor_to_field_map(p.get_children(), node.size)) raise ValueError("Must be a root prensor")