def _broadcast_impl( root: expression.Expression, origin: path.Path, sibling: path.Step, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: sibling_path = origin.get_parent().get_child(sibling) new_expr = _BroadcastExpression( root.get_descendant_or_error(origin), root.get_descendant_or_error(origin.get_parent().get_child(sibling))) new_path = sibling_path.get_child(new_field_name) return expression_add.add_paths(root, {new_path: new_expr}), new_path
def get_index_from_end( t: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Gets the number of steps from the end of the array. Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this is [-3,-2,-1]. Args: t: original expression source_path: path in expression to get index of. new_field_name: the name of the new field. Returns: The new expression and the new path as a pair. """ new_path = source_path.get_parent().get_child(new_field_name) work_expr, positional_index_path = get_positional_index( t, source_path, path.get_anonymous_field()) work_expr, size_path = size.size_anonymous(work_expr, source_path) work_expr = expression_add.add_paths( work_expr, { new_path: _PositionalIndexFromEndExpression( work_expr.get_descendant_or_error(positional_index_path), work_expr.get_descendant_or_error(size_path)) }) # Removing the intermediate anonymous nodes. result = expression_add.add_to(t, {new_path: work_expr}) return result, new_path
def _promote_impl( root: expression.Expression, p: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Promotes a path to be a child of its grandparent, and gives it a name. Args: root: The root expression. p: The path to promote. This can be the path to a leaf or child node. new_field_name: The name of the promoted field. Returns: An _AddPathsExpression that wraps a PromoteExpression. """ if len(p) < 2: raise ValueError("Cannot do a promotion beyond the root: {}".format( str(p))) parent_path = p.get_parent() grandparent_path = parent_path.get_parent() p_expression = root.get_descendant_or_error(p) new_path = grandparent_path.get_child(new_field_name) if p_expression.is_leaf: promote_expression_factory = PromoteExpression else: promote_expression_factory = PromoteChildExpression return expression_add.add_paths( root, { new_path: promote_expression_factory( p_expression, root.get_descendant_or_error(parent_path)) }), new_path
def slice_expression(expr: expression.Expression, p: path.Path, new_field_name: path.Step, begin: Optional[IndexValue], end: Optional[IndexValue]) -> expression.Expression: """Creates a new subtree with a sliced expression. This follows the pattern of python slice() method. See module-level comments for examples. Args: expr: the original root expression p: the path to the source to be sliced. new_field_name: the name of the new subtree. begin: beginning index end: end index. Returns: A new root expression. """ work_expr, mask_anonymous_path = _get_slice_mask(expr, p, begin, end) work_expr = filter_expression.filter_by_sibling( work_expr, p, mask_anonymous_path.field_list[-1], new_field_name) new_path = p.get_parent().get_child(new_field_name) # We created a lot of anonymous fields and intermediate expressions. Just grab # the final result (and its children). return expression_add.add_to(expr, {new_path: work_expr})
def filter_by_sibling(expr: expression.Expression, p: path.Path, sibling_field_name: path.Step, new_field_name: path.Step) -> expression.Expression: """Filter an expression by its sibling. This is similar to boolean_mask. The shape of the path being filtered and the sibling must be identical (e.g., each parent object must have an equal number of source and sibling children). Args: expr: the root expression. p: a path to the source to be filtered. sibling_field_name: the sibling to use as a mask. new_field_name: a new sibling to create. Returns: a new root. """ origin = expr.get_descendant_or_error(p) parent_path = p.get_parent() sibling = expr.get_descendant_or_error( parent_path.get_child(sibling_field_name)) new_expr = _FilterBySiblingExpression(origin, sibling) new_path = parent_path.get_child(new_field_name) return expression_add.add_paths(expr, {new_path: new_expr})
def _broadcast_impl( root: expression.Expression, origin: path.Path, sibling: path.Step, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Broadcasts origin to sibling for an expression.""" sibling_path = origin.get_parent().get_child(sibling) origin_expression = root.get_descendant_or_error(origin) broadcast_expression_factory = (_BroadcastExpression if origin_expression.is_leaf else _BroadcastChildExpression) new_expr = broadcast_expression_factory( origin_expression, root.get_descendant_or_error(origin.get_parent().get_child(sibling))) new_path = sibling_path.get_child(new_field_name) result = expression_add.add_paths(root, {new_path: new_expr}) return result, new_path
def _get_slice_mask( expr: expression.Expression, p: path.Path, begin: Optional[IndexValue], end: Optional[IndexValue]) -> Tuple[expression.Expression, path.Path]: """Gets a mask for slicing a path. One way to consider the elements of a path "foo.bar" is as a list of list of list of elements. Slicing a path slices this doubly nested list of elements, based upon positions in its parent list. Each parent list has a size, and there is a beginning and end relative to the elements in that list. At each path p, there is conceptually a list of...list of elements. For example, given: an index with respect to its parent The range is specified with beginning and an end. 1. If begin is not present, begin_index is implied to be zero. 2. If begin is negative, begin_index is the size of a particular list + begin 3. If end is not present, end_index is the length of the list + 1. 4. If end is negative, end_index is the length of the list + end 5. If end is non-negative, end_index is end. The mask is positive for all elements in range(begin_index, end_index), and negative elsewhere. The mask returned is a sibling of path p, where for every element in p, there is a corresponding element in the mask. Args: expr: the root expression p: the path to be sliced begin: the beginning index end: the ending index Returns: An expression,path pair, where the expression contains all the children in `expr` and an anonymous field of the mask and the path points to the mask field. """ if begin is None: if end is None: raise ValueError("Must specify begin or end.") return _get_end_mask(expr, p, end) else: if end is None: return _get_begin_mask(expr, p, begin) work_expr, begin_mask = _get_begin_mask(expr, p, begin) work_expr, end_mask = _get_end_mask(work_expr, p, end) return map_values.map_many_values( work_expr, p.get_parent(), [x.field_list[-1] for x in [begin_mask, end_mask]], tf.logical_and, tf.bool, path.get_anonymous_field())
def _get_mask( t: expression.Expression, p: path.Path, threshold: IndexValue, relation: Callable[[tf.Tensor, IndexValue], tf.Tensor] ) -> Tuple[expression.Expression, path.Path]: """Gets a mask based on a relation of the index to a threshold. If the threshold is non-negative, then we create a mask that is true if the relation(index, threshold) is true. If the threshold is negative, then we create a mask that is true if the relation(size - index, threshold) is true. Args: t: expression to add the field to. p: path to create the mask for. threshold: the cutoff threshold. relation: tf.less or tf.greater_equal. Returns: A boolean mask on the fields to keep on the model. Raises: ValueError: if p is not in t. """ if t.get_descendant(p) is None: raise ValueError("Path not found: {}".format(str(p))) work_expr, index_from_end = index.get_index_from_end( t, p, path.get_anonymous_field()) work_expr, mask_for_negative_threshold = map_values.map_values_anonymous( work_expr, index_from_end, lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool) work_expr, positional_index = index.get_positional_index( work_expr, p, path.get_anonymous_field()) work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous( work_expr, positional_index, lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool) if isinstance(threshold, int): if threshold >= 0: return work_expr, mask_for_non_negative_threshold return work_expr, mask_for_negative_threshold else: def tf_cond_on_threshold(a, b): return tf.cond(tf.greater_equal(threshold, 0), a, b) return map_values.map_many_values(work_expr, p.get_parent(), [ x.field_list[-1] for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
def _promote_impl(root: expression.Expression, p: path.Path, new_field_name: path.Step ) -> Tuple[expression.Expression, path.Path]: if len(p) < 2: raise ValueError("Cannot do a promotion beyond the root: {}".format(str(p))) parent_path = p.get_parent() grandparent_path = parent_path.get_parent() new_path = grandparent_path.get_child(new_field_name) return expression_add.add_paths( root, { new_path: PromoteExpression( root.get_descendant_or_error(p), root.get_descendant_or_error(parent_path)) }), new_path
def _size_impl( root: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: if not source_path: raise ValueError("Cannot get the size of the root.") if root.get_descendant(source_path) is None: raise ValueError("Path not found: {}".format(str(source_path))) parent_path = source_path.get_parent() new_path = parent_path.get_child(new_field_name) return expression_add.add_paths( root, { new_path: SizeExpression(root.get_descendant_or_error(source_path), root.get_descendant_or_error(parent_path)) }), new_path
def map_values(root: expression.Expression, source_path: path.Path, operation: Callable[[tf.Tensor], tf.Tensor], dtype: tf.DType, new_field_name: path.Step) -> expression.Expression: """Map field into a new sibling. The shape of the output must be the same as the input. Args: root: original root. source_path: source of the operation. operation: operation from source_fields to new field. dtype: type of new field. new_field_name: name of the new field. Returns: The new expression. """ if not source_path: raise ValueError('Cannot map the root.') return map_many_values(root, source_path.get_parent(), [source_path.field_list[-1]], operation, dtype, new_field_name)[0]
def get_positional_index( expr: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Gets the positional index. Given a field with parent_index [0,1,1,2,3,4,4], this returns: parent_index [0,1,1,2,3,4,4] and value [0,0,1,0,0,0,1] Args: expr: original expression source_path: path in expression to get index of. new_field_name: the name of the new field. Returns: The new expression and the new path as a pair. """ new_path = source_path.get_parent().get_child(new_field_name) return expression_add.add_paths( expr, { new_path: _PositionalIndexExpression( expr.get_descendant_or_error(source_path)) }), new_path
def filter_by_child(expr: expression.Expression, p: path.Path, child_field_name: path.Step, new_field_name: path.Step) -> expression.Expression: """Filter an expression by an optional boolean child field. If the child field is present and True, then keep that parent. Otherwise, drop the parent. Args: expr: the original expression p: the path to filter. child_field_name: the boolean child field to use to filter. new_field_name: the new, filtered version of path. Returns: The new root expression. """ origin = expr.get_descendant_or_error(p) child = origin.get_child_or_error(child_field_name) new_expr = _FilterByChildExpression(origin, child) new_path = p.get_parent().get_child(new_field_name) return expression_add.add_paths(expr, {new_path: new_expr})