Пример #1
0
def add_paths(root: expression.Expression,
              path_map: Mapping[path.Path, expression.Expression]):
    """Creates a new expression based on `root` with paths in `path_map` added.

  This operation should be used with care: e.g., there is no guarantee that
  the new expressions make any sense in the place they are put in the tree. It
  is useful when wrapping a new Expression type, but should not be used by
  end users.

  Prefer add_to to add_paths.

  Args:
    root: the root of the tree.
    path_map: a map from a path to the new subtree.

  Returns:
    a new tree with the nodes from the root and the new subtrees.
  """
    for p in path_map.keys():
        if root.get_descendant(p.get_parent()) is None:
            raise ValueError("No parent of {}".format(p))
        if root.get_descendant(p) is not None:
            raise ValueError("Path already set: {}".format(str(p)))
    _, map_of_maps = create_subtrees(path_map)
    return _AddPathsExpression(root, map_of_maps)
Пример #2
0
def _promote_impl(
        root: expression.Expression, p: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Promotes a path to be a child of its grandparent, and gives it a name.

  Args:
    root: The root expression.
    p: The path to promote. This can be the path to a leaf or child node.
    new_field_name: The name of the promoted field.

  Returns:
    An _AddPathsExpression that wraps a PromoteExpression.
  """
    if len(p) < 2:
        raise ValueError("Cannot do a promotion beyond the root: {}".format(
            str(p)))
    parent_path = p.get_parent()
    grandparent_path = parent_path.get_parent()

    p_expression = root.get_descendant_or_error(p)
    new_path = grandparent_path.get_child(new_field_name)

    if p_expression.is_leaf:
        promote_expression_factory = PromoteExpression
    else:
        promote_expression_factory = PromoteChildExpression

    return expression_add.add_paths(
        root, {
            new_path:
            promote_expression_factory(
                p_expression, root.get_descendant_or_error(parent_path))
        }), new_path
Пример #3
0
def filter_by_sibling(expr: expression.Expression, p: path.Path,
                      sibling_field_name: path.Step,
                      new_field_name: path.Step) -> expression.Expression:
    """Filter an expression by its sibling.


  This is similar to boolean_mask. The shape of the path being filtered and
  the sibling must be identical (e.g., each parent object must have an
  equal number of source and sibling children).

  Args:
    expr: the root expression.
    p: a path to the source to be filtered.
    sibling_field_name: the sibling to use as a mask.
    new_field_name: a new sibling to create.

  Returns:
    a new root.
  """
    origin = expr.get_descendant_or_error(p)
    parent_path = p.get_parent()
    sibling = expr.get_descendant_or_error(
        parent_path.get_child(sibling_field_name))
    new_expr = _FilterBySiblingExpression(origin, sibling)
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(expr, {new_path: new_expr})
Пример #4
0
def add_to(root: expression.Expression,
           origins: Mapping[path.Path, expression.Expression]):
    """Copies subtrees from the origins to the root.

  This operation can be used to reduce the number of expressions in the graph.
  1. The path must be present in the associated origin.
  2. The root must not have any of the paths already.
  3. The root must already have the parent of the path.
  4. The parent of the path in the root must be a source expression of the
     parent of the path in the origin.

  Args:
    root: the original tree that has new expressions added to it.
    origins: mapping from path to trees that have subtrees at the path.

  Returns:
    A tree with the root and the additional subtrees.
  """
    for p, origin_root in origins.items():
        path_parent = p.get_parent()
        if not _is_true_source_expression(
                root.get_descendant_or_error(path_parent),
                origin_root.get_descendant_or_error(path_parent)):
            raise ValueError("Not a true source for tree with {}".format(
                str(p)))
        if root.get_descendant(p) is not None:
            raise ValueError("Already contains {}.".format(str(p)))
    path_map = {
        p: origin_root.get_descendant_or_error(p)
        for p, origin_root in origins.items()
    }
    return add_paths(root, path_map)
Пример #5
0
def calculate_value_slowly(
        expr: expression.Expression,
        destinations: Optional[Sequence[expression.Expression]] = None,
        options: Optional[calculate_options.Options] = None
) -> prensor.NodeTensor:
    """A calculation of the node tensor of an expression, without optimization.

  This will not do any common subexpression elimination or caching of
  node tensors, and will likely be very slow for larger operations.

  Args:
    expr: The expression to calculate.
    destinations: Where the calculation will be used (None implies directly)
    options: Calculation options for individual calculations

  Returns:
    The node tensor of the expression.
  """
    new_options = calculate_options.get_default_options(
    ) if options is None else options

    source_node_tensors = [
        calculate_value_slowly(x, [expr], new_options)
        for x in expr.get_source_expressions()
    ]
    real_dest = [] if destinations is None else destinations
    return expr.calculate(source_node_tensors, real_dest, new_options)
Пример #6
0
def _broadcast_impl(
        root: expression.Expression, origin: path.Path, sibling: path.Step,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    sibling_path = origin.get_parent().get_child(sibling)
    new_expr = _BroadcastExpression(
        root.get_descendant_or_error(origin),
        root.get_descendant_or_error(origin.get_parent().get_child(sibling)))
    new_path = sibling_path.get_child(new_field_name)
    return expression_add.add_paths(root, {new_path: new_expr}), new_path
Пример #7
0
def _promote_impl(root: expression.Expression, p: path.Path,
                  new_field_name: path.Step
                 ) -> Tuple[expression.Expression, path.Path]:
  if len(p) < 2:
    raise ValueError("Cannot do a promotion beyond the root: {}".format(str(p)))
  parent_path = p.get_parent()
  grandparent_path = parent_path.get_parent()
  new_path = grandparent_path.get_child(new_field_name)
  return expression_add.add_paths(
      root, {
          new_path:
              PromoteExpression(
                  root.get_descendant_or_error(p),
                  root.get_descendant_or_error(parent_path))
      }), new_path
Пример #8
0
def _size_impl(
        root: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    if not source_path:
        raise ValueError("Cannot get the size of the root.")
    if root.get_descendant(source_path) is None:
        raise ValueError("Path not found: {}".format(str(source_path)))
    parent_path = source_path.get_parent()
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(
        root, {
            new_path:
            SizeExpression(root.get_descendant_or_error(source_path),
                           root.get_descendant_or_error(parent_path))
        }), new_path
Пример #9
0
def map_many_values(
        root: expression.Expression, parent_path: path.Path,
        source_fields: Sequence[path.Step],
        operation: Callable[..., tf.Tensor], dtype: tf.DType,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Map multiple sibling fields into a new sibling.

  All source fields must have the same shape, and the shape of the output
  must be the same as well.

  Args:
    root: original root.
    parent_path: parent path of all sources and the new field.
    source_fields: source fields of the operation. Must have the same shape.
    operation: operation from source_fields to new field.
    dtype: type of new field.
    new_field_name: name of the new field.

  Returns:
    The new expression and the new path as a pair.
  """
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(
        root, {
            new_path:
            _MapValuesExpression([
                root.get_descendant_or_error(parent_path.get_child(f))
                for f in source_fields
            ], operation, dtype)
        }), new_path
Пример #10
0
 def __init__(self, original_root: expression.Expression,
              field_name: path.Step):
     super().__init__(True, None)
     self._field_name = field_name
     self._original_root = original_root
     self._new_root = original_root.get_child_or_error(field_name)
     if self._new_root.type is not None:
         raise ValueError("New root must be a message type: {}".format(
             str(self._field_name)))
Пример #11
0
def _broadcast_impl(
        root: expression.Expression, origin: path.Path, sibling: path.Step,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Broadcasts origin to sibling for an expression."""
    sibling_path = origin.get_parent().get_child(sibling)

    origin_expression = root.get_descendant_or_error(origin)

    broadcast_expression_factory = (_BroadcastExpression
                                    if origin_expression.is_leaf else
                                    _BroadcastChildExpression)

    new_expr = broadcast_expression_factory(
        origin_expression,
        root.get_descendant_or_error(origin.get_parent().get_child(sibling)))
    new_path = sibling_path.get_child(new_field_name)
    result = expression_add.add_paths(root, {new_path: new_expr})

    return result, new_path
Пример #12
0
def _get_mask(
    t: expression.Expression, p: path.Path, threshold: IndexValue,
    relation: Callable[[tf.Tensor, IndexValue], tf.Tensor]
) -> Tuple[expression.Expression, path.Path]:
    """Gets a mask based on a relation of the index to a threshold.

  If the threshold is non-negative, then we create a mask that is true if
  the relation(index, threshold) is true.

  If the threshold is negative, then we create a mask that is true if
  the relation(size - index, threshold) is true.

  Args:
    t: expression to add the field to.
    p: path to create the mask for.
    threshold: the cutoff threshold.
    relation: tf.less or tf.greater_equal.

  Returns:
    A boolean mask on the fields to keep on the model.

  Raises:
    ValueError: if p is not in t.
  """
    if t.get_descendant(p) is None:
        raise ValueError("Path not found: {}".format(str(p)))
    work_expr, index_from_end = index.get_index_from_end(
        t, p, path.get_anonymous_field())
    work_expr, mask_for_negative_threshold = map_values.map_values_anonymous(
        work_expr, index_from_end,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    work_expr, positional_index = index.get_positional_index(
        work_expr, p, path.get_anonymous_field())
    work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous(
        work_expr, positional_index,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    if isinstance(threshold, int):
        if threshold >= 0:
            return work_expr, mask_for_non_negative_threshold
        return work_expr, mask_for_negative_threshold
    else:

        def tf_cond_on_threshold(a, b):
            return tf.cond(tf.greater_equal(threshold, 0), a, b)

        return map_values.map_many_values(work_expr, p.get_parent(), [
            x.field_list[-1] for x in
            [mask_for_non_negative_threshold, mask_for_negative_threshold]
        ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
Пример #13
0
def _map_prensor_impl(
        root: expression.Expression, root_path: path.Path,
        paths_needed: Sequence[path.Path],
        operation: Callable[[prensor.Prensor, calculate_options.Options],
                            prensor.LeafNodeTensor], is_repeated: bool,
        dtype: tf.DType,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Map prensor implementation."""
    child_expr = root.get_descendant_or_error(root_path)
    sibling_child_expr = project.project(child_expr, paths_needed)
    new_field_expr = _MapPrensorExpression(sibling_child_expr, operation,
                                           is_repeated, dtype)
    new_path = root_path.get_child(new_field_name)
    return expression_add.add_paths(root, {new_path: new_field_expr}), new_path
Пример #14
0
    def __init__(self, expr: expression.Expression):
        """Construct a node in the graph.

    Args:
      expr: must be the result of _get_earliest_equal_calculation(...)
    """

        self.expression = expr
        self.sources = [
            _get_earliest_equal_calculation(x)
            for x in expr.get_source_expressions()
        ]
        self.destinations = []  # type: List[_ExpressionNode]
        self.value = None
Пример #15
0
def project(expr: expression.Expression,
            paths: Sequence[path.Path]) -> expression.Expression:
  """select a subtree.

  Paths not selected are removed.
  Paths that are selected are "known", such that if calculate_prensors is
  called, they will be in the result.

  Args:
    expr: the original expression.
    paths: the paths to include.

  Returns:
    A projected expression.
  """
  missing_paths = [p for p in paths if expr.get_descendant(p) is None]
  if missing_paths:
    raise ValueError("{} Path(s) missing in project: {}".format(
        len(missing_paths), ", ".join([str(x) for x in missing_paths])))
  return _ProjectExpression(expr, paths)
Пример #16
0
def get_positional_index(
        expr: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    """Gets the positional index.

  Given a field with parent_index [0,1,1,2,3,4,4], this returns:
  parent_index [0,1,1,2,3,4,4] and value [0,0,1,0,0,0,1]

  Args:
    expr: original expression
    source_path: path in expression to get index of.
    new_field_name: the name of the new field.

  Returns:
    The new expression and the new path as a pair.
  """
    new_path = source_path.get_parent().get_child(new_field_name)
    return expression_add.add_paths(
        expr, {
            new_path:
            _PositionalIndexExpression(
                expr.get_descendant_or_error(source_path))
        }), new_path
Пример #17
0
def filter_by_child(expr: expression.Expression, p: path.Path,
                    child_field_name: path.Step,
                    new_field_name: path.Step) -> expression.Expression:
    """Filter an expression by an optional boolean child field.

  If the child field is present and True, then keep that parent.
  Otherwise, drop the parent.

  Args:
    expr: the original expression
    p: the path to filter.
    child_field_name: the boolean child field to use to filter.
    new_field_name: the new, filtered version of path.

  Returns:
    The new root expression.
  """
    origin = expr.get_descendant_or_error(p)
    child = origin.get_child_or_error(child_field_name)
    new_expr = _FilterByChildExpression(origin, child)
    new_path = p.get_parent().get_child(new_field_name)

    return expression_add.add_paths(expr, {new_path: new_expr})
Пример #18
0
def create_transformed_field(
        expr: expression.Expression, source_path: path.CoercableToPath,
        dest_field: StrStep,
        transform_fn: TransformFn) -> expression.Expression:
    """Create an expression that transforms serialized proto tensors.

  The transform_fn argument should take the form:

  def transform_fn(parent_indices, values):
    ...
    return (transformed_parent_indices, transformed_values)

  Given:
  - parent_indices: an int64 vector of non-decreasing parent message indices.
  - values: a string vector of serialized protos having the same shape as
    `parent_indices`.
  `transform_fn` must return new parent indices and serialized values encoding
  the same proto message as the passed in `values`.  These two vectors must
  have the same size, but it need not be the same as the input arguments.

  Args:
    expr: a source expression containing `source_path`.
    source_path: the path to the field to reverse.
    dest_field: the name of the newly created field. This field will be a
      sibling of the field identified by `source_path`.
    transform_fn: a callable that accepts parent_indices and serialized proto
      values and returns a posibly modified parent_indices and values.

  Returns:
    An expression.

  Raises:
    ValueError: if the source path is not a proto message field.
  """
    source_path = path.create_path(source_path)
    source_expr = expr.get_descendant_or_error(source_path)
    if not isinstance(source_expr, _ProtoChildExpression):
        raise ValueError(
            "Expected _ProtoChildExpression for field {}, but found {}.".
            format(str(source_path), source_expr))

    if isinstance(source_expr, _TransformProtoChildExpression):
        # In order to be able to propagate fields needed for parsing, the source
        # expression of _TransformProtoChildExpression must always be the original
        # _ProtoChildExpression before any transformation. This means that two
        # sequentially applied _TransformProtoChildExpression would have the same
        # source and would apply the transformation to the source directly, instead
        # of one transform operating on the output of the other.
        # To work around this, the user supplied transform function is wrapped to
        # first call the source's transform function.
        # The downside of this approach is that the initial transform may be
        # applied redundantly if there are other expressions derived directly
        # from it.
        def final_transform(parent_indices: tf.Tensor,
                            values: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
            parent_indices, values = source_expr.transform_fn(
                parent_indices, values)
            return transform_fn(parent_indices, values)
    else:
        final_transform = transform_fn

    transformed_expr = _TransformProtoChildExpression(
        parent=source_expr._parent,  # pylint: disable=protected-access
        desc=source_expr._desc,  # pylint: disable=protected-access
        is_repeated=source_expr.is_repeated,
        name_as_field=source_expr.name_as_field,
        transform_fn=final_transform)
    dest_path = source_path.get_parent().get_child(dest_field)
    return expression_add.add_paths(expr, {dest_path: transformed_expr})
Пример #19
0
def create_transformed_field(
        expr: expression.Expression, source_path: path.CoercableToPath,
        dest_field: StrStep,
        transform_fn: TransformFn) -> expression.Expression:
    """Create an expression that transforms serialized proto tensors.

  The transform_fn argument should take the form:

  def transform_fn(parent_indices, values):
    ...
    return (transformed_parent_indices, transformed_values)

  Given:
  - parent_indices: an int64 vector of non-decreasing parent message indices.
  - values: a string vector of serialized protos having the same shape as
    `parent_indices`.
  `transform_fn` must return new parent indices and serialized values encoding
  the same proto message as the passed in `values`.  These two vectors must
  have the same size, but it need not be the same as the input arguments.

  Note:
    If CalculateOptions.use_string_view (set at calculate time, thus this
    Expression cannot know beforehand) is True, `values` passed to
    `transform_fn` are string views pointing all the way back to the original
    input tensor (of serialized root protos). And `transform_fn` must maintain
    such views and avoid creating new values that are either not string views
    into the root protos or self-owned strings. This is because downstream
    decoding ops will still produce string views referring into its input
    (which are string views into the root proto) and they will only hold a
    reference to the original, root proto tensor, keeping it alive. So the input
    tensor may get destroyed after the decoding op.

    In short, you can do element-wise transforms to `values`, but can't mutate
    the contents of elements in `values` or create new elements.

    To lift this restriction, a decoding op must be told to hold a reference
    of the input tensors of all its upstream decoding ops.


  Args:
    expr: a source expression containing `source_path`.
    source_path: the path to the field to reverse.
    dest_field: the name of the newly created field. This field will be a
      sibling of the field identified by `source_path`.
    transform_fn: a callable that accepts parent_indices and serialized proto
      values and returns a posibly modified parent_indices and values. Note that
      when CalcuateOptions.use_string_view is set, transform_fn should not have
      any stateful side effecting uses of serialized proto inputs. Doing so
      could cause segfaults as the backing string tensor lifetime is not
      guaranteed when the side effecting operations are run.

  Returns:
    An expression.

  Raises:
    ValueError: if the source path is not a proto message field.
  """
    source_path = path.create_path(source_path)
    source_expr = expr.get_descendant_or_error(source_path)
    if not isinstance(source_expr, _ProtoChildExpression):
        raise ValueError(
            "Expected _ProtoChildExpression for field {}, but found {}.".
            format(str(source_path), source_expr))

    if isinstance(source_expr, _TransformProtoChildExpression):
        # In order to be able to propagate fields needed for parsing, the source
        # expression of _TransformProtoChildExpression must always be the original
        # _ProtoChildExpression before any transformation. This means that two
        # sequentially applied _TransformProtoChildExpression would have the same
        # source and would apply the transformation to the source directly, instead
        # of one transform operating on the output of the other.
        # To work around this, the user supplied transform function is wrapped to
        # first call the source's transform function.
        # The downside of this approach is that the initial transform may be
        # applied redundantly if there are other expressions derived directly
        # from it.
        def final_transform(parent_indices: tf.Tensor,
                            values: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
            parent_indices, values = source_expr.transform_fn(
                parent_indices, values)
            return transform_fn(parent_indices, values)
    else:
        final_transform = transform_fn

    transformed_expr = _TransformProtoChildExpression(
        parent=source_expr._parent,  # pylint: disable=protected-access
        desc=source_expr._desc,  # pylint: disable=protected-access
        is_repeated=source_expr.is_repeated,
        name_as_field=source_expr.name_as_field,
        transform_fn=final_transform,
        backing_str_tensor=source_expr._backing_str_tensor)  # pylint: disable=protected-access
    dest_path = source_path.get_parent().get_child(dest_field)
    return expression_add.add_paths(expr, {dest_path: transformed_expr})
Пример #20
0
 def calculation_equal(self, expr: expression.Expression) -> bool:
   return expr.calculation_is_identity()
def map_prensor_to_prensor(root_expr: expression.Expression, source: path.Path,
                           paths_needed: Sequence[path.Path],
                           prensor_op: Callable[[prensor.Prensor],
                                                prensor.Prensor],
                           output_schema: Schema) -> expression.Expression:
    r"""Maps an expression to a prensor, and merges that prensor.

  For example, suppose you have an op my_op, that takes a prensor of the form:

    event
     / \
   foo   bar

  and produces a prensor of the form my_result_schema:

     event
      / \
   foo2 bar2

  If you give it an expression original with the schema:

   session
      |
    event
    /  \
  foo   bar

  result = map_prensor_to_prensor(
    original,
    path.Path(["session","event"]),
    my_op,
    my_output_schema)

  Result will have the schema:

   session
      |
    event--------
    /  \    \    \
  foo   bar foo2 bar2

  Args:
    root_expr: the root expression
    source: the path where the prensor op is applied.
    paths_needed: the paths needed for the op.
    prensor_op: the prensor op
    output_schema: the output schema of the op.

  Returns:
    A new expression where the prensor is merged.
  """
    original_child = root_expr.get_descendant_or_error(source).project(
        paths_needed)
    prensor_child = _PrensorOpExpression(original_child, prensor_op,
                                         output_schema)
    paths_map = {
        source.get_child(k): prensor_child.get_child_or_error(k)
        for k in prensor_child.known_field_names()
    }
    result = expression_add.add_paths(root_expr, paths_map)
    return result