예제 #1
0
    def test_progress_hooks_empty_pipeline(self):
        publisher = TestingProgressPublisher()
        hooks = [ProgressHooks.with_static_publisher(publisher)]
        pipeline = Pipeline({}, domain=US_EQUITIES)
        start_date, end_date = self.trading_days[[-10, -1]]
        expected_chunks = [
            tuple(self.trading_days[[-10, -6]]),
            tuple(self.trading_days[[-5, -1]]),
        ]

        self.run_chunked_pipeline(
            pipeline=pipeline,
            start_date=start_date,
            end_date=end_date,
            chunksize=5,
            hooks=hooks,
        )

        self.verify_trace(
            publisher.trace,
            pipeline_start_date=start_date,
            pipeline_end_date=end_date,
            expected_chunks=expected_chunks,
            empty=True,
        )
예제 #2
0
    def test_progress_hooks(self):
        publisher = TestingProgressPublisher()
        hooks = [ProgressHooks.with_static_publisher(publisher)]
        pipeline = Pipeline(
            {
                'bool_': TestingDataSet.bool_col.latest,
                'factor_rank': TrivialFactor().rank().zscore(),
                'prepopulated': PREPOPULATED_TERM,
            },
            domain=US_EQUITIES,
        )
        start_date, end_date = self.trading_days[[-10, -1]]
        expected_chunks = [
            tuple(self.trading_days[[-10, -6]]),
            tuple(self.trading_days[[-5, -1]]),
        ]

        self.run_chunked_pipeline(
            pipeline=pipeline,
            start_date=start_date,
            end_date=end_date,
            chunksize=5,
            hooks=hooks,
        )

        expected_loads = set(TrivialFactor.inputs) | {TestingDataSet.bool_col}
        expected_computes = {
            TestingDataSet.bool_col.latest,
            TrivialFactor(),
            TrivialFactor().rank(),
            TrivialFactor().rank().zscore(),
            Everything(),  # Default input for .rank().
        }

        self.verify_trace(
            publisher.trace,
            pipeline_start_date=start_date,
            pipeline_end_date=end_date,
            expected_loads=expected_loads,
            expected_computes=expected_computes,
            expected_chunks=expected_chunks,
        )
예제 #3
0
    def run_pipeline(self,
                     pipeline,
                     start_date,
                     end_date=None,
                     chunksize=120,
                     hooks=None):
        if end_date is None:
            end_date = start_date

        if hooks is None:
            hooks = [
                ProgressHooks.with_static_publisher(CliProgressPublisher())
            ]

        if chunksize <= 1:
            log.info("Compute pipeline values without chunks.")
            return super().run_pipeline(pipeline, start_date, end_date, hooks)

        return super().run_chunked_pipeline(pipeline, start_date, end_date,
                                            chunksize, hooks)
예제 #4
0
    def test_progress_hooks(self):
        publisher = TestingProgressPublisher()
        hooks = [ProgressHooks.with_static_publisher(publisher)]
        pipeline = Pipeline(
            {
                'bool_': TestingDataSet.bool_col.latest,
                'factor_rank': TrivialFactor().rank().zscore(),
                'prepopulated': PREPOPULATED_TERM,
            },
            domain=US_EQUITIES,
        )
        start_date, end_date = self.trading_days[[-10, -1]]
        expected_chunks = [
            tuple(self.trading_days[[-10, -6]]),
            tuple(self.trading_days[[-5, -1]]),
        ]

        # First chunk should get prepopulated term in initial workspace.
        self.assertLess(expected_chunks[0][1], self.PREPOPULATED_TERM_CUTOFF)

        # Second chunk should have to compute PREPOPULATED_TERM explicitly.
        self.assertLess(expected_chunks[0][1], self.PREPOPULATED_TERM_CUTOFF)

        self.run_chunked_pipeline(
            pipeline=pipeline,
            start_date=start_date,
            end_date=end_date,
            chunksize=5,
            hooks=hooks,
        )

        self.verify_trace(
            publisher.trace,
            pipeline_start_date=start_date,
            pipeline_end_date=end_date,
            expected_chunks=expected_chunks,
        )
예제 #5
0
    def test_error_handling(self, chunked):
        publisher = TestingProgressPublisher()
        hooks = [ProgressHooks.with_static_publisher(publisher)]

        class SomeError(Exception):
            pass

        class ExplodingFactor(CustomFactor):
            inputs = [TestingDataSet.float_col]
            window_length = 1

            def compute(self, *args, **kwargs):
                raise SomeError()

        pipeline = Pipeline({'boom': ExplodingFactor()}, domain=US_EQUITIES)
        start_date, end_date = self.trading_days[[-10, -1]]

        with self.assertRaises(SomeError):
            if chunked:
                self.run_chunked_pipeline(
                    pipeline=pipeline,
                    start_date=start_date,
                    end_date=end_date,
                    chunksize=5,
                    hooks=hooks,
                )
            else:
                self.run_pipeline(
                    pipeline=pipeline,
                    start_date=start_date,
                    end_date=end_date,
                    hooks=hooks,
                )

        final_update = publisher.trace[-1]
        self.assertEqual(final_update.state, 'error')
예제 #6
0
import pytz

from zipline.data import bundles
from zipline.pipeline.data import EquityPricing
from zipline.pipeline.domain import (CN_EQUITIES, GENERIC, Domain,
                                     EquitySessionDomain)
from zipline.pipeline.engine import SimplePipelineEngine
from zipline.pipeline.hooks.progress import (IPythonWidgetProgressPublisher,
                                             ProgressHooks)
from zipline.pipeline.loaders import EquityPricingLoader
from zipline.pipeline.loaders.blaze import global_loader
from zipline.utils.memoize import remember_last
from zipline.utils.ts_utils import ensure_utc

publisher = IPythonWidgetProgressPublisher()
hooks = [ProgressHooks.with_static_publisher(publisher)]


TZ = 'Asia/Shanghai'


def create_domain(sessions,
                  data_query_time=time(0, 0, tzinfo=pytz.utc),
                  data_query_date_offset=0):
    if sessions.tz is None:
        sessions = sessions.tz_localize('UTC')

    return EquitySessionDomain(
        sessions,
        country_code='CN',
        data_query_time=data_query_time,