Ejemplo n.º 1
0
    def matrix_irs(self):
        hl.index_bgen(resource('example.8bits.bgen'),
                      reference_genome=hl.get_reference('GRCh37'),
                      contig_recoding={'01': '1'})

        collect = ir.MakeStruct([('x', ir.ApplyAggOp('Collect', [], None, [ir.I32(0)]))])

        matrix_read = ir.MatrixRead(
            ir.MatrixNativeReader(
                resource('backward_compatability/1.0.0/matrix_table/0.hmt'), None, False),
            False, False)
        table_read = ir.TableRead(
            ir.TableNativeReader(resource('backward_compatability/1.0.0/table/0.ht'), None, False), False)

        matrix_range = ir.MatrixRead(ir.MatrixRangeReader(1, 1, 10))
        matrix_irs = [
            ir.MatrixRepartition(matrix_range, 100, ir.RepartitionStrategy.SHUFFLE),
            ir.MatrixUnionRows(matrix_range, matrix_range),
            ir.MatrixDistinctByRow(matrix_range),
            ir.MatrixRowsHead(matrix_read, 5),
            ir.MatrixColsHead(matrix_read, 5),
            ir.CastTableToMatrix(
                ir.CastMatrixToTable(matrix_read, '__entries', '__cols'),
                '__entries',
                '__cols',
                []),
            ir.MatrixAggregateRowsByKey(matrix_read, collect, collect),
            ir.MatrixAggregateColsByKey(matrix_read, collect, collect),
            matrix_read,
            matrix_range,
            ir.MatrixRead(ir.MatrixVCFReader(resource('sample.vcf'), ['GT'], hl.tfloat64, None, None, None, None,
                                             False, True, False, True, None, None, None)),
            ir.MatrixRead(ir.MatrixBGENReader(resource('example.8bits.bgen'), None, {}, 10, 1, None)),
            ir.MatrixFilterRows(matrix_read, ir.FalseIR()),
            ir.MatrixFilterCols(matrix_read, ir.FalseIR()),
            ir.MatrixFilterEntries(matrix_read, ir.FalseIR()),
            ir.MatrixChooseCols(matrix_read, [1, 0]),
            ir.MatrixMapCols(matrix_read, ir.MakeStruct([('x', ir.I64(20))]), ['x']),
            ir.MatrixKeyRowsBy(matrix_read, ['row_i64'], False),
            ir.MatrixMapRows(ir.MatrixKeyRowsBy(matrix_read, []), ir.MakeStruct([('x', ir.I64(20))])),
            ir.MatrixMapEntries(matrix_read, ir.MakeStruct([('x', ir.I64(20))])),
            ir.MatrixMapGlobals(matrix_read, ir.MakeStruct([('x', ir.I64(20))])),
            ir.MatrixCollectColsByKey(matrix_read),
            ir.MatrixExplodeRows(matrix_read, ['row_aset']),
            ir.MatrixExplodeCols(matrix_read, ['col_aset']),
            ir.MatrixAnnotateRowsTable(matrix_read, table_read, '__foo'),
            ir.MatrixAnnotateColsTable(matrix_read, table_read, '__foo'),
            ir.MatrixToMatrixApply(matrix_read, {'name': 'MatrixFilterPartitions', 'parts': [0], 'keep': True}),
            ir.MatrixRename(matrix_read, {'global_f32': 'global_foo'}, {'col_f32': 'col_foo'}, {'row_aset': 'row_aset2'}, {'entry_f32': 'entry_foo'}),
            ir.MatrixFilterIntervals(matrix_read, [hl.utils.Interval(hl.utils.Struct(row_idx=0), hl.utils.Struct(row_idx=10))], hl.tstruct(row_idx=hl.tint32), keep=False),
        ]

        return matrix_irs
Ejemplo n.º 2
0
def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]:
    """Filter rows with a list of intervals.

    Examples
    --------

    Filter to loci falling within one interval:

    >>> ds_result = hl.filter_intervals(dataset, [hl.parse_locus_interval('17:38449840-38530994')])

    Remove all loci within list of intervals:

    >>> intervals = [hl.parse_locus_interval(x) for x in ['1:50M-75M', '2:START-400000', '3-22']]
    >>> ds_result = hl.filter_intervals(dataset, intervals, keep=False)

    Notes
    -----
    Based on the `keep` argument, this method will either restrict to points
    in the supplied interval ranges, or remove all rows in those ranges.

    When ``keep=True``, partitions that don't overlap any supplied interval
    will not be loaded at all.  This enables :func:`.filter_intervals` to be
    used for reasonably low-latency queries of small ranges of the dataset, even
    on large datasets.

    Parameters
    ----------
    ds : :class:`.MatrixTable` or :class:`.Table`
        Dataset to filter.
    intervals : :class:`.ArrayExpression` of type :py:data:`.tinterval`
        Intervals to filter on.  The point type of the interval must
        be a prefix of the key or equal to the first field of the key.
    keep : :obj:`bool`
        If ``True``, keep only rows that fall within any interval in `intervals`.
        If ``False``, keep only rows that fall outside all intervals in
        `intervals`.

    Returns
    -------
    :class:`.MatrixTable` or :class:`.Table`

    """

    if isinstance(ds, MatrixTable):
        k_type = ds.row_key.dtype
    else:
        assert isinstance(ds, Table)
        k_type = ds.key.dtype

    point_type = intervals.dtype.element_type.point_type

    def is_struct_prefix(partial, full):
        if list(partial) != list(full)[:len(partial)]:
            return False
        for k, v in partial.items():
            if full[k] != v:
                return False
        return True

    if point_type == k_type[0]:
        needs_wrapper = True
        k_name = k_type.fields[0]
        point_type = hl.tstruct(**{k_name: k_type[k_name]})
    elif isinstance(point_type, tstruct) and is_struct_prefix(point_type, k_type):
        needs_wrapper = False
    else:
        raise TypeError(
            "The point type is incompatible with key type of the dataset ('{}', '{}')".format(repr(point_type),
                                                                                              repr(k_type)))

    def wrap_input(interval):
        if interval is None:
            raise TypeError("'filter_intervals' does not allow missing values in 'intervals'.")
        elif needs_wrapper:
            return Interval(Struct(**{k_name: interval.start}),
                            Struct(**{k_name: interval.end}),
                            interval.includes_start,
                            interval.includes_end)
        else:
            return interval

    intervals = hl.eval(intervals)
    intervals = [wrap_input(i) for i in intervals]

    if isinstance(ds, MatrixTable):
        return MatrixTable(ir.MatrixFilterIntervals(ds._mir, intervals, point_type, keep))
    else:
        return Table(ir.TableFilterIntervals(ds._tir, intervals, point_type, keep))