Exemplo n.º 1
0
def test_partitioning():
    schema = pa.schema(
        [pa.field('i64', pa.int64()),
         pa.field('f64', pa.float64())])
    for klass in [ds.DirectoryPartitioning, ds.HivePartitioning]:
        partitioning = klass(schema)
        assert isinstance(partitioning, ds.Partitioning)

    partitioning = ds.DirectoryPartitioning(
        pa.schema(
            [pa.field('group', pa.int64()),
             pa.field('key', pa.float64())]))
    expr = partitioning.parse('/3/3.14')
    assert isinstance(expr, ds.Expression)

    expected = (ds.field('group') == 3) & (ds.field('key') == 3.14)
    assert expr.equals(expected)

    with pytest.raises(pa.ArrowInvalid):
        partitioning.parse('/prefix/3/aaa')

    partitioning = ds.HivePartitioning(
        pa.schema(
            [pa.field('alpha', pa.int64()),
             pa.field('beta', pa.int64())]))
    expr = partitioning.parse('/alpha=0/beta=3')
    expected = ((ds.field('alpha') == ds.scalar(0)) &
                (ds.field('beta') == ds.scalar(3)))
    assert expr.equals(expected)
def test_with_partition_pruning():
    if skip:
        return
    filter_expression = ((ds.field('tip_amount') > 10) &
                         (ds.field('payment_type') > 2) &
                         (ds.field('VendorID') > 1))
    projection_cols = ['payment_type', 'tip_amount', 'VendorID']
    partitioning = ds.partitioning(pa.schema([("payment_type", pa.int32()),
                                              ("VendorID", pa.int32())]),
                                   flavor="hive")

    rados_parquet_dataset = ds.dataset(
        "file:///mnt/cephfs/nyc/",
        partitioning=partitioning,
        format=ds.RadosParquetFileFormat("/etc/ceph/ceph.conf"))
    parquet_dataset = ds.dataset("file:///mnt/cephfs/nyc/",
                                 partitioning=partitioning,
                                 format="parquet")

    rados_parquet_df = rados_parquet_dataset.to_table(
        columns=projection_cols, filter=filter_expression).to_pandas()

    parquet_df = parquet_dataset.to_table(
        columns=projection_cols, filter=filter_expression).to_pandas()

    assert rados_parquet_df.equals(parquet_df) == 1
Exemplo n.º 3
0
def test_fragments(tempdir):
    table, dataset = _create_dataset_for_fragments(tempdir)

    # list fragments
    fragments = list(dataset.get_fragments())
    assert len(fragments) == 2
    f = fragments[0]

    # file's schema does not include partition column
    phys_schema = f.schema.remove(f.schema.get_field_index('part'))
    assert f.format.inspect(f.path, f.filesystem) == phys_schema
    assert f.partition_expression.equals(ds.field('part') == 'a')

    # scanning fragment includes partition columns
    result = f.to_table()
    assert f.schema == result.schema
    assert result.column_names == ['f1', 'f2', 'part']
    assert len(result) == 4
    assert result.equals(table.slice(0, 4))

    # scanning fragments follow column projection
    fragments = list(dataset.get_fragments(columns=['f1', 'part']))
    assert len(fragments) == 2
    result = fragments[0].to_table()
    assert result.column_names == ['f1', 'part']
    assert len(result) == 4

    # scanning fragments follow filter predicate
    fragments = list(dataset.get_fragments(filter=ds.field('f1') < 2))
    assert len(fragments) == 2
    result = fragments[0].to_table()
    assert result.column_names == ['f1', 'f2', 'part']
    assert len(result) == 2
    result = fragments[1].to_table()
    assert len(result) == 0
Exemplo n.º 4
0
def test_fragments(tempdir):
    table, dataset = _create_dataset_for_fragments(tempdir)

    # list fragments
    fragments = list(dataset.get_fragments())
    assert len(fragments) == 2
    f = fragments[0]

    physical_names = ['f1', 'f2']
    # file's schema does not include partition column
    assert f.physical_schema.names == physical_names
    assert f.format.inspect(f.path, f.filesystem) == f.physical_schema
    assert f.partition_expression.equals(ds.field('part') == 'a')

    # By default, the partition column is not part of the schema.
    result = f.to_table()
    assert result.column_names == physical_names
    assert result.equals(table.remove_column(2).slice(0, 4))

    # scanning fragment includes partition columns when given the proper
    # schema.
    result = f.to_table(schema=dataset.schema)
    assert result.column_names == ['f1', 'f2', 'part']
    assert result.equals(table.slice(0, 4))
    assert f.physical_schema == result.schema.remove(2)

    # scanning fragments follow filter predicate
    result = f.to_table(schema=dataset.schema, filter=ds.field('f1') < 2)
    assert result.column_names == ['f1', 'f2', 'part']
Exemplo n.º 5
0
    def _query(
        self,
        filename,
        filter_expr=None,
        instrument_ids=None,
        start=None,
        end=None,
        ts_column="ts_event_ns",
    ):
        filters = [filter_expr] if filter_expr is not None else []
        if instrument_ids is not None:
            if not isinstance(instrument_ids, list):
                instrument_ids = [instrument_ids]
            filters.append(
                ds.field("instrument_id").isin(list(set(instrument_ids))))
        if start is not None:
            filters.append(ds.field(ts_column) >= start)
        if end is not None:
            filters.append(ds.field(ts_column) <= end)

        dataset = ds.dataset(
            f"{self.root}/{filename}.parquet/",
            partitioning="hive",
            filesystem=self.fs,
        )
        df = (dataset.to_table(filter=combine_filters(
            *filters)).to_pandas().drop_duplicates())
        if "instrument_id" in df.columns:
            df = df.astype({"instrument_id": "category"})
        return df
Exemplo n.º 6
0
def get_dates_df(symbol: str, tick_type: str, start_date: str, end_date: str, source: str='local') -> pd.DataFrame:
    
    if source == 'local':
        ds = get_local_dataset(tick_type=tick_type, symbol=symbol)
    elif source == 's3':
        ds = get_s3_dataset(tick_type=tick_type, symbol=symbol)
    
    filter_exp = (field('date') >= start_date) & (field('date') <= end_date)
    return ds.to_table(filter=filter_exp).to_pandas()
Exemplo n.º 7
0
def test_fragments_reconstruct(tempdir):
    table, dataset = _create_dataset_for_fragments(tempdir)

    def assert_yields_projected(fragment, row_slice, columns):
        actual = fragment.to_table()
        assert actual.column_names == columns

        expected = table.slice(*row_slice).to_pandas()[[*columns]]
        assert actual.equals(pa.Table.from_pandas(expected))

    fragment = list(dataset.get_fragments())[0]
    parquet_format = fragment.format

    # manually re-construct a fragment, with explicit schema
    new_fragment = parquet_format.make_fragment(
        fragment.path,
        fragment.filesystem,
        schema=dataset.schema,
        partition_expression=fragment.partition_expression)
    assert new_fragment.to_table().equals(fragment.to_table())
    assert_yields_projected(new_fragment, (0, 4), table.column_names)

    # filter / column projection, inspected schema
    new_fragment = parquet_format.make_fragment(
        fragment.path,
        fragment.filesystem,
        columns=['f1'],
        filter=ds.field('f1') < 2,
        partition_expression=fragment.partition_expression)
    assert_yields_projected(new_fragment, (0, 2), ['f1'])

    # filter requiring cast / column projection, inspected schema
    new_fragment = parquet_format.make_fragment(
        fragment.path,
        fragment.filesystem,
        columns=['f1'],
        filter=ds.field('f1') < 2.0,
        partition_expression=fragment.partition_expression)
    assert_yields_projected(new_fragment, (0, 2), ['f1'])

    # filter on the partition column, explicit schema
    new_fragment = parquet_format.make_fragment(
        fragment.path,
        fragment.filesystem,
        schema=dataset.schema,
        filter=ds.field('part') == 'a',
        partition_expression=fragment.partition_expression)
    assert_yields_projected(new_fragment, (0, 4), table.column_names)

    # filter on the partition column, inspected schema
    with pytest.raises(ValueError, match="Field named 'part' not found"):
        new_fragment = parquet_format.make_fragment(
            fragment.path,
            fragment.filesystem,
            filter=ds.field('part') == 'a',
            partition_expression=fragment.partition_expression)
Exemplo n.º 8
0
    def _query(
        self,
        cls,
        filter_expr=None,
        instrument_ids=None,
        start=None,
        end=None,
        ts_column="ts_init",
        raise_on_empty=True,
        instrument_id_column="instrument_id",
        table_kwargs: Optional[Dict] = None,
        clean_instrument_keys=True,
        as_dataframe=True,
        **kwargs,
    ):
        filters = [filter_expr] if filter_expr is not None else []
        if instrument_ids is not None:
            if not isinstance(instrument_ids, list):
                instrument_ids = [instrument_ids]
            if clean_instrument_keys:
                instrument_ids = list(set(map(clean_key, instrument_ids)))
            filters.append(
                ds.field(instrument_id_column).cast("string").isin(
                    instrument_ids))
        if start is not None:
            filters.append(
                ds.field(ts_column) >= int(
                    pd.Timestamp(start).to_datetime64()))
        if end is not None:
            filters.append(
                ds.field(ts_column) <= int(pd.Timestamp(end).to_datetime64()))

        full_path = self._make_path(cls=cls)
        if not (self.fs.exists(full_path) or self.fs.isdir(full_path)):
            if raise_on_empty:
                raise FileNotFoundError(
                    f"protocol={self.fs.protocol}, path={full_path}")
            else:
                return pd.DataFrame() if as_dataframe else None

        dataset = ds.dataset(full_path,
                             partitioning="hive",
                             filesystem=self.fs)
        table = dataset.to_table(filter=combine_filters(*filters),
                                 **(table_kwargs or {}))
        mappings = self.load_inverse_mappings(path=full_path)
        if as_dataframe:
            return self._handle_table_dataframe(table=table,
                                                mappings=mappings,
                                                raise_on_empty=raise_on_empty,
                                                **kwargs)
        else:
            return self._handle_table_nautilus(table=table,
                                               cls=cls,
                                               mappings=mappings)
Exemplo n.º 9
0
 def test_data_catalog_generic_data(self):
     TestStubs.setup_news_event_persistence()
     process_files(
         glob_path=f"{TEST_DATA_DIR}/news_events.csv",
         reader=CSVReader(block_parser=TestStubs.news_event_parser),
         catalog=self.catalog,
     )
     df = self.catalog.generic_data(cls=NewsEventData, filter_expr=ds.field("currency") == "USD")
     assert len(df) == 22925
     data = self.catalog.generic_data(
         cls=NewsEventData, filter_expr=ds.field("currency") == "CHF", as_nautilus=True
     )
     assert len(data) == 2745 and isinstance(data[0], GenericData)
Exemplo n.º 10
0
def filter_by_time_period(parquet_partition_name: str,
                          start: int,
                          stop: int,
                          population_filter: list = None,
                          incl_attributes=False) -> Table:
    stop_missing: Expression = ~ds.field("stop_epoch_days").is_valid()
    start_epoch_le_start: Expression = ds.field('start_epoch_days') <= start
    start_epoch_ge_start: Expression = ds.field('start_epoch_days') >= start
    start_epoch_le_stop: Expression = ds.field('start_epoch_days') <= stop
    start_epoch_g_start: Expression = ds.field('start_epoch_days') > start
    stop_epoch_ge_start: Expression = ds.field('stop_epoch_days') >= start
    stop_epoch_le_stop: Expression = ds.field('stop_epoch_days') <= stop

    find_by_time_period_filter = ((start_epoch_le_start & stop_missing) |
                                  (start_epoch_le_start & stop_epoch_ge_start)
                                  |
                                  (start_epoch_ge_start & start_epoch_le_stop)
                                  | (start_epoch_g_start & stop_epoch_le_stop))
    if population_filter:
        population: Expression = ds.field("unit_id").isin(population_filter)
        find_by_time_period_filter = population & find_by_time_period_filter

    table = do_filter(find_by_time_period_filter, incl_attributes,
                      parquet_partition_name)
    return table
Exemplo n.º 11
0
def test_complex_expr():
    expr = Expressions.or_(
        Expressions.and_(Expressions.greater_than('a', 1),
                         Expressions.equal("b", "US")),
        Expressions.equal("c", True))

    translated_dataset_filter = get_dataset_filter(expr, {
        'a': 'a',
        'b': 'b',
        'c': 'c'
    })
    dataset_filter = (((ds.field("a") > 1) & (ds.field("b") == "US")) |
                      (ds.field("c") == True))  # noqa: E712
    assert dataset_filter.equals(translated_dataset_filter)
Exemplo n.º 12
0
def test_read_table_with_filter():
    table_path = "../rust/tests/data/delta-0.8.0-partitioned"
    dt = DeltaTable(table_path)
    expected = {
        "value": ["6", "7", "5"],
        "year": ["2021", "2021", "2021"],
        "month": ["12", "12", "12"],
        "day": ["20", "20", "4"],
    }
    filter_expr = (ds.field("year") == "2021") & (ds.field("month") == "12")

    dataset = dt.to_pyarrow_dataset()

    assert len(list(dataset.get_fragments(filter=filter_expr))) == 2
    assert dataset.to_table(filter=filter_expr).to_pydict() == expected
Exemplo n.º 13
0
def test_fragments_parquet_row_groups_reconstruct(tempdir):
    table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2)

    fragment = list(dataset.get_fragments())[0]
    parquet_format = fragment.format
    row_group_fragments = list(fragment.get_row_group_fragments())

    # manually re-construct row group fragments
    new_fragment = parquet_format.make_fragment(
        fragment.path, fragment.filesystem,
        partition_expression=fragment.partition_expression,
        row_groups=[0])
    result = new_fragment.to_table()
    assert result.equals(row_group_fragments[0].to_table())

    # manually re-construct a row group fragment with filter/column projection
    new_fragment = parquet_format.make_fragment(
        fragment.path, fragment.filesystem,
        partition_expression=fragment.partition_expression,
        row_groups={1})
    result = new_fragment.to_table(columns=['f1', 'part'],
                                   filter=ds.field('f1') < 3, )
    assert result.column_names == ['f1', 'part']
    assert len(result) == 1

    # out of bounds row group index
    new_fragment = parquet_format.make_fragment(
        fragment.path, fragment.filesystem,
        partition_expression=fragment.partition_expression,
        row_groups={2})
    with pytest.raises(IndexError, match="trying to scan row group 2"):
        new_fragment.to_table()
def test_without_partition_pruning():
    if skip:
        return
    rados_parquet_dataset = ds.dataset(
        "file:///mnt/cephfs/nyc/",
        format=ds.RadosParquetFileFormat("/etc/ceph/ceph.conf"))
    parquet_dataset = ds.dataset("file:///mnt/cephfs/nyc/", format="parquet")

    rados_parquet_df = rados_parquet_dataset.to_table(
        columns=['DOLocationID', 'total_amount', 'fare_amount'],
        filter=(ds.field('total_amount') > 200)).to_pandas()
    parquet_df = parquet_dataset.to_table(
        columns=['DOLocationID', 'total_amount', 'fare_amount'],
        filter=(ds.field('total_amount') > 200)).to_pandas()

    assert rados_parquet_df.equals(parquet_df) == 1
Exemplo n.º 15
0
def dataset_batches(
    file_meta: FileMeta, fs: fsspec.AbstractFileSystem, n_rows: int
) -> Iterator[pd.DataFrame]:
    try:
        d: ds.Dataset = ds.dataset(file_meta.filename, filesystem=fs)
    except ArrowInvalid:
        return
    filter_expr = (ds.field("ts_init") >= file_meta.start) & (ds.field("ts_init") <= file_meta.end)
    scanner: ds.Scanner = d.scanner(filter=filter_expr, batch_size=n_rows)
    for batch in scanner.to_batches():
        if batch.num_rows == 0:
            break
        data = batch.to_pandas()
        if file_meta.instrument_id:
            data.loc[:, "instrument_id"] = file_meta.instrument_id
        yield data
Exemplo n.º 16
0
def test_read_table_with_stats():
    table_path = "../rust/tests/data/COVID-19_NYT"
    dt = DeltaTable(table_path)
    dataset = dt.to_pyarrow_dataset()

    filter_expr = ds.field("date") > "2021-02-20"
    assert len(list(dataset.get_fragments(filter=filter_expr))) == 2

    data = dataset.to_table(filter=filter_expr)
    assert data.num_rows < 147181 + 47559

    filter_expr = ds.field("cases") < 0
    assert len(list(dataset.get_fragments(filter=filter_expr))) == 0

    data = dataset.to_table(filter=filter_expr)
    assert data.num_rows == 0
Exemplo n.º 17
0
    def test_catalog_generic_data_not_overwritten(self):
        # Arrange
        TestPersistenceStubs.setup_news_event_persistence()
        process_files(
            glob_path=f"{TEST_DATA_DIR}/news_events.csv",
            reader=CSVReader(block_parser=TestPersistenceStubs.news_event_parser),
            catalog=self.catalog,
        )
        objs = self.catalog.generic_data(
            cls=NewsEventData, filter_expr=ds.field("currency") == "USD", as_nautilus=True
        )

        # Clear the catalog again
        data_catalog_setup()
        self.catalog = DataCatalog.from_env()

        assert (
            len(self.catalog.generic_data(NewsEventData, raise_on_empty=False, as_nautilus=True))
            == 0
        )

        chunk1, chunk2 = objs[:10], objs[5:15]

        # Act, Assert
        write_objects(catalog=self.catalog, chunk=chunk1)
        assert len(self.catalog.generic_data(NewsEventData)) == 10

        write_objects(catalog=self.catalog, chunk=chunk2)
        assert len(self.catalog.generic_data(NewsEventData)) == 15
Exemplo n.º 18
0
def test_dataset(dataset):
    assert isinstance(dataset, ds.Dataset)
    assert isinstance(dataset.schema, pa.Schema)

    # TODO(kszucs): test non-boolean Exprs for filter do raise

    expected_i64 = pa.array([0, 1, 2, 3, 4], type=pa.int64())
    expected_f64 = pa.array([0, 1, 2, 3, 4], type=pa.float64())
    for task in dataset.scan():
        assert isinstance(task, ds.ScanTask)
        for batch in task.execute():
            assert batch.column(0).equals(expected_i64)
            assert batch.column(1).equals(expected_f64)

    batches = dataset.to_batches()
    assert all(isinstance(batch, pa.RecordBatch) for batch in batches)

    table = dataset.to_table()
    assert isinstance(table, pa.Table)
    assert len(table) == 10

    condition = ds.field('i64') == 1
    scanner = ds.Scanner(dataset, use_threads=True, filter=condition)
    result = scanner.to_table().to_pydict()

    # don't rely on the scanning order
    assert result['i64'] == [1, 1]
    assert result['f64'] == [1., 1.]
    assert sorted(result['group']) == [1, 2]
    assert sorted(result['key']) == ['xxx', 'yyy']
Exemplo n.º 19
0
def test_expression_construction():
    zero = ds.scalar(0)
    one = ds.scalar(1)
    true = ds.scalar(True)
    false = ds.scalar(False)
    string = ds.scalar("string")
    field = ds.field("field")

    zero | one == string
    ~true == false
    for typ in ("bool", pa.bool_()):
        field.cast(typ) == true

    field.isin([1, 2])

    with pytest.raises(TypeError):
        field.isin(1)

    # operations with non-scalar values
    with pytest.raises(TypeError):
        field == [1]

    with pytest.raises(TypeError):
        field != {1}

    with pytest.raises(TypeError):
        field & [1]

    with pytest.raises(TypeError):
        field | [1]
Exemplo n.º 20
0
def test_expression_ergonomics():
    zero = ds.scalar(0)
    one = ds.scalar(1)
    true = ds.scalar(True)
    false = ds.scalar(False)
    string = ds.scalar("string")
    field = ds.field("field")

    assert one.equals(ds.ScalarExpression(1))
    assert zero.equals(ds.ScalarExpression(0))
    assert true.equals(ds.ScalarExpression(True))
    assert false.equals(ds.ScalarExpression(False))
    assert string.equals(ds.ScalarExpression("string"))
    assert field.equals(ds.FieldExpression("field"))

    expected = ds.AndExpression(ds.ScalarExpression(1), ds.ScalarExpression(0))
    for expr in [one & zero, 1 & zero, one & 0]:
        assert expr.equals(expected)

    expected = ds.OrExpression(ds.ScalarExpression(1), ds.ScalarExpression(0))
    for expr in [one | zero, 1 | zero, one | 0]:
        assert expr.equals(expected)

    comparison_ops = [
        (operator.eq, ds.CompareOperator.Equal),
        (operator.ne, ds.CompareOperator.NotEqual),
        (operator.ge, ds.CompareOperator.GreaterEqual),
        (operator.le, ds.CompareOperator.LessEqual),
        (operator.lt, ds.CompareOperator.Less),
        (operator.gt, ds.CompareOperator.Greater),
    ]
    for op, compare_op in comparison_ops:
        expr = op(zero, one)
        expected = ds.ComparisonExpression(compare_op, zero, one)
        assert expr.equals(expected)

    expr = ~true == false
    expected = ds.ComparisonExpression(
        ds.CompareOperator.Equal,
        ds.NotExpression(ds.ScalarExpression(True)),
        ds.ScalarExpression(False)
    )
    assert expr.equals(expected)

    for typ in ("bool", pa.bool_()):
        expr = field.cast(typ) == true
        expected = ds.ComparisonExpression(
            ds.CompareOperator.Equal,
            ds.CastExpression(ds.FieldExpression("field"), pa.bool_()),
            ds.ScalarExpression(True)
        )
        assert expr.equals(expected)

    expr = field.isin([1, 2])
    expected = ds.InExpression(ds.FieldExpression("field"), pa.array([1, 2]))
    assert expr.equals(expected)

    with pytest.raises(TypeError):
        field.isin(1)
Exemplo n.º 21
0
def filter_by_time(parquet_partition_name: str,
                   date: int,
                   population_filter: list = None,
                   incl_attributes=False) -> Table:
    stop_missing: Expression = ~ds.field("stop_epoch_days").is_valid()
    start_epoch_le_date: Expression = ds.field('start_epoch_days') <= date
    stop_epoch_ge_date: Expression = ds.field('stop_epoch_days') >= date

    find_by_time_filter = ((start_epoch_le_date & stop_missing) |
                           (start_epoch_le_date & stop_epoch_ge_date))
    if population_filter:
        population: Expression = ds.field("unit_id").isin(population_filter)
        find_by_time_filter = population & find_by_time_filter

    table = do_filter(find_by_time_filter, incl_attributes,
                      parquet_partition_name)
    return table
Exemplo n.º 22
0
def test_filter_implicit_cast(tempdir):
    # ARROW-7652
    table = pa.table({'a': pa.array([0, 1, 2, 3, 4, 5], type=pa.int8())})
    _, path = _create_single_file(tempdir, table)
    dataset = ds.dataset(str(path))

    filter_ = ds.field('a') > 2
    assert len(dataset.to_table(filter=filter_)) == 3
Exemplo n.º 23
0
    def test_data_catalog_filter(self):
        # Arrange, Act
        deltas = self.catalog.order_book_deltas()
        filtered_deltas = self.catalog.order_book_deltas(filter_expr=ds.field("action") == "DELETE")

        # Assert
        assert len(deltas) == 2384
        assert len(filtered_deltas) == 351
Exemplo n.º 24
0
def test_fragments_parquet_row_groups(tempdir):
    table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2)

    fragment = list(dataset.get_fragments())[0]

    # list and scan row group fragments
    row_group_fragments = list(fragment.get_row_group_fragments())
    assert len(row_group_fragments) == 2
    result = row_group_fragments[0].to_table(schema=dataset.schema)
    assert result.column_names == ['f1', 'f2', 'part']
    assert len(result) == 2
    assert result.equals(table.slice(0, 2))

    fragment = list(dataset.get_fragments(filter=ds.field('f1') < 1))[0]
    row_group_fragments = list(fragment.get_row_group_fragments())
    assert len(row_group_fragments) == 1
    result = row_group_fragments[0].to_table(filter=ds.field('f1') < 1)
    assert len(result) == 1
Exemplo n.º 25
0
def filter_by_fixed(parquet_partition_name: str,
                    population_filter: list = None,
                    incl_attributes=False) -> Table:
    if population_filter:
        fixed_filter: Expression = ds.field("unit_id").isin(population_filter)
        table = do_filter(fixed_filter, incl_attributes,
                          parquet_partition_name)
    else:
        table = do_filter(None, incl_attributes, parquet_partition_name)
    return table
Exemplo n.º 26
0
def test_filesystem_dataset(mockfs):
    schema = pa.schema([pa.field('const', pa.int64())])

    file_format = ds.ParquetFileFormat()

    paths = ['subdir/1/xxx/file0.parquet', 'subdir/2/yyy/file1.parquet']
    partitions = [ds.ScalarExpression(True), ds.ScalarExpression(True)]

    dataset = ds.FileSystemDataset(schema,
                                   root_partition=None,
                                   file_format=file_format,
                                   filesystem=mockfs,
                                   paths_or_selector=paths,
                                   partitions=partitions)
    assert isinstance(dataset.format, ds.ParquetFileFormat)

    root_partition = ds.ComparisonExpression(ds.CompareOperator.Equal,
                                             ds.FieldExpression('level'),
                                             ds.ScalarExpression(1337))
    partitions = [
        ds.ComparisonExpression(ds.CompareOperator.Equal,
                                ds.FieldExpression('part'),
                                ds.ScalarExpression(1)),
        ds.ComparisonExpression(ds.CompareOperator.Equal,
                                ds.FieldExpression('part'),
                                ds.ScalarExpression(2))
    ]
    dataset = ds.FileSystemDataset(paths_or_selector=paths,
                                   schema=schema,
                                   root_partition=root_partition,
                                   filesystem=mockfs,
                                   partitions=partitions,
                                   file_format=file_format)
    assert dataset.partition_expression.equals(root_partition)
    assert set(dataset.files) == set(paths)

    fragments = list(dataset.get_fragments())
    for fragment, partition, path in zip(fragments, partitions, paths):
        assert fragment.partition_expression.equals(
            ds.AndExpression(root_partition, partition))
        assert fragment.path == path
        assert isinstance(fragment, ds.ParquetFileFragment)
        assert fragment.row_groups is None

        row_group_fragments = list(fragment.get_row_group_fragments())
        assert len(row_group_fragments) == 1
        assert isinstance(fragment, ds.ParquetFileFragment)
        assert row_group_fragments[0].path == path
        assert row_group_fragments[0].row_groups == {0}

    # test predicate pushdown using row group metadata
    fragments = list(dataset.get_fragments(filter=ds.field("const") == 0))
    assert len(fragments) == 2
    assert len(list(fragments[0].get_row_group_fragments())) == 1
    assert len(list(fragments[1].get_row_group_fragments())) == 0
Exemplo n.º 27
0
def test_fragments_implicit_cast(tempdir):
    # ARROW-8693
    import pyarrow.parquet as pq

    table = pa.table([range(8), [1] * 4 + [2] * 4], names=['col', 'part'])
    path = str(tempdir / "test_parquet_dataset")
    pq.write_to_dataset(table, path, partition_cols=["part"])

    part = ds.partitioning(pa.schema([('part', 'int8')]), flavor="hive")
    dataset = ds.dataset(path, format="parquet", partitioning=part)
    fragments = dataset.get_fragments(filter=ds.field("part") >= 2)
    assert len(list(fragments)) == 1
Exemplo n.º 28
0
def test_expression_serialization():
    a = ds.scalar(1)
    b = ds.scalar(1.1)
    c = ds.scalar(True)
    d = ds.scalar("string")
    e = ds.scalar(None)

    condition = ds.field('i64') > 5
    schema = pa.schema([
        pa.field('i64', pa.int64()),
        pa.field('f64', pa.float64())
    ])
    assert condition.validate(schema) == pa.bool_()

    assert condition.assume(ds.field('i64') == 5).equals(
        ds.scalar(False))

    assert condition.assume(ds.field('i64') == 7).equals(
        ds.scalar(True))

    all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c,
                 d.is_valid(), a.cast(pa.int32(), safe=False),
                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
                 ds.field('i64') > 5, ds.field('i64') == 5,
                 ds.field('i64') == 7]
    for expr in all_exprs:
        assert isinstance(expr, ds.Expression)
        restored = pickle.loads(pickle.dumps(expr))
        assert expr.equals(restored)
Exemplo n.º 29
0
def main(catalog: DataCatalog):
    """Rename match_id to trade_id in TradeTick"""
    fs: fsspec.AbstractFileSystem = catalog.fs

    print("Loading instrument ids")
    instrument_ids = catalog.query(TradeTick,
                                   table_kwargs={"columns": ["instrument_id"]
                                                 })["instrument_id"].unique()

    tmp_catalog = DataCatalog(str(catalog.path) + "_tmp")
    tmp_catalog.fs = catalog.fs

    for ins_id in tqdm(instrument_ids):

        # Load trades for instrument
        trades = catalog.trade_ticks(
            instrument_ids=[ins_id],
            projections={"trade_id": ds.field("match_id")},
            as_nautilus=True,
        )

        # Create temp parquet in case of error
        fs.move(
            f"{catalog.path}/data/trade_tick.parquet/instrument_id={ins_id}",
            f"{catalog.path}/data/trade_tick.parquet_tmp/instrument_id={ins_id}",
            recursive=True,
        )

        try:
            # Rewrite to new catalog
            write_objects(tmp_catalog, trades)

            # Ensure we can query again
            _ = tmp_catalog.trade_ticks(instrument_ids=[ins_id],
                                        as_nautilus=True)

            # Clear temp parquet
            fs.rm(
                f"{catalog.path}/data/trade_tick.parquet_tmp/instrument_id={ins_id}",
                recursive=True)

        except Exception:
            warnings.warn(f"Failed to write or read instrument_id {ins_id}")
            fs.move(
                f"{catalog.path}/data/trade_tick.parquet_tmp/instrument_id={ins_id}",
                f"{catalog.path}/data/trade_tick.parquet/instrument_id={ins_id}",
                recursive=True,
            )
Exemplo n.º 30
0
    def test_data_catalog_query_filtered(self):
        ticks = self.catalog.trade_ticks()
        assert len(ticks) == 312

        ticks = self.catalog.trade_ticks(start="2019-12-20 20:56:18")
        assert len(ticks) == 123

        ticks = self.catalog.trade_ticks(start=1576875378384999936)
        assert len(ticks) == 123

        ticks = self.catalog.trade_ticks(start=datetime.datetime(2019, 12, 20, 20, 56, 18))
        assert len(ticks) == 123

        deltas = self.catalog.order_book_deltas()
        assert len(deltas) == 2384

        filtered_deltas = self.catalog.order_book_deltas(filter_expr=ds.field("action") == "DELETE")
        assert len(filtered_deltas) == 351