예제 #1
0
def add_paths(root: expression.Expression,
              path_map: Mapping[path.Path, expression.Expression]):
    """Creates a new expression based on `root` with paths in `path_map` added.

  This operation should be used with care: e.g., there is no guarantee that
  the new expressions make any sense in the place they are put in the tree. It
  is useful when wrapping a new Expression type, but should not be used by
  end users.

  Prefer add_to to add_paths.

  Args:
    root: the root of the tree.
    path_map: a map from a path to the new subtree.

  Returns:
    a new tree with the nodes from the root and the new subtrees.
  """
    for p in path_map.keys():
        if root.get_descendant(p.get_parent()) is None:
            raise ValueError("No parent of {}".format(p))
        if root.get_descendant(p) is not None:
            raise ValueError("Path already set: {}".format(str(p)))
    _, map_of_maps = create_subtrees(path_map)
    return _AddPathsExpression(root, map_of_maps)
예제 #2
0
def add_to(root: expression.Expression,
           origins: Mapping[path.Path, expression.Expression]):
    """Copies subtrees from the origins to the root.

  This operation can be used to reduce the number of expressions in the graph.
  1. The path must be present in the associated origin.
  2. The root must not have any of the paths already.
  3. The root must already have the parent of the path.
  4. The parent of the path in the root must be a source expression of the
     parent of the path in the origin.

  Args:
    root: the original tree that has new expressions added to it.
    origins: mapping from path to trees that have subtrees at the path.

  Returns:
    A tree with the root and the additional subtrees.
  """
    for p, origin_root in origins.items():
        path_parent = p.get_parent()
        if not _is_true_source_expression(
                root.get_descendant_or_error(path_parent),
                origin_root.get_descendant_or_error(path_parent)):
            raise ValueError("Not a true source for tree with {}".format(
                str(p)))
        if root.get_descendant(p) is not None:
            raise ValueError("Already contains {}.".format(str(p)))
    path_map = {
        p: origin_root.get_descendant_or_error(p)
        for p, origin_root in origins.items()
    }
    return add_paths(root, path_map)
예제 #3
0
def _get_mask(
    t: expression.Expression, p: path.Path, threshold: IndexValue,
    relation: Callable[[tf.Tensor, IndexValue], tf.Tensor]
) -> Tuple[expression.Expression, path.Path]:
    """Gets a mask based on a relation of the index to a threshold.

  If the threshold is non-negative, then we create a mask that is true if
  the relation(index, threshold) is true.

  If the threshold is negative, then we create a mask that is true if
  the relation(size - index, threshold) is true.

  Args:
    t: expression to add the field to.
    p: path to create the mask for.
    threshold: the cutoff threshold.
    relation: tf.less or tf.greater_equal.

  Returns:
    A boolean mask on the fields to keep on the model.

  Raises:
    ValueError: if p is not in t.
  """
    if t.get_descendant(p) is None:
        raise ValueError("Path not found: {}".format(str(p)))
    work_expr, index_from_end = index.get_index_from_end(
        t, p, path.get_anonymous_field())
    work_expr, mask_for_negative_threshold = map_values.map_values_anonymous(
        work_expr, index_from_end,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    work_expr, positional_index = index.get_positional_index(
        work_expr, p, path.get_anonymous_field())
    work_expr, mask_for_non_negative_threshold = map_values.map_values_anonymous(
        work_expr, positional_index,
        lambda x: relation(x, tf.cast(threshold, tf.int64)), tf.bool)

    if isinstance(threshold, int):
        if threshold >= 0:
            return work_expr, mask_for_non_negative_threshold
        return work_expr, mask_for_negative_threshold
    else:

        def tf_cond_on_threshold(a, b):
            return tf.cond(tf.greater_equal(threshold, 0), a, b)

        return map_values.map_many_values(work_expr, p.get_parent(), [
            x.field_list[-1] for x in
            [mask_for_non_negative_threshold, mask_for_negative_threshold]
        ], tf_cond_on_threshold, tf.bool, path.get_anonymous_field())
예제 #4
0
def _size_impl(
        root: expression.Expression, source_path: path.Path,
        new_field_name: path.Step) -> Tuple[expression.Expression, path.Path]:
    if not source_path:
        raise ValueError("Cannot get the size of the root.")
    if root.get_descendant(source_path) is None:
        raise ValueError("Path not found: {}".format(str(source_path)))
    parent_path = source_path.get_parent()
    new_path = parent_path.get_child(new_field_name)
    return expression_add.add_paths(
        root, {
            new_path:
            SizeExpression(root.get_descendant_or_error(source_path),
                           root.get_descendant_or_error(parent_path))
        }), new_path
예제 #5
0
def project(expr: expression.Expression,
            paths: Sequence[path.Path]) -> expression.Expression:
  """select a subtree.

  Paths not selected are removed.
  Paths that are selected are "known", such that if calculate_prensors is
  called, they will be in the result.

  Args:
    expr: the original expression.
    paths: the paths to include.

  Returns:
    A projected expression.
  """
  missing_paths = [p for p in paths if expr.get_descendant(p) is None]
  if missing_paths:
    raise ValueError("{} Path(s) missing in project: {}".format(
        len(missing_paths), ", ".join([str(x) for x in missing_paths])))
  return _ProjectExpression(expr, paths)