Пример #1
0
    def _bin_lambda_method(self, name, f, input_type, ret_type_f, *args):
        args = (to_expr(arg) for arg in args)
        new_id = Env.get_uid()
        lambda_result = to_expr(
            f(
                expressions.construct_expr(VariableReference(new_id),
                                           input_type, self._indices,
                                           self._aggregations)))

        indices, aggregations = unify_all(self, lambda_result)
        ast = LambdaClassMethod(name, new_id, self._ast, lambda_result._ast,
                                *(a._ast for a in args))
        return expressions.construct_expr(ast, ret_type_f(lambda_result._type),
                                          indices, aggregations)
Пример #2
0
Файл: nd.py Проект: saponas/hail
def svd(nd, full_matrices=True, compute_uv=True):
    """Performs a singular value decomposition.

    :param nd: :class:`.NDArrayExpression`
        A 2 dimensional ndarray, shape(M, N).
    :param full_matrices: `bool`
        If True (default), u and vt have dimensions (M, M) and (N, N) respectively. Otherwise, they have dimensions
        (M, K) and (K, N), where K = min(M, N)
    :param compute_uv: `bool`
        If True (default), compute the singular vectors u and v. Otherwise, only return a single ndarray, s.

    Returns
    -------
    - u: :class:`.NDArrayExpression`
        The left singular vectors.
    - s: :class:`.NDArrayExpression`
        The singular values.
    - vt: :class:`.NDArrayExpression`
        The right singular vectors.
    """
    float_nd = nd.map(lambda x: hl.float64(x))
    ir = NDArraySVD(float_nd._ir, full_matrices, compute_uv)

    return_type = ttuple(tndarray(tfloat64, 2), tndarray(tfloat64, 1),
                         tndarray(tfloat64, 2)) if compute_uv else tndarray(
                             tfloat64, 1)
    return construct_expr(ir, return_type)
Пример #3
0
 def _persist(self):
     src = self._indices.source
     if src is not None:
         raise ValueError(
             "Can only persist a scalar (no Table/MatrixTable source)")
     executed_jir = Env.backend().persist_ir(self._ir)
     return expressions.construct_expr(executed_jir, self.dtype)
Пример #4
0
def solve_triangular(nd_coef, nd_dep, lower=False):
    """Solve a triangular linear system.

    Parameters
    ----------
    nd_coef : :class:`.NDArrayNumericExpression`, (N, N)
        Triangular coefficient matrix.
    nd_dep : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Dependent variables.
    lower : `bool`:
        If true, nd_coef is interpreted as a lower triangular matrix
        If false, nd_coef is interpreted as a upper triangular matrix

    Returns
    -------
    :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Solution to the triangular system Ax = B. Shape is same as shape of B.

    """
    nd_dep_ndim_orig = nd_dep.ndim
    nd_coef, nd_dep = solve_helper(nd_coef, nd_dep, nd_dep_ndim_orig)
    return_type = hl.tndarray(hl.tfloat64, 2)
    ir = Apply("linear_triangular_solve", return_type, nd_coef._ir, nd_dep._ir,
               lower._ir)
    result = construct_expr(ir, return_type, nd_coef._indices,
                            nd_coef._aggregations)
    if nd_dep_ndim_orig == 1:
        result = result.reshape((-1))
    return result
Пример #5
0
def define_function(f, *param_types, _name=None):
    mname = _name if _name is not None else Env.get_uid()
    param_names = [Env.get_uid() for _ in param_types]
    body = f(*(construct_expr(Ref(pn), pt)
               for pn, pt in zip(param_names, param_types)))
    ret_type = body.dtype

    r = CSERenderer(stop_at_jir=True)
    code = r(body._ir)
    jbody = body._ir.parse(code,
                           ref_map=dict(zip(param_names, param_types)),
                           ir_map=r.jirs)

    Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR(
        mname, param_names, [pt._parsable_string() for pt in param_types],
        ret_type._parsable_string(), jbody)
    register_session_function(mname, param_types, ret_type)

    @typecheck(args=expr_any)
    def f(*args):
        indices, aggregations = unify_all(*args)
        return construct_expr(Apply(mname, ret_type, *(a._ir for a in args)),
                              ret_type, indices, aggregations)

    return Function(f, param_types, ret_type, mname)
Пример #6
0
 def transform_ir(agg, continuation):
     agg_uid = Env.get_uid()
     agg_ref = expressions.construct_variable(agg_uid, self._type, agg._indices, agg._aggregations)
     ir = Let(agg_uid, agg._ir,
              If(Let(uid, agg_ref._ir, pred._ir),
                 continuation(agg_ref)._ir,
                 Begin([])))
     return expressions.construct_expr(ir, self._type, self._indices, self._aggregations)
Пример #7
0
    def _ir_lambda_method(self, irf, f, input_type, ret_type_f, *args):
        args = (to_expr(arg)._ir for arg in args)
        new_id = Env.get_uid()
        lambda_result = to_expr(
            f(expressions.construct_variable(new_id, input_type, self._indices, self._aggregations)))

        indices, aggregations = unify_all(self, lambda_result)
        ir = irf(self._ir, new_id, lambda_result._ir, *args)
        return expressions.construct_expr(ir, ret_type_f(lambda_result._type), indices, aggregations)
Пример #8
0
 def _map(self, f):
     uid = Env.get_uid()
     ref = expressions.construct_expr(VariableReference(uid), self._type,
                                      self._indices, self._aggregations)
     mapped = f(ref)
     indices, aggregations = unify_all(ref, mapped)
     return expressions.Aggregable(
         LambdaClassMethod('map', uid, self._ast, mapped._ast),
         mapped.dtype, indices, aggregations)
Пример #9
0
    def _ir_lambda_method(self, irf, f, input_type, ret_type_f, *args):
        args = (to_expr(arg)._ir for arg in args)
        new_id = Env.get_uid()
        lambda_result = to_expr(
            f(expressions.construct_variable(new_id, input_type, self._indices, self._aggregations)))

        indices, aggregations = unify_all(self, lambda_result)
        ir = irf(self._ir, new_id, lambda_result._ir, *args)
        return expressions.construct_expr(ir, ret_type_f(lambda_result._type), indices, aggregations)
Пример #10
0
        def transform_ir(agg, continuation):
            elt_uid = Env.get_uid()
            elt_ref = expressions.construct_variable(
                elt_uid, res_expr.dtype.element_type, res_expr._indices,
                res_expr._aggregations)

            ir = ArrayFor(ToArray(Let(uid, agg._ir, res_expr._ir)), elt_uid,
                          continuation(elt_ref)._ir)
            return expressions.construct_expr(ir, res_expr.dtype.element_type,
                                              indices, aggregations)
Пример #11
0
Файл: nd.py Проект: saponas/hail
def qr(nd, mode="reduced"):
    """Performs a QR decomposition.

    :param nd: A 2 dimensional ndarray, shape(M, N)
    :param mode: One of "reduced", "complete", "r", or "raw".

        If K = min(M, N), then:

        - `reduced`: returns q and r with dimensions (M, K), (K, N)
        - `complete`: returns q and r with dimensions (M, M), (M, N)
        - `r`: returns only r with dimensions (K, N)
        - `raw`: returns h, tau with dimensions (N, M), (K,)

    Returns
    -------
    - q: ndarray of float64
        A matrix with orthonormal columns.
    - r: ndarray of float64
        The upper-triangular matrix R.
    - (h, tau): ndarrays of float64
        The array h contains the Householder reflectors that generate q along with r.
        The tau array contains scaling factors for the reflectors
    """

    assert nd.ndim == 2, "QR decomposition requires 2 dimensional ndarray"

    if mode not in ["reduced", "r", "raw", "complete"]:
        raise ValueError(f"Unrecognized mode '{mode}' for QR decomposition")

    float_nd = nd.map(lambda x: hl.float64(x))
    ir = NDArrayQR(float_nd._ir, mode)
    indices = nd._indices
    aggs = nd._aggregations
    if mode == "raw":
        return construct_expr(
            ir, ttuple(tndarray(tfloat64, 2), tndarray(tfloat64, 1)), indices,
            aggs)
    elif mode == "r":
        return construct_expr(ir, tndarray(tfloat64, 2), indices, aggs)
    elif mode in ["complete", "reduced"]:
        return construct_expr(
            ir, ttuple(tndarray(tfloat64, 2), tndarray(tfloat64, 2)), indices,
            aggs)
Пример #12
0
 def _bin_op(self, name, other, ret_type):
     other = to_expr(other)
     indices, aggregations = unify_all(self, other)
     if (name in {'+', '-', '*', '/', '//'}) and (ret_type in {tint32, tint64, tfloat32, tfloat64}):
         op = ApplyBinaryOp(name, self._ir, other._ir)
     elif name in {"==", "!=", "<", "<=", ">", ">="}:
         op = ApplyComparisonOp(name, self._ir, other._ir)
     else:
         op = Apply(name, self._ir, other._ir)
     return expressions.construct_expr(op, ret_type, indices, aggregations)
Пример #13
0
 def _bin_op(self, name, other, ret_type):
     other = to_expr(other)
     indices, aggregations = unify_all(self, other)
     if (name in {'+', '-', '*', '/', '//'}) and (ret_type in {tint32, tint64, tfloat32, tfloat64}):
         op = ApplyBinaryOp(name, self._ir, other._ir)
     elif name in {"==", "!=", "<", "<=", ">", ">="}:
         op = ApplyComparisonOp(name, self._ir, other._ir)
     else:
         op = Apply(name, self._ir, other._ir)
     return expressions.construct_expr(op, ret_type, indices, aggregations)
Пример #14
0
 def make_loop(*recur_exprs):
     if len(recur_exprs) != len(args):
         raise TypeError('Recursive call in loop has wrong number of arguments')
     err = None
     for i, (rexpr, expr) in enumerate(zip(recur_exprs, args)):
         if rexpr.dtype != expr.dtype:
             if err is None:
                 err = 'Type error in recursive call,'
             err += f'\n  at argument index {i}, loop arg type: {expr.dtype}, '
             err += f'recur arg type: {rexpr.dtype}'
     if err is not None:
         raise TypeError(err)
     irs = [expr._ir for expr in recur_exprs]
     indices, aggregations = unify_all(*recur_exprs)
     return construct_expr(ir.Recur(loop_name, irs, typ), typ, indices, aggregations)
Пример #15
0
Файл: nd.py Проект: saponas/hail
def inv(nd):
    """Performs a matrix inversion.

    :param nd: A 2 dimensional ndarray, shape(M, N)

    Returns
    -------
    - a: ndarray of float64
        The inverted matrix
    """

    assert nd.ndim == 2, "Matrix inversion requires 2 dimensional ndarray"

    float_nd = nd.map(lambda x: hl.float64(x))
    ir = NDArrayInv(float_nd._ir)
    return construct_expr(ir, tndarray(tfloat64, 2))
Пример #16
0
 def _ir_lambda_method2(self, other, irf, f, input_type1, input_type2,
                        ret_type_f, *args):
     args = (to_expr(arg)._ir for arg in args)
     new_id1 = Env.get_uid()
     new_id2 = Env.get_uid()
     lambda_result = to_expr(
         f(
             expressions.construct_variable(new_id1, input_type1,
                                            self._indices,
                                            self._aggregations),
             expressions.construct_variable(new_id2, input_type2,
                                            other._indices,
                                            other._aggregations)))
     indices, aggregations = unify_all(self, other, lambda_result)
     x = irf(self._ir, other._ir, new_id1, new_id2, lambda_result._ir,
             *args)
     return expressions.construct_expr(x, ret_type_f(lambda_result._type),
                                       indices, aggregations)
Пример #17
0
def define_function(f, *param_types, _name=None, type_args=()):
    mname = _name if _name is not None else Env.get_uid()
    param_names = [Env.get_uid(mname) for _ in param_types]
    body = f(*(construct_expr(Ref(pn), pt)
               for pn, pt in zip(param_names, param_types)))
    ret_type = body.dtype

    Env.backend().register_ir_function(mname, type_args, param_names,
                                       param_types, ret_type, body)

    @typecheck(args=expr_any)
    def f(*args):
        indices, aggregations = unify_all(*args)
        return construct_expr(
            Apply(mname, ret_type, *(a._ir for a in args),
                  type_args=type_args), ret_type, indices, aggregations)

    return Function(f, param_types, ret_type, mname, type_args)
Пример #18
0
    def or_error(self, message):
        """Finish the case statement by throwing an error with the given message.

        Notes
        -----
        If no condition from a :meth:`.CaseBuilder.when` call is ``True``, then
        an error is thrown.

        Parameters
        ----------
        message : :class:`.Expression` of type :data:`tstr`

        Returns
        -------
        :class:`.Expression`
        """
        if len(self._cases) == 0:
            raise ExpressionException("'or_error' cannot be called without at least one 'when' call")
        error_expr = construct_expr(ir.Die(message._ir, self._ret_type), self._ret_type)
        return self._finish(error_expr)
Пример #19
0
 def _bin_op(self, name, other, ret_type):
     other = to_expr(other)
     indices, aggregations = unify_all(self, other)
     if (name in {'+', '-', '*', '/', '//'
                  }) and (ret_type in {tint32, tint64, tfloat32, tfloat64}):
         op = ir.ApplyBinaryPrimOp(name, self._ir, other._ir)
     elif name in {"==", "!=", "<", "<=", ">", ">="}:
         op = ir.ApplyComparisonOp(name, self._ir, other._ir)
     else:
         d = {
             '+': 'add',
             '-': 'sub',
             '*': 'mul',
             '/': 'div',
             '//': 'floordiv',
             '%': 'mod',
             '**': 'pow'
         }
         op = ir.Apply(d.get(name, name), ret_type, self._ir, other._ir)
     return expressions.construct_expr(op, ret_type, indices, aggregations)
Пример #20
0
    def _slice(self, ret_type, start=None, stop=None, step=None):
        if start is not None:
            start = to_expr(start)
            start_ast = start._ast
        else:
            start_ast = None
        if stop is not None:
            stop = to_expr(stop)
            stop_ast = stop._ast
        else:
            stop_ast = None
        if step is not None:
            raise NotImplementedError(
                'Variable slice step size is not currently supported')

        non_null = [x for x in [start, stop] if x is not None]
        indices, aggregations = unify_all(self, *non_null)
        return expressions.construct_expr(
            Index(self._ast, Slice(start_ast, stop_ast)), ret_type, indices,
            aggregations)
Пример #21
0
def define_function(f, *param_types):
    mname = Env.get_uid()
    param_names = [Env.get_uid() for _ in param_types]
    body = f(*(construct_expr(Ref(pn), pt) for pn, pt in zip(param_names, param_types)))
    ret_type = body.dtype

    r = Renderer(stop_at_jir=True)
    code = r(body._ir)
    jbody = body._ir.parse(code, ref_map=dict(zip(param_names, param_types)), ir_map=r.jirs)

    Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR(
        mname, param_names, [pt._parsable_string() for pt in param_types], ret_type._parsable_string(),
        jbody)
    register_function(mname, param_types, ret_type)

    @typecheck(args=expr_any)
    def f(*args):
        indices, aggregations = unify_all(*args)
        return construct_expr(Apply(mname, *(a._ir for a in args)), ret_type, indices, aggregations)

    return Function(f, mname)
Пример #22
0
Файл: nd.py Проект: saponas/hail
def concatenate(nds, axis=0):
    """Join a sequence of arrays along an existing axis.

    Examples
    --------

    >>> x = hl.nd.array([[1., 2.], [3., 4.]])
    >>> y = hl.nd.array([[5.], [6.]])
    >>> hl.eval(hl.nd.concatenate([x, y], axis=1))
    array([[1., 2., 5.],
           [3., 4., 6.]])
    >>> x = hl.nd.array([1., 2.])
    >>> y = hl.nd.array([3., 4.])
    >>> hl.eval(hl.nd.concatenate((x, y), axis=0))
    array([1., 2., 3., 4.])

    Parameters
    ----------
    :param nds: a1, a2, …sequence of array_like
        The arrays must have the same shape, except in the dimension corresponding to axis (the first, by default).
        Note: unlike Numpy, the numerical element type of each array_like must match.
    :param axis: int, optional
        The axis along which the arrays will be joined. Default is 0.
        Note: unlike Numpy, if provided, axis cannot be None.

    Returns
    -------
    - res: ndarray
        The concatenated array
    """
    head_nd = nds[0]
    head_ndim = head_nd.ndim
    hl.case().when(hl.all(lambda a: a.ndim == head_ndim, nds),
                   True).or_error("Mismatched ndim")

    makearr = aarray(nds)
    concat_ir = NDArrayConcat(makearr._ir, axis)

    return construct_expr(concat_ir,
                          tndarray(head_nd._type.element_type, head_ndim))
Пример #23
0
    def or_error(self, message):
        """Finish the case statement by throwing an error with the given message.

        Notes
        -----
        If no condition from a :meth:`.CaseBuilder.when` call is ``True``, then
        an error is thrown.

        Parameters
        ----------
        message : :class:`.Expression` of type :obj:`.tstr`

        Returns
        -------
        :class:`.Expression`
        """
        if len(self._cases) == 0:
            raise ExpressionException(
                "'or_error' cannot be called without at least one 'when' call")
        error_expr = construct_expr(ir.Die(message._ir, self._ret_type),
                                    self._ret_type)
        return self._finish(error_expr)
Пример #24
0
Файл: nd.py Проект: saponas/hail
def solve(a, b):
    """Solve a linear system.

    Parameters
    ----------
    a : :class:`.NDArrayNumericExpression`, (N, N)
        Coefficient matrix.
    b : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Dependent variables.

    Returns
    -------
    :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Solution to the system Ax = B. Shape is same as shape of B.

    """
    assert a.ndim == 2
    assert b.ndim == 1 or b.ndim == 2

    b_ndim_orig = b.ndim

    if b_ndim_orig == 1:
        b = b.reshape((-1, 1))

    if a.dtype.element_type != hl.tfloat64:
        a = a.map(lambda e: hl.float64(e))
    if b.dtype.element_type != hl.tfloat64:
        b = b.map(lambda e: hl.float64(e))

    ir = Apply("linear_solve", hl.tndarray(hl.tfloat64, 2), a._ir, b._ir)
    result = construct_expr(ir, hl.tndarray(hl.tfloat64, 2), a._indices,
                            a._aggregations)

    if b_ndim_orig == 1:
        result = result.reshape((-1))
    return result
Пример #25
0
def solve(a, b, no_crash=False):
    """Solve a linear system.

    Parameters
    ----------
    a : :class:`.NDArrayNumericExpression`, (N, N)
        Coefficient matrix.
    b : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Dependent variables.

    Returns
    -------
    :class:`.NDArrayNumericExpression`, (N,) or (N, K)
        Solution to the system Ax = B. Shape is same as shape of B.

    """
    b_ndim_orig = b.ndim
    a, b = solve_helper(a, b, b_ndim_orig)
    if no_crash:
        name = "linear_solve_no_crash"
        return_type = hl.tstruct(solution=hl.tndarray(hl.tfloat64, 2),
                                 failed=hl.tbool)
    else:
        name = "linear_solve"
        return_type = hl.tndarray(hl.tfloat64, 2)

    ir = Apply(name, return_type, a._ir, b._ir)
    result = construct_expr(ir, return_type, a._indices, a._aggregations)

    if b_ndim_orig == 1:
        if no_crash:
            result = hl.struct(solution=result.solution.reshape((-1)),
                               failed=result.failed)
        else:
            result = result.reshape((-1))
    return result
Пример #26
0
 def _method(self, name, ret_type, *args):
     args = tuple(to_expr(arg) for arg in args)
     indices, aggregations = unify_all(self, *args)
     ir = Apply(name, self._ir, *(a._ir for a in args))
     return expressions.construct_expr(ir, ret_type, indices, aggregations)
Пример #27
0
def _to_expr(e, dtype):
    if e is None:
        return None
    elif isinstance(e, Expression):
        if e.dtype != dtype:
            assert is_numeric(dtype), 'expected {}, got {}'.format(
                dtype, e.dtype)
            if dtype == tfloat64:
                return hl.float64(e)
            elif dtype == tfloat32:
                return hl.float32(e)
            elif dtype == tint64:
                return hl.int64(e)
            else:
                assert dtype == tint32
                return hl.int32(e)
        return e
    elif not is_compound(dtype):
        # these are not container types and cannot contain expressions if we got here
        return e
    elif isinstance(dtype, tstruct):
        new_fields = []
        found_expr = False
        for f, t in dtype.items():
            value = _to_expr(e[f], t)
            found_expr = found_expr or isinstance(value, Expression)
            new_fields.append(value)

        if not found_expr:
            return e
        else:
            exprs = [
                new_fields[i] if isinstance(new_fields[i], Expression) else
                hl.literal(new_fields[i], dtype[i])
                for i in range(len(new_fields))
            ]
            fields = {name: expr for name, expr in zip(dtype.keys(), exprs)}
            from .typed_expressions import StructExpression
            return StructExpression._from_fields(fields)

    elif isinstance(dtype, tarray):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert (len(elements) > 0)
            exprs = [
                element if isinstance(element, Expression) else hl.literal(
                    element, dtype.element_type) for element in elements
            ]
            indices, aggregations = unify_all(*exprs)
        ir = MakeArray([e._ir for e in exprs], None)
        return expressions.construct_expr(ir, dtype, indices, aggregations)
    elif isinstance(dtype, tset):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert (len(elements) > 0)
            exprs = [
                element if isinstance(element, Expression) else hl.literal(
                    element, dtype.element_type) for element in elements
            ]
            indices, aggregations = unify_all(*exprs)
            ir = ToSet(MakeArray([e._ir for e in exprs], None))
            return expressions.construct_expr(ir, dtype, indices, aggregations)
    elif isinstance(dtype, ttuple):
        elements = []
        found_expr = False
        assert len(e) == len(dtype.types)
        for i in range(len(e)):
            value = _to_expr(e[i], dtype.types[i])
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            exprs = [
                elements[i] if isinstance(elements[i], Expression) else
                hl.literal(elements[i], dtype.types[i])
                for i in range(len(elements))
            ]
            indices, aggregations = unify_all(*exprs)
            ir = MakeTuple([expr._ir for expr in exprs])
            return expressions.construct_expr(ir, dtype, indices, aggregations)
    elif isinstance(dtype, tdict):
        keys = []
        values = []
        found_expr = False
        for k, v in e.items():
            k_ = _to_expr(k, dtype.key_type)
            v_ = _to_expr(v, dtype.value_type)
            found_expr = found_expr or isinstance(k_, Expression)
            found_expr = found_expr or isinstance(v_, Expression)
            keys.append(k_)
            values.append(v_)
        if not found_expr:
            return e
        else:
            assert len(keys) > 0
            # Here I use `to_expr` to call `lit` the keys and values separately.
            # I anticipate a common mode is statically-known keys and Expression
            # values.
            key_array = to_expr(keys, tarray(dtype.key_type))
            value_array = to_expr(values, tarray(dtype.value_type))
            return hl.dict(hl.zip(key_array, value_array))
    else:
        raise NotImplementedError(dtype)
Пример #28
0
 def _unary_op(self, name):
     return expressions.construct_expr(ApplyUnaryOp(name,
                                                    self._ir), self._type,
                                       self._indices, self._aggregations)
Пример #29
0
 def _method(self, name, ret_type, *args):
     args = tuple(to_expr(arg) for arg in args)
     indices, aggregations = unify_all(self, *args)
     return expressions.construct_expr(
         ClassMethod(name, self._ast, *(a._ast for a in args)), ret_type,
         indices, aggregations)
Пример #30
0
 def _field(self, name, ret_type):
     return expressions.construct_expr(GetField(self._ir, name),
                                       ret_type, self._indices, self._aggregations)
Пример #31
0
 def _method(self, name, ret_type, *args):
     args = tuple(to_expr(arg) for arg in args)
     indices, aggregations = unify_all(self, *args)
     ir = Apply(name, self._ir, *(a._ir for a in args))
     return expressions.construct_expr(ir, ret_type, indices, aggregations)
Пример #32
0
 def _index(self, ret_type, key):
     key = to_expr(key)
     indices, aggregations = unify_all(self, key)
     return expressions.construct_expr(Index(self._ast, key._ast), ret_type,
                                       indices, aggregations)
Пример #33
0
 def _unary_op(self, name):
     return expressions.construct_expr(ApplyUnaryOp(name, self._ir), self._type, self._indices, self._aggregations)
Пример #34
0
 def transform_ir(agg, continuation):
     indices, aggregations = unify_all(ref, mapped, agg)
     return continuation(
         expressions.construct_expr(Let(uid, agg._ir,
                                        mapped._ir), mapped._type,
                                    indices, aggregations))
Пример #35
0
def _to_expr(e, dtype):
    if e is None:
        return hl.null(dtype)
    elif isinstance(e, Expression):
        if e.dtype != dtype:
            assert is_numeric(dtype), 'expected {}, got {}'.format(dtype, e.dtype)
            if dtype == tfloat64:
                return hl.float64(e)
            elif dtype == tfloat32:
                return hl.float32(e)
            elif dtype == tint64:
                return hl.int64(e)
            else:
                assert dtype == tint32
                return hl.int32(e)
        return e
    elif not is_compound(dtype):
        # these are not container types and cannot contain expressions if we got here
        return e
    elif isinstance(dtype, tstruct):
        new_fields = []
        found_expr = False
        for f, t in dtype.items():
            value = _to_expr(e[f], t)
            found_expr = found_expr or isinstance(value, Expression)
            new_fields.append(value)

        if not found_expr:
            return e
        else:
            exprs = [new_fields[i] if isinstance(new_fields[i], Expression)
                     else hl.literal(new_fields[i], dtype[i])
                     for i in range(len(new_fields))]
            fields = {name: expr for name, expr in zip(dtype.keys(), exprs)}
            from .typed_expressions import StructExpression
            return StructExpression._from_fields(fields)

    elif isinstance(dtype, tarray):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert (len(elements) > 0)
            exprs = [element if isinstance(element, Expression)
                     else hl.literal(element, dtype.element_type)
                     for element in elements]
            indices, aggregations = unify_all(*exprs)
        ir = MakeArray([e._ir for e in exprs], None)
        return expressions.construct_expr(ir, dtype, indices, aggregations)
    elif isinstance(dtype, tset):
        elements = []
        found_expr = False
        for element in e:
            value = _to_expr(element, dtype.element_type)
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            assert (len(elements) > 0)
            exprs = [element if isinstance(element, Expression)
                     else hl.literal(element, dtype.element_type)
                     for element in elements]
            indices, aggregations = unify_all(*exprs)
            ir = ToSet(MakeArray([e._ir for e in exprs], None))
            return expressions.construct_expr(ir, dtype, indices, aggregations)
    elif isinstance(dtype, ttuple):
        elements = []
        found_expr = False
        assert len(e) == len(dtype.types)
        for i in range(len(e)):
            value = _to_expr(e[i], dtype.types[i])
            found_expr = found_expr or isinstance(value, Expression)
            elements.append(value)
        if not found_expr:
            return e
        else:
            exprs = [elements[i] if isinstance(elements[i], Expression)
                     else hl.literal(elements[i], dtype.types[i])
                     for i in range(len(elements))]
            indices, aggregations = unify_all(*exprs)
            ir = MakeTuple([expr._ir for expr in exprs])
            return expressions.construct_expr(ir, dtype, indices, aggregations)
    elif isinstance(dtype, tdict):
        keys = []
        values = []
        found_expr = False
        for k, v in e.items():
            k_ = _to_expr(k, dtype.key_type)
            v_ = _to_expr(v, dtype.value_type)
            found_expr = found_expr or isinstance(k_, Expression)
            found_expr = found_expr or isinstance(v_, Expression)
            keys.append(k_)
            values.append(v_)
        if not found_expr:
            return e
        else:
            assert len(keys) > 0
            # Here I use `to_expr` to call `lit` the keys and values separately.
            # I anticipate a common mode is statically-known keys and Expression
            # values.
            key_array = to_expr(keys, tarray(dtype.key_type))
            value_array = to_expr(values, tarray(dtype.value_type))
            return hl.dict(hl.zip(key_array, value_array))
    else:
        raise NotImplementedError(dtype)
Пример #36
0
def loop(f: Callable, typ, *args):
    r"""Define and call a tail-recursive function with given arguments.

    Notes
    -----
    The argument `f` must be a function where the first argument defines the
    recursive call, and the remaining arguments are the arguments to the
    recursive function, e.g. to define the recursive function

    .. math::

        f(x, y) = \begin{cases}
        y & \textrm{if } x \equiv 0 \\
        f(x - 1, y + x) & \textrm{otherwise}
        \end{cases}


    we would write:
    >>> f = lambda recur, x, y: hl.if_else(x == 0, y, recur(x - 1, y + x))

    Full recursion is not supported, and any non-tail-recursive methods will
    throw an error when called.

    This means that the result of any recursive call within the function must
    also be the result of the entire function, without modification. Let's
    consider two different recursive definitions for the triangle function
    :math:`f(x) = 0 + 1 + \dots + x`:

    >>> def triangle1(x):
    ...     if x == 1:
    ...         return x
    ...     return x + triangle1(x - 1)

    >>> def triangle2(x, total):
    ...     if x == 0:
    ...         return total
    ...     return triangle2(x - 1, total + x)

    The first function definition, `triangle1`, will call itself and then add x.
    This is an example of a non-tail recursive function, since `triangle1(9)`
    needs to modify the result of the inner recursive call to `triangle1(8)` by
    adding 9 to the result.

    The second function is tail recursive: the result of `triangle2(9, 0)` is
    the same as the result of the inner recursive call, `triangle2(8, 9)`.

    Example
    -------
    To find the sum of all the numbers from n=1...10:
    >>> triangle_f = lambda f, x, total: hl.if_else(x == 0, total, f(x - 1, total + x))
    >>> x = hl.experimental.loop(triangle_f, hl.tint32, 10, 0)
    >>> hl.eval(x)
    55

    Let's say we want to find the root of a polynomial equation:
    >>> def polynomial(x):
    ...     return 5 * x**3 - 2 * x - 1

    We'll use `Newton's method<https://en.wikipedia.org/wiki/Newton%27s_method>`
    to find it, so we'll also define the derivative:

    >>> def derivative(x):
    ...     return 15 * x**2 - 2

    and starting at :math:`x_0 = 0`, we'll compute the next step :math:`x_{i+1} = x_i - \frac{f(x_i)}{f'(x_i)}`
    until the difference between :math:`x_{i}` and :math:`x_{i+1}` falls below
    our convergence threshold:

    >>> threshold = 0.005
    >>> def find_root(f, guess, error):
    ...     converged = hl.is_defined(error) & (error < threshold)
    ...     new_guess = guess - (polynomial(guess) / derivative(guess))
    ...     new_error = hl.abs(new_guess - guess)
    ...     return hl.if_else(converged, guess, f(new_guess, new_error))
    >>> x = hl.experimental.loop(find_root, hl.tfloat, 0.0, hl.missing(hl.tfloat))
    >>> hl.eval(x)
    0.8052291984599675

    Warning
    -------
    Using arguments of a type other than numeric types and booleans can cause
    memory issues if if you expect the recursive call to happen many times.

    Parameters
    ----------
    f : function ( (marker, \*args) -> :class:`.Expression`
        Function of one callable marker, denoting where the recursive call (or calls) is located,
        and many `args`, the loop variables.
    typ : :class:`str` or :class:`.HailType`
        Type the loop returns.
    args : variable-length args of :class:`.Expression`
        Expressions to initialize the loop values.
    Returns
    -------
    :class:`.Expression`
        Result of the loop with `args` as initial loop values.
    """

    loop_name = Env.get_uid()

    def contains_recursive_call(non_recursive):
        if isinstance(non_recursive,
                      ir.Recur) and non_recursive.name == loop_name:
            return True
        return any(
            [contains_recursive_call(c) for c in non_recursive.children])

    def check_tail_recursive(loop_ir):
        if isinstance(loop_ir, ir.If):
            if contains_recursive_call(loop_ir.cond):
                raise TypeError(
                    "branch condition can't contain recursive call!")
            check_tail_recursive(loop_ir.cnsq)
            check_tail_recursive(loop_ir.altr)
        elif isinstance(loop_ir, ir.Let):
            if contains_recursive_call(loop_ir.value):
                raise TypeError(
                    "bound value used in other expression can't contain recursive call!"
                )
            check_tail_recursive(loop_ir.body)
        elif isinstance(loop_ir, ir.TailLoop):
            if any(contains_recursive_call(x) for n, x in loop_ir.params):
                raise TypeError(
                    "parameters passed to inner loop can't contain recursive call!"
                )
        elif not isinstance(loop_ir,
                            ir.Recur) and contains_recursive_call(loop_ir):
            raise TypeError(
                "found recursive expression outside of tail position!")

    @typecheck(recur_exprs=expr_any)
    def make_loop(*recur_exprs):
        if len(recur_exprs) != len(args):
            raise TypeError(
                'Recursive call in loop has wrong number of arguments')
        err = None
        for i, (rexpr, expr) in enumerate(zip(recur_exprs, args)):
            if rexpr.dtype != expr.dtype:
                if err is None:
                    err = 'Type error in recursive call,'
                err += f'\n  at argument index {i}, loop arg type: {expr.dtype}, '
                err += f'recur arg type: {rexpr.dtype}'
        if err is not None:
            raise TypeError(err)
        irs = [expr._ir for expr in recur_exprs]
        indices, aggregations = unify_all(*recur_exprs)
        return construct_expr(ir.Recur(loop_name, irs, typ), typ, indices,
                              aggregations)

    uid_irs = []
    loop_vars = []

    for expr in args:
        uid = Env.get_uid()
        loop_vars.append(
            construct_variable(uid, expr._type, expr._indices,
                               expr._aggregations))
        uid_irs.append((uid, expr._ir))

    loop_f = to_expr(f(make_loop, *loop_vars))
    if loop_f.dtype != typ:
        raise TypeError(
            f"requested type {typ} does not match inferred type {loop_f.dtype}"
        )
    check_tail_recursive(loop_f._ir)
    indices, aggregations = unify_all(*args, loop_f)

    return construct_expr(ir.TailLoop(loop_name, loop_f._ir, uid_irs),
                          loop_f.dtype, indices, aggregations)
Пример #37
0
 def f(*args):
     indices, aggregations = unify_all(*args)
     return construct_expr(Apply(mname, ret_type, *(a._ir for a in args)),
                           ret_type, indices, aggregations)
Пример #38
0
 def f(*args):
     indices, aggregations = unify_all(*args)
     return construct_expr(Apply(mname, *(a._ir for a in args)), ret_type, indices, aggregations)
Пример #39
0
 def _field(self, name, ret_type):
     return expressions.construct_expr(GetField(self._ir, name), ret_type,
                                       self._indices, self._aggregations)