def test_partitioning(): schema = pa.schema( [pa.field('i64', pa.int64()), pa.field('f64', pa.float64())]) for klass in [ds.DirectoryPartitioning, ds.HivePartitioning]: partitioning = klass(schema) assert isinstance(partitioning, ds.Partitioning) partitioning = ds.DirectoryPartitioning( pa.schema( [pa.field('group', pa.int64()), pa.field('key', pa.float64())])) expr = partitioning.parse('/3/3.14') assert isinstance(expr, ds.Expression) expected = (ds.field('group') == 3) & (ds.field('key') == 3.14) assert expr.equals(expected) with pytest.raises(pa.ArrowInvalid): partitioning.parse('/prefix/3/aaa') partitioning = ds.HivePartitioning( pa.schema( [pa.field('alpha', pa.int64()), pa.field('beta', pa.int64())])) expr = partitioning.parse('/alpha=0/beta=3') expected = ((ds.field('alpha') == ds.scalar(0)) & (ds.field('beta') == ds.scalar(3))) assert expr.equals(expected)
def test_with_partition_pruning(): if skip: return filter_expression = ((ds.field('tip_amount') > 10) & (ds.field('payment_type') > 2) & (ds.field('VendorID') > 1)) projection_cols = ['payment_type', 'tip_amount', 'VendorID'] partitioning = ds.partitioning(pa.schema([("payment_type", pa.int32()), ("VendorID", pa.int32())]), flavor="hive") rados_parquet_dataset = ds.dataset( "file:///mnt/cephfs/nyc/", partitioning=partitioning, format=ds.RadosParquetFileFormat("/etc/ceph/ceph.conf")) parquet_dataset = ds.dataset("file:///mnt/cephfs/nyc/", partitioning=partitioning, format="parquet") rados_parquet_df = rados_parquet_dataset.to_table( columns=projection_cols, filter=filter_expression).to_pandas() parquet_df = parquet_dataset.to_table( columns=projection_cols, filter=filter_expression).to_pandas() assert rados_parquet_df.equals(parquet_df) == 1
def test_fragments(tempdir): table, dataset = _create_dataset_for_fragments(tempdir) # list fragments fragments = list(dataset.get_fragments()) assert len(fragments) == 2 f = fragments[0] # file's schema does not include partition column phys_schema = f.schema.remove(f.schema.get_field_index('part')) assert f.format.inspect(f.path, f.filesystem) == phys_schema assert f.partition_expression.equals(ds.field('part') == 'a') # scanning fragment includes partition columns result = f.to_table() assert f.schema == result.schema assert result.column_names == ['f1', 'f2', 'part'] assert len(result) == 4 assert result.equals(table.slice(0, 4)) # scanning fragments follow column projection fragments = list(dataset.get_fragments(columns=['f1', 'part'])) assert len(fragments) == 2 result = fragments[0].to_table() assert result.column_names == ['f1', 'part'] assert len(result) == 4 # scanning fragments follow filter predicate fragments = list(dataset.get_fragments(filter=ds.field('f1') < 2)) assert len(fragments) == 2 result = fragments[0].to_table() assert result.column_names == ['f1', 'f2', 'part'] assert len(result) == 2 result = fragments[1].to_table() assert len(result) == 0
def test_fragments(tempdir): table, dataset = _create_dataset_for_fragments(tempdir) # list fragments fragments = list(dataset.get_fragments()) assert len(fragments) == 2 f = fragments[0] physical_names = ['f1', 'f2'] # file's schema does not include partition column assert f.physical_schema.names == physical_names assert f.format.inspect(f.path, f.filesystem) == f.physical_schema assert f.partition_expression.equals(ds.field('part') == 'a') # By default, the partition column is not part of the schema. result = f.to_table() assert result.column_names == physical_names assert result.equals(table.remove_column(2).slice(0, 4)) # scanning fragment includes partition columns when given the proper # schema. result = f.to_table(schema=dataset.schema) assert result.column_names == ['f1', 'f2', 'part'] assert result.equals(table.slice(0, 4)) assert f.physical_schema == result.schema.remove(2) # scanning fragments follow filter predicate result = f.to_table(schema=dataset.schema, filter=ds.field('f1') < 2) assert result.column_names == ['f1', 'f2', 'part']
def _query( self, filename, filter_expr=None, instrument_ids=None, start=None, end=None, ts_column="ts_event_ns", ): filters = [filter_expr] if filter_expr is not None else [] if instrument_ids is not None: if not isinstance(instrument_ids, list): instrument_ids = [instrument_ids] filters.append( ds.field("instrument_id").isin(list(set(instrument_ids)))) if start is not None: filters.append(ds.field(ts_column) >= start) if end is not None: filters.append(ds.field(ts_column) <= end) dataset = ds.dataset( f"{self.root}/{filename}.parquet/", partitioning="hive", filesystem=self.fs, ) df = (dataset.to_table(filter=combine_filters( *filters)).to_pandas().drop_duplicates()) if "instrument_id" in df.columns: df = df.astype({"instrument_id": "category"}) return df
def get_dates_df(symbol: str, tick_type: str, start_date: str, end_date: str, source: str='local') -> pd.DataFrame: if source == 'local': ds = get_local_dataset(tick_type=tick_type, symbol=symbol) elif source == 's3': ds = get_s3_dataset(tick_type=tick_type, symbol=symbol) filter_exp = (field('date') >= start_date) & (field('date') <= end_date) return ds.to_table(filter=filter_exp).to_pandas()
def test_fragments_reconstruct(tempdir): table, dataset = _create_dataset_for_fragments(tempdir) def assert_yields_projected(fragment, row_slice, columns): actual = fragment.to_table() assert actual.column_names == columns expected = table.slice(*row_slice).to_pandas()[[*columns]] assert actual.equals(pa.Table.from_pandas(expected)) fragment = list(dataset.get_fragments())[0] parquet_format = fragment.format # manually re-construct a fragment, with explicit schema new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, schema=dataset.schema, partition_expression=fragment.partition_expression) assert new_fragment.to_table().equals(fragment.to_table()) assert_yields_projected(new_fragment, (0, 4), table.column_names) # filter / column projection, inspected schema new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, columns=['f1'], filter=ds.field('f1') < 2, partition_expression=fragment.partition_expression) assert_yields_projected(new_fragment, (0, 2), ['f1']) # filter requiring cast / column projection, inspected schema new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, columns=['f1'], filter=ds.field('f1') < 2.0, partition_expression=fragment.partition_expression) assert_yields_projected(new_fragment, (0, 2), ['f1']) # filter on the partition column, explicit schema new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, schema=dataset.schema, filter=ds.field('part') == 'a', partition_expression=fragment.partition_expression) assert_yields_projected(new_fragment, (0, 4), table.column_names) # filter on the partition column, inspected schema with pytest.raises(ValueError, match="Field named 'part' not found"): new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, filter=ds.field('part') == 'a', partition_expression=fragment.partition_expression)
def _query( self, cls, filter_expr=None, instrument_ids=None, start=None, end=None, ts_column="ts_init", raise_on_empty=True, instrument_id_column="instrument_id", table_kwargs: Optional[Dict] = None, clean_instrument_keys=True, as_dataframe=True, **kwargs, ): filters = [filter_expr] if filter_expr is not None else [] if instrument_ids is not None: if not isinstance(instrument_ids, list): instrument_ids = [instrument_ids] if clean_instrument_keys: instrument_ids = list(set(map(clean_key, instrument_ids))) filters.append( ds.field(instrument_id_column).cast("string").isin( instrument_ids)) if start is not None: filters.append( ds.field(ts_column) >= int( pd.Timestamp(start).to_datetime64())) if end is not None: filters.append( ds.field(ts_column) <= int(pd.Timestamp(end).to_datetime64())) full_path = self._make_path(cls=cls) if not (self.fs.exists(full_path) or self.fs.isdir(full_path)): if raise_on_empty: raise FileNotFoundError( f"protocol={self.fs.protocol}, path={full_path}") else: return pd.DataFrame() if as_dataframe else None dataset = ds.dataset(full_path, partitioning="hive", filesystem=self.fs) table = dataset.to_table(filter=combine_filters(*filters), **(table_kwargs or {})) mappings = self.load_inverse_mappings(path=full_path) if as_dataframe: return self._handle_table_dataframe(table=table, mappings=mappings, raise_on_empty=raise_on_empty, **kwargs) else: return self._handle_table_nautilus(table=table, cls=cls, mappings=mappings)
def test_data_catalog_generic_data(self): TestStubs.setup_news_event_persistence() process_files( glob_path=f"{TEST_DATA_DIR}/news_events.csv", reader=CSVReader(block_parser=TestStubs.news_event_parser), catalog=self.catalog, ) df = self.catalog.generic_data(cls=NewsEventData, filter_expr=ds.field("currency") == "USD") assert len(df) == 22925 data = self.catalog.generic_data( cls=NewsEventData, filter_expr=ds.field("currency") == "CHF", as_nautilus=True ) assert len(data) == 2745 and isinstance(data[0], GenericData)
def filter_by_time_period(parquet_partition_name: str, start: int, stop: int, population_filter: list = None, incl_attributes=False) -> Table: stop_missing: Expression = ~ds.field("stop_epoch_days").is_valid() start_epoch_le_start: Expression = ds.field('start_epoch_days') <= start start_epoch_ge_start: Expression = ds.field('start_epoch_days') >= start start_epoch_le_stop: Expression = ds.field('start_epoch_days') <= stop start_epoch_g_start: Expression = ds.field('start_epoch_days') > start stop_epoch_ge_start: Expression = ds.field('stop_epoch_days') >= start stop_epoch_le_stop: Expression = ds.field('stop_epoch_days') <= stop find_by_time_period_filter = ((start_epoch_le_start & stop_missing) | (start_epoch_le_start & stop_epoch_ge_start) | (start_epoch_ge_start & start_epoch_le_stop) | (start_epoch_g_start & stop_epoch_le_stop)) if population_filter: population: Expression = ds.field("unit_id").isin(population_filter) find_by_time_period_filter = population & find_by_time_period_filter table = do_filter(find_by_time_period_filter, incl_attributes, parquet_partition_name) return table
def test_complex_expr(): expr = Expressions.or_( Expressions.and_(Expressions.greater_than('a', 1), Expressions.equal("b", "US")), Expressions.equal("c", True)) translated_dataset_filter = get_dataset_filter(expr, { 'a': 'a', 'b': 'b', 'c': 'c' }) dataset_filter = (((ds.field("a") > 1) & (ds.field("b") == "US")) | (ds.field("c") == True)) # noqa: E712 assert dataset_filter.equals(translated_dataset_filter)
def test_read_table_with_filter(): table_path = "../rust/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path) expected = { "value": ["6", "7", "5"], "year": ["2021", "2021", "2021"], "month": ["12", "12", "12"], "day": ["20", "20", "4"], } filter_expr = (ds.field("year") == "2021") & (ds.field("month") == "12") dataset = dt.to_pyarrow_dataset() assert len(list(dataset.get_fragments(filter=filter_expr))) == 2 assert dataset.to_table(filter=filter_expr).to_pydict() == expected
def test_fragments_parquet_row_groups_reconstruct(tempdir): table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2) fragment = list(dataset.get_fragments())[0] parquet_format = fragment.format row_group_fragments = list(fragment.get_row_group_fragments()) # manually re-construct row group fragments new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, partition_expression=fragment.partition_expression, row_groups=[0]) result = new_fragment.to_table() assert result.equals(row_group_fragments[0].to_table()) # manually re-construct a row group fragment with filter/column projection new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, partition_expression=fragment.partition_expression, row_groups={1}) result = new_fragment.to_table(columns=['f1', 'part'], filter=ds.field('f1') < 3, ) assert result.column_names == ['f1', 'part'] assert len(result) == 1 # out of bounds row group index new_fragment = parquet_format.make_fragment( fragment.path, fragment.filesystem, partition_expression=fragment.partition_expression, row_groups={2}) with pytest.raises(IndexError, match="trying to scan row group 2"): new_fragment.to_table()
def test_without_partition_pruning(): if skip: return rados_parquet_dataset = ds.dataset( "file:///mnt/cephfs/nyc/", format=ds.RadosParquetFileFormat("/etc/ceph/ceph.conf")) parquet_dataset = ds.dataset("file:///mnt/cephfs/nyc/", format="parquet") rados_parquet_df = rados_parquet_dataset.to_table( columns=['DOLocationID', 'total_amount', 'fare_amount'], filter=(ds.field('total_amount') > 200)).to_pandas() parquet_df = parquet_dataset.to_table( columns=['DOLocationID', 'total_amount', 'fare_amount'], filter=(ds.field('total_amount') > 200)).to_pandas() assert rados_parquet_df.equals(parquet_df) == 1
def dataset_batches( file_meta: FileMeta, fs: fsspec.AbstractFileSystem, n_rows: int ) -> Iterator[pd.DataFrame]: try: d: ds.Dataset = ds.dataset(file_meta.filename, filesystem=fs) except ArrowInvalid: return filter_expr = (ds.field("ts_init") >= file_meta.start) & (ds.field("ts_init") <= file_meta.end) scanner: ds.Scanner = d.scanner(filter=filter_expr, batch_size=n_rows) for batch in scanner.to_batches(): if batch.num_rows == 0: break data = batch.to_pandas() if file_meta.instrument_id: data.loc[:, "instrument_id"] = file_meta.instrument_id yield data
def test_read_table_with_stats(): table_path = "../rust/tests/data/COVID-19_NYT" dt = DeltaTable(table_path) dataset = dt.to_pyarrow_dataset() filter_expr = ds.field("date") > "2021-02-20" assert len(list(dataset.get_fragments(filter=filter_expr))) == 2 data = dataset.to_table(filter=filter_expr) assert data.num_rows < 147181 + 47559 filter_expr = ds.field("cases") < 0 assert len(list(dataset.get_fragments(filter=filter_expr))) == 0 data = dataset.to_table(filter=filter_expr) assert data.num_rows == 0
def test_catalog_generic_data_not_overwritten(self): # Arrange TestPersistenceStubs.setup_news_event_persistence() process_files( glob_path=f"{TEST_DATA_DIR}/news_events.csv", reader=CSVReader(block_parser=TestPersistenceStubs.news_event_parser), catalog=self.catalog, ) objs = self.catalog.generic_data( cls=NewsEventData, filter_expr=ds.field("currency") == "USD", as_nautilus=True ) # Clear the catalog again data_catalog_setup() self.catalog = DataCatalog.from_env() assert ( len(self.catalog.generic_data(NewsEventData, raise_on_empty=False, as_nautilus=True)) == 0 ) chunk1, chunk2 = objs[:10], objs[5:15] # Act, Assert write_objects(catalog=self.catalog, chunk=chunk1) assert len(self.catalog.generic_data(NewsEventData)) == 10 write_objects(catalog=self.catalog, chunk=chunk2) assert len(self.catalog.generic_data(NewsEventData)) == 15
def test_dataset(dataset): assert isinstance(dataset, ds.Dataset) assert isinstance(dataset.schema, pa.Schema) # TODO(kszucs): test non-boolean Exprs for filter do raise expected_i64 = pa.array([0, 1, 2, 3, 4], type=pa.int64()) expected_f64 = pa.array([0, 1, 2, 3, 4], type=pa.float64()) for task in dataset.scan(): assert isinstance(task, ds.ScanTask) for batch in task.execute(): assert batch.column(0).equals(expected_i64) assert batch.column(1).equals(expected_f64) batches = dataset.to_batches() assert all(isinstance(batch, pa.RecordBatch) for batch in batches) table = dataset.to_table() assert isinstance(table, pa.Table) assert len(table) == 10 condition = ds.field('i64') == 1 scanner = ds.Scanner(dataset, use_threads=True, filter=condition) result = scanner.to_table().to_pydict() # don't rely on the scanning order assert result['i64'] == [1, 1] assert result['f64'] == [1., 1.] assert sorted(result['group']) == [1, 2] assert sorted(result['key']) == ['xxx', 'yyy']
def test_expression_construction(): zero = ds.scalar(0) one = ds.scalar(1) true = ds.scalar(True) false = ds.scalar(False) string = ds.scalar("string") field = ds.field("field") zero | one == string ~true == false for typ in ("bool", pa.bool_()): field.cast(typ) == true field.isin([1, 2]) with pytest.raises(TypeError): field.isin(1) # operations with non-scalar values with pytest.raises(TypeError): field == [1] with pytest.raises(TypeError): field != {1} with pytest.raises(TypeError): field & [1] with pytest.raises(TypeError): field | [1]
def test_expression_ergonomics(): zero = ds.scalar(0) one = ds.scalar(1) true = ds.scalar(True) false = ds.scalar(False) string = ds.scalar("string") field = ds.field("field") assert one.equals(ds.ScalarExpression(1)) assert zero.equals(ds.ScalarExpression(0)) assert true.equals(ds.ScalarExpression(True)) assert false.equals(ds.ScalarExpression(False)) assert string.equals(ds.ScalarExpression("string")) assert field.equals(ds.FieldExpression("field")) expected = ds.AndExpression(ds.ScalarExpression(1), ds.ScalarExpression(0)) for expr in [one & zero, 1 & zero, one & 0]: assert expr.equals(expected) expected = ds.OrExpression(ds.ScalarExpression(1), ds.ScalarExpression(0)) for expr in [one | zero, 1 | zero, one | 0]: assert expr.equals(expected) comparison_ops = [ (operator.eq, ds.CompareOperator.Equal), (operator.ne, ds.CompareOperator.NotEqual), (operator.ge, ds.CompareOperator.GreaterEqual), (operator.le, ds.CompareOperator.LessEqual), (operator.lt, ds.CompareOperator.Less), (operator.gt, ds.CompareOperator.Greater), ] for op, compare_op in comparison_ops: expr = op(zero, one) expected = ds.ComparisonExpression(compare_op, zero, one) assert expr.equals(expected) expr = ~true == false expected = ds.ComparisonExpression( ds.CompareOperator.Equal, ds.NotExpression(ds.ScalarExpression(True)), ds.ScalarExpression(False) ) assert expr.equals(expected) for typ in ("bool", pa.bool_()): expr = field.cast(typ) == true expected = ds.ComparisonExpression( ds.CompareOperator.Equal, ds.CastExpression(ds.FieldExpression("field"), pa.bool_()), ds.ScalarExpression(True) ) assert expr.equals(expected) expr = field.isin([1, 2]) expected = ds.InExpression(ds.FieldExpression("field"), pa.array([1, 2])) assert expr.equals(expected) with pytest.raises(TypeError): field.isin(1)
def filter_by_time(parquet_partition_name: str, date: int, population_filter: list = None, incl_attributes=False) -> Table: stop_missing: Expression = ~ds.field("stop_epoch_days").is_valid() start_epoch_le_date: Expression = ds.field('start_epoch_days') <= date stop_epoch_ge_date: Expression = ds.field('stop_epoch_days') >= date find_by_time_filter = ((start_epoch_le_date & stop_missing) | (start_epoch_le_date & stop_epoch_ge_date)) if population_filter: population: Expression = ds.field("unit_id").isin(population_filter) find_by_time_filter = population & find_by_time_filter table = do_filter(find_by_time_filter, incl_attributes, parquet_partition_name) return table
def test_filter_implicit_cast(tempdir): # ARROW-7652 table = pa.table({'a': pa.array([0, 1, 2, 3, 4, 5], type=pa.int8())}) _, path = _create_single_file(tempdir, table) dataset = ds.dataset(str(path)) filter_ = ds.field('a') > 2 assert len(dataset.to_table(filter=filter_)) == 3
def test_data_catalog_filter(self): # Arrange, Act deltas = self.catalog.order_book_deltas() filtered_deltas = self.catalog.order_book_deltas(filter_expr=ds.field("action") == "DELETE") # Assert assert len(deltas) == 2384 assert len(filtered_deltas) == 351
def test_fragments_parquet_row_groups(tempdir): table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2) fragment = list(dataset.get_fragments())[0] # list and scan row group fragments row_group_fragments = list(fragment.get_row_group_fragments()) assert len(row_group_fragments) == 2 result = row_group_fragments[0].to_table(schema=dataset.schema) assert result.column_names == ['f1', 'f2', 'part'] assert len(result) == 2 assert result.equals(table.slice(0, 2)) fragment = list(dataset.get_fragments(filter=ds.field('f1') < 1))[0] row_group_fragments = list(fragment.get_row_group_fragments()) assert len(row_group_fragments) == 1 result = row_group_fragments[0].to_table(filter=ds.field('f1') < 1) assert len(result) == 1
def filter_by_fixed(parquet_partition_name: str, population_filter: list = None, incl_attributes=False) -> Table: if population_filter: fixed_filter: Expression = ds.field("unit_id").isin(population_filter) table = do_filter(fixed_filter, incl_attributes, parquet_partition_name) else: table = do_filter(None, incl_attributes, parquet_partition_name) return table
def test_filesystem_dataset(mockfs): schema = pa.schema([pa.field('const', pa.int64())]) file_format = ds.ParquetFileFormat() paths = ['subdir/1/xxx/file0.parquet', 'subdir/2/yyy/file1.parquet'] partitions = [ds.ScalarExpression(True), ds.ScalarExpression(True)] dataset = ds.FileSystemDataset(schema, root_partition=None, file_format=file_format, filesystem=mockfs, paths_or_selector=paths, partitions=partitions) assert isinstance(dataset.format, ds.ParquetFileFormat) root_partition = ds.ComparisonExpression(ds.CompareOperator.Equal, ds.FieldExpression('level'), ds.ScalarExpression(1337)) partitions = [ ds.ComparisonExpression(ds.CompareOperator.Equal, ds.FieldExpression('part'), ds.ScalarExpression(1)), ds.ComparisonExpression(ds.CompareOperator.Equal, ds.FieldExpression('part'), ds.ScalarExpression(2)) ] dataset = ds.FileSystemDataset(paths_or_selector=paths, schema=schema, root_partition=root_partition, filesystem=mockfs, partitions=partitions, file_format=file_format) assert dataset.partition_expression.equals(root_partition) assert set(dataset.files) == set(paths) fragments = list(dataset.get_fragments()) for fragment, partition, path in zip(fragments, partitions, paths): assert fragment.partition_expression.equals( ds.AndExpression(root_partition, partition)) assert fragment.path == path assert isinstance(fragment, ds.ParquetFileFragment) assert fragment.row_groups is None row_group_fragments = list(fragment.get_row_group_fragments()) assert len(row_group_fragments) == 1 assert isinstance(fragment, ds.ParquetFileFragment) assert row_group_fragments[0].path == path assert row_group_fragments[0].row_groups == {0} # test predicate pushdown using row group metadata fragments = list(dataset.get_fragments(filter=ds.field("const") == 0)) assert len(fragments) == 2 assert len(list(fragments[0].get_row_group_fragments())) == 1 assert len(list(fragments[1].get_row_group_fragments())) == 0
def test_fragments_implicit_cast(tempdir): # ARROW-8693 import pyarrow.parquet as pq table = pa.table([range(8), [1] * 4 + [2] * 4], names=['col', 'part']) path = str(tempdir / "test_parquet_dataset") pq.write_to_dataset(table, path, partition_cols=["part"]) part = ds.partitioning(pa.schema([('part', 'int8')]), flavor="hive") dataset = ds.dataset(path, format="parquet", partitioning=part) fragments = dataset.get_fragments(filter=ds.field("part") >= 2) assert len(list(fragments)) == 1
def test_expression_serialization(): a = ds.scalar(1) b = ds.scalar(1.1) c = ds.scalar(True) d = ds.scalar("string") e = ds.scalar(None) condition = ds.field('i64') > 5 schema = pa.schema([ pa.field('i64', pa.int64()), pa.field('f64', pa.float64()) ]) assert condition.validate(schema) == pa.bool_() assert condition.assume(ds.field('i64') == 5).equals( ds.scalar(False)) assert condition.assume(ds.field('i64') == 7).equals( ds.scalar(True)) all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c, d.is_valid(), a.cast(pa.int32(), safe=False), a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), ds.field('i64') > 5, ds.field('i64') == 5, ds.field('i64') == 7] for expr in all_exprs: assert isinstance(expr, ds.Expression) restored = pickle.loads(pickle.dumps(expr)) assert expr.equals(restored)
def main(catalog: DataCatalog): """Rename match_id to trade_id in TradeTick""" fs: fsspec.AbstractFileSystem = catalog.fs print("Loading instrument ids") instrument_ids = catalog.query(TradeTick, table_kwargs={"columns": ["instrument_id"] })["instrument_id"].unique() tmp_catalog = DataCatalog(str(catalog.path) + "_tmp") tmp_catalog.fs = catalog.fs for ins_id in tqdm(instrument_ids): # Load trades for instrument trades = catalog.trade_ticks( instrument_ids=[ins_id], projections={"trade_id": ds.field("match_id")}, as_nautilus=True, ) # Create temp parquet in case of error fs.move( f"{catalog.path}/data/trade_tick.parquet/instrument_id={ins_id}", f"{catalog.path}/data/trade_tick.parquet_tmp/instrument_id={ins_id}", recursive=True, ) try: # Rewrite to new catalog write_objects(tmp_catalog, trades) # Ensure we can query again _ = tmp_catalog.trade_ticks(instrument_ids=[ins_id], as_nautilus=True) # Clear temp parquet fs.rm( f"{catalog.path}/data/trade_tick.parquet_tmp/instrument_id={ins_id}", recursive=True) except Exception: warnings.warn(f"Failed to write or read instrument_id {ins_id}") fs.move( f"{catalog.path}/data/trade_tick.parquet_tmp/instrument_id={ins_id}", f"{catalog.path}/data/trade_tick.parquet/instrument_id={ins_id}", recursive=True, )
def test_data_catalog_query_filtered(self): ticks = self.catalog.trade_ticks() assert len(ticks) == 312 ticks = self.catalog.trade_ticks(start="2019-12-20 20:56:18") assert len(ticks) == 123 ticks = self.catalog.trade_ticks(start=1576875378384999936) assert len(ticks) == 123 ticks = self.catalog.trade_ticks(start=datetime.datetime(2019, 12, 20, 20, 56, 18)) assert len(ticks) == 123 deltas = self.catalog.order_book_deltas() assert len(deltas) == 2384 filtered_deltas = self.catalog.order_book_deltas(filter_expr=ds.field("action") == "DELETE") assert len(filtered_deltas) == 351