def _get_mask( t: expression.Expression, p: path.Path, threshold: IndexValue, relation: Callable[[tf.Tensor, IndexValue], tf.Tensor] ) -> Tuple[expression.Expression, path.Path]: """Gets a mask based on a relation of the index to a threshold. If the threshold is non-negative, then we create a mask that is true if the relation(index, threshold) is true. If the threshold is negative, then we create a mask that is true if the relation(size - index, threshold) is true. Args: t: expression to add the field to. p: path to create the mask for. threshold: the cutoff threshold. relation: tf.less or tf.greater_equal. Returns: A boolean mask on the fields to keep on the model. Raises: ValueError: if p is not in t. """ if t.get_descendant(p) is None: raise ValueError("Path not found: {}".format(str(p))) work_expr, index_from_end = index.get_index_from_end( t, p, path.get_anonymous_field()) work_expr, mask_for_negative_threshold = map_values.map_values_anonymous( work_expr, index_from_end, lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool) work_expr, positional_index = index.get_positional_index( work_expr, p, path.get_anonymous_field()) work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous( work_expr, positional_index, lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool) if isinstance(threshold, int): if threshold >= 0: return work_expr, mask_for_non_negative_threshold return work_expr, mask_for_negative_threshold else: def tf_cond_on_threshold(a, b): return tf.cond(tf.greater_equal(threshold, 0), a, b) return map_values.map_many_values(work_expr, p.get_parent(), [ x.field_list[-1] for x in [mask_for_non_negative_threshold, mask_for_negative_threshold] ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
def get_index_from_end( t: expression.Expression, source_path: path.Path, new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]: """Gets the number of steps from the end of the array. Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this is [-3,-2,-1]. Args: t: original expression source_path: path in expression to get index of. new_field_name: the name of the new field. Returns: The new expression and the new path as a pair. """ new_path = source_path.get_parent().get_child(new_field_name) work_expr, positional_index_path = get_positional_index( t, source_path, path.get_anonymous_field()) work_expr, size_path = size.size_anonymous(work_expr, source_path) work_expr = expression_add.add_paths( work_expr, { new_path: _PositionalIndexFromEndExpression( work_expr.get_descendant_or_error(positional_index_path), work_expr.get_descendant_or_error(size_path)) }) # Removing the intermediate anonymous nodes. result = expression_add.add_to(t, {new_path: work_expr}) return result, new_path
def test_get_positional_index_calculate(self): expr = create_expression.create_expression_from_prensor( prensor_test_util.create_nested_prensor()) new_root, new_path = index.get_positional_index( expr, path.Path(["user", "friends"]), path.get_anonymous_field()) new_field = new_root.get_descendant_or_error(new_path) leaf_node = expression_test_util.calculate_value_slowly(new_field) self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) self.assertAllEqual(leaf_node.values, [0, 0, 1, 0, 0])
def size_anonymous(root, source_path): """Calculate the size of a field, and store it as an anonymous sibling. Args: root: the original expression. source_path: the source path to measure. Cannot be root. Returns: The new expression and the new field as a pair. """ return _size_impl(root, source_path, path.get_anonymous_field())
def _get_slice_mask( expr: expression.Expression, p: path.Path, begin: Optional[IndexValue], end: Optional[IndexValue]) -> Tuple[expression.Expression, path.Path]: """Gets a mask for slicing a path. One way to consider the elements of a path "foo.bar" is as a list of list of list of elements. Slicing a path slices this doubly nested list of elements, based upon positions in its parent list. Each parent list has a size, and there is a beginning and end relative to the elements in that list. At each path p, there is conceptually a list of...list of elements. For example, given: an index with respect to its parent The range is specified with beginning and an end. 1. If begin is not present, begin_index is implied to be zero. 2. If begin is negative, begin_index is the size of a particular list + begin 3. If end is not present, end_index is the length of the list + 1. 4. If end is negative, end_index is the length of the list + end 5. If end is non-negative, end_index is end. The mask is positive for all elements in range(begin_index, end_index), and negative elsewhere. The mask returned is a sibling of path p, where for every element in p, there is a corresponding element in the mask. Args: expr: the root expression p: the path to be sliced begin: the beginning index end: the ending index Returns: An expression,path pair, where the expression contains all the children in `expr` and an anonymous field of the mask and the path points to the mask field. """ if begin is None: if end is None: raise ValueError("Must specify begin or end.") return _get_end_mask(expr, p, end) else: if end is None: return _get_begin_mask(expr, p, begin) work_expr, begin_mask = _get_begin_mask(expr, p, begin) work_expr, end_mask = _get_end_mask(work_expr, p, end) return map_values.map_many_values( work_expr, p.get_parent(), [x.field_list[-1] for x in [begin_mask, end_mask]], tf.logical_and, tf.bool, path.get_anonymous_field())
def test_get_index_from_end(self): expr = create_expression.create_expression_from_prensor( prensor_test_util.create_nested_prensor()) new_root, new_path = index.get_index_from_end( expr, path.Path(["user", "friends"]), path.get_anonymous_field()) new_field = new_root.get_descendant_or_error(new_path) self.assertTrue(new_field.is_repeated) self.assertEqual(new_field.type, tf.int64) self.assertTrue(new_field.is_leaf) self.assertTrue(new_field.calculation_equal(new_field)) self.assertFalse(new_field.calculation_equal(expr)) leaf_node = expression_test_util.calculate_value_slowly(new_field) self.assertEqual(leaf_node.values.dtype, tf.int64) self.assertEqual(new_field.known_field_names(), frozenset())
def test_get_index_from_end_calculate(self): expr = create_expression.create_expression_from_prensor( prensor_test_util.create_nested_prensor()) new_root, new_path = index.get_index_from_end( expr, path.Path(["user", "friends"]), path.get_anonymous_field()) print("test_get_index_from_end_calculate: new_path: {}".format(new_path)) new_field = new_root.get_descendant_or_error(new_path) print("test_get_index_from_end_calculate: new_field: {}".format( str(new_field))) leaf_node = expression_test_util.calculate_value_slowly(new_field) print("test_get_index_from_end_calculate: leaf_node: {}".format( str(leaf_node))) self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3]) self.assertAllEqual(leaf_node.values, [-1, -2, -1, -1, -1])
def map_values_anonymous(root, source_path, operation, dtype): """Map field into a new sibling. The shape of the output must be the same as the input. Args: root: original root. source_path: source of the operation. operation: operation from source_fields to new field. dtype: type of new field. Returns: The new expression and the new path as a pair. """ if not source_path: raise ValueError('Cannot map the root.') return map_many_values(root, source_path.get_parent(), [source_path.field_list[-1]], operation, dtype, path.get_anonymous_field())
def promote_anonymous(root: expression.Expression, p: path.Path) -> Tuple[expression.Expression, path.Path]: """Promote a path to be a new anonymous child of its grandparent.""" return _promote_impl(root, p, path.get_anonymous_field())
def broadcast_anonymous(root, origin, sibling): return _broadcast_impl(root, origin, sibling, path.get_anonymous_field())
def broadcast_anonymous( root: expression.Expression, origin: path.Path, sibling: path.Step) -> Tuple[expression.Expression, path.Path]: return _broadcast_impl(root, origin, sibling, path.get_anonymous_field())
def promote_anonymous(root, p): """Promote a path to be a new anonymous child of its grandparent.""" return _promote_impl(root, p, path.get_anonymous_field())