コード例 #1
0
def map_many_values(
        root: expression.Expression, parent_path: path.Path,
        source_fields: Sequence[path.Step],
        operation: Callable[..., tf.Tensor], dtype: tf.DType,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Map multiple sibling fields into a new sibling.

  All source fields must have the same shape, and the shape of the output
  must be the same as well.

  Args:
    root: original root.
    parent_path: parent path of all sources and the new field.
    source_fields: source fields of the operation. Must have the same shape.
    operation: operation from source_fields to new field.
    dtype: type of new field.
    new_field_name: name of the new field.

  Returns:
    The new expression and the new path as a pair.
  """
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(
        root, {
            new_path:
            _MapValuesExpression([
                root.get_descendant_or_error(parent_path.get_child(f))
                for f in source_fields
            ], operation, dtype)
        }), new_path
コード例 #2
0
def _broadcast_impl(
        root: expression.Expression, origin: path.Path, sibling: path.Step,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    sibling_path = origin.get_parent().get_child(sibling)
    new_expr = _BroadcastExpression(
        root.get_descendant_or_error(origin),
        root.get_descendant_or_error(origin.get_parent().get_child(sibling)))
    new_path = sibling_path.get_child(new_field_name)
    return expression_add.add_paths(root, {new_path: new_expr}), new_path
コード例 #3
0
ファイル: promote.py プロジェクト: NeoTim/struct2tensor
def _promote_impl(
        root: expression.Expression, p: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Promotes a path to be a child of its grandparent, and gives it a name.

  Args:
    root: The root expression.
    p: The path to promote. This can be the path to a leaf or child node.
    new_field_name: The name of the promoted field.

  Returns:
    An _AddPathsExpression that wraps a PromoteExpression.
  """
    if len(p) < 2:
        raise ValueError("Cannot do a promotion beyond the root: {}".format(
            str(p)))
    parent_path = p.get_parent()
    grandparent_path = parent_path.get_parent()

    p_expression = root.get_descendant_or_error(p)
    new_path = grandparent_path.get_child(new_field_name)

    if p_expression.is_leaf:
        promote_expression_factory = PromoteExpression
    else:
        promote_expression_factory = PromoteChildExpression

    return expression_add.add_paths(
        root, {
            new_path:
            promote_expression_factory(
                p_expression, root.get_descendant_or_error(parent_path))
        }), new_path
コード例 #4
0
def slice_expression(expr: expression.Expression, p: path.Path,
                     new_field_name: path.Step, begin: Optional[IndexValue],
                     end: Optional[IndexValue]) -> expression.Expression:
    """Creates a new subtree with a sliced expression.

  This follows the pattern of python slice() method.
  See module-level comments for examples.

  Args:
    expr: the original root expression
    p: the path to the source to be sliced.
    new_field_name: the name of the new subtree.
    begin: beginning index
    end: end index.

  Returns:
    A new root expression.
  """
    work_expr, mask_anonymous_path = _get_slice_mask(expr, p, begin, end)
    work_expr = filter_expression.filter_by_sibling(
        work_expr, p, mask_anonymous_path.field_list[-1], new_field_name)
    new_path = p.get_parent().get_child(new_field_name)
    # We created a lot of anonymous fields and intermediate expressions. Just grab
    # the final result (and its children).
    return expression_add.add_to(expr, {new_path: work_expr})
コード例 #5
0
def filter_by_sibling(expr: expression.Expression, p: path.Path,
                      sibling_field_name: path.Step,
                      new_field_name: path.Step) -> expression.Expression:
    """Filter an expression by its sibling.


  This is similar to boolean_mask. The shape of the path being filtered and
  the sibling must be identical (e.g., each parent object must have an
  equal number of source and sibling children).

  Args:
    expr: the root expression.
    p: a path to the source to be filtered.
    sibling_field_name: the sibling to use as a mask.
    new_field_name: a new sibling to create.

  Returns:
    a new root.
  """
    origin = expr.get_descendant_or_error(p)
    parent_path = p.get_parent()
    sibling = expr.get_descendant_or_error(
        parent_path.get_child(sibling_field_name))
    new_expr = _FilterBySiblingExpression(origin, sibling)
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(expr, {new_path: new_expr})
コード例 #6
0
def promote_and_broadcast(
        root: expression.Expression, path_dictionary: Mapping[path.Step,
                                                              path.Path],
        dest_path_parent: path.Path) -> expression.Expression:
    """Promote and broadcast a set of paths to a particular location.

  Args:
    root: the original expression.
    path_dictionary: a map from destination fields to origin paths.
    dest_path_parent: a map from destination strings.

  Returns:
    A new expression, where all the origin paths are promoted and broadcast
    until they are children of dest_path_parent.
  """

    result_paths = {}
    # Here, we branch out and create a different tree for each field that is
    # promoted and broadcast.
    for field_name, origin_path in path_dictionary.items():
        result_path = dest_path_parent.get_child(field_name)
        new_root = _promote_and_broadcast_name(root, origin_path,
                                               dest_path_parent, field_name)
        result_paths[result_path] = new_root
    # We create a new tree that has all of the generated fields from the older
    # trees.
    return expression_add.add_to(root, result_paths)
コード例 #7
0
def _get_leaf_node_path(p: path.Path, t: prensor.Prensor) -> _LeafNodePath:
    """Creates a _LeafNodePath to p."""
    leaf_node = t.get_descendant_or_error(p).node
    if not isinstance(leaf_node, prensor.LeafNodeTensor):
        raise ValueError("Expected Leaf Node at {} in {}".format(
            str(p), str(t)))
    if not p:
        raise ValueError("Leaf should not be at the root")
    # If there is a leaf at the root, this will return a ValueError.
    root_node = _as_root_node_tensor(t.node)

    # Not the root, not p.
    strict_ancestor_paths = [p.prefix(i) for i in range(1, len(p))]

    child_node_pairs = [(t.get_descendant_or_error(ancestor).node, ancestor)
                        for ancestor in strict_ancestor_paths]
    bad_struct_paths = [
        ancestor for node, ancestor in child_node_pairs
        if not isinstance(node, prensor.ChildNodeTensor)
    ]
    if bad_struct_paths:
        raise ValueError("Expected ChildNodeTensor at {} in {}".format(
            " ".join([str(x) for x in bad_struct_paths]), str(t)))
    # This should select all elements: the isinstance is for type-checking.
    child_nodes = [
        node for node, ancestor in child_node_pairs
        if isinstance(node, prensor.ChildNodeTensor)
    ]
    assert len(child_nodes) == len(child_node_pairs)
    return _LeafNodePath(root_node, child_nodes, leaf_node)
コード例 #8
0
ファイル: index.py プロジェクト: pk-organics/struct2tensor
def get_index_from_end(
        t: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Gets the number of steps from the end of the array.

  Given an array ["a", "b", "c"], with indices [0, 1, 2], the result of this
  is [-3,-2,-1].

  Args:
    t: original expression
    source_path: path in expression to get index of.
    new_field_name: the name of the new field.

  Returns:
    The new expression and the new path as a pair.
  """
    new_path = source_path.get_parent().get_child(new_field_name)
    work_expr, positional_index_path = get_positional_index(
        t, source_path, path.get_anonymous_field())
    work_expr, size_path = size.size_anonymous(work_expr, source_path)
    work_expr = expression_add.add_paths(
        work_expr, {
            new_path:
            _PositionalIndexFromEndExpression(
                work_expr.get_descendant_or_error(positional_index_path),
                work_expr.get_descendant_or_error(size_path))
        })
    # Removing the intermediate anonymous nodes.
    result = expression_add.add_to(t, {new_path: work_expr})
    return result, new_path
コード例 #9
0
def _promote_and_broadcast_name(
        root: expression.Expression, origin: path.Path,
        dest_path_parent: path.Path,
        field_name: path.Step) -> expression.Expression:
    new_root, anonymous_path = promote_and_broadcast_anonymous(
        root, origin, dest_path_parent)
    path_result = dest_path_parent.get_child(field_name)
    return expression_add.add_paths(
        new_root,
        {path_result: new_root.get_descendant_or_error(anonymous_path)})
コード例 #10
0
ファイル: broadcast.py プロジェクト: google/struct2tensor
def _broadcast_impl(
        root: expression.Expression, origin: path.Path, sibling: path.Step,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Broadcasts origin to sibling for an expression."""
    sibling_path = origin.get_parent().get_child(sibling)

    origin_expression = root.get_descendant_or_error(origin)

    broadcast_expression_factory = (_BroadcastExpression
                                    if origin_expression.is_leaf else
                                    _BroadcastChildExpression)

    new_expr = broadcast_expression_factory(
        origin_expression,
        root.get_descendant_or_error(origin.get_parent().get_child(sibling)))
    new_path = sibling_path.get_child(new_field_name)
    result = expression_add.add_paths(root, {new_path: new_expr})

    return result, new_path
コード例 #11
0
def _get_slice_mask(
        expr: expression.Expression, p: path.Path, begin: Optional[IndexValue],
        end: Optional[IndexValue]) -> Tuple[expression.Expression, path.Path]:
    """Gets a mask for slicing a path.

  One way to consider the elements of a path "foo.bar" is as a list of list of
  list of elements. Slicing a path slices this doubly nested list of elements,
  based upon positions in its parent list. Each parent list has a size, and
  there is a beginning and end relative to the elements in that list.

  At each path p, there is conceptually a list of...list of elements.

  For example, given:
  an index with respect to its parent
  The range is specified with beginning and an end.
  1. If begin is not present, begin_index is implied to be zero.
  2. If begin is negative, begin_index is the size of a particular
      list + begin
  3. If end is not present, end_index is the length of the list + 1.
  4. If end is negative, end_index is the length of the list + end
  5. If end is non-negative, end_index is end.
  The mask is positive for all elements in range(begin_index, end_index), and
  negative elsewhere.

  The mask returned is a sibling of path p, where for every element in p, there
  is a corresponding element in the mask.

  Args:
    expr: the root expression
    p: the path to be sliced
    begin: the beginning index
    end: the ending index

  Returns:
    An expression,path pair, where the expression contains all the children in
    `expr` and an anonymous field of the mask and the path points to
    the mask field.
  """
    if begin is None:
        if end is None:
            raise ValueError("Must specify begin or end.")
        return _get_end_mask(expr, p, end)
    else:
        if end is None:
            return _get_begin_mask(expr, p, begin)
        work_expr, begin_mask = _get_begin_mask(expr, p, begin)
        work_expr, end_mask = _get_end_mask(work_expr, p, end)
        return map_values.map_many_values(
            work_expr, p.get_parent(),
            [x.field_list[-1]
             for x in [begin_mask, end_mask]], tf.logical_and, tf.bool,
            path.get_anonymous_field())
コード例 #12
0
def _get_mask(
    t: expression.Expression, p: path.Path, threshold: IndexValue,
    relation: Callable[[tf.Tensor, IndexValue], tf.Tensor]
) -> Tuple[expression.Expression, path.Path]:
    """Gets a mask based on a relation of the index to a threshold.

  If the threshold is non-negative, then we create a mask that is true if
  the relation(index, threshold) is true.

  If the threshold is negative, then we create a mask that is true if
  the relation(size - index, threshold) is true.

  Args:
    t: expression to add the field to.
    p: path to create the mask for.
    threshold: the cutoff threshold.
    relation: tf.less or tf.greater_equal.

  Returns:
    A boolean mask on the fields to keep on the model.

  Raises:
    ValueError: if p is not in t.
  """
    if t.get_descendant(p) is None:
        raise ValueError("Path not found: {}".format(str(p)))
    work_expr, index_from_end = index.get_index_from_end(
        t, p, path.get_anonymous_field())
    work_expr, mask_for_negative_threshold = map_values.map_values_anonymous(
        work_expr, index_from_end,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    work_expr, positional_index = index.get_positional_index(
        work_expr, p, path.get_anonymous_field())
    work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous(
        work_expr, positional_index,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    if isinstance(threshold, int):
        if threshold >= 0:
            return work_expr, mask_for_non_negative_threshold
        return work_expr, mask_for_negative_threshold
    else:

        def tf_cond_on_threshold(a, b):
            return tf.cond(tf.greater_equal(threshold, 0), a, b)

        return map_values.map_many_values(work_expr, p.get_parent(), [
            x.field_list[-1] for x in
            [mask_for_non_negative_threshold, mask_for_negative_threshold]
        ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
コード例 #13
0
def _map_prensor_impl(
        root: expression.Expression, root_path: path.Path,
        paths_needed: Sequence[path.Path],
        operation: Callable[[prensor.Prensor, calculate_options.Options],
                            prensor.LeafNodeTensor], is_repeated: bool,
        dtype: tf.DType,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Map prensor implementation."""
    child_expr = root.get_descendant_or_error(root_path)
    sibling_child_expr = project.project(child_expr, paths_needed)
    new_field_expr = _MapPrensorExpression(sibling_child_expr, operation,
                                           is_repeated, dtype)
    new_path = root_path.get_child(new_field_name)
    return expression_add.add_paths(root, {new_path: new_field_expr}), new_path
コード例 #14
0
ファイル: path_test.py プロジェクト: priyansh19/struct2tensor
 def test_cmp(self):
     self.assertGreater(Path([1]), Path("foo"))
     self.assertLess(Path("foo"), Path([1]))
     self.assertGreater(Path([1]), Path([0]))
     self.assertGreater(create_path("foo.baz"), create_path("foo"))
     self.assertGreater(create_path("foo.baz"), create_path("foo.bar"))
     self.assertLess(create_path("foo"), create_path("foo.bar"))
     self.assertLess(create_path("foo.bar"), create_path("foo.baz"))
     self.assertEqual(create_path("foo.baz"), create_path("foo.baz"))
コード例 #15
0
ファイル: size.py プロジェクト: pk-organics/struct2tensor
def _size_impl(
        root: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    if not source_path:
        raise ValueError("Cannot get the size of the root.")
    if root.get_descendant(source_path) is None:
        raise ValueError("Path not found: {}".format(str(source_path)))
    parent_path = source_path.get_parent()
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(
        root, {
            new_path:
            SizeExpression(root.get_descendant_or_error(source_path),
                           root.get_descendant_or_error(parent_path))
        }), new_path
コード例 #16
0
def _promote_impl(root: expression.Expression, p: path.Path,
                  new_field_name: path.Step
                 ) -> Tuple[expression.Expression, path.Path]:
  if len(p) < 2:
    raise ValueError("Cannot do a promotion beyond the root: {}".format(str(p)))
  parent_path = p.get_parent()
  grandparent_path = parent_path.get_parent()
  new_path = grandparent_path.get_child(new_field_name)
  return expression_add.add_paths(
      root, {
          new_path:
              PromoteExpression(
                  root.get_descendant_or_error(p),
                  root.get_descendant_or_error(parent_path))
      }), new_path
コード例 #17
0
def promote_and_broadcast_anonymous(
        root: expression.Expression, origin: path.Path,
        new_parent: path.Path) -> Tuple[expression.Expression, path.Path]:
    """Promotes then broadcasts the origin until its parent is new_parent."""
    least_common_ancestor = origin.get_least_common_ancestor(new_parent)

    new_expr, new_path = root, origin
    while new_path.get_parent() != least_common_ancestor:
        new_expr, new_path = promote.promote_anonymous(new_expr, new_path)

    while new_path.get_parent() != new_parent:
        new_parent_step = new_parent.field_list[len(new_path) - 1]
        new_expr, new_path = broadcast.broadcast_anonymous(
            new_expr, new_path, new_parent_step)

    return new_expr, new_path
コード例 #18
0
def map_values(root: expression.Expression, source_path: path.Path,
               operation: Callable[[tf.Tensor], tf.Tensor], dtype: tf.DType,
               new_field_name: path.Step) -> expression.Expression:
    """Map field into a new sibling.

  The shape of the output must be the same as the input.

  Args:
    root: original root.
    source_path: source of the operation.
    operation: operation from source_fields to new field.
    dtype: type of new field.
    new_field_name: name of the new field.

  Returns:
    The new expression.
  """
    if not source_path:
        raise ValueError('Cannot map the root.')
    return map_many_values(root, source_path.get_parent(),
                           [source_path.field_list[-1]], operation, dtype,
                           new_field_name)[0]
コード例 #19
0
ファイル: index.py プロジェクト: pk-organics/struct2tensor
def get_positional_index(
        expr: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Gets the positional index.

  Given a field with parent_index [0,1,1,2,3,4,4], this returns:
  parent_index [0,1,1,2,3,4,4] and value [0,0,1,0,0,0,1]

  Args:
    expr: original expression
    source_path: path in expression to get index of.
    new_field_name: the name of the new field.

  Returns:
    The new expression and the new path as a pair.
  """
    new_path = source_path.get_parent().get_child(new_field_name)
    return expression_add.add_paths(
        expr, {
            new_path:
            _PositionalIndexExpression(
                expr.get_descendant_or_error(source_path))
        }), new_path
コード例 #20
0
def filter_by_child(expr: expression.Expression, p: path.Path,
                    child_field_name: path.Step,
                    new_field_name: path.Step) -> expression.Expression:
    """Filter an expression by an optional boolean child field.

  If the child field is present and True, then keep that parent.
  Otherwise, drop the parent.

  Args:
    expr: the original expression
    p: the path to filter.
    child_field_name: the boolean child field to use to filter.
    new_field_name: the new, filtered version of path.

  Returns:
    The new root expression.
  """
    origin = expr.get_descendant_or_error(p)
    child = origin.get_child_or_error(child_field_name)
    new_expr = _FilterByChildExpression(origin, child)
    new_path = p.get_parent().get_child(new_field_name)

    return expression_add.add_paths(expr, {new_path: new_expr})
コード例 #21
0
def map_prensor_to_prensor(root_expr: expression.Expression, source: path.Path,
                           paths_needed: Sequence[path.Path],
                           prensor_op: Callable[[prensor.Prensor],
                                                prensor.Prensor],
                           output_schema: Schema) -> expression.Expression:
    r"""Maps an expression to a prensor, and merges that prensor.

  For example, suppose you have an op my_op, that takes a prensor of the form:

    event
     / \
   foo   bar

  and produces a prensor of the form my_result_schema:

     event
      / \
   foo2 bar2

  If you give it an expression original with the schema:

   session
      |
    event
    /  \
  foo   bar

  result = map_prensor_to_prensor(
    original,
    path.Path(["session","event"]),
    my_op,
    my_output_schema)

  Result will have the schema:

   session
      |
    event--------
    /  \    \    \
  foo   bar foo2 bar2

  Args:
    root_expr: the root expression
    source: the path where the prensor op is applied.
    paths_needed: the paths needed for the op.
    prensor_op: the prensor op
    output_schema: the output schema of the op.

  Returns:
    A new expression where the prensor is merged.
  """
    original_child = root_expr.get_descendant_or_error(source).project(
        paths_needed)
    prensor_child = _PrensorOpExpression(original_child, prensor_op,
                                         output_schema)
    paths_map = {
        source.get_child(k): prensor_child.get_child_or_error(k)
        for k in prensor_child.known_field_names()
    }
    result = expression_add.add_paths(root_expr, paths_map)
    return result