コード例 #1
0
    def init_class_fixtures(cls):
        super(SecurityListTestCase, cls).init_class_fixtures()
        # this is ugly, but we need to create two different
        # TradingEnvironment/DataPortal pairs

        cls.start = pd.Timestamp(list(LEVERAGED_ETFS.keys())[0])
        end = pd.Timestamp('2015-02-17', tz='utc')
        cls.extra_knowledge_date = pd.Timestamp('2015-01-27', tz='utc')
        cls.trading_day_before_first_kd = pd.Timestamp('2015-01-23', tz='utc')
        symbols = ['AAPL', 'GOOG', 'BZQ', 'URTY', 'JFT']

        cls.env = cls.enter_class_context(
            tmp_trading_env(
                equities=pd.DataFrame.from_records([{
                    'start_date': cls.start,
                    'end_date': end,
                    'symbol': symbol,
                    'exchange': "TEST",
                } for symbol in symbols]),
                load=cls.make_load_function(),
            ))

        cls.sim_params = factory.create_simulation_parameters(
            start=cls.start, num_days=4, trading_calendar=cls.trading_calendar)

        cls.sim_params2 = sp2 = factory.create_simulation_parameters(
            start=cls.trading_day_before_first_kd, num_days=4)

        cls.env2 = cls.enter_class_context(
            tmp_trading_env(
                equities=pd.DataFrame.from_records([{
                    'start_date': sp2.start_session,
                    'end_date': sp2.end_session,
                    'symbol': symbol,
                    'exchange': "TEST",
                } for symbol in symbols]),
                load=cls.make_load_function(),
            ))

        cls.tempdir = cls.enter_class_context(tmp_dir())
        cls.tempdir2 = cls.enter_class_context(tmp_dir())

        cls.data_portal = create_data_portal(
            asset_finder=cls.env.asset_finder,
            tempdir=cls.tempdir,
            sim_params=cls.sim_params,
            sids=range(0, 5),
            trading_calendar=cls.trading_calendar,
        )

        cls.data_portal2 = create_data_portal(
            asset_finder=cls.env2.asset_finder,
            tempdir=cls.tempdir2,
            sim_params=cls.sim_params2,
            sids=range(0, 5),
            trading_calendar=cls.trading_calendar,
        )
コード例 #2
0
ファイル: test_fetcher.py プロジェクト: zhoukalex/catalyst
    def test_fetcher_in_before_trading_start(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/fetcher_nflx_data.csv',
            body=NFLX_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2013-06-13", tz='UTC'),
            end=pd.Timestamp("2013-11-15", tz='UTC'),
            data_frequency="minute"
        )

        results = self.run_algo("""
from catalyst.api import fetch_csv, record, symbol

def initialize(context):
    fetch_csv('https://fake.urls.com/fetcher_nflx_data.csv',
               date_column = 'Settlement Date',
               date_format = '%m/%d/%y')
    context.stock = symbol('NFLX')

def before_trading_start(context, data):
    record(Short_Interest = data.current(context.stock, 'dtc'))
""", sim_params=sim_params, data_frequency="minute")

        values = results["Short_Interest"]
        np.testing.assert_array_equal(values[0:33], np.full(33, np.nan))
        np.testing.assert_array_almost_equal(values[33:44], [1.690317] * 11)
        np.testing.assert_array_almost_equal(values[44:55], [2.811858] * 11)
        np.testing.assert_array_almost_equal(values[55:64], [2.50233] * 9)
        np.testing.assert_array_almost_equal(values[64:75], [2.550829] * 11)
        np.testing.assert_array_almost_equal(values[75:], [2.64484] * 35)
コード例 #3
0
    def _test_before_trading_start(self, test_name, num_days, freq,
                                   emission_rate):
        params = factory.create_simulation_parameters(
            num_days=num_days,
            data_frequency=freq,
            emission_rate=emission_rate)

        def fake_benchmark(self, dt):
            return 0.01

        with patch.object(BenchmarkSource, "get_value",
                          self.fake_minutely_benchmark):
            algo = BeforeTradingAlgorithm(sim_params=params, env=self.env)
            algo.run(FakeDataPortal(self.env))

            self.assertEqual(len(algo.perf_tracker.sim_params.sessions),
                             num_days)

            bts_minutes = days_at_time(params.sessions, time(8, 45),
                                       "US/Eastern")

            self.assertTrue(
                bts_minutes.equals(pd.DatetimeIndex(algo.before_trading_at)),
                "Expected %s but was %s." %
                (params.sessions, algo.before_trading_at))
コード例 #4
0
ファイル: test_fetcher.py プロジェクト: zhoukalex/catalyst
    def test_fetcher_universe_non_security_return(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/bad_fetcher_universe_data.csv',
            body=NON_ASSET_FETCHER_UNIVERSE_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2006-01-09", tz='UTC'),
            end=pd.Timestamp("2006-01-10", tz='UTC')
        )

        self.run_algo(
            """
from catalyst.api import fetch_csv

def initialize(context):
    fetch_csv(
        'https://fake.urls.com/bad_fetcher_universe_data.csv',
        date_format='%m/%d/%Y'
    )

def handle_data(context, data):
    if len(data.fetcher_assets) > 0:
        raise Exception("Shouldn't be any assets in fetcher_assets!")
            """,
            sim_params=sim_params,
        )
コード例 #5
0
    def test_fetcher_universe_non_security_return(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/bad_fetcher_universe_data.csv',
            body=NON_ASSET_FETCHER_UNIVERSE_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2006-01-09", tz='UTC'),
            end=pd.Timestamp("2006-01-10", tz='UTC'))

        self.run_algo(
            """
from catalyst.api import fetch_csv

def initialize(context):
    fetch_csv(
        'https://fake.urls.com/bad_fetcher_universe_data.csv',
        date_format='%m/%d/%Y'
    )

def handle_data(context, data):
    if len(data.fetcher_assets) > 0:
        raise Exception("Shouldn't be any assets in fetcher_assets!")
            """,
            sim_params=sim_params,
        )
コード例 #6
0
ファイル: test_fetcher.py プロジェクト: zhoukalex/catalyst
    def test_minutely_fetcher(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/aapl_minute_csv_data.csv',
            body=AAPL_MINUTE_CSV_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2006-01-03", tz='UTC'),
            end=pd.Timestamp("2006-01-10", tz='UTC'),
            emission_rate="minute",
            data_frequency="minute"
        )

        test_algo = TradingAlgorithm(
            script="""
from catalyst.api import fetch_csv, record, sid

def initialize(context):
    fetch_csv('https://fake.urls.com/aapl_minute_csv_data.csv')

def handle_data(context, data):
    record(aapl_signal=data.current(sid(24), "signal"))
""", sim_params=sim_params, data_frequency="minute", env=self.env)

        # manually setting data portal and getting generator because we need
        # the minutely emission packets here.  TradingAlgorithm.run() only
        # returns daily packets.
        test_algo.data_portal = FetcherDataPortal(self.env,
                                                  self.trading_calendar)
        gen = test_algo.get_generator()
        perf_packets = list(gen)

        signal = [result["minute_perf"]["recorded_vars"]["aapl_signal"] for
                  result in perf_packets if "minute_perf" in result]

        self.assertEqual(6 * 390, len(signal))

        # csv data is:
        # symbol,date,signal
        # aapl,1/4/06 5:31AM, 1
        # aapl,1/4/06 11:30AM, 2
        # aapl,1/5/06 5:31AM, 1
        # aapl,1/5/06 11:30AM, 3
        # aapl,1/9/06 5:31AM, 1
        # aapl,1/9/06 11:30AM, 4 for dates 1/3 to 1/10

        # 2 signals per day, only last signal is taken. So we expect
        # 390 bars of signal NaN on 1/3
        # 390 bars of signal 2 on 1/4
        # 390 bars of signal 3 on 1/5
        # 390 bars of signal 3 on 1/6 (forward filled)
        # 390 bars of signal 4 on 1/9
        # 390 bars of signal 4 on 1/9 (forward filled)

        np.testing.assert_array_equal([np.NaN] * 390, signal[0:390])
        np.testing.assert_array_equal([2] * 390, signal[390:780])
        np.testing.assert_array_equal([3] * 780, signal[780:1560])
        np.testing.assert_array_equal([4] * 780, signal[1560:])
コード例 #7
0
    def test_fetcher_in_before_trading_start(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/fetcher_nflx_data.csv',
            body=NFLX_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2013-06-13", tz='UTC'),
            end=pd.Timestamp("2013-11-15", tz='UTC'),
            data_frequency="minute")

        results = self.run_algo("""
from catalyst.api import fetch_csv, record, symbol

def initialize(context):
    fetch_csv('https://fake.urls.com/fetcher_nflx_data.csv',
               date_column = 'Settlement Date',
               date_format = '%m/%d/%y')
    context.stock = symbol('NFLX')

def before_trading_start(context, data):
    record(Short_Interest = data.current(context.stock, 'dtc'))
""",
                                sim_params=sim_params,
                                data_frequency="minute")

        values = results["Short_Interest"]
        np.testing.assert_array_equal(values[0:33], np.full(33, np.nan))
        np.testing.assert_array_almost_equal(values[33:44], [1.690317] * 11)
        np.testing.assert_array_almost_equal(values[44:55], [2.811858] * 11)
        np.testing.assert_array_almost_equal(values[55:64], [2.50233] * 9)
        np.testing.assert_array_almost_equal(values[64:75], [2.550829] * 11)
        np.testing.assert_array_almost_equal(values[75:], [2.64484] * 35)
コード例 #8
0
    def test_algo_without_rl_violation_after_delete(self):
        sim_params = factory.create_simulation_parameters(
            start=self.extra_knowledge_date,
            num_days=4,
        )
        equities = pd.DataFrame.from_records([{
            'symbol': 'BZQ',
            'start_date': sim_params.start_session,
            'end_date': sim_params.end_session,
            'exchange': "TEST",
        }])
        with TempDirectory() as new_tempdir, \
                security_list_copy(), \
                tmp_trading_env(equities=equities,
                                load=self.make_load_function()) as env:
            # add a delete statement removing bzq
            # write a new delete statement file to disk
            add_security_data([], ['BZQ'])

            data_portal = create_data_portal(
                env.asset_finder,
                new_tempdir,
                sim_params,
                range(0, 5),
                trading_calendar=self.trading_calendar,
            )

            algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
                                              sim_params=sim_params,
                                              env=env)
            algo.run(data_portal)
コード例 #9
0
    def test_fetcher_bad_data(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/fetcher_nflx_data.csv',
            body=NFLX_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2013-06-12", tz='UTC'),
            end=pd.Timestamp("2013-06-14", tz='UTC'),
            data_frequency="minute")

        results = self.run_algo("""
from catalyst.api import fetch_csv, symbol
import numpy as np

def initialize(context):
    fetch_csv('https://fake.urls.com/fetcher_nflx_data.csv',
               date_column = 'Settlement Date',
               date_format = '%m/%d/%y')
    context.nflx = symbol('NFLX')
    context.aapl = symbol('AAPL')

def handle_data(context, data):
    assert np.isnan(data.current(context.nflx, 'invalid_column'))
    assert np.isnan(data.current(context.aapl, 'invalid_column'))
    assert np.isnan(data.current(context.aapl, 'dtc'))
""",
                                sim_params=sim_params,
                                data_frequency="minute")

        self.assertEqual(3, len(results))
コード例 #10
0
ファイル: test_fetcher.py プロジェクト: zhoukalex/catalyst
    def test_fetcher_bad_data(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/fetcher_nflx_data.csv',
            body=NFLX_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2013-06-12", tz='UTC'),
            end=pd.Timestamp("2013-06-14", tz='UTC'),
            data_frequency="minute"
        )

        results = self.run_algo("""
from catalyst.api import fetch_csv, symbol
import numpy as np

def initialize(context):
    fetch_csv('https://fake.urls.com/fetcher_nflx_data.csv',
               date_column = 'Settlement Date',
               date_format = '%m/%d/%y')
    context.nflx = symbol('NFLX')
    context.aapl = symbol('AAPL')

def handle_data(context, data):
    assert np.isnan(data.current(context.nflx, 'invalid_column'))
    assert np.isnan(data.current(context.aapl, 'invalid_column'))
    assert np.isnan(data.current(context.aapl, 'dtc'))
""", sim_params=sim_params, data_frequency="minute")

        self.assertEqual(3, len(results))
コード例 #11
0
    def test_before_trading_start(self, test_name, num_days, freq,
                                  emission_rate):
        params = factory.create_simulation_parameters(
            num_days=num_days, data_frequency=freq,
            emission_rate=emission_rate)

        def fake_benchmark(self, dt):
            return 0.01

        with patch.object(BenchmarkSource, "get_value",
                          self.fake_minutely_benchmark):
            algo = BeforeTradingAlgorithm(sim_params=params, env=self.env)
            algo.run(FakeDataPortal(self.env))

            self.assertEqual(
                len(algo.perf_tracker.sim_params.sessions),
                num_days
            )

            bts_minutes = days_at_time(
                params.sessions, time(8, 45), "US/Eastern"
            )

            self.assertTrue(
                bts_minutes.equals(
                    pd.DatetimeIndex(algo.before_trading_at)
                ),
                "Expected %s but was %s." % (params.sessions,
                                             algo.before_trading_at))
コード例 #12
0
 def test_minutely_emissions_generate_performance_stats_for_last_day(self):
     params = factory.create_simulation_parameters(num_days=1,
                                                   data_frequency='minute',
                                                   emission_rate='minute')
     with patch.object(BenchmarkSource, "get_value",
                       self.fake_minutely_benchmark):
         algo = NoopAlgorithm(sim_params=params, env=self.env)
         algo.run(FakeDataPortal(self.env))
         self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1)
コード例 #13
0
 def test_minutely_emissions_generate_performance_stats_for_last_day(self):
     params = factory.create_simulation_parameters(num_days=1,
                                                   data_frequency='minute',
                                                   emission_rate='minute')
     with patch.object(BenchmarkSource, "get_value",
                       self.fake_minutely_benchmark):
         algo = NoopAlgorithm(sim_params=params, env=self.env)
         algo.run(FakeDataPortal(self.env))
         self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1)
コード例 #14
0
ファイル: test_fetcher.py プロジェクト: zhoukalex/catalyst
    def test_fetcher_universe(self, name, data, column_name):
        # Patching fetch_url here rather than using responses because (a) it's
        # easier given the paramaterization, and (b) there are enough tests
        # using responses that the fetch_url code is getting a good workout so
        # we don't have to use it in every test.
        with patch('catalyst.sources.requests_csv.PandasRequestsCSV.fetch_url',
                   new=lambda *a, **k: data):
            sim_params = factory.create_simulation_parameters(
                start=pd.Timestamp("2006-01-09", tz='UTC'),
                end=pd.Timestamp("2006-01-11", tz='UTC')
            )

            algocode = """
from pandas import Timestamp
from pandas.tseries.tools import normalize_date
from catalyst.api import fetch_csv, record, sid, get_datetime

def initialize(context):
    fetch_csv(
        'https://dl.dropbox.com/u/16705795/dtoc_history.csv',
        date_format='%m/%d/%Y'{token}
    )
    context.expected_sids = {{
        Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
    }}
    context.bar_count = 0

def handle_data(context, data):
    expected = context.expected_sids[normalize_date(get_datetime())]
    actual = data.fetcher_assets
    for stk in expected:
        if stk not in actual:
            raise Exception(
                "{{stk}} is missing on dt={{dt}}".format(
                    stk=stk, dt=get_datetime()))

    record(sid_count=len(actual))
    record(bar_count=context.bar_count)
    context.bar_count += 1
            """
            replacement = ""
            if column_name:
                replacement = ",symbol_column='%s'\n" % column_name
            real_algocode = algocode.format(token=replacement)

            results = self.run_algo(real_algocode, sim_params=sim_params)

            self.assertEqual(len(results), 3)
            self.assertEqual(3, results["sid_count"].iloc[0])
            self.assertEqual(3, results["sid_count"].iloc[1])
            self.assertEqual(4, results["sid_count"].iloc[2])
コード例 #15
0
    def test_fetcher_universe(self, name, data, column_name):
        # Patching fetch_url here rather than using responses because (a) it's
        # easier given the paramaterization, and (b) there are enough tests
        # using responses that the fetch_url code is getting a good workout so
        # we don't have to use it in every test.
        with patch('catalyst.sources.requests_csv.PandasRequestsCSV.fetch_url',
                   new=lambda *a, **k: data):
            sim_params = factory.create_simulation_parameters(
                start=pd.Timestamp("2006-01-09", tz='UTC'),
                end=pd.Timestamp("2006-01-11", tz='UTC'))

            algocode = """
from pandas import Timestamp
from catalyst.api import fetch_csv, record, sid, get_datetime
from catalyst.utils.compat import normalize_date

def initialize(context):
    fetch_csv(
        'https://dl.dropbox.com/u/16705795/dtoc_history.csv',
        date_format='%m/%d/%Y'{token}
    )
    context.expected_sids = {{
        Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
    }}
    context.bar_count = 0

def handle_data(context, data):
    expected = context.expected_sids[normalize_date(get_datetime())]
    actual = data.fetcher_assets
    for stk in expected:
        if stk not in actual:
            raise Exception(
                "{{stk}} is missing on dt={{dt}}".format(
                    stk=stk, dt=get_datetime()))

    record(sid_count=len(actual))
    record(bar_count=context.bar_count)
    context.bar_count += 1
            """
            replacement = ""
            if column_name:
                replacement = ",symbol_column='%s'\n" % column_name
            real_algocode = algocode.format(token=replacement)

            results = self.run_algo(real_algocode, sim_params=sim_params)

            self.assertEqual(len(results), 3)
            self.assertEqual(3, results["sid_count"].iloc[0])
            self.assertEqual(3, results["sid_count"].iloc[1])
            self.assertEqual(4, results["sid_count"].iloc[2])
コード例 #16
0
ファイル: test_fetcher.py プロジェクト: zhoukalex/catalyst
    def test_fetcher_universe_minute(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/fetcher_universe_data.csv',
            body=FETCHER_UNIVERSE_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2006-01-09", tz='UTC'),
            end=pd.Timestamp("2006-01-11", tz='UTC'),
            data_frequency="minute"
        )

        results = self.run_algo(
            """
from pandas import Timestamp
from catalyst.api import fetch_csv, record, get_datetime

def initialize(context):
    fetch_csv(
        'https://fake.urls.com/fetcher_universe_data.csv',
        date_format='%m/%d/%Y'
    )
    context.expected_sids = {
        Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
    }
    context.bar_count = 0

def handle_data(context, data):
    expected = context.expected_sids[get_datetime().replace(hour=0, minute=0)]
    actual = data.fetcher_assets
    for stk in expected:
        if stk not in actual:
            raise Exception("{stk} is missing".format(stk=stk))

    record(sid_count=len(actual))
    record(bar_count=context.bar_count)
    context.bar_count += 1
        """, sim_params=sim_params, data_frequency="minute"
        )

        self.assertEqual(3, len(results))
        self.assertEqual(3, results["sid_count"].iloc[0])
        self.assertEqual(3, results["sid_count"].iloc[1])
        self.assertEqual(4, results["sid_count"].iloc[2])
コード例 #17
0
    def test_fetcher_universe_minute(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/fetcher_universe_data.csv',
            body=FETCHER_UNIVERSE_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2006-01-09", tz='UTC'),
            end=pd.Timestamp("2006-01-11", tz='UTC'),
            data_frequency="minute")

        results = self.run_algo("""
from pandas import Timestamp
from catalyst.api import fetch_csv, record, get_datetime

def initialize(context):
    fetch_csv(
        'https://fake.urls.com/fetcher_universe_data.csv',
        date_format='%m/%d/%Y'
    )
    context.expected_sids = {
        Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
        Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
    }
    context.bar_count = 0

def handle_data(context, data):
    expected = context.expected_sids[get_datetime().replace(hour=0, minute=0)]
    actual = data.fetcher_assets
    for stk in expected:
        if stk not in actual:
            raise Exception("{stk} is missing".format(stk=stk))

    record(sid_count=len(actual))
    record(bar_count=context.bar_count)
    context.bar_count += 1
        """,
                                sim_params=sim_params,
                                data_frequency="minute")

        self.assertEqual(3, len(results))
        self.assertEqual(3, results["sid_count"].iloc[0])
        self.assertEqual(3, results["sid_count"].iloc[1])
        self.assertEqual(4, results["sid_count"].iloc[2])
コード例 #18
0
    def test_algo_with_rl_violation_cumulative(self):
        """
        Add a new restriction, run a test long after both
        knowledge dates, make sure stock from original restriction
        set is still disallowed.
        """
        sim_params = factory.create_simulation_parameters(start=self.start +
                                                          timedelta(days=7),
                                                          num_days=4)

        with security_list_copy():
            add_security_data(['AAPL'], [])
            algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
                                              sim_params=sim_params,
                                              env=self.env)
            with self.assertRaises(TradingControlViolation) as ctx:
                algo.run(self.data_portal)

            self.check_algo_exception(algo, ctx, 0)
コード例 #19
0
    def test_algo_with_rl_violation_after_knowledge_date(self):
        sim_params = factory.create_simulation_parameters(start=self.start +
                                                          timedelta(days=7),
                                                          num_days=5)

        data_portal = create_data_portal(
            self.env.asset_finder,
            self.tempdir,
            sim_params=sim_params,
            sids=range(0, 5),
            trading_calendar=self.trading_calendar,
        )

        algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
                                          sim_params=sim_params,
                                          env=self.env)
        with self.assertRaises(TradingControlViolation) as ctx:
            algo.run(data_portal)

        self.check_algo_exception(algo, ctx, 0)
コード例 #20
0
    def transaction_sim(self, **params):
        """This is a utility method that asserts expected
        results for conversion of orders to transactions given a
        trade history
        """
        trade_count = params['trade_count']
        trade_interval = params['trade_interval']
        order_count = params['order_count']
        order_amount = params['order_amount']
        order_interval = params['order_interval']
        expected_txn_count = params['expected_txn_count']
        expected_txn_volume = params['expected_txn_volume']

        # optional parameters
        # ---------------------
        # if present, alternate between long and short sales
        alternate = params.get('alternate')

        # if present, expect transaction amounts to match orders exactly.
        complete_fill = params.get('complete_fill')

        asset1 = self.asset_finder.retrieve_asset(1)
        metadata = make_simple_equity_info([asset1.sid], self.start, self.end)
        with TempDirectory() as tempdir, \
                tmp_trading_env(equities=metadata,
                                load=self.make_load_function()) as env:

            if trade_interval < timedelta(days=1):
                sim_params = factory.create_simulation_parameters(
                    start=self.start,
                    end=self.end,
                    data_frequency="minute"
                )

                minutes = self.trading_calendar.minutes_window(
                    sim_params.first_open,
                    int((trade_interval.total_seconds() / 60) * trade_count)
                    + 100)

                price_data = np.array([10.1] * len(minutes))
                assets = {
                    asset1.sid: pd.DataFrame({
                        "open": price_data,
                        "high": price_data,
                        "low": price_data,
                        "close": price_data,
                        "volume": np.array([100] * len(minutes)),
                        "dt": minutes
                    }).set_index("dt")
                }

                write_bcolz_minute_data(
                    self.trading_calendar,
                    self.trading_calendar.sessions_in_range(
                        self.trading_calendar.minute_to_session_label(
                            minutes[0]
                        ),
                        self.trading_calendar.minute_to_session_label(
                            minutes[-1]
                        )
                    ),
                    tempdir.path,
                    iteritems(assets),
                )

                equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

                data_portal = DataPortal(
                    env.asset_finder, self.trading_calendar,
                    first_trading_day=equity_minute_reader.first_trading_day,
                    minute_reader=equity_minute_reader,
                )
            else:
                sim_params = factory.create_simulation_parameters(
                    data_frequency="daily"
                )

                days = sim_params.sessions

                assets = {
                    1: pd.DataFrame({
                        "open": [10.1] * len(days),
                        "high": [10.1] * len(days),
                        "low": [10.1] * len(days),
                        "close": [10.1] * len(days),
                        "volume": [100] * len(days),
                        "day": [day.value for day in days]
                    }, index=days)
                }

                path = os.path.join(tempdir.path, "testdata.bcolz")
                BcolzDailyBarWriter(path, self.trading_calendar, days[0],
                                    days[-1]).write(
                    assets.items()
                )

                equity_daily_reader = BcolzDailyBarReader(path)

                data_portal = DataPortal(
                    env.asset_finder, self.trading_calendar,
                    first_trading_day=equity_daily_reader.first_trading_day,
                    daily_reader=equity_daily_reader,
                )

            if "default_slippage" not in params or \
               not params["default_slippage"]:
                slippage_func = FixedSlippage()
            else:
                slippage_func = None

            blotter = Blotter(sim_params.data_frequency, slippage_func)

            start_date = sim_params.first_open

            if alternate:
                alternator = -1
            else:
                alternator = 1

            tracker = PerformanceTracker(sim_params, self.trading_calendar,
                                         self.env)

            # replicate what tradesim does by going through every minute or day
            # of the simulation and processing open orders each time
            if sim_params.data_frequency == "minute":
                ticks = minutes
            else:
                ticks = days

            transactions = []

            order_list = []
            order_date = start_date
            for tick in ticks:
                blotter.current_dt = tick
                if tick >= order_date and len(order_list) < order_count:
                    # place an order
                    direction = alternator ** len(order_list)
                    order_id = blotter.order(
                        asset1,
                        order_amount * direction,
                        MarketOrder())
                    order_list.append(blotter.orders[order_id])
                    order_date = order_date + order_interval
                    # move after market orders to just after market next
                    # market open.
                    if order_date.hour >= 21:
                        if order_date.minute >= 00:
                            order_date = order_date + timedelta(days=1)
                            order_date = order_date.replace(hour=14, minute=30)
                else:
                    bar_data = BarData(
                        data_portal=data_portal,
                        simulation_dt_func=lambda: tick,
                        data_frequency=sim_params.data_frequency,
                        trading_calendar=self.trading_calendar,
                        restrictions=NoRestrictions(),
                    )
                    txns, _, closed_orders = blotter.get_transactions(bar_data)
                    for txn in txns:
                        tracker.process_transaction(txn)
                        transactions.append(txn)

                    blotter.prune_orders(closed_orders)

            for i in range(order_count):
                order = order_list[i]
                self.assertEqual(order.asset, asset1)
                self.assertEqual(order.amount, order_amount * alternator ** i)

            if complete_fill:
                self.assertEqual(len(transactions), len(order_list))

            total_volume = 0
            for i in range(len(transactions)):
                txn = transactions[i]
                total_volume += txn.amount
                if complete_fill:
                    order = order_list[i]
                    self.assertEqual(order.amount, txn.amount)

            self.assertEqual(total_volume, expected_txn_volume)

            self.assertEqual(len(transactions), expected_txn_count)

            cumulative_pos = tracker.position_tracker.positions[asset1]
            if total_volume == 0:
                self.assertIsNone(cumulative_pos)
            else:
                self.assertEqual(total_volume, cumulative_pos.amount)

            # the open orders should not contain the asset.
            oo = blotter.open_orders
            self.assertNotIn(
                asset1,
                oo,
                "Entry is removed when no open orders"
            )
コード例 #21
0
def create_test_catalyst(**config):
    """
       :param config: A configuration object that is a dict with:

           - sid - an integer, which will be used as the asset ID.
           - order_count - the number of orders the test algo will place,
             defaults to 100
           - order_amount - the number of shares per order, defaults to 100
           - trade_count - the number of trades to simulate, defaults to 101
             to ensure all orders are processed.
           - algorithm - optional parameter providing an algorithm. defaults
             to :py:class:`catalyst.test.algorithms.TestAlgorithm`
           - trade_source - optional parameter to specify trades, if present.
             If not present :py:class:`catalyst.sources.SpecificEquityTrades`
             is the source, with daily frequency in trades.
           - slippage: optional parameter that configures the
             :py:class:`catalyst.gens.tradingsimulation.TransactionSimulator`.
             Expects an object with a simulate mehod, such as
             :py:class:`catalyst.gens.tradingsimulation.FixedSlippage`.
             :py:mod:`catalyst.finance.trading`
       """
    assert isinstance(config, dict)

    try:
        sid_list = config['sid_list']
    except KeyError:
        try:
            sid_list = [config['sid']]
        except KeyError:
            raise Exception("simfactory create_test_catalyst() requires "
                            "argument 'sid_list' or 'sid'")

    concurrent_trades = config.get('concurrent_trades', False)
    order_count = config.get('order_count', 100)
    order_amount = config.get('order_amount', 100)
    trading_calendar = config.get('trading_calendar', get_calendar("NYSE"))

    # -------------------
    # Create the Algo
    # -------------------
    if 'algorithm' in config:
        test_algo = config['algorithm']
    else:
        test_algo = TestAlgorithm(sid_list[0],
                                  order_amount,
                                  order_count,
                                  sim_params=config.get(
                                      'sim_params',
                                      factory.create_simulation_parameters()),
                                  trading_calendar=trading_calendar,
                                  slippage=config.get('slippage'),
                                  identifiers=sid_list)

    # -------------------
    # Trade Source
    # -------------------
    if 'skip_data' not in config:
        if 'trade_source' in config:
            trade_source = config['trade_source']
        else:
            trade_source = factory.create_daily_trade_source(
                sid_list,
                test_algo.sim_params,
                test_algo.trading_environment,
                trading_calendar,
                concurrent=concurrent_trades,
            )

        trades_by_sid = {}
        for trade in trade_source:
            if trade.sid not in trades_by_sid:
                trades_by_sid[trade.sid] = []

            trades_by_sid[trade.sid].append(trade)

        data_portal = create_data_portal_from_trade_history(
            config['env'].asset_finder, trading_calendar, config['tempdir'],
            config['sim_params'], trades_by_sid)

        test_algo.data_portal = data_portal

    # -------------------
    # Benchmark source
    # -------------------

    test_algo.benchmark_return_source = config.get('benchmark_source', None)

    # ------------------
    # generator/simulator
    sim = test_algo.get_generator()

    return sim
コード例 #22
0
    def test_minutely_fetcher(self):
        self.responses.add(
            self.responses.GET,
            'https://fake.urls.com/aapl_minute_csv_data.csv',
            body=AAPL_MINUTE_CSV_DATA,
            content_type='text/csv',
        )

        sim_params = factory.create_simulation_parameters(
            start=pd.Timestamp("2006-01-03", tz='UTC'),
            end=pd.Timestamp("2006-01-10", tz='UTC'),
            emission_rate="minute",
            data_frequency="minute")

        test_algo = TradingAlgorithm(script="""
from catalyst.api import fetch_csv, record, sid

def initialize(context):
    fetch_csv('https://fake.urls.com/aapl_minute_csv_data.csv')

def handle_data(context, data):
    record(aapl_signal=data.current(sid(24), "signal"))
""",
                                     sim_params=sim_params,
                                     data_frequency="minute",
                                     env=self.env)

        # manually setting data portal and getting generator because we need
        # the minutely emission packets here.  TradingAlgorithm.run() only
        # returns daily packets.
        test_algo.data_portal = FetcherDataPortal(self.env,
                                                  self.trading_calendar)
        gen = test_algo.get_generator()
        perf_packets = list(gen)

        signal = [
            result["minute_perf"]["recorded_vars"]["aapl_signal"]
            for result in perf_packets if "minute_perf" in result
        ]

        self.assertEqual(6 * 390, len(signal))

        # csv data is:
        # symbol,date,signal
        # aapl,1/4/06 5:31AM, 1
        # aapl,1/4/06 11:30AM, 2
        # aapl,1/5/06 5:31AM, 1
        # aapl,1/5/06 11:30AM, 3
        # aapl,1/9/06 5:31AM, 1
        # aapl,1/9/06 11:30AM, 4 for dates 1/3 to 1/10

        # 2 signals per day, only last signal is taken. So we expect
        # 390 bars of signal NaN on 1/3
        # 390 bars of signal 2 on 1/4
        # 390 bars of signal 3 on 1/5
        # 390 bars of signal 3 on 1/6 (forward filled)
        # 390 bars of signal 4 on 1/9
        # 390 bars of signal 4 on 1/9 (forward filled)

        np.testing.assert_array_equal([np.NaN] * 390, signal[0:390])
        np.testing.assert_array_equal([2] * 390, signal[390:780])
        np.testing.assert_array_equal([3] * 780, signal[780:1560])
        np.testing.assert_array_equal([4] * 780, signal[1560:])
コード例 #23
0
ファイル: run_algo.py プロジェクト: dyangelo-grullon/catalyst
def _build_algo_and_data(handle_data, initialize, before_trading_start,
                         analyze, algofile, algotext, defines, data_frequency,
                         capital_base, data, bundle, bundle_timestamp, start,
                         end, output, print_algo, local_namespace, environ,
                         live, exchange, algo_namespace, base_currency,
                         live_graph, analyze_live, simulate_orders,
                         stats_output):
    namespace = _build_namespace(algotext, local_namespace, defines)
    if algotext is not None:
        algotext = algofile.read()

    if print_algo:
        _pretty_print_code(algotext)

    mode = _mode(simulate_orders, live)
    log.info('running algo in {mode} mode'.format(mode=mode))

    exchanges = _build_exchanges_dict(exchange, live, simulate_orders,
                                      base_currency)

    open_calendar = get_calendar('OPEN')

    env = TradingEnvironment(
        load=partial(load_crypto_market_data,
                     environ=environ,
                     start_dt=start,
                     end_dt=end),
        environ=environ,
        exchange_tz='UTC',
        asset_db_path=None)  # We don't need an asset db, we have exchanges

    env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)

    choose_loader = partial(_choose_loader, data_frequency)

    if live:
        start, end = _get_live_time_range()
        data_frequency = 'minute'  # TODO double check if this is the desired behavior

    sim_params = create_simulation_parameters(start=start,
                                              end=end,
                                              capital_base=capital_base,
                                              emission_rate=data_frequency,
                                              data_frequency=data_frequency)

    if algotext is None:
        algorithm_class_kwargs = {
            'initialize': initialize,
            'handle_data': handle_data,
            'before_trading_start': before_trading_start,
            'analyze': analyze
        }
    else:
        algorithm_class_kwargs = {
            'algo_filename': getattr(algofile, 'name', '<algorithm>'),
            'script': algotext
        }

    if live:
        return _build_live_algo_and_data(
            sim_params, exchanges, env, open_calendar, simulate_orders,
            algo_namespace, capital_base, live_graph, stats_output,
            analyze_live, base_currency, namespace, choose_loader,
            algorithm_class_kwargs)
    else:
        return _build_backtest_algo_and_data(exchanges, bundle, env, environ,
                                             bundle_timestamp, open_calendar,
                                             start, end, namespace,
                                             choose_loader, sim_params,
                                             algorithm_class_kwargs)
コード例 #24
0
def _run(handle_data,
         initialize,
         before_trading_start,
         analyze,
         algofile,
         algotext,
         defines,
         data_frequency,
         capital_base,
         data,
         bundle,
         bundle_timestamp,
         start,
         end,
         output,
         print_algo,
         local_namespace,
         environ,
         live,
         exchange,
         algo_namespace,
         quote_currency,
         live_graph,
         analyze_live,
         simulate_orders,
         auth_aliases,
         stats_output):
    """Run a backtest for the given algorithm.

    This is shared between the cli and :func:`catalyst.run_algo`.
    """
    # TODO: refactor for more granularity
    if algotext is not None:
        if local_namespace:
            ip = get_ipython()  # noqa
            namespace = ip.user_ns
        else:
            namespace = {}

        for assign in defines:
            try:
                name, value = assign.split('=', 2)
            except ValueError:
                raise ValueError(
                    'invalid define %r, should be of the form name=value' %
                    assign,
                )
            try:
                # evaluate in the same namespace so names may refer to
                # eachother
                namespace[name] = eval(value, namespace)
            except Exception as e:
                raise ValueError(
                    'failed to execute definition for name %r: %s' % (name, e),
                )
    elif defines:
        raise _RunAlgoError(
            'cannot pass define without `algotext`',
            "cannot pass '-D' / '--define' without '-t' / '--algotext'",
        )
    else:
        namespace = {}
        if algofile is not None:
            algotext = algofile.read()

    if print_algo:
        if PYGMENTS:
            highlight(
                algotext,
                PythonLexer(),
                TerminalFormatter(),
                outfile=sys.stdout,
            )
        else:
            click.echo(algotext)

    log.info('Catalyst version {}'.format(catalyst.__version__))
    if not DISABLE_ALPHA_WARNING:
        log.warn(ALPHA_WARNING_MESSAGE)
        # sleep(3)

    if live:
        if simulate_orders:
            mode = 'paper-trading'
        else:
            mode = 'live-trading'
    else:
        mode = 'backtest'

    log.info('running algo in {mode} mode'.format(mode=mode))

    exchange_name = exchange
    if exchange_name is None:
        raise ValueError('Please specify at least one exchange.')

    if isinstance(auth_aliases, string_types):
        aliases = auth_aliases.split(',')
        if len(aliases) < 2 or len(aliases) % 2 != 0:
            raise ValueError(
                'the `auth_aliases` parameter must contain an even list '
                'of comma-delimited values. For example, '
                '"binance,auth2" or "binance,auth2,bittrex,auth2".'
            )

        auth_aliases = dict(zip(aliases[::2], aliases[1::2]))

    exchange_list = [x.strip().lower() for x in exchange.split(',')]
    exchanges = dict()
    for name in exchange_list:
        if auth_aliases is not None and name in auth_aliases:
            auth_alias = auth_aliases[name]
        else:
            auth_alias = None

        exchanges[name] = get_exchange(
            exchange_name=name,
            quote_currency=quote_currency,
            must_authenticate=(live and not simulate_orders),
            skip_init=True,
            auth_alias=auth_alias,
        )

    open_calendar = get_calendar('OPEN')

    env = TradingEnvironment(
        load=partial(
            load_crypto_market_data,
            environ=environ,
            start_dt=start,
            end_dt=end
        ),
        environ=environ,
        exchange_tz='UTC',
        asset_db_path=None  # We don't need an asset db, we have exchanges
    )
    env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)

    def choose_loader(column):
        bound_cols = TradingPairPricing.columns
        if column in bound_cols:
            return ExchangePricingLoader(data_frequency)
        raise ValueError(
            "No PipelineLoader registered for column %s." % column
        )

    if live:
        # TODO: fix the start data.
        # is_start checks if a start date was specified by user
        # needed for live clock
        is_start = True

        if start is None:
            start = pd.Timestamp.utcnow()
            is_start = False
        elif start:
            assert pd.Timestamp.utcnow() <= start, \
                "specified start date is in the past."
        elif start and end:
            assert start < end, "start date is later than end date."

        # TODO: fix the end data.
        # is_end checks if an end date was specified by user
        # needed for live clock
        is_end = True

        if end is None:
            end = start + timedelta(hours=8760)
            is_end = False

        data = DataPortalExchangeLive(
            exchanges=exchanges,
            asset_finder=env.asset_finder,
            trading_calendar=open_calendar,
            first_trading_day=pd.to_datetime('today', utc=True)
        )

        sim_params = create_simulation_parameters(
            start=start,
            end=end,
            capital_base=capital_base,
            emission_rate='minute',
            data_frequency='minute'
        )

        # TODO: use the constructor instead
        sim_params._arena = 'live'

        algorithm_class = partial(
            ExchangeTradingAlgorithmLive,
            exchanges=exchanges,
            algo_namespace=algo_namespace,
            live_graph=live_graph,
            simulate_orders=simulate_orders,
            stats_output=stats_output,
            analyze_live=analyze_live,
            start=start,
            is_start=is_start,
            end=end,
            is_end=is_end,
        )
    elif exchanges:
        # Removed the existing Poloniex fork to keep things simple
        # We can add back the complexity if required.

        # I don't think that we should have arbitrary price data bundles
        # Instead, we should center this data around exchanges.
        # We still need to support bundles for other misc data, but we
        # can handle this later.

        if (start and start != pd.tslib.normalize_date(start)) or \
                (end and end != pd.tslib.normalize_date(end)):
            # todo: add to Sim_Params the option to
            # start & end at specific times
            log.warn(
                "Catalyst currently starts and ends on the start and "
                "end of the dates specified, respectively. We hope to "
                "Modify this and support specific times in a future release."
            )

        data = DataPortalExchangeBacktest(
            exchange_names=[ex_name for ex_name in exchanges],
            asset_finder=None,
            trading_calendar=open_calendar,
            first_trading_day=start,
            last_available_session=end
        )

        sim_params = create_simulation_parameters(
            start=start,
            end=end,
            capital_base=capital_base,
            data_frequency=data_frequency,
            emission_rate=data_frequency,
        )

        algorithm_class = partial(
            ExchangeTradingAlgorithmBacktest,
            exchanges=exchanges
        )

    elif bundle is not None:
        bundle_data = load(
            bundle,
            environ,
            bundle_timestamp,
        )

        prefix, connstr = re.split(
            r'sqlite:///',
            str(bundle_data.asset_finder.engine.url),
            maxsplit=1,
        )
        if prefix:
            raise ValueError(
                "invalid url %r, must begin with 'sqlite:///'" %
                str(bundle_data.asset_finder.engine.url),
            )

        env = TradingEnvironment(asset_db_path=connstr, environ=environ)
        first_trading_day = \
            bundle_data.equity_minute_bar_reader.first_trading_day

        data = DataPortal(
            env.asset_finder, open_calendar,
            first_trading_day=first_trading_day,
            equity_minute_reader=bundle_data.equity_minute_bar_reader,
            equity_daily_reader=bundle_data.equity_daily_bar_reader,
            adjustment_reader=bundle_data.adjustment_reader,
        )

    perf = algorithm_class(
        namespace=namespace,
        env=env,
        get_pipeline_loader=choose_loader,
        sim_params=sim_params,
        **{
            'initialize': initialize,
            'handle_data': handle_data,
            'before_trading_start': before_trading_start,
            'analyze': analyze,
        } if algotext is None else {
            'algo_filename': getattr(algofile, 'name', '<algorithm>'),
            'script': algotext,
        }
    ).run(
        data,
        overwrite_sim_params=False,
    )

    if output == '-':
        click.echo(str(perf))
    elif output != os.devnull:  # make the catalyst magic not write any data
        perf.to_pickle(output)

    return perf
コード例 #25
0
def _run(handle_data, initialize, before_trading_start, analyze, algofile,
         algotext, defines, data_frequency, capital_base, data, bundle,
         bundle_timestamp, start, end, output, print_algo, local_namespace,
         environ):
    """Run a backtest for the given algorithm.

    This is shared between the cli and :func:`catalyst.run_algo`.
    """
    if algotext is not None:
        if local_namespace:
            ip = get_ipython()  # noqa
            namespace = ip.user_ns
        else:
            namespace = {}

        for assign in defines:
            try:
                name, value = assign.split('=', 2)
            except ValueError:
                raise ValueError(
                    'invalid define %r, should be of the form name=value' %
                    assign, )
            try:
                # evaluate in the same namespace so names may refer to
                # eachother
                namespace[name] = eval(value, namespace)
            except Exception as e:
                raise ValueError(
                    'failed to execute definition for name %r: %s' %
                    (name, e), )
    elif defines:
        raise _RunAlgoError(
            'cannot pass define without `algotext`',
            "cannot pass '-D' / '--define' without '-t' / '--algotext'",
        )
    else:
        namespace = {}
        if algofile is not None:
            algotext = algofile.read()

    if print_algo:
        if PYGMENTS:
            highlight(
                algotext,
                PythonLexer(),
                TerminalFormatter(),
                outfile=sys.stdout,
            )
        else:
            click.echo(algotext)

    if bundle is not None:
        bundles = bundle.split(',')

        def get_trading_env_and_data(bundles):
            env = data = None

            b = 'poloniex'
            if len(bundles) == 0:
                return env, data
            elif len(bundles) == 1:
                b = bundles[0]

            bundle_data = load(
                b,
                environ,
                bundle_timestamp,
            )

            prefix, connstr = re.split(
                r'sqlite:///',
                str(bundle_data.asset_finder.engine.url),
                maxsplit=1,
            )
            if prefix:
                raise ValueError(
                    "invalid url %r, must begin with 'sqlite:///'" %
                    str(bundle_data.asset_finder.engine.url), )

            open_calendar = get_calendar('OPEN')

            env = TradingEnvironment(
                load=partial(load_crypto_market_data, environ=environ),
                bm_symbol='USDT_BTC',
                trading_calendar=open_calendar,
                asset_db_path=connstr,
                environ=environ,
            )

            first_trading_day = bundle_data.minute_bar_reader.first_trading_day

            data = DataPortal(
                env.asset_finder,
                open_calendar,
                first_trading_day=first_trading_day,
                minute_reader=bundle_data.minute_bar_reader,
                five_minute_reader=bundle_data.five_minute_bar_reader,
                daily_reader=bundle_data.daily_bar_reader,
                adjustment_reader=bundle_data.adjustment_reader,
            )

            return env, data

        def get_loader_for_bundle(b):
            bundle_data = load(
                b,
                environ,
                bundle_timestamp,
            )

            if b == 'poloniex':
                return CryptoPricingLoader(
                    bundle_data,
                    data_frequency,
                    CryptoPricing,
                )
            elif b == 'quandl':
                return USEquityPricingLoader(
                    bundle_data,
                    data_frequency,
                    USEquityPricing,
                )
            raise ValueError("No PipelineLoader registered for bundle %s." % b)

        loaders = [get_loader_for_bundle(b) for b in bundles]
        env, data = get_trading_env_and_data(bundles)

        def choose_loader(column):
            for loader in loaders:
                if column in loader.columns:
                    return loader
            raise ValueError("No PipelineLoader registered for column %s." %
                             column)

    else:
        env = TradingEnvironment(environ=environ)
        choose_loader = None

    perf = TradingAlgorithm(
        namespace=namespace,
        env=env,
        get_pipeline_loader=choose_loader,
        sim_params=create_simulation_parameters(
            start=start,
            end=end,
            capital_base=capital_base,
            data_frequency=data_frequency,
            emission_rate=data_frequency,
        ),
        **{
            'initialize': initialize,
            'handle_data': handle_data,
            'before_trading_start': before_trading_start,
            'analyze': analyze,
        } if algotext is None else {
            'algo_filename': getattr(algofile, 'name', '<algorithm>'),
            'script': algotext,
        }).run(
            data,
            overwrite_sim_params=False,
        )

    if output == '-':
        click.echo(str(perf))
    elif output != os.devnull:  # make the catalyst magic not write any data
        perf.to_pickle(output)

    return perf
コード例 #26
0
ファイル: run_algo.py プロジェクト: sompalarya1/catalyst
def _run(handle_data, initialize, before_trading_start, analyze, algofile,
         algotext, defines, data_frequency, capital_base, data, bundle,
         bundle_timestamp, start, end, output, print_algo, local_namespace,
         environ, live, exchange, algo_namespace, base_currency, live_graph,
         analyze_live, simulate_orders, stats_output):
    """Run a backtest for the given algorithm.

    This is shared between the cli and :func:`catalyst.run_algo`.
    """
    if algotext is not None:
        if local_namespace:
            ip = get_ipython()  # noqa
            namespace = ip.user_ns
        else:
            namespace = {}

        for assign in defines:
            try:
                name, value = assign.split('=', 2)
            except ValueError:
                raise ValueError(
                    'invalid define %r, should be of the form name=value' %
                    assign, )
            try:
                # evaluate in the same namespace so names may refer to
                # eachother
                namespace[name] = eval(value, namespace)
            except Exception as e:
                raise ValueError(
                    'failed to execute definition for name %r: %s' %
                    (name, e), )
    elif defines:
        raise _RunAlgoError(
            'cannot pass define without `algotext`',
            "cannot pass '-D' / '--define' without '-t' / '--algotext'",
        )
    else:
        namespace = {}
        if algofile is not None:
            algotext = algofile.read()

    if print_algo:
        if PYGMENTS:
            highlight(
                algotext,
                PythonLexer(),
                TerminalFormatter(),
                outfile=sys.stdout,
            )
        else:
            click.echo(algotext)

    mode = 'paper-trading' if simulate_orders else 'live-trading' \
        if live else 'backtest'
    log.info('running algo in {mode} mode'.format(mode=mode))

    exchange_name = exchange
    if exchange_name is None:
        raise ValueError('Please specify at least one exchange.')

    exchange_list = [x.strip().lower() for x in exchange.split(',')]

    exchanges = dict()
    for exchange_name in exchange_list:
        exchanges[exchange_name] = get_exchange(
            exchange_name=exchange_name,
            base_currency=base_currency,
            must_authenticate=(live and not simulate_orders),
            skip_init=True,
        )

    open_calendar = get_calendar('OPEN')

    env = TradingEnvironment(
        load=partial(load_crypto_market_data,
                     environ=environ,
                     start_dt=start,
                     end_dt=end),
        environ=environ,
        exchange_tz='UTC',
        asset_db_path=None  # We don't need an asset db, we have exchanges
    )
    env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)

    def choose_loader(column):
        bound_cols = TradingPairPricing.columns
        if column in bound_cols:
            return ExchangePricingLoader(data_frequency)
        raise ValueError("No PipelineLoader registered for column %s." %
                         column)

    if live:
        start = pd.Timestamp.utcnow()

        # TODO: fix the end data.
        end = start + timedelta(hours=8760)

        data = DataPortalExchangeLive(exchanges=exchanges,
                                      asset_finder=env.asset_finder,
                                      trading_calendar=open_calendar,
                                      first_trading_day=pd.to_datetime(
                                          'today', utc=True))

        def fetch_capital_base(exchange, attempt_index=0):
            """
            Fetch the base currency amount required to bootstrap
            the algorithm against the exchange.

            The algorithm cannot continue without this value.

            :param exchange: the targeted exchange
            :param attempt_index:
            :return capital_base: the amount of base currency available for
            trading
            """
            try:
                log.debug('retrieving capital base in {} to bootstrap '
                          'exchange {}'.format(base_currency, exchange_name))
                balances = exchange.get_balances()
            except ExchangeRequestError as e:
                if attempt_index < 20:
                    log.warn('could not retrieve balances on {}: {}'.format(
                        exchange.name, e))
                    sleep(5)
                    return fetch_capital_base(exchange, attempt_index + 1)

                else:
                    raise ExchangeRequestErrorTooManyAttempts(
                        attempts=attempt_index, error=e)

            if base_currency in balances:
                base_currency_available = balances[base_currency]['free']
                log.info(
                    'base currency available in the account: {} {}'.format(
                        base_currency_available, base_currency))

                return base_currency_available
            else:
                raise BaseCurrencyNotFoundError(base_currency=base_currency,
                                                exchange=exchange_name)

        if not simulate_orders:
            for exchange_name in exchanges:
                exchange = exchanges[exchange_name]
                balance = fetch_capital_base(exchange)

                if balance < capital_base:
                    raise NotEnoughCapitalError(
                        exchange=exchange_name,
                        base_currency=base_currency,
                        balance=balance,
                        capital_base=capital_base,
                    )

        sim_params = create_simulation_parameters(start=start,
                                                  end=end,
                                                  capital_base=capital_base,
                                                  emission_rate='minute',
                                                  data_frequency='minute')

        # TODO: use the constructor instead
        sim_params._arena = 'live'

        algorithm_class = partial(
            ExchangeTradingAlgorithmLive,
            exchanges=exchanges,
            algo_namespace=algo_namespace,
            live_graph=live_graph,
            simulate_orders=simulate_orders,
            stats_output=stats_output,
            analyze_live=analyze_live,
        )
    elif exchanges:
        # Removed the existing Poloniex fork to keep things simple
        # We can add back the complexity if required.

        # I don't think that we should have arbitrary price data bundles
        # Instead, we should center this data around exchanges.
        # We still need to support bundles for other misc data, but we
        # can handle this later.

        data = DataPortalExchangeBacktest(
            exchange_names=[exchange_name for exchange_name in exchanges],
            asset_finder=None,
            trading_calendar=open_calendar,
            first_trading_day=start,
            last_available_session=end)

        sim_params = create_simulation_parameters(
            start=start,
            end=end,
            capital_base=capital_base,
            data_frequency=data_frequency,
            emission_rate=data_frequency,
        )

        algorithm_class = partial(ExchangeTradingAlgorithmBacktest,
                                  exchanges=exchanges)

    elif bundle is not None:
        bundle_data = load(
            bundle,
            environ,
            bundle_timestamp,
        )

        prefix, connstr = re.split(
            r'sqlite:///',
            str(bundle_data.asset_finder.engine.url),
            maxsplit=1,
        )
        if prefix:
            raise ValueError(
                "invalid url %r, must begin with 'sqlite:///'" %
                str(bundle_data.asset_finder.engine.url), )

        env = TradingEnvironment(asset_db_path=connstr, environ=environ)
        first_trading_day = \
            bundle_data.equity_minute_bar_reader.first_trading_day

        data = DataPortal(
            env.asset_finder,
            open_calendar,
            first_trading_day=first_trading_day,
            equity_minute_reader=bundle_data.equity_minute_bar_reader,
            equity_daily_reader=bundle_data.equity_daily_bar_reader,
            adjustment_reader=bundle_data.adjustment_reader,
        )

    perf = algorithm_class(
        namespace=namespace,
        env=env,
        get_pipeline_loader=choose_loader,
        sim_params=sim_params,
        **{
            'initialize': initialize,
            'handle_data': handle_data,
            'before_trading_start': before_trading_start,
            'analyze': analyze,
        } if algotext is None else {
            'algo_filename': getattr(algofile, 'name', '<algorithm>'),
            'script': algotext,
        }).run(
            data,
            overwrite_sim_params=False,
        )

    if output == '-':
        click.echo(str(perf))
    elif output != os.devnull:  # make the catalyst magic not write any data
        perf.to_pickle(output)

    return perf
コード例 #27
0
ファイル: run_algo.py プロジェクト: halferemite/catalyst
def _run(handle_data, initialize, before_trading_start, analyze, algofile,
         algotext, defines, data_frequency, capital_base, data, bundle,
         bundle_timestamp, start, end, output, print_algo, local_namespace,
         environ, live, exchange, algo_namespace, base_currency, live_graph):
    """Run a backtest for the given algorithm.

    This is shared between the cli and :func:`catalyst.run_algo`.
    """
    if algotext is not None:
        if local_namespace:
            ip = get_ipython()  # noqa
            namespace = ip.user_ns
        else:
            namespace = {}

        for assign in defines:
            try:
                name, value = assign.split('=', 2)
            except ValueError:
                raise ValueError(
                    'invalid define %r, should be of the form name=value' %
                    assign, )
            try:
                # evaluate in the same namespace so names may refer to
                # eachother
                namespace[name] = eval(value, namespace)
            except Exception as e:
                raise ValueError(
                    'failed to execute definition for name %r: %s' %
                    (name, e), )
    elif defines:
        raise _RunAlgoError(
            'cannot pass define without `algotext`',
            "cannot pass '-D' / '--define' without '-t' / '--algotext'",
        )
    else:
        namespace = {}
        if algofile is not None:
            algotext = algofile.read()

    if print_algo:
        if PYGMENTS:
            highlight(
                algotext,
                PythonLexer(),
                TerminalFormatter(),
                outfile=sys.stdout,
            )
        else:
            click.echo(algotext)

    mode = 'live' if live else 'backtest'
    log.info('running algo in {mode} mode'.format(mode=mode))

    exchange_name = exchange
    if exchange_name is None:
        raise ValueError('Please specify at least one exchange.')

    exchange_list = [x.strip().lower() for x in exchange.split(',')]

    exchanges = dict()
    for exchange_name in exchange_list:

        # Looking for the portfolio from the cache first
        portfolio = get_algo_object(algo_name=algo_namespace,
                                    key='portfolio_{}'.format(exchange_name),
                                    environ=environ)

        if portfolio is None:
            portfolio = ExchangePortfolio(start_date=pd.Timestamp.utcnow())

        # This corresponds to the json file containing api token info
        exchange_auth = get_exchange_auth(exchange_name)

        if live and (exchange_auth['key'] == ''
                     or exchange_auth['secret'] == ''):
            raise ExchangeAuthEmpty(exchange=exchange_name.title(),
                                    filename=os.path.join(
                                        get_exchange_folder(
                                            exchange_name, environ),
                                        'auth.json'))

        if exchange_name == 'bitfinex':
            exchanges[exchange_name] = Bitfinex(key=exchange_auth['key'],
                                                secret=exchange_auth['secret'],
                                                base_currency=base_currency,
                                                portfolio=portfolio)
        elif exchange_name == 'bittrex':
            exchanges[exchange_name] = Bittrex(key=exchange_auth['key'],
                                               secret=exchange_auth['secret'],
                                               base_currency=base_currency,
                                               portfolio=portfolio)
        elif exchange_name == 'poloniex':
            exchanges[exchange_name] = Poloniex(key=exchange_auth['key'],
                                                secret=exchange_auth['secret'],
                                                base_currency=base_currency,
                                                portfolio=portfolio)
        else:
            raise ExchangeNotFoundError(exchange_name=exchange_name)

    open_calendar = get_calendar('OPEN')

    env = TradingEnvironment(
        load=partial(load_crypto_market_data,
                     environ=environ,
                     start_dt=start,
                     end_dt=end),
        environ=environ,
        exchange_tz='UTC',
        asset_db_path=None  # We don't need an asset db, we have exchanges
    )
    env.asset_finder = AssetFinderExchange()
    choose_loader = None  # TODO: use the DataPortal for in the algorithm class for this

    if live:
        start = pd.Timestamp.utcnow()

        # TODO: fix the end data.
        end = start + timedelta(hours=8760)

        data = DataPortalExchangeLive(exchanges=exchanges,
                                      asset_finder=env.asset_finder,
                                      trading_calendar=open_calendar,
                                      first_trading_day=pd.to_datetime(
                                          'today', utc=True))

        def fetch_capital_base(exchange, attempt_index=0):
            """
            Fetch the base currency amount required to bootstrap
            the algorithm against the exchange.

            The algorithm cannot continue without this value.

            :param exchange: the targeted exchange
            :param attempt_index:
            :return capital_base: the amount of base currency available for
            trading
            """
            try:
                log.debug('retrieving capital base in {} to bootstrap '
                          'exchange {}'.format(base_currency, exchange_name))
                balances = exchange.get_balances()
            except ExchangeRequestError as e:
                if attempt_index < 20:
                    log.warn('could not retrieve balances on {}: {}'.format(
                        exchange.name, e))
                    sleep(5)
                    return fetch_capital_base(exchange, attempt_index + 1)

                else:
                    raise ExchangeRequestErrorTooManyAttempts(
                        attempts=attempt_index, error=e)

            if base_currency in balances:
                return balances[base_currency]
            else:
                raise BaseCurrencyNotFoundError(base_currency=base_currency,
                                                exchange=exchange_name)

        capital_base = 0
        for exchange_name in exchanges:
            exchange = exchanges[exchange_name]
            capital_base += fetch_capital_base(exchange)

        sim_params = create_simulation_parameters(start=start,
                                                  end=end,
                                                  capital_base=capital_base,
                                                  emission_rate='minute',
                                                  data_frequency='minute')

        # TODO: use the constructor instead
        sim_params._arena = 'live'

        algorithm_class = partial(ExchangeTradingAlgorithmLive,
                                  exchanges=exchanges,
                                  algo_namespace=algo_namespace,
                                  live_graph=live_graph)
    else:
        # Removed the existing Poloniex fork to keep things simple
        # We can add back the complexity if required.

        # I don't think that we should have arbitrary price data bundles
        # Instead, we should center this data around exchanges.
        # We still need to support bundles for other misc data, but we
        # can handle this later.

        data = DataPortalExchangeBacktest(exchanges=exchanges,
                                          asset_finder=None,
                                          trading_calendar=open_calendar,
                                          first_trading_day=start,
                                          last_available_session=end)

        sim_params = create_simulation_parameters(
            start=start,
            end=end,
            capital_base=capital_base,
            data_frequency=data_frequency,
            emission_rate=data_frequency,
        )

        algorithm_class = partial(ExchangeTradingAlgorithmBacktest,
                                  exchanges=exchanges)

    perf = algorithm_class(
        namespace=namespace,
        env=env,
        get_pipeline_loader=choose_loader,
        sim_params=sim_params,
        **{
            'initialize': initialize,
            'handle_data': handle_data,
            'before_trading_start': before_trading_start,
            'analyze': analyze,
        } if algotext is None else {
            'algo_filename': getattr(algofile, 'name', '<algorithm>'),
            'script': algotext,
        }).run(
            data,
            overwrite_sim_params=False,
        )

    if output == '-':
        click.echo(str(perf))
    elif output != os.devnull:  # make the catalyst magic not write any data
        perf.to_pickle(output)

    return perf
コード例 #28
0
ファイル: test_finance.py プロジェクト: zhoukalex/catalyst
    def transaction_sim(self, **params):
        """This is a utility method that asserts expected
        results for conversion of orders to transactions given a
        trade history
        """
        trade_count = params['trade_count']
        trade_interval = params['trade_interval']
        order_count = params['order_count']
        order_amount = params['order_amount']
        order_interval = params['order_interval']
        expected_txn_count = params['expected_txn_count']
        expected_txn_volume = params['expected_txn_volume']

        # optional parameters
        # ---------------------
        # if present, alternate between long and short sales
        alternate = params.get('alternate')

        # if present, expect transaction amounts to match orders exactly.
        complete_fill = params.get('complete_fill')

        asset1 = self.asset_finder.retrieve_asset(1)
        metadata = make_simple_equity_info([asset1.sid], self.start, self.end)
        with TempDirectory() as tempdir, \
                tmp_trading_env(equities=metadata,
                                load=self.make_load_function()) as env:

            if trade_interval < timedelta(days=1):
                sim_params = factory.create_simulation_parameters(
                    start=self.start,
                    end=self.end,
                    data_frequency="minute"
                )

                minutes = self.trading_calendar.minutes_window(
                    sim_params.first_open,
                    int((trade_interval.total_seconds() / 60) * trade_count)
                    + 100)

                price_data = np.array([10.1] * len(minutes))
                assets = {
                    asset1.sid: pd.DataFrame({
                        "open": price_data,
                        "high": price_data,
                        "low": price_data,
                        "close": price_data,
                        "volume": np.array([100] * len(minutes)),
                        "dt": minutes
                    }).set_index("dt")
                }

                write_bcolz_minute_data(
                    self.trading_calendar,
                    self.trading_calendar.sessions_in_range(
                        self.trading_calendar.minute_to_session_label(
                            minutes[0]
                        ),
                        self.trading_calendar.minute_to_session_label(
                            minutes[-1]
                        )
                    ),
                    tempdir.path,
                    iteritems(assets),
                )

                equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

                data_portal = DataPortal(
                    env.asset_finder, self.trading_calendar,
                    first_trading_day=equity_minute_reader.first_trading_day,
                    equity_minute_reader=equity_minute_reader,
                )
            else:
                sim_params = factory.create_simulation_parameters(
                    data_frequency="daily"
                )

                days = sim_params.sessions

                assets = {
                    1: pd.DataFrame({
                        "open": [10.1] * len(days),
                        "high": [10.1] * len(days),
                        "low": [10.1] * len(days),
                        "close": [10.1] * len(days),
                        "volume": [100] * len(days),
                        "day": [day.value for day in days]
                    }, index=days)
                }

                path = os.path.join(tempdir.path, "testdata.bcolz")
                BcolzDailyBarWriter(path, self.trading_calendar, days[0],
                                    days[-1]).write(
                    assets.items()
                )

                equity_daily_reader = BcolzDailyBarReader(path)

                data_portal = DataPortal(
                    env.asset_finder, self.trading_calendar,
                    first_trading_day=equity_daily_reader.first_trading_day,
                    equity_daily_reader=equity_daily_reader,
                )

            if "default_slippage" not in params or \
               not params["default_slippage"]:
                slippage_func = FixedSlippage()
            else:
                slippage_func = None

            blotter = Blotter(sim_params.data_frequency, slippage_func)

            start_date = sim_params.first_open

            if alternate:
                alternator = -1
            else:
                alternator = 1

            tracker = PerformanceTracker(sim_params, self.trading_calendar,
                                         self.env)

            # replicate what tradesim does by going through every minute or day
            # of the simulation and processing open orders each time
            if sim_params.data_frequency == "minute":
                ticks = minutes
            else:
                ticks = days

            transactions = []

            order_list = []
            order_date = start_date
            for tick in ticks:
                blotter.current_dt = tick
                if tick >= order_date and len(order_list) < order_count:
                    # place an order
                    direction = alternator ** len(order_list)
                    order_id = blotter.order(
                        asset1,
                        order_amount * direction,
                        MarketOrder())
                    order_list.append(blotter.orders[order_id])
                    order_date = order_date + order_interval
                    # move after market orders to just after market next
                    # market open.
                    if order_date.hour >= 21:
                        if order_date.minute >= 00:
                            order_date = order_date + timedelta(days=1)
                            order_date = order_date.replace(hour=14, minute=30)
                else:
                    bar_data = BarData(
                        data_portal=data_portal,
                        simulation_dt_func=lambda: tick,
                        data_frequency=sim_params.data_frequency,
                        trading_calendar=self.trading_calendar,
                        restrictions=NoRestrictions(),
                    )
                    txns, _, closed_orders = blotter.get_transactions(bar_data)
                    for txn in txns:
                        tracker.process_transaction(txn)
                        transactions.append(txn)

                    blotter.prune_orders(closed_orders)

            for i in range(order_count):
                order = order_list[i]
                self.assertEqual(order.asset, asset1)
                self.assertEqual(order.amount, order_amount * alternator ** i)

            if complete_fill:
                self.assertEqual(len(transactions), len(order_list))

            total_volume = 0
            for i in range(len(transactions)):
                txn = transactions[i]
                total_volume += txn.amount
                if complete_fill:
                    order = order_list[i]
                    self.assertEqual(order.amount, txn.amount)

            self.assertEqual(total_volume, expected_txn_volume)

            self.assertEqual(len(transactions), expected_txn_count)

            cumulative_pos = tracker.position_tracker.positions[asset1]
            if total_volume == 0:
                self.assertIsNone(cumulative_pos)
            else:
                self.assertEqual(total_volume, cumulative_pos.amount)

            # the open orders should not contain the asset.
            oo = blotter.open_orders
            self.assertNotIn(
                asset1,
                oo,
                "Entry is removed when no open orders"
            )
コード例 #29
0
def _run(handle_data, initialize, before_trading_start, analyze, algofile,
         algotext, defines, data_frequency, capital_base, data, bundle,
         bundle_timestamp, start, end, output, print_algo, local_namespace,
         environ, live, exchange, algo_namespace, base_currency, live_graph):
    """Run a backtest for the given algorithm.

    This is shared between the cli and :func:`catalyst.run_algo`.
    """
    if algotext is not None:
        if local_namespace:
            ip = get_ipython()  # noqa
            namespace = ip.user_ns
        else:
            namespace = {}

        for assign in defines:
            try:
                name, value = assign.split('=', 2)
            except ValueError:
                raise ValueError(
                    'invalid define %r, should be of the form name=value' %
                    assign, )
            try:
                # evaluate in the same namespace so names may refer to
                # eachother
                namespace[name] = eval(value, namespace)
            except Exception as e:
                raise ValueError(
                    'failed to execute definition for name %r: %s' %
                    (name, e), )
    elif defines:
        raise _RunAlgoError(
            'cannot pass define without `algotext`',
            "cannot pass '-D' / '--define' without '-t' / '--algotext'",
        )
    else:
        namespace = {}
        if algofile is not None:
            algotext = algofile.read()

    if print_algo:
        if PYGMENTS:
            highlight(
                algotext,
                PythonLexer(),
                TerminalFormatter(),
                outfile=sys.stdout,
            )
        else:
            click.echo(algotext)

    mode = 'live' if live else 'backtest'
    log.info('running algo in {mode} mode'.format(mode=mode))

    if live and exchange is not None:
        exchange_name = exchange
        start = pd.Timestamp.utcnow()
        end = start + timedelta(minutes=1439)

        portfolio = get_algo_object(algo_name=algo_namespace,
                                    key='portfolio_{}'.format(exchange_name),
                                    environ=environ)
        if portfolio is None:
            portfolio = ExchangePortfolio(start_date=pd.Timestamp.utcnow())

        exchange_auth = get_exchange_auth(exchange_name)
        if exchange_name == 'bitfinex':
            exchange = Bitfinex(key=exchange_auth['key'],
                                secret=exchange_auth['secret'],
                                base_currency=base_currency,
                                portfolio=portfolio)
        elif exchange_name == 'bittrex':
            exchange = Bittrex(key=exchange_auth['key'],
                               secret=exchange_auth['secret'],
                               base_currency=base_currency,
                               portfolio=portfolio)
        else:
            raise NotImplementedError('exchange not supported: %s' %
                                      exchange_name)

    open_calendar = get_calendar('OPEN')
    sim_params = create_simulation_parameters(
        start=start,
        end=end,
        capital_base=capital_base,
        data_frequency=data_frequency,
        emission_rate=data_frequency,
    )

    if live and exchange is not None:
        env = TradingEnvironment(environ=environ,
                                 exchange_tz='UTC',
                                 asset_db_path=None)
        env.asset_finder = AssetFinderExchange(exchange)

        data = DataPortalExchange(exchange=exchange,
                                  asset_finder=env.asset_finder,
                                  trading_calendar=open_calendar,
                                  first_trading_day=pd.to_datetime('today',
                                                                   utc=True))
        choose_loader = None

        def fetch_capital_base(attempt_index=0):
            """
            Fetch the base currency amount required to bootstrap
            the algorithm against the exchange.

            The algorithm cannot continue without this value.

            :param attempt_index:
            :return capital_base: the amount of base currency available for
            trading
            """
            try:
                log.debug('retrieving capital base in {} to bootstrap '
                          'exchange {}'.format(base_currency, exchange_name))
                balances = exchange.get_balances()
            except ExchangeRequestError as e:
                if attempt_index < 20:
                    sleep(5)
                    return fetch_capital_base(attempt_index + 1)
                else:
                    raise ExchangeRequestErrorTooManyAttempts(
                        attempts=attempt_index, error=e)

            if base_currency in balances:
                return balances[base_currency]
            else:
                raise BaseCurrencyNotFoundError(base_currency=base_currency,
                                                exchange=exchange_name)

        sim_params = create_simulation_parameters(
            start=start,
            end=end,
            capital_base=fetch_capital_base(),
            emission_rate='minute',
            data_frequency='minute')

    elif bundle is not None:
        bundles = bundle.split(',')

        def get_trading_env_and_data(bundles):
            env = data = None

            b = 'poloniex'
            if len(bundles) == 0:
                return env, data
            elif len(bundles) == 1:
                b = bundles[0]

            bundle_data = load(
                b,
                environ,
                bundle_timestamp,
            )

            prefix, connstr = re.split(
                r'sqlite:///',
                str(bundle_data.asset_finder.engine.url),
                maxsplit=1,
            )
            if prefix:
                raise ValueError(
                    "invalid url %r, must begin with 'sqlite:///'" %
                    str(bundle_data.asset_finder.engine.url), )

            env = TradingEnvironment(
                load=partial(load_crypto_market_data,
                             bundle=b,
                             bundle_data=bundle_data,
                             environ=environ),
                bm_symbol='USDT_BTC',
                trading_calendar=open_calendar,
                asset_db_path=connstr,
                environ=environ,
            )

            first_trading_day = bundle_data.minute_bar_reader.first_trading_day

            data = DataPortal(
                env.asset_finder,
                open_calendar,
                first_trading_day=first_trading_day,
                minute_reader=bundle_data.minute_bar_reader,
                five_minute_reader=bundle_data.five_minute_bar_reader,
                daily_reader=bundle_data.daily_bar_reader,
                adjustment_reader=bundle_data.adjustment_reader,
            )

            return env, data

        def get_loader_for_bundle(b):
            bundle_data = load(
                b,
                environ,
                bundle_timestamp,
            )

            if b == 'poloniex':
                return CryptoPricingLoader(
                    bundle_data,
                    data_frequency,
                    CryptoPricing,
                )
            elif b == 'quandl':
                return USEquityPricingLoader(
                    bundle_data,
                    data_frequency,
                    USEquityPricing,
                )
            raise ValueError("No PipelineLoader registered for bundle %s." % b)

        loaders = [get_loader_for_bundle(b) for b in bundles]
        env, data = get_trading_env_and_data(bundles)

        def choose_loader(column):
            for loader in loaders:
                if column in loader.columns:
                    return loader
            raise ValueError("No PipelineLoader registered for column %s." %
                             column)

    else:
        env = TradingEnvironment(environ=environ)
        choose_loader = None

    TradingAlgorithmClass = (partial(ExchangeTradingAlgorithm,
                                     exchange=exchange,
                                     algo_namespace=algo_namespace,
                                     live_graph=live_graph)
                             if live and exchange else TradingAlgorithm)

    perf = TradingAlgorithmClass(
        namespace=namespace,
        env=env,
        get_pipeline_loader=choose_loader,
        sim_params=sim_params,
        **{
            'initialize': initialize,
            'handle_data': handle_data,
            'before_trading_start': before_trading_start,
            'analyze': analyze,
        } if algotext is None else {
            'algo_filename': getattr(algofile, 'name', '<algorithm>'),
            'script': algotext,
        }).run(
            data,
            overwrite_sim_params=False,
        )

    if output == '-':
        click.echo(str(perf))
    elif output != os.devnull:  # make the catalyst magic not write any data
        perf.to_pickle(output)

    return perf
コード例 #30
0
ファイル: simfactory.py プロジェクト: zhoukalex/catalyst
def create_test_catalyst(**config):
    """
       :param config: A configuration object that is a dict with:

           - sid - an integer, which will be used as the asset ID.
           - order_count - the number of orders the test algo will place,
             defaults to 100
           - order_amount - the number of shares per order, defaults to 100
           - trade_count - the number of trades to simulate, defaults to 101
             to ensure all orders are processed.
           - algorithm - optional parameter providing an algorithm. defaults
             to :py:class:`catalyst.test.algorithms.TestAlgorithm`
           - trade_source - optional parameter to specify trades, if present.
             If not present :py:class:`catalyst.sources.SpecificEquityTrades`
             is the source, with daily frequency in trades.
           - slippage: optional parameter that configures the
             :py:class:`catalyst.gens.tradingsimulation.TransactionSimulator`.
             Expects an object with a simulate mehod, such as
             :py:class:`catalyst.gens.tradingsimulation.FixedSlippage`.
             :py:mod:`catalyst.finance.trading`
       """
    assert isinstance(config, dict)

    try:
        sid_list = config['sid_list']
    except KeyError:
        try:
            sid_list = [config['sid']]
        except KeyError:
            raise Exception("simfactory create_test_catalyst() requires "
                            "argument 'sid_list' or 'sid'")

    concurrent_trades = config.get('concurrent_trades', False)
    order_count = config.get('order_count', 100)
    order_amount = config.get('order_amount', 100)
    trading_calendar = config.get('trading_calendar', get_calendar("NYSE"))

    # -------------------
    # Create the Algo
    # -------------------
    if 'algorithm' in config:
        test_algo = config['algorithm']
    else:
        test_algo = TestAlgorithm(
            sid_list[0],
            order_amount,
            order_count,
            sim_params=config.get('sim_params',
                                  factory.create_simulation_parameters()),
            trading_calendar=trading_calendar,
            slippage=config.get('slippage'),
            identifiers=sid_list
        )

    # -------------------
    # Trade Source
    # -------------------
    if 'skip_data' not in config:
        if 'trade_source' in config:
            trade_source = config['trade_source']
        else:
            trade_source = factory.create_daily_trade_source(
                sid_list,
                test_algo.sim_params,
                test_algo.trading_environment,
                trading_calendar,
                concurrent=concurrent_trades,
            )

        trades_by_sid = {}
        for trade in trade_source:
            if trade.sid not in trades_by_sid:
                trades_by_sid[trade.sid] = []

            trades_by_sid[trade.sid].append(trade)

        data_portal = create_data_portal_from_trade_history(
            config['env'].asset_finder,
            trading_calendar,
            config['tempdir'],
            config['sim_params'],
            trades_by_sid
        )

        test_algo.data_portal = data_portal

    # -------------------
    # Benchmark source
    # -------------------

    test_algo.benchmark_return_source = config.get('benchmark_source', None)

    # ------------------
    # generator/simulator
    sim = test_algo.get_generator()

    return sim