Ejemplo n.º 1
0
Archivo: misc.py Proyecto: rcownie/hail
def get_select_exprs(caller, exprs, named_exprs, indices, protect_keys=True):
    from hail.expr.expressions import to_expr, ExpressionException, analyze
    exprs = [
        to_expr(e) if not isinstance(e, str) else indices.source[e]
        for e in exprs
    ]
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
    assignments = OrderedDict()

    for e in exprs:
        if not e._ir.is_nested_field:
            raise ExpressionException(
                "method '{}' expects keyword arguments for complex expressions"
                .format(caller))
        analyze(caller, e, indices, broadcast=False)
        if protect_keys:
            check_keys(e._ir.name, indices)
        assignments[e._ir.name] = e
    for k, e in named_exprs.items():
        if protect_keys:
            check_keys(k, indices)
        check_collisions(indices.source._fields, k, indices)
        assignments[k] = e
    check_field_uniqueness(assignments.keys())
    return assignments
Ejemplo n.º 2
0
def get_select_exprs(caller, exprs, named_exprs, indices, protect_keys=True):
    from hail.expr.expressions import to_expr, ExpressionException, TopLevelReference, Select
    exprs = [
        to_expr(e) if not isinstance(e, str) else indices.source[e]
        for e in exprs
    ]
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
    assignments = OrderedDict()

    for e in exprs:
        if not e._indices == indices:
            raise ExpressionException(
                "method '{}' parameter 'exprs' expects {}-indexed fields,"
                " found indices {}".format(caller, list(indices.axes),
                                           list(e._indices.axes)))
        if not e._ast.is_nested_field:
            raise ExpressionException(
                "method '{}' expects keyword arguments for complex expressions"
                .format(caller))
        if protect_keys:
            check_keys(e._ast.name, indices)
        assignments[e._ast.name] = e
    for k, e in named_exprs.items():
        if protect_keys:
            check_keys(k, indices)
        check_collisions(indices.source._fields, k, indices)
        assignments[k] = e
    check_field_uniqueness(assignments.keys())
    return assignments
Ejemplo n.º 3
0
Archivo: misc.py Proyecto: troels/hail
def get_key_by_exprs(caller, exprs, named_exprs, indices, override_protected_indices=None):
    from hail.expr.expressions import to_expr, ExpressionException, analyze
    exprs = [indices.source[e] if isinstance(e, str) else e for e in exprs]
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}

    bindings = []

    def is_top_level_field(e):
        return e in indices.source._fields_inverse

    existing_key_fields = []
    final_key = []
    for e in exprs:
        analyze(caller, e, indices, broadcast=False)
        if not e._ir.is_nested_field:
            raise ExpressionException(f"{caller!r} expects keyword arguments for complex expressions\n"
                                      f"  Correct:   ht = ht.key_by('x')\n"
                                      f"  Correct:   ht = ht.key_by(ht.x)\n"
                                      f"  Correct:   ht = ht.key_by(x = ht.x.replace(' ', '_'))\n"
                                      f"  INCORRECT: ht = ht.key_by(ht.x.replace(' ', '_'))")

        name = e._ir.name
        final_key.append(name)

        if not is_top_level_field(e):
            bindings.append((name, e))
        else:
            existing_key_fields.append(name)

    final_key.extend(named_exprs)
    bindings.extend(named_exprs.items())
    check_collisions(caller, final_key, indices, override_protected_indices=override_protected_indices)
    return final_key, dict(bindings)
Ejemplo n.º 4
0
 def coerce(self, x) -> Expression:
     x = to_expr(x)
     if not self.can_coerce(x.dtype):
         raise ExpressionException(f"cannot coerce type '{x.dtype}' to type '{self.str_t}'")
     if self._requires_conversion(x.dtype):
         return self._coerce(x)
     else:
         return x
Ejemplo n.º 5
0
 def coerce(self, x) -> Expression:
     x = to_expr(x)
     if not self.can_coerce(x.dtype):
         raise ExpressionException(f"cannot coerce type '{x.dtype}' to type '{self.str_t}'")
     if self._requires_conversion(x.dtype):
         return self._coerce(x)
     else:
         return x
Ejemplo n.º 6
0
def get_annotate_exprs(caller, named_exprs, indices):
    from hail.expr.expressions import to_expr, ExpressionException
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
    for k, v in named_exprs.items():
        check_keys(k, indices)
        if indices.key and k in indices.key.keys():
            raise ExpressionException("'{}' cannot overwrite key field: {}"
                                      .format(caller, repr(k)))
        check_collisions(indices.source._fields, k, indices)
    return named_exprs
Ejemplo n.º 7
0
def get_annotate_exprs(caller, named_exprs, indices):
    from hail.expr.expressions import to_expr, ExpressionException
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
    for k, v in named_exprs.items():
        check_keys(k, indices)
        if indices.key and k in indices.key.keys():
            raise ExpressionException("'{}' cannot overwrite key field: {}"
                                      .format(caller, repr(k)))
        check_collisions(indices.source._fields, k, indices)
    return named_exprs
Ejemplo n.º 8
0
def get_select_exprs(caller, exprs, named_exprs, indices, protect_keys=True):
    from hail.expr.expressions import to_expr, ExpressionException, analyze
    exprs = [to_expr(e) if not isinstance(e, str) else indices.source[e] for e in exprs]
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
    assignments = OrderedDict()

    for e in exprs:
        if not e._ir.is_nested_field:
            raise ExpressionException("method '{}' expects keyword arguments for complex expressions".format(caller))
        analyze(caller, e, indices, broadcast=False)
        if protect_keys:
            check_keys(e._ir.name, indices)
        assignments[e._ir.name] = e
    for k, e in named_exprs.items():
        if protect_keys:
            check_keys(k, indices)
        check_collisions(indices.source._fields, k, indices)
        assignments[k] = e
    check_field_uniqueness(assignments.keys())
    return assignments
Ejemplo n.º 9
0
def get_select_exprs(caller, exprs, named_exprs, indices, base_struct):
    from hail.expr.expressions import to_expr, ExpressionException, analyze
    exprs = [indices.source[e] if isinstance(e, str) else e for e in exprs]
    named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
    select_fields = indices.protected_key[:]
    protected_key = set(select_fields)
    insertions = {}

    final_fields = select_fields[:]

    def is_top_level_field(e):
        return e in indices.source._fields_inverse

    for e in exprs:
        if not e._ir.is_nested_field:
            raise ExpressionException(
                f"{caller!r} expects keyword arguments for complex expressions\n"
                f"  Correct:   ht = ht.select('x')\n"
                f"  Correct:   ht = ht.select(ht.x)\n"
                f"  Correct:   ht = ht.select(x = ht.x.replace(' ', '_'))\n"
                f"  INCORRECT: ht = ht.select(ht.x.replace(' ', '_'))")
        analyze(caller, e, indices, broadcast=False)

        name = e._ir.name
        check_keys(caller, name, protected_key)
        final_fields.append(name)
        if is_top_level_field(e):
            select_fields.append(name)
        else:
            insertions[name] = e
    for k, e in named_exprs.items():
        check_keys(caller, k, protected_key)
        final_fields.append(k)
        insertions[k] = e

    check_collisions(caller, final_fields, indices)

    if final_fields == select_fields + list(insertions):
        # don't clog the IR with redundant field names
        s = base_struct.select(*select_fields).annotate(**insertions)
    else:
        s = base_struct.select(*select_fields)._annotate_ordered(
            insertions, final_fields)

    assert list(s) == final_fields
    return s
Ejemplo n.º 10
0
 def check(self, x: Any, caller: str, param: str) -> Any:
     try:
         return self.coerce(to_expr(x))
     except ExpressionException as e:
         raise TypecheckFailure from e
Ejemplo n.º 11
0
 def check(self, x: Any, caller: str, param: str) -> Any:
     try:
         return self.coerce(to_expr(x))
     except ExpressionException as e:
         raise TypecheckFailure from e
Ejemplo n.º 12
0
def loop(f: Callable, typ, *args):
    r"""Define and call a tail-recursive function with given arguments.

    Notes
    -----
    The argument `f` must be a function where the first argument defines the
    recursive call, and the remaining arguments are the arguments to the
    recursive function, e.g. to define the recursive function

    .. math::

        f(x, y) = \begin{cases}
        y & \textrm{if } x \equiv 0 \\
        f(x - 1, y + x) & \textrm{otherwise}
        \end{cases}


    we would write:
    >>> f = lambda recur, x, y: hl.if_else(x == 0, y, recur(x - 1, y + x))

    Full recursion is not supported, and any non-tail-recursive methods will
    throw an error when called.

    This means that the result of any recursive call within the function must
    also be the result of the entire function, without modification. Let's
    consider two different recursive definitions for the triangle function
    :math:`f(x) = 0 + 1 + \dots + x`:

    >>> def triangle1(x):
    ...     if x == 1:
    ...         return x
    ...     return x + triangle1(x - 1)

    >>> def triangle2(x, total):
    ...     if x == 0:
    ...         return total
    ...     return triangle2(x - 1, total + x)

    The first function definition, `triangle1`, will call itself and then add x.
    This is an example of a non-tail recursive function, since `triangle1(9)`
    needs to modify the result of the inner recursive call to `triangle1(8)` by
    adding 9 to the result.

    The second function is tail recursive: the result of `triangle2(9, 0)` is
    the same as the result of the inner recursive call, `triangle2(8, 9)`.

    Example
    -------
    To find the sum of all the numbers from n=1...10:
    >>> triangle_f = lambda f, x, total: hl.if_else(x == 0, total, f(x - 1, total + x))
    >>> x = hl.experimental.loop(triangle_f, hl.tint32, 10, 0)
    >>> hl.eval(x)
    55

    Let's say we want to find the root of a polynomial equation:
    >>> def polynomial(x):
    ...     return 5 * x**3 - 2 * x - 1

    We'll use `Newton's method<https://en.wikipedia.org/wiki/Newton%27s_method>`
    to find it, so we'll also define the derivative:

    >>> def derivative(x):
    ...     return 15 * x**2 - 2

    and starting at :math:`x_0 = 0`, we'll compute the next step :math:`x_{i+1} = x_i - \frac{f(x_i)}{f'(x_i)}`
    until the difference between :math:`x_{i}` and :math:`x_{i+1}` falls below
    our convergence threshold:

    >>> threshold = 0.005
    >>> def find_root(f, guess, error):
    ...     converged = hl.is_defined(error) & (error < threshold)
    ...     new_guess = guess - (polynomial(guess) / derivative(guess))
    ...     new_error = hl.abs(new_guess - guess)
    ...     return hl.if_else(converged, guess, f(new_guess, new_error))
    >>> x = hl.experimental.loop(find_root, hl.tfloat, 0.0, hl.missing(hl.tfloat))
    >>> hl.eval(x)
    0.8052291984599675

    Warning
    -------
    Using arguments of a type other than numeric types and booleans can cause
    memory issues if if you expect the recursive call to happen many times.

    Parameters
    ----------
    f : function ( (marker, \*args) -> :class:`.Expression`
        Function of one callable marker, denoting where the recursive call (or calls) is located,
        and many `args`, the loop variables.
    typ : :class:`str` or :class:`.HailType`
        Type the loop returns.
    args : variable-length args of :class:`.Expression`
        Expressions to initialize the loop values.
    Returns
    -------
    :class:`.Expression`
        Result of the loop with `args` as initial loop values.
    """

    loop_name = Env.get_uid()

    def contains_recursive_call(non_recursive):
        if isinstance(non_recursive,
                      ir.Recur) and non_recursive.name == loop_name:
            return True
        return any(
            [contains_recursive_call(c) for c in non_recursive.children])

    def check_tail_recursive(loop_ir):
        if isinstance(loop_ir, ir.If):
            if contains_recursive_call(loop_ir.cond):
                raise TypeError(
                    "branch condition can't contain recursive call!")
            check_tail_recursive(loop_ir.cnsq)
            check_tail_recursive(loop_ir.altr)
        elif isinstance(loop_ir, ir.Let):
            if contains_recursive_call(loop_ir.value):
                raise TypeError(
                    "bound value used in other expression can't contain recursive call!"
                )
            check_tail_recursive(loop_ir.body)
        elif isinstance(loop_ir, ir.TailLoop):
            if any(contains_recursive_call(x) for n, x in loop_ir.params):
                raise TypeError(
                    "parameters passed to inner loop can't contain recursive call!"
                )
        elif not isinstance(loop_ir,
                            ir.Recur) and contains_recursive_call(loop_ir):
            raise TypeError(
                "found recursive expression outside of tail position!")

    @typecheck(recur_exprs=expr_any)
    def make_loop(*recur_exprs):
        if len(recur_exprs) != len(args):
            raise TypeError(
                'Recursive call in loop has wrong number of arguments')
        err = None
        for i, (rexpr, expr) in enumerate(zip(recur_exprs, args)):
            if rexpr.dtype != expr.dtype:
                if err is None:
                    err = 'Type error in recursive call,'
                err += f'\n  at argument index {i}, loop arg type: {expr.dtype}, '
                err += f'recur arg type: {rexpr.dtype}'
        if err is not None:
            raise TypeError(err)
        irs = [expr._ir for expr in recur_exprs]
        indices, aggregations = unify_all(*recur_exprs)
        return construct_expr(ir.Recur(loop_name, irs, typ), typ, indices,
                              aggregations)

    uid_irs = []
    loop_vars = []

    for expr in args:
        uid = Env.get_uid()
        loop_vars.append(
            construct_variable(uid, expr._type, expr._indices,
                               expr._aggregations))
        uid_irs.append((uid, expr._ir))

    loop_f = to_expr(f(make_loop, *loop_vars))
    if loop_f.dtype != typ:
        raise TypeError(
            f"requested type {typ} does not match inferred type {loop_f.dtype}"
        )
    check_tail_recursive(loop_f._ir)
    indices, aggregations = unify_all(*args, loop_f)

    return construct_expr(ir.TailLoop(loop_name, loop_f._ir, uid_irs),
                          loop_f.dtype, indices, aggregations)