def blockmatrix_irs(self): scalar_ir = ir.F64(2) vector_ir = ir.MakeArray([ir.F64(3), ir.F64(2)], hl.tarray(hl.tfloat64)) read = ir.BlockMatrixRead(ir.BlockMatrixNativeReader(resource('blockmatrix_example/0'))) add_two_bms = ir.BlockMatrixMap2(read, read, ir.ApplyBinaryPrimOp('+', ir.Ref('l'), ir.Ref('r'))) negate_bm = ir.BlockMatrixMap(read, ir.ApplyUnaryPrimOp('-', ir.Ref('element'))) sqrt_bm = ir.BlockMatrixMap(read, hl.sqrt(construct_expr(ir.Ref('element'), hl.tfloat64))._ir) scalar_to_bm = ir.ValueToBlockMatrix(scalar_ir, [1, 1], 1) col_vector_to_bm = ir.ValueToBlockMatrix(vector_ir, [2, 1], 1) row_vector_to_bm = ir.ValueToBlockMatrix(vector_ir, [1, 2], 1) broadcast_scalar = ir.BlockMatrixBroadcast(scalar_to_bm, [], [2, 2], 256) broadcast_col = ir.BlockMatrixBroadcast(col_vector_to_bm, [0], [2, 2], 256) broadcast_row = ir.BlockMatrixBroadcast(row_vector_to_bm, [1], [2, 2], 256) transpose = ir.BlockMatrixBroadcast(broadcast_scalar, [1, 0], [2, 2], 256) matmul = ir.BlockMatrixDot(broadcast_scalar, transpose) pow_ir = (construct_expr(ir.Ref('l'), hl.tfloat64) ** construct_expr(ir.Ref('r'), hl.tfloat64))._ir squared_bm = ir.BlockMatrixMap2(scalar_to_bm, scalar_to_bm, pow_ir) return [ read, add_two_bms, negate_bm, sqrt_bm, scalar_to_bm, col_vector_to_bm, row_vector_to_bm, broadcast_scalar, broadcast_col, broadcast_row, squared_bm, transpose, matmul ]
def persist_expression(self, expr): return construct_expr( JavaIR( self._jbackend.executeLiteral(self._to_java_value_ir( expr._ir))), expr.dtype)
def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Table: """Return a table containing the vertices in a near `maximal independent set <https://en.wikipedia.org/wiki/Maximal_independent_set>`_ of an undirected graph whose edges are given by a two-column table. Examples -------- Run PC-relate and compute pairs of closely related individuals: >>> pc_rel = hl.pc_relate(dataset.GT, 0.001, k=2, statistics='kin') >>> pairs = pc_rel.filter(pc_rel['kin'] > 0.125) Starting from the above pairs, prune individuals from a dataset until no close relationships remain: >>> related_samples_to_remove = hl.maximal_independent_set(pairs.i, pairs.j, False) >>> result = dataset.filter_cols( ... hl.is_defined(related_samples_to_remove[dataset.col_key]), keep=False) Starting from the above pairs, prune individuals from a dataset until no close relationships remain, preferring to keep cases over controls: >>> samples = dataset.cols() >>> pairs_with_case = pairs.key_by( ... i=hl.struct(id=pairs.i, is_case=samples[pairs.i].is_case), ... j=hl.struct(id=pairs.j, is_case=samples[pairs.j].is_case)) >>> def tie_breaker(l, r): ... return hl.cond(l.is_case & ~r.is_case, -1, ... hl.cond(~l.is_case & r.is_case, 1, 0)) >>> related_samples_to_remove = hl.maximal_independent_set( ... pairs_with_case.i, pairs_with_case.j, False, tie_breaker) >>> result = dataset.filter_cols(hl.is_defined( ... related_samples_to_remove.key_by( ... s = related_samples_to_remove.node.id.s)[dataset.col_key]), keep=False) Notes ----- The vertex set of the graph is implicitly all the values realized by `i` and `j` on the rows of this table. Each row of the table corresponds to an undirected edge between the vertices given by evaluating `i` and `j` on that row. An undirected edge may appear multiple times in the table and will not affect the output. Vertices with self-edges are removed as they are not independent of themselves. The expressions for `i` and `j` must have the same type. The value of `keep` determines whether the vertices returned are those in the maximal independent set, or those in the complement of this set. This is useful if you need to filter a table without removing vertices that don't appear in the graph at all. This method implements a greedy algorithm which iteratively removes a vertex of highest degree until the graph contains no edges. The greedy algorithm always returns an independent set, but the set may not always be perfectly maximal. `tie_breaker` is a Python function taking two arguments---say `l` and `r`---each of which is an :class:`Expression` of the same type as `i` and `j`. `tie_breaker` returns a :class:`NumericExpression`, which defines an ordering on nodes. A pair of nodes can be ordered in one of three ways, and `tie_breaker` must encode the relationship as follows: - if ``l < r`` then ``tie_breaker`` evaluates to some negative integer - if ``l == r`` then ``tie_breaker`` evaluates to 0 - if ``l > r`` then ``tie_breaker`` evaluates to some positive integer For example, the usual ordering on the integers is defined by: ``l - r``. The `tie_breaker` function must satisfy the following property: ``tie_breaker(l, r) == -tie_breaker(r, l)``. When multiple nodes have the same degree, this algorithm will order the nodes according to ``tie_breaker`` and remove the *largest* node. If `keyed` is ``False``, then a node may appear twice in the resulting table. Parameters ---------- i : :class:`.Expression` Expression to compute one endpoint of an edge. j : :class:`.Expression` Expression to compute another endpoint of an edge. keep : :obj:`bool` If ``True``, return vertices in set. If ``False``, return vertices removed. tie_breaker : function Function used to order nodes with equal degree. keyed : :obj:`bool` If ``True``, key the resulting table by the `node` field, this requires a sort. Returns ------- :class:`.Table` Table with the set of independent vertices. The table schema is one row field `node` which has the same type as input expressions `i` and `j`. """ if i.dtype != j.dtype: raise ValueError("'maximal_independent_set' expects arguments `i` and `j` to have same type. " "Found {} and {}.".format(i.dtype, j.dtype)) source = i._indices.source if not isinstance(source, Table): raise ValueError("'maximal_independent_set' expects an expression of 'Table'. Found {}".format( "expression of '{}'".format( source.__class__) if source is not None else 'scalar expression')) if i._indices.source != j._indices.source: raise ValueError( "'maximal_independent_set' expects arguments `i` and `j` to be expressions of the same Table. " "Found\n{}\n{}".format(i, j)) node_t = i.dtype if tie_breaker: wrapped_node_t = ttuple(node_t) left = construct_variable('l', wrapped_node_t) right = construct_variable('r', wrapped_node_t) tie_breaker_expr = hl.float64(tie_breaker(left[0], right[0])) t, _ = source._process_joins(i, j, tie_breaker_expr) tie_breaker_str = str(tie_breaker_expr._ir) else: t, _ = source._process_joins(i, j) tie_breaker_str = None edges = t.select(__i=i, __j=j).key_by().select('__i', '__j') edges_path = new_temp_file() edges.write(edges_path) edges = hl.read_table(edges_path) mis_nodes = construct_expr( ir.JavaIR(Env.hail().utils.Graph.pyMaximalIndependentSet( Env.spark_backend('maximal_independent_set')._to_java_value_ir(edges.collect(_localize=False)._ir), node_t._parsable_string(), tie_breaker_str)), hl.tset(node_t)) nodes = edges.select(node=[edges.__i, edges.__j]) nodes = nodes.explode(nodes.node) nodes = nodes.annotate_globals(mis_nodes=mis_nodes) nodes = nodes.filter(nodes.mis_nodes.contains(nodes.node), keep) nodes = nodes.select_globals() if keyed: return nodes.key_by('node').distinct() return nodes