def map_many_values( root: expression.Expression, parent_path: path.Path, source_fields: Sequence[path.Step], operation: Callable[..., tf.Tensor], dtype: tf.DType, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Map multiple sibling fields into a new sibling. All source fields must have the same shape, and the shape of the output must be the same as well. Args: root: original root. parent_path: parent path of all sources and the new field. source_fields: source fields of the operation. Must have the same shape. operation: operation from source_fields to new field. dtype: type of new field. new_field_name: name of the new field. Returns: The new expression and the new path as a pair. """ new_path = parent_path.get_child(new_field_name) return expression_add.add_paths( root, { new_path: _MapValuesExpression([ root.get_descendant_or_error(parent_path.get_child(f)) for f in source_fields ], operation, dtype) }), new_path
def _broadcast_impl( root: expression.Expression, origin: path.Path, sibling: path.Step, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: sibling_path = origin.get_parent().get_child(sibling) new_expr = _BroadcastExpression( root.get_descendant_or_error(origin), root.get_descendant_or_error(origin.get_parent().get_child(sibling))) new_path = sibling_path.get_child(new_field_name) return expression_add.add_paths(root, {new_path: new_expr}), new_path
def _promote_impl( root: expression.Expression, p: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Promotes a path to be a child of its grandparent, and gives it a name. Args: root: The root expression. p: The path to promote. This can be the path to a leaf or child node. new_field_name: The name of the promoted field. Returns: An _AddPathsExpression that wraps a PromoteExpression. """ if len(p) < 2: raise ValueError("Cannot do a promotion beyond the root: {}".format( str(p))) parent_path = p.get_parent() grandparent_path = parent_path.get_parent() p_expression = root.get_descendant_or_error(p) new_path = grandparent_path.get_child(new_field_name) if p_expression.is_leaf: promote_expression_factory = PromoteExpression else: promote_expression_factory = PromoteChildExpression return expression_add.add_paths( root, { new_path: promote_expression_factory( p_expression, root.get_descendant_or_error(parent_path)) }), new_path
def slice_expression(expr: expression.Expression, p: path.Path, new_field_name: path.Step, begin: Optional[IndexValue], end: Optional[IndexValue]) -> expression.Expression: """Creates a new subtree with a sliced expression. This follows the pattern of python slice() method. See module-level comments for examples. Args: expr: the original root expression p: the path to the source to be sliced. new_field_name: the name of the new subtree. begin: beginning index end: end index. Returns: A new root expression. """ work_expr, mask_anonymous_path = _get_slice_mask(expr, p, begin, end) work_expr = filter_expression.filter_by_sibling( work_expr, p, mask_anonymous_path.field_list[-1], new_field_name) new_path = p.get_parent().get_child(new_field_name) # We created a lot of anonymous fields and intermediate expressions. Just grab # the final result (and its children). return expression_add.add_to(expr, {new_path: work_expr})
def filter_by_sibling(expr: expression.Expression, p: path.Path, sibling_field_name: path.Step, new_field_name: path.Step) -> expression.Expression: """Filter an expression by its sibling. This is similar to boolean_mask. The shape of the path being filtered and the sibling must be identical (e.g., each parent object must have an equal number of source and sibling children). Args: expr: the root expression. p: a path to the source to be filtered. sibling_field_name: the sibling to use as a mask. new_field_name: a new sibling to create. Returns: a new root. """ origin = expr.get_descendant_or_error(p) parent_path = p.get_parent() sibling = expr.get_descendant_or_error( parent_path.get_child(sibling_field_name)) new_expr = _FilterBySiblingExpression(origin, sibling) new_path = parent_path.get_child(new_field_name) return expression_add.add_paths(expr, {new_path: new_expr})
def promote_and_broadcast( root: expression.Expression, path_dictionary: Mapping[path.Step, path.Path], dest_path_parent: path.Path) -> expression.Expression: """Promote and broadcast a set of paths to a particular location. Args: root: the original expression. path_dictionary: a map from destination fields to origin paths. dest_path_parent: a map from destination strings. Returns: A new expression, where all the origin paths are promoted and broadcast until they are children of dest_path_parent. """ result_paths = {} # Here, we branch out and create a different tree for each field that is # promoted and broadcast. for field_name, origin_path in path_dictionary.items(): result_path = dest_path_parent.get_child(field_name) new_root = _promote_and_broadcast_name(root, origin_path, dest_path_parent, field_name) result_paths[result_path] = new_root # We create a new tree that has all of the generated fields from the older # trees. return expression_add.add_to(root, result_paths)
def _get_leaf_node_path(p: path.Path, t: prensor.Prensor) -> _LeafNodePath: """Creates a _LeafNodePath to p.""" leaf_node = t.get_descendant_or_error(p).node if not isinstance(leaf_node, prensor.LeafNodeTensor): raise ValueError("Expected Leaf Node at {} in {}".format( str(p), str(t))) if not p: raise ValueError("Leaf should not be at the root") # If there is a leaf at the root, this will return a ValueError. root_node = _as_root_node_tensor(t.node) # Not the root, not p. strict_ancestor_paths = [p.prefix(i) for i in range(1, len(p))] child_node_pairs = [(t.get_descendant_or_error(ancestor).node, ancestor) for ancestor in strict_ancestor_paths] bad_struct_paths = [ ancestor for node, ancestor in child_node_pairs if not isinstance(node, prensor.ChildNodeTensor) ] if bad_struct_paths: raise ValueError("Expected ChildNodeTensor at {} in {}".format( " ".join([str(x) for x in bad_struct_paths]), str(t))) # This should select all elements: the isinstance is for type-checking. child_nodes = [ node for node, ancestor in child_node_pairs if isinstance(node, prensor.ChildNodeTensor) ] assert len(child_nodes) == len(child_node_pairs) return _LeafNodePath(root_node, child_nodes, leaf_node)
def get_index_from_end( t: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Gets the number of steps from the end of the array. Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this is [-3,-2,-1]. Args: t: original expression source_path: path in expression to get index of. new_field_name: the name of the new field. Returns: The new expression and the new path as a pair. """ new_path = source_path.get_parent().get_child(new_field_name) work_expr, positional_index_path = get_positional_index( t, source_path, path.get_anonymous_field()) work_expr, size_path = size.size_anonymous(work_expr, source_path) work_expr = expression_add.add_paths( work_expr, { new_path: _PositionalIndexFromEndExpression( work_expr.get_descendant_or_error(positional_index_path), work_expr.get_descendant_or_error(size_path)) }) # Removing the intermediate anonymous nodes. result = expression_add.add_to(t, {new_path: work_expr}) return result, new_path
def _promote_and_broadcast_name( root: expression.Expression, origin: path.Path, dest_path_parent: path.Path, field_name: path.Step) -> expression.Expression: new_root, anonymous_path = promote_and_broadcast_anonymous( root, origin, dest_path_parent) path_result = dest_path_parent.get_child(field_name) return expression_add.add_paths( new_root, {path_result: new_root.get_descendant_or_error(anonymous_path)})
def _broadcast_impl( root: expression.Expression, origin: path.Path, sibling: path.Step, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Broadcasts origin to sibling for an expression.""" sibling_path = origin.get_parent().get_child(sibling) origin_expression = root.get_descendant_or_error(origin) broadcast_expression_factory = (_BroadcastExpression if origin_expression.is_leaf else _BroadcastChildExpression) new_expr = broadcast_expression_factory( origin_expression, root.get_descendant_or_error(origin.get_parent().get_child(sibling))) new_path = sibling_path.get_child(new_field_name) result = expression_add.add_paths(root, {new_path: new_expr}) return result, new_path
def _get_slice_mask( expr: expression.Expression, p: path.Path, begin: Optional[IndexValue], end: Optional[IndexValue]) -> Tuple[expression.Expression, path.Path]: """Gets a mask for slicing a path. One way to consider the elements of a path "foo.bar" is as a list of list of list of elements. Slicing a path slices this doubly nested list of elements, based upon positions in its parent list. Each parent list has a size, and there is a beginning and end relative to the elements in that list. At each path p, there is conceptually a list of...list of elements. For example, given: an index with respect to its parent The range is specified with beginning and an end. 1. If begin is not present, begin_index is implied to be zero. 2. If begin is negative, begin_index is the size of a particular list + begin 3. If end is not present, end_index is the length of the list + 1. 4. If end is negative, end_index is the length of the list + end 5. If end is non-negative, end_index is end. The mask is positive for all elements in range(begin_index, end_index), and negative elsewhere. The mask returned is a sibling of path p, where for every element in p, there is a corresponding element in the mask. Args: expr: the root expression p: the path to be sliced begin: the beginning index end: the ending index Returns: An expression,path pair, where the expression contains all the children in `expr` and an anonymous field of the mask and the path points to the mask field. """ if begin is None: if end is None: raise ValueError("Must specify begin or end.") return _get_end_mask(expr, p, end) else: if end is None: return _get_begin_mask(expr, p, begin) work_expr, begin_mask = _get_begin_mask(expr, p, begin) work_expr, end_mask = _get_end_mask(work_expr, p, end) return map_values.map_many_values( work_expr, p.get_parent(), [x.field_list[-1] for x in [begin_mask, end_mask]], tf.logical_and, tf.bool, path.get_anonymous_field())
def _get_mask( t: expression.Expression, p: path.Path, threshold: IndexValue, relation: Callable[[tf.Tensor, IndexValue], tf.Tensor] ) -> Tuple[expression.Expression, path.Path]: """Gets a mask based on a relation of the index to a threshold. If the threshold is non-negative, then we create a mask that is true if the relation(index, threshold) is true. If the threshold is negative, then we create a mask that is true if the relation(size - index, threshold) is true. Args: t: expression to add the field to. p: path to create the mask for. threshold: the cutoff threshold. relation: tf.less or tf.greater_equal. Returns: A boolean mask on the fields to keep on the model. Raises: ValueError: if p is not in t. """ if t.get_descendant(p) is None: raise ValueError("Path not found: {}".format(str(p))) work_expr, index_from_end = index.get_index_from_end( t, p, path.get_anonymous_field()) work_expr, mask_for_negative_threshold = map_values.map_values_anonymous( work_expr, index_from_end, lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool) work_expr, positional_index = index.get_positional_index( work_expr, p, path.get_anonymous_field()) work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous( work_expr, positional_index, lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool) if isinstance(threshold, int): if threshold >= 0: return work_expr, mask_for_non_negative_threshold return work_expr, mask_for_negative_threshold else: def tf_cond_on_threshold(a, b): return tf.cond(tf.greater_equal(threshold, 0), a, b) return map_values.map_many_values(work_expr, p.get_parent(), [ x.field_list[-1] for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
def _map_prensor_impl( root: expression.Expression, root_path: path.Path, paths_needed: Sequence[path.Path], operation: Callable[[prensor.Prensor, calculate_options.Options], prensor.LeafNodeTensor], is_repeated: bool, dtype: tf.DType, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Map prensor implementation.""" child_expr = root.get_descendant_or_error(root_path) sibling_child_expr = project.project(child_expr, paths_needed) new_field_expr = _MapPrensorExpression(sibling_child_expr, operation, is_repeated, dtype) new_path = root_path.get_child(new_field_name) return expression_add.add_paths(root, {new_path: new_field_expr}), new_path
def test_cmp(self): self.assertGreater(Path([1]), Path("foo")) self.assertLess(Path("foo"), Path([1])) self.assertGreater(Path([1]), Path([0])) self.assertGreater(create_path("foo.baz"), create_path("foo")) self.assertGreater(create_path("foo.baz"), create_path("foo.bar")) self.assertLess(create_path("foo"), create_path("foo.bar")) self.assertLess(create_path("foo.bar"), create_path("foo.baz")) self.assertEqual(create_path("foo.baz"), create_path("foo.baz"))
def _size_impl( root: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: if not source_path: raise ValueError("Cannot get the size of the root.") if root.get_descendant(source_path) is None: raise ValueError("Path not found: {}".format(str(source_path))) parent_path = source_path.get_parent() new_path = parent_path.get_child(new_field_name) return expression_add.add_paths( root, { new_path: SizeExpression(root.get_descendant_or_error(source_path), root.get_descendant_or_error(parent_path)) }), new_path
def _promote_impl(root: expression.Expression, p: path.Path, new_field_name: path.Step ) -> Tuple[expression.Expression, path.Path]: if len(p) < 2: raise ValueError("Cannot do a promotion beyond the root: {}".format(str(p))) parent_path = p.get_parent() grandparent_path = parent_path.get_parent() new_path = grandparent_path.get_child(new_field_name) return expression_add.add_paths( root, { new_path: PromoteExpression( root.get_descendant_or_error(p), root.get_descendant_or_error(parent_path)) }), new_path
def promote_and_broadcast_anonymous( root: expression.Expression, origin: path.Path, new_parent: path.Path) -> Tuple[expression.Expression, path.Path]: """Promotes then broadcasts the origin until its parent is new_parent.""" least_common_ancestor = origin.get_least_common_ancestor(new_parent) new_expr, new_path = root, origin while new_path.get_parent() != least_common_ancestor: new_expr, new_path = promote.promote_anonymous(new_expr, new_path) while new_path.get_parent() != new_parent: new_parent_step = new_parent.field_list[len(new_path) - 1] new_expr, new_path = broadcast.broadcast_anonymous( new_expr, new_path, new_parent_step) return new_expr, new_path
def map_values(root: expression.Expression, source_path: path.Path, operation: Callable[[tf.Tensor], tf.Tensor], dtype: tf.DType, new_field_name: path.Step) -> expression.Expression: """Map field into a new sibling. The shape of the output must be the same as the input. Args: root: original root. source_path: source of the operation. operation: operation from source_fields to new field. dtype: type of new field. new_field_name: name of the new field. Returns: The new expression. """ if not source_path: raise ValueError('Cannot map the root.') return map_many_values(root, source_path.get_parent(), [source_path.field_list[-1]], operation, dtype, new_field_name)[0]
def get_positional_index( expr: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Gets the positional index. Given a field with parent_index [0,1,1,2,3,4,4], this returns: parent_index [0,1,1,2,3,4,4] and value [0,0,1,0,0,0,1] Args: expr: original expression source_path: path in expression to get index of. new_field_name: the name of the new field. Returns: The new expression and the new path as a pair. """ new_path = source_path.get_parent().get_child(new_field_name) return expression_add.add_paths( expr, { new_path: _PositionalIndexExpression( expr.get_descendant_or_error(source_path)) }), new_path
def filter_by_child(expr: expression.Expression, p: path.Path, child_field_name: path.Step, new_field_name: path.Step) -> expression.Expression: """Filter an expression by an optional boolean child field. If the child field is present and True, then keep that parent. Otherwise, drop the parent. Args: expr: the original expression p: the path to filter. child_field_name: the boolean child field to use to filter. new_field_name: the new, filtered version of path. Returns: The new root expression. """ origin = expr.get_descendant_or_error(p) child = origin.get_child_or_error(child_field_name) new_expr = _FilterByChildExpression(origin, child) new_path = p.get_parent().get_child(new_field_name) return expression_add.add_paths(expr, {new_path: new_expr})
def map_prensor_to_prensor(root_expr: expression.Expression, source: path.Path, paths_needed: Sequence[path.Path], prensor_op: Callable[[prensor.Prensor], prensor.Prensor], output_schema: Schema) -> expression.Expression: r"""Maps an expression to a prensor, and merges that prensor. For example, suppose you have an op my_op, that takes a prensor of the form: event / \ foo bar and produces a prensor of the form my_result_schema: event / \ foo2 bar2 If you give it an expression original with the schema: session | event / \ foo bar result = map_prensor_to_prensor( original, path.Path(["session","event"]), my_op, my_output_schema) Result will have the schema: session | event-------- / \ \ \ foo bar foo2 bar2 Args: root_expr: the root expression source: the path where the prensor op is applied. paths_needed: the paths needed for the op. prensor_op: the prensor op output_schema: the output schema of the op. Returns: A new expression where the prensor is merged. """ original_child = root_expr.get_descendant_or_error(source).project( paths_needed) prensor_child = _PrensorOpExpression(original_child, prensor_op, output_schema) paths_map = { source.get_child(k): prensor_child.get_child_or_error(k) for k in prensor_child.known_field_names() } result = expression_add.add_paths(root_expr, paths_map) return result