Ejemplo n.º 1
0
    def from_matrix_table(
            mt: MatrixTable,
            entry_field: str,
            *,
            n_partitions: Optional[int] = None,
            block_size: Optional[int] = None,
            sort_columns: bool = False
    ) -> 'DNDArray':
        if n_partitions is None:
            n_partitions = mt.n_partitions()
        if block_size is None:
            block_size = DNDArray.default_block_size
        if n_partitions == 0:
            assert mt.count_cols() == 0
            assert mt.count_rows() == 0
            t = range_table(0, 0)
            t = t.annotate(r=0, c=0, block=nd.array([]).reshape((0, 0)))
            t = t.select_globals(
                n_rows=0,
                n_cols=0,
                n_block_rows=0,
                n_block_cols=0,
                block_size=0)
            return DNDArray(t)

        assert 'r' not in mt.row
        assert 'c' not in mt.row
        assert 'block' not in mt.row

        n_rows, n_cols = mt.count()
        n_block_rows = (n_rows + block_size - 1) // block_size
        n_block_cols = (n_cols + block_size - 1) // block_size
        entries, cols, row_index, col_blocks = (Env.get_uid() for _ in range(4))

        if sort_columns:
            col_index = Env.get_uid()
            col_order = mt.add_col_index(col_index)
            col_order = col_order.key_cols_by().cols()
            col_order = col_order.select(key=col_order.row.select(*mt.col_key),
                                         index=col_order[col_index])
            col_order = col_order.collect(_localize=False)
            col_order = hl.sorted(col_order, key=lambda x: x.key)
            col_order = col_order['index'].collect()[0]
            mt = mt.choose_cols(col_order)
        else:
            col_keys = mt.col_key.collect(_localize=False)
            out_of_order = hl.range(hl.len(col_keys) - 1).map(
                lambda i: col_keys[i] > col_keys[i + 1])
            out_of_order = out_of_order.collect()[0]
            if any(out_of_order):
                raise ValueError(
                    'from_matrix_table: columns are not in sorted order. You may request a '
                    'sort with sort_columns=True.')

        mt = (mt
              .select_globals()
              .select_rows()
              .select_cols()
              .add_row_index(row_index)
              .localize_entries(entries, cols))
        # FIXME: remove when ndarray support structs
        mt = mt.annotate(**{entries: mt[entries][entry_field]})
        mt = mt.annotate(
            **{col_blocks: hl.range(n_block_cols).map(
                lambda c: hl.struct(
                    c=c,
                    entries=mt[entries][(c * block_size):((c + 1) * block_size)]))}
        )
        mt = mt.explode(col_blocks)
        mt = mt.select(row_index, **mt[col_blocks])
        mt = mt.annotate(r=hl.int(mt[row_index] // block_size))
        mt = mt.key_by(mt.r, mt.c)
        mt = mt.group_by(mt.r, mt.c).aggregate(
            entries=hl.sorted(
                hl.agg.collect(hl.struct(row_index=mt[row_index], entries=mt.entries)),
                key=lambda x: x.row_index
            ).map(lambda x: x.entries))
        mt = mt.select(block=hl.nd.array(mt.entries))
        mt = mt.select_globals(
            n_rows=n_rows,
            n_cols=n_cols,
            n_block_rows=n_block_rows,
            n_block_cols=n_block_cols,
            block_size=block_size)
        fname = new_temp_file()
        mt = mt.key_by(mt.r, mt.c)
        mt.write(fname, _codec_spec=DNDArray.fast_codec_spec)
        t = hl.read_table(fname, _intervals=[
            hl.Interval(hl.Struct(r=i, c=j),
                        hl.Struct(r=i, c=j + 1))
            for i in range(n_block_rows)
            for j in range(n_block_cols)])
        return DNDArray(t)
Ejemplo n.º 2
0
    def from_matrix_table(
            mt: MatrixTable,
            entrc_field: str,
            *,
            n_partitions: Optional[int] = None,
            block_size: Optional[int] = None
    ) -> 'DNDArray':
        if n_partitions is None:
            n_partitions = mt.n_partitions()
        if block_size is None:
            block_size = DNDArray.default_block_size
        if n_partitions == 0:
            assert mt.count_cols() == 0
            assert mt.count_rows() == 0
            t = range_table(0, 0)
            t = t.annotate(r=0, c=0, block=nd.array([]).reshape((0, 0)))
            t = t.select_globals(
                r_field='r',
                c_field='c',
                n_rows=0,
                n_cols=0,
                n_block_rows=0,
                n_block_cols=0,
                block_size=0)
            return DNDArray(t)

        assert 'r' not in mt.row
        assert 'c' not in mt.row
        assert 'block' not in mt.row

        n_rows, n_cols = mt.count()
        n_block_rows = (n_rows + block_size - 1) // block_size
        n_block_cols = (n_cols + block_size - 1) // block_size
        entries, cols, row_index, col_blocks = (Env.get_uid() for _ in range(4))
        mt = (mt
              .select_globals()
              .select_rows()
              .select_cols()
              .add_row_index(row_index)
              .localize_entries(entries, cols))
        # FIXME: remove when ndarray support structs
        mt = mt.annotate(**{entries: mt[entries][entrc_field]})
        mt = mt.annotate(
            **{col_blocks: hl.range(n_block_cols).map(
                lambda c: hl.struct(
                    c=c,
                    entries=mt[entries][(c * block_size):((c + 1) * block_size)]))}
        )
        mt = mt.explode(col_blocks)
        mt = mt.select(row_index, **mt[col_blocks])
        mt = mt.annotate(r=hl.int(mt[row_index] // block_size))
        mt = mt.key_by(mt.r, mt.c)
        mt = mt.group_by(mt.r, mt.c).aggregate(
            entries=hl.sorted(
                hl.agg.collect(hl.struct(row_index=mt[row_index], entries=mt.entries)),
                key=lambda x: x.row_index
            ).map(lambda x: x.entries))
        mt = mt.select(block=hl.nd.array(mt.entries))
        mt = mt.select_globals(
            r_field='r',
            c_field='c',
            n_rows=n_rows,
            n_cols=n_cols,
            n_block_rows=n_block_rows,
            n_block_cols=n_block_cols,
            block_size=block_size)
        fname = new_temp_file()
        mt = mt.key_by(mt.r, mt.c)
        mt.write(fname, _codec_spec=DNDArray.fast_codec_spec)
        t = hl.read_table(fname, _intervals=[
            hl.Interval(hl.Struct(r=i, c=j),
                        hl.Struct(r=i, c=j + 1))
            for i in range(n_block_rows)
            for j in range(n_block_cols)])
        return DNDArray(t)