Пример #1
0
def make_env():
    # stock_db
    stock_db = StockDatabase(db_path)

    # sampler
    #ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
    #                                     sampling_ticker_number=ticker_number)

    ticker_names_sampler = ConstSamper(
        TickerSampler(ticker_codes, ticker_number).sample())  # 固定する

    start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                             end_datetime=end_datetime,
                                             episode_length=episode_length,
                                             freq_str=freq_str)

    portfolio_sampler = PortfolioVectorSampler(ticker_number + 1)

    sampler_manager = SamplerManager(
        ticker_names_sampler=ticker_names_sampler,
        datetime_sampler=start_datetime_sampler,
        portfolio_vector_sampler=portfolio_sampler,
    )

    # PriceSupplierの設定
    price_supplier = StockDBPriceSupplier(
        stock_db,
        [],  # 最初は何の銘柄コードも指定しない
        episode_length,
        freq_str,
        interpolate=True)

    # PortfolioTransformerの設定
    portfolio_transformer = PortfolioTransformer(
        price_supplier,
        portfolio_restrictor=PortfolioRestrictorIdentity(),
        use_ohlc="Close",
        initial_all_assets=1e6,  # 学習には関係ない
        fee_calculator=FeeCalculatorFree())

    # TradeEnvの設定
    trade_env = TradeEnv(portfolio_transformer,
                         sampler_manager,
                         window=window,
                         fee_const=0.0025)

    return trade_env
Пример #2
0
if __name__ == "__main__":
    import argparse
    print("[{}] schedule program start".format(str(datetime.datetime.now())))

    parser = argparse.ArgumentParser(
        description='insert data to database with scheduling')
    parser.add_argument("--tempfile",
                        action="store_true",
                        help="tempfileを利用するかどうか")

    args = parser.parse_args()

    # 必要なインスタンス
    db_path = Path("db/stock_db") / Path("stock.db")
    stock_db = StockDatabase(db_path,
                             column_upper_limit=1000,
                             table_name_base="table")
    nikkei_code_file_path = Path("get_stock_price") / Path("nikkei225.csv")
    tosho_code_file_path = Path("get_stock_price") / Path("tosho.csv")

    stock_loader = YahooFinanceStockLoaderMin(None,
                                              past_day=5,
                                              stop_time_span=2.0,
                                              is_use_stop=False)  #ストップしない

    nikkei_kobetsu_insert = CsvKobetsuInsert(nikkei_code_file_path,
                                             stock_loader,
                                             stock_db,
                                             stock_group="nikkei_255",
                                             use_tempfile=args.tempfile)
    tosho_kobetsu_insert = CsvKobetsuInsert(tosho_code_file_path,
    import torch.nn.functional as F

    import pandas as pd
    import collections

    import datetime
    from pytz import timezone
    from pathlib import Path

    from get_stock_price import StockDatabase

    from envs_ver2 import OneStockEnv, NormalizeState, NormalizeReward

    db_path = Path("E:/システムトレード入門/trade_system_git_workspace/db/stock_db"
                   ) / Path("stock.db")
    stock_db = StockDatabase(db_path)

    jst_timezone = timezone("Asia/Tokyo")
    start_datetime = jst_timezone.localize(
        datetime.datetime(2020, 11, 1, 0, 0, 0))
    end_datetime = jst_timezone.localize(
        datetime.datetime(2020, 12, 1, 0, 0, 0))
    #end_datetime = get_next_workday_jp(start_datetime, days=11)  # 営業日で一週間(5日間)

    #stock_names = "4755"
    stock_names = "9984"
    #stock_names = ["6502","4755"]
    #stock_list = ["4755","9984","6701","7203","7267"]

    use_ohlc = "Close"
Пример #4
0
if __name__ == "__main__":
    import sys
    sys.path.append(r"E:\システムトレード入門\tutorials\rl\pfrl")
    sys.path.append(r"E:\システムトレード入門\trade_system_git_workspace")

    import datetime
    from pytz import timezone
    from pathlib import Path

    import pfrl
    from get_stock_price import StockDatabase
    from envs_ver2 import OneStockEnv, NormalizeState, NormalizeReward

    db_path = Path("E:/システムトレード入門/trade_system_git_workspace/db/stock_db"
                   ) / Path("stock.db")
    stock_db = StockDatabase(db_path)

    jst_timezone = timezone("Asia/Tokyo")
    start_datetime = jst_timezone.localize(
        datetime.datetime(2020, 11, 1, 0, 0, 0))
    end_datetime = jst_timezone.localize(
        datetime.datetime(2020, 12, 1, 0, 0, 0))
    #end_datetime = get_next_workday_jp(start_datetime, days=11)  # 営業日で一週間(5日間)

    #stock_names = "4755"
    #stock_names = "9984"
    stock_names = "6502"
    #stock_names = ["6502","4755"]
    #stock_list = ["4755","9984","6701","7203","7267"]

    stock_df = stock_db.search_span(stock_names=stock_names,
Пример #5
0
def make_env(
        db_path,
        csv_path,
        is_ticker_sample=True,
        start_datetime=jst.localize(datetime.datetime(2020, 11, 10, 0, 0, 0)),
        end_datetime=jst.localize(datetime.datetime(2020, 11, 20, 0, 0, 0)),
        episode_length=300,
        window=np.arange(0, 50),
        ticker_number=19,
        fee_const=0.0025,
):
    ticker_codes_df = pd.read_csv(csv_path, header=0)  # 自分で作成
    ticker_codes = ticker_codes_df["code"].values.astype(str).tolist()
    # stock_db
    stock_db = StockDatabase(db_path)

    # sampler
    if is_ticker_sample:
        ticker_names_sampler = TickerSampler(
            all_ticker_names=ticker_codes,
            sampling_ticker_number=ticker_number)
    else:
        ticker_names_sampler = ConstSamper(
            TickerSampler(ticker_codes, ticker_number).sample())  # 固定する

    start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                             end_datetime=end_datetime,
                                             episode_length=episode_length,
                                             freq_str=freq_str)

    portfolio_sampler = PortfolioVectorSampler(ticker_number + 1)

    sampler_manager = SamplerManager(
        ticker_names_sampler=ticker_names_sampler,
        datetime_sampler=start_datetime_sampler,
        portfolio_vector_sampler=portfolio_sampler,
    )

    # PriceSupplierの設定
    price_supplier = StockDBPriceSupplier(
        stock_db,
        [],  # 最初は何の銘柄コードも指定しない
        episode_length,
        freq_str,
        interpolate=True)

    # PortfolioTransformerの設定
    portfolio_transformer = PortfolioTransformer(
        price_supplier,
        portfolio_restrictor=PortfolioRestrictorIdentity(),
        use_ohlc="Close",
        initial_all_assets=1e6,  # 学習には関係ない
        fee_calculator=FeeCalculatorFree())

    # TradeEnvの設定
    trade_env = TradeEnv(portfolio_transformer,
                         sampler_manager,
                         window=window,
                         fee_const=fee_const)

    return trade_env
Пример #6
0
                             use_x_range=use_x_range,
                             use_y_range=use_y_range,
                             data_left_times=data_left_times,
                             is_notebook=is_notebook,
                             use_formatter=use_formatter)


if __name__ == "__main__":
    from tornado.ioloop import IOLoop  # サーバーをたてるのに必要
    from bokeh.server.server import Server  # サーバーを立てるのに必要
    from pytz import timezone

    from get_stock_price import StockDatabase

    db_path = Path("db/stock_db") / Path("stock.db")
    stock_db = StockDatabase(db_path)

    stock_name = "4755"  # 楽天
    stock_timestamp_df = stock_db.stock_timestamp(stock_names=["4755"],
                                                  to_tokyo=True)
    day_before = stock_timestamp_df.loc[0,
                                        "min_datetime"] + datetime.timedelta(
                                            days=10)  # 最初の日時から次の日とする.

    # 日時の取得
    jst_timezone = timezone("Asia/Tokyo")
    start_time = jst_timezone.localize(
        datetime.datetime(day_before.year, day_before.month, day_before.day, 9,
                          0, 0))
    #start_time = jst_timezone.localize(datetime.datetime(day_before.year, day_before.month, day_before.day, 12, 30, 0))
    end_time = jst_timezone.localize(