示例#1
0
def _get_mask(
    t: expression.Expression, p: path.Path, threshold: IndexValue,
    relation: Callable[[tf.Tensor, IndexValue], tf.Tensor]
) -> Tuple[expression.Expression, path.Path]:
    """Gets a mask based on a relation of the index to a threshold.

  If the threshold is non-negative, then we create a mask that is true if
  the relation(index, threshold) is true.

  If the threshold is negative, then we create a mask that is true if
  the relation(size - index, threshold) is true.

  Args:
    t: expression to add the field to.
    p: path to create the mask for.
    threshold: the cutoff threshold.
    relation: tf.less or tf.greater_equal.

  Returns:
    A boolean mask on the fields to keep on the model.

  Raises:
    ValueError: if p is not in t.
  """
    if t.get_descendant(p) is None:
        raise ValueError("Path not found: {}".format(str(p)))
    work_expr, index_from_end = index.get_index_from_end(
        t, p, path.get_anonymous_field())
    work_expr, mask_for_negative_threshold = map_values.map_values_anonymous(
        work_expr, index_from_end,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    work_expr, positional_index = index.get_positional_index(
        work_expr, p, path.get_anonymous_field())
    work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous(
        work_expr, positional_index,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    if isinstance(threshold, int):
        if threshold >= 0:
            return work_expr, mask_for_non_negative_threshold
        return work_expr, mask_for_negative_threshold
    else:

        def tf_cond_on_threshold(a, b):
            return tf.cond(tf.greater_equal(threshold, 0), a, b)

        return map_values.map_many_values(work_expr, p.get_parent(), [
            x.field_list[-1] for x in
            [mask_for_non_negative_threshold, mask_for_negative_threshold]
        ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
示例#2
0
def get_index_from_end(
        t: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Gets the number of steps from the end of the array.

  Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this
  is [-3,-2,-1].

  Args:
    t: original expression
    source_path: path in expression to get index of.
    new_field_name: the name of the new field.

  Returns:
    The new expression and the new path as a pair.
  """
    new_path = source_path.get_parent().get_child(new_field_name)
    work_expr, positional_index_path = get_positional_index(
        t, source_path, path.get_anonymous_field())
    work_expr, size_path = size.size_anonymous(work_expr, source_path)
    work_expr = expression_add.add_paths(
        work_expr, {
            new_path:
            _PositionalIndexFromEndExpression(
                work_expr.get_descendant_or_error(positional_index_path),
                work_expr.get_descendant_or_error(size_path))
        })
    # Removing the intermediate anonymous nodes.
    result = expression_add.add_to(t, {new_path: work_expr})
    return result, new_path
示例#3
0
 def test_get_positional_index_calculate(self):
   expr = create_expression.create_expression_from_prensor(
       prensor_test_util.create_nested_prensor())
   new_root, new_path = index.get_positional_index(
       expr, path.Path(["user", "friends"]), path.get_anonymous_field())
   new_field = new_root.get_descendant_or_error(new_path)
   leaf_node = expression_test_util.calculate_value_slowly(new_field)
   self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3])
   self.assertAllEqual(leaf_node.values, [0, 0, 1, 0, 0])
示例#4
0
def size_anonymous(root, source_path):
    """Calculate the size of a field, and store it as an anonymous sibling.

  Args:
    root: the original expression.
    source_path: the source path to measure. Cannot be root.

  Returns:
    The new expression and the new field as a pair.
  """
    return _size_impl(root, source_path, path.get_anonymous_field())
示例#5
0
def _get_slice_mask(
        expr: expression.Expression, p: path.Path, begin: Optional[IndexValue],
        end: Optional[IndexValue]) -> Tuple[expression.Expression, path.Path]:
    """Gets a mask for slicing a path.

  One way to consider the elements of a path "foo.bar" is as a list of list of
  list of elements. Slicing a path slices this doubly nested list of elements,
  based upon positions in its parent list. Each parent list has a size, and
  there is a beginning and end relative to the elements in that list.

  At each path p, there is conceptually a list of...list of elements.

  For example, given:
  an index with respect to its parent
  The range is specified with beginning and an end.
  1. If begin is not present, begin_index is implied to be zero.
  2. If begin is negative, begin_index is the size of a particular
      list + begin
  3. If end is not present, end_index is the length of the list + 1.
  4. If end is negative, end_index is the length of the list + end
  5. If end is non-negative, end_index is end.
  The mask is positive for all elements in range(begin_index, end_index), and
  negative elsewhere.

  The mask returned is a sibling of path p, where for every element in p, there
  is a corresponding element in the mask.

  Args:
    expr: the root expression
    p: the path to be sliced
    begin: the beginning index
    end: the ending index

  Returns:
    An expression,path pair, where the expression contains all the children in
    `expr` and an anonymous field of the mask and the path points to
    the mask field.
  """
    if begin is None:
        if end is None:
            raise ValueError("Must specify begin or end.")
        return _get_end_mask(expr, p, end)
    else:
        if end is None:
            return _get_begin_mask(expr, p, begin)
        work_expr, begin_mask = _get_begin_mask(expr, p, begin)
        work_expr, end_mask = _get_end_mask(work_expr, p, end)
        return map_values.map_many_values(
            work_expr, p.get_parent(),
            [x.field_list[-1]
             for x in [begin_mask, end_mask]], tf.logical_and, tf.bool,
            path.get_anonymous_field())
示例#6
0
 def test_get_index_from_end(self):
   expr = create_expression.create_expression_from_prensor(
       prensor_test_util.create_nested_prensor())
   new_root, new_path = index.get_index_from_end(
       expr, path.Path(["user", "friends"]), path.get_anonymous_field())
   new_field = new_root.get_descendant_or_error(new_path)
   self.assertTrue(new_field.is_repeated)
   self.assertEqual(new_field.type, tf.int64)
   self.assertTrue(new_field.is_leaf)
   self.assertTrue(new_field.calculation_equal(new_field))
   self.assertFalse(new_field.calculation_equal(expr))
   leaf_node = expression_test_util.calculate_value_slowly(new_field)
   self.assertEqual(leaf_node.values.dtype, tf.int64)
   self.assertEqual(new_field.known_field_names(), frozenset())
示例#7
0
  def test_get_index_from_end_calculate(self):
    expr = create_expression.create_expression_from_prensor(
        prensor_test_util.create_nested_prensor())
    new_root, new_path = index.get_index_from_end(
        expr, path.Path(["user", "friends"]), path.get_anonymous_field())
    print("test_get_index_from_end_calculate: new_path: {}".format(new_path))
    new_field = new_root.get_descendant_or_error(new_path)
    print("test_get_index_from_end_calculate: new_field: {}".format(
        str(new_field)))

    leaf_node = expression_test_util.calculate_value_slowly(new_field)
    print("test_get_index_from_end_calculate: leaf_node: {}".format(
        str(leaf_node)))

    self.assertAllEqual(leaf_node.parent_index, [0, 1, 1, 2, 3])
    self.assertAllEqual(leaf_node.values, [-1, -2, -1, -1, -1])
示例#8
0
def map_values_anonymous(root, source_path, operation, dtype):
    """Map field into a new sibling.

  The shape of the output must be the same as the input.

  Args:
    root: original root.
    source_path: source of the operation.
    operation: operation from source_fields to new field.
    dtype: type of new field.

  Returns:
    The new expression and the new path as a pair.
  """
    if not source_path:
        raise ValueError('Cannot map the root.')
    return map_many_values(root, source_path.get_parent(),
                           [source_path.field_list[-1]], operation, dtype,
                           path.get_anonymous_field())
示例#9
0
def promote_anonymous(root: expression.Expression,
                      p: path.Path) -> Tuple[expression.Expression, path.Path]:
    """Promote a path to be a new anonymous child of its grandparent."""
    return _promote_impl(root, p, path.get_anonymous_field())
示例#10
0
def broadcast_anonymous(root, origin, sibling):
    return _broadcast_impl(root, origin, sibling, path.get_anonymous_field())
示例#11
0
def broadcast_anonymous(
        root: expression.Expression, origin: path.Path,
        sibling: path.Step) -> Tuple[expression.Expression, path.Path]:
    return _broadcast_impl(root, origin, sibling, path.get_anonymous_field())
示例#12
0
def promote_anonymous(root, p):
    """Promote a path to be a new anonymous child of its grandparent."""
    return _promote_impl(root, p, path.get_anonymous_field())