Exemplo n.º 1
0
def cal_period_rate(sec_data, by='month'):

    # 计算周期收益率
    start_date = sec_data.index.min().date()
    end_date = sec_data.index.max().date()

    # 构造周期列表
    periods = []

    # 年周期
    if by == 'year':
        for year in range(start_date.year, end_date.year + 1):
            p = '%(year)s' % dict(year=year)
            periods.append((p, p))

    # 月周期
    elif by == 'month':
        for year in range(start_date.year, end_date.year + 1):
            for month in range(1, 13):
                if year >= end_date.year and month > end_date.month:
                    break
                p = '%(year)s-%(month)02d' % dict(year=year, month=month)
                periods.append((p, p))

    # 周周期
    elif by == 'week':
        week_start = start_date
        while week_start < end_date:
            week_end = week_start + datetime.timedelta(
                days=(6 - week_start.weekday()))
            periods.append((week_start, week_end))
            week_start = week_end + datetime.timedelta(days=1)
    else:
        print('Invalid period')

    # 计算周期收益率
    period_rate = {'period': [], 'rate': []}
    for p_pair in periods:
        tmp_data = sec_data[p_pair[0]:p_pair[1]]
        if len(tmp_data) == 0:
            continue
        else:
            period_rate['period'].append(p_pair[0])
            period_rate['rate'].append(
                cal_HPR(data=tmp_data, start=None, end=None, dim='Close'))

    period_rate = pd.DataFrame(period_rate)
    period_rate = util.df_2_timeseries(df=period_rate, time_col='period')

    return period_rate
Exemplo n.º 2
0
def back_test(signal,
              cash=0,
              stock=0,
              start_date=None,
              end_date=None,
              trading_fee=3,
              stop_profit=0.1,
              stop_loss=0.6,
              mode='earning',
              print_trading=True,
              plot_trading=True):

    # 获取指定期间的信号
    signal = signal[start_date:end_date]
    original_cash = cash

    # 记录交易
    record = {
        'date': [],
        'action': [],
        'holding': [],
        'price': [],
        'cash': [],
        'total': []
    }

    # 以盈利模式进行回测
    if mode == 'earning':

        # 获取买入信号
        buy_signals = signal.query('signal == "b"').index.tolist()
        selling_date = signal.index.min()

        # 从第一次买入信号开始交易
        for date in buy_signals:

            # 信号的第二天开始操作
            tmp_data = signal[date:][1:]
            if (len(tmp_data) < 2) or (date < selling_date):
                continue

            # 买入(开盘价)
            if stock == 0 and cash > 0:
                buying_date = tmp_data.index.min()
                buying_price = tmp_data.loc[buying_date, 'Open']
                stock = math.floor((cash - trading_fee) / buying_price)

                if stock > 0:
                    cash = cash - stock * buying_price - trading_fee
                    total = (cash + stock * buying_price)

                    # 记录交易信息
                    record['date'].append(buying_date.date())
                    record['action'].append('b')
                    record['holding'].append(stock)
                    record['price'].append(buying_price)
                    record['cash'].append(cash)
                    record['total'].append(total)

                    # 打印交易记录
                    if print_trading:
                        print(
                            buying_date.date(),
                            '买入 %(stock)s, 价格%(price)s, 流动资金%(cash)s, 总值%(total)s'
                            % dict(stock=stock,
                                   price=buying_price,
                                   cash=cash,
                                   total=total))
                else:
                    print(buying_date.date(),
                          '买入 %(stock)s' % dict(stock=stock))

            # 卖出(如果有持仓)
            if stock > 0:
                for index, row in tmp_data.iterrows():
                    selling_date = index
                    selling_price = row['Close']

                    # 收益卖出(收盘价)
                    if ((selling_price - buying_price) /
                            buying_price) > stop_profit:
                        cash = cash + selling_price * stock - trading_fee
                        stock = 0
                        total = cash
                        if print_trading:
                            print(
                                selling_date.date(),
                                '止盈, 价格%(price)s, 流动资金%(cash)s, 总值%(total)s' %
                                dict(price=selling_price,
                                     cash=cash,
                                     total=total))

                        # 记录交易信息
                        record['date'].append(selling_date.date())
                        record['action'].append('s')
                        record['holding'].append(stock)
                        record['price'].append(selling_price)
                        record['cash'].append(cash)
                        record['total'].append(cash)
                        break

                    # 止损卖出(收盘价)
                    elif ((selling_price - buying_price) /
                          buying_price) < -stop_loss:
                        cash = cash + selling_price * stock - trading_fee
                        stock = 0
                        total = cash
                        if print_trading:
                            print(
                                selling_date.date(),
                                '止损, 价格%(price)s, 流动资金%(cash)s, 总值%(total)s' %
                                dict(price=selling_price,
                                     cash=cash,
                                     total=total))

                        # 记录交易信息
                        record['date'].append(selling_date.date())
                        record['action'].append('s')
                        record['holding'].append(stock)
                        record['price'].append(selling_price)
                        record['cash'].append(cash)
                        record['total'].append(total)
                        break

    # 以信号模式进行回测
    elif mode == 'signal':

        # 去除冲突的信号
        buy_sell_signals = signal.query('signal != "n"')
        trading_signals = []
        last_signal = 'n'
        for index, row in buy_sell_signals.iterrows():
            current_signal = row['signal']
            if current_signal == last_signal:
                continue
            else:
                trading_signals.append(index)
            last_signal = current_signal

        # 开始交易
        for date in trading_signals:

            if date == signal.index.max():
                print('信号于', date, '发出')
                break

            # 信号的第二天交易
            tmp_signal = signal.loc[date, 'signal']
            tmp_data = signal[date:][1:]
            trading_date = tmp_data.index.min()

            # 以开盘价买入
            if tmp_signal == 'b':
                buying_price = signal.loc[trading_date, 'Open']
                stock = math.floor((cash - trading_fee) / buying_price)
                if stock > 0:
                    cash = cash - stock * buying_price - trading_fee
                    total = (cash + stock * buying_price)
                    if print_trading:
                        print(
                            trading_date.date(),
                            '买入 %(stock)s, 价格%(price)s, 流动资金%(cash)s, 总值%(total)s'
                            % dict(stock=stock,
                                   price=buying_price,
                                   cash=cash,
                                   total=total))

                    # 记录交易信息
                    record['date'].append(trading_date.date())
                    record['action'].append('b')
                    record['holding'].append(stock)
                    record['price'].append(buying_price)
                    record['cash'].append(cash)
                    record['total'].append(total)
                else:
                    print(trading_date.date(),
                          '买入 %(stock)s' % dict(stock=stock))

            # 以收盘价卖出
            elif tmp_signal == 's':
                if stock > 0:
                    selling_price = signal.loc[trading_date, 'Close']
                    cash = cash + selling_price * stock - trading_fee
                    stock = 0
                    total = cash + stock * selling_price
                    if print_trading:
                        print(
                            trading_date.date(),
                            '卖出, 价格%(price)s, 流动资金%(cash)s, 总值%(total)s' %
                            dict(price=selling_price, cash=cash, total=total))

                    # 记录交易信息
                    record['date'].append(trading_date.date())
                    record['action'].append('s')
                    record['holding'].append(stock)
                    record['price'].append(selling_price)
                    record['cash'].append(cash)
                    record['total'].append(total)

            else:
                print('invalid signal %s' % tmp_signal)

    # 未定义的模式
    else:
        print('mode [%s] not found' % mode)

    # 记录最新数据
    current_date = signal.index.max()
    current_price = signal.loc[current_date, 'Close']
    total = cash + stock * current_price
    record['date'].append(current_date.date())
    record['action'].append(signal.loc[current_date, 'signal'])
    record['holding'].append(stock)
    record['price'].append(current_price)
    record['cash'].append(cash)
    record['total'].append(total)
    if print_trading:
        print(
            current_date.date(), '当前, 价格%(price)s, 总值%(total)s' %
            dict(price=current_price, total=total))

    # 将记录转化为时序数据
    record = util.df_2_timeseries(pd.DataFrame(record), time_col='date')

    # 画出回测图
    if plot_trading:
        buying_points = record.query('action == "b"')
        selling_points = record.query('action == "s"')

        f, ax = plt.subplots(figsize=(20, 3))
        plt.plot(signal[['Close']])
        plt.scatter(buying_points.index, buying_points.price, c='green')
        plt.scatter(selling_points.index, selling_points.price, c='red')

        total_value_data = pd.merge(signal[['Close']],
                                    record[['cash', 'holding', 'action']],
                                    how='left',
                                    left_index=True,
                                    right_index=True)
        total_value_data.fillna(method='ffill', inplace=True)
        total_value_data['original'] = original_cash
        total_value_data['total'] = total_value_data[
            'Close'] * total_value_data['holding'] + total_value_data['cash']
        total_value_data[['total', 'original']].plot(figsize=(20, 3))

    return record
Exemplo n.º 3
0
def cal_period_rate_risk(data, dim='value', by='month'):
    """
  Calculate rate and risk in a specfic period

  :param data: original OHLCV data
  :param dim: price dim to calculate
  :param by: by which period: year/month/week
  :returns: periodical return and risk
  :raises: none
  """
    # calculate the change rate by day
    data = ta_util.cal_change_rate(df=data, target_col=dim, periods=1)

    # get start/end date, construct period list
    start_date = data.index.min().date()
    end_date = data.index.max().date()
    periods = []

    # by year
    if by == 'year':
        for year in range(start_date.year, end_date.year + 1):
            p = '%(year)s' % dict(year=year)
            periods.append((p, p))

    # by month
    elif by == 'month':
        for year in range(start_date.year, end_date.year + 1):
            for month in range(1, 13):
                if year >= end_date.year and month > end_date.month:
                    break
                p = '%(year)s-%(month)02d' % dict(year=year, month=month)
                periods.append((p, p))

    # by week
    elif by == 'week':
        week_start = start_date
        while week_start < end_date:
            week_end = week_start + datetime.timedelta(
                days=(6 - week_start.weekday()))
            periods.append((week_start, week_end))
            week_start = week_end + datetime.timedelta(days=1)
    else:
        print('Invalid period')

    # calculate the risk/return for the period
    period_rate = {
        'period': [],
        'start': [],
        'end': [],
        'HPR': [],
        'EAR': [],
        'APR': [],
        'CCR': [],
        'daily_rate_mean': [],
        'daily_rate_std': []
    }
    for p_pair in periods:
        tmp_data = data[p_pair[0]:p_pair[1]]
        if len(tmp_data) <= 1:
            continue
        else:
            period_rate['period'].append(p_pair[0])
            period_rate['start'].append(p_pair[0])
            period_rate['end'].append(p_pair[1])
            period_rate['HPR'].append(
                cal_HPR(data=tmp_data, start=None, end=None, dim='Close'))
            period_rate['EAR'].append(
                cal_EAR(data=tmp_data, start=None, end=None, dim='Close'))
            period_rate['APR'].append(
                cal_APR(data=tmp_data, start=None, end=None, dim='Close'))
            period_rate['CCR'].append(
                cal_CCR(data=tmp_data, start=None, end=None, dim='Close'))
            period_rate['daily_rate_mean'].append(tmp_data.rate.mean())
            period_rate['daily_rate_std'].append(tmp_data.rate.std())

    period_rate = pd.DataFrame(period_rate)
    period_rate = util.df_2_timeseries(df=period_rate, time_col='period')

    return period_rate
Exemplo n.º 4
0
def download_stock_data_from_tiger(sec_code,
                                   time_col='time',
                                   quote_client=None,
                                   download_limit=1200,
                                   start_date=None,
                                   end_date=None,
                                   file_path='drive/My Drive/stock_data_us/',
                                   file_format='.csv',
                                   is_return=False,
                                   is_print=True):

    # 构建股票数据文件名
    filename = file_path + sec_code + file_format

    # 下载开始
    stage = 'downloading_started'
    try:
        # 查看是否已存在下载好的文件, 若有则读取, 若没有则初始化
        stage = 'loading_existed_data'
        data = pd.DataFrame()
        if os.path.exists(filename):
            data = read_stock_data(sec_code,
                                   file_path=file_path,
                                   file_format=file_format,
                                   time_col='Date')

        # 记录原始数据记录数, 更新下载起始日期
        init_len = len(data)
        if init_len > 0:
            start_date = util.time_2_string(data.index.max(),
                                            date_format='%Y-%m-%d')

        # 从老虎API下载数据
        stage = 'downloading_new_data'

        # 将开始结束时间转化为时间戳
        if start_date is not None:
            begin_time = round(
                time.mktime(util.string_2_time(start_date).timetuple()) * 1000)
        else:
            begin_time = 0
        if end_date is not None:
            end_time = round(
                time.mktime(util.string_2_time(end_date).timetuple()) * 1000)
        else:
            end_time = round(time.time() * 1000)

        # 开始下载数据
        tmp_len = download_limit
        new_data = pd.DataFrame()
        while tmp_len >= download_limit:
            tmp_data = quote_client.get_bars([sec_code],
                                             begin_time=begin_time,
                                             end_time=end_time,
                                             limit=download_limit)
            tmp_len = len(tmp_data)
            new_data = tmp_data.append(new_data)
            end_time = int(tmp_data.time.min())

        # 处理下载的数据
        stage = 'processing_new_data'
        if len(new_data) > 0:
            new_data.drop('symbol', axis=1, inplace=True)
            new_data[time_col] = new_data[time_col].apply(
                lambda x: util.timestamp_2_time(x).date())
            new_data.rename(columns={
                'open': 'Open',
                'high': 'High',
                'low': 'Low',
                'close': 'Close',
                'volume': 'Volume',
                'time': 'Date'
            },
                            inplace=True)
            new_data['Adj Close'] = new_data['Close']
            time_col = 'Date'
            new_data = util.df_2_timeseries(df=new_data, time_col=time_col)

            # 附上已有数据
            data = data.append(new_data, sort=False)

            # 去重,保存数据
            stage = 'saving_data'
            data = data.reset_index().drop_duplicates(subset=time_col,
                                                      keep='last')
            data.sort_values(by=time_col, )
            data.to_csv(filename, index=False)

        # 对比记录数量变化
        if is_print:
            final_len = len(data)
            diff_len = final_len - init_len
            print(
                '[From Tiger]%(sec_code)s: %(first_date)s - %(latest_date)s, 新增记录 %(diff_len)s/%(final_len)s, '
                % dict(diff_len=diff_len,
                       final_len=final_len,
                       first_date=data[time_col].min().date(),
                       latest_date=data[time_col].max().date(),
                       sec_code=sec_code))

    except Exception as e:
        print(sec_code, stage, e)

    # 返回数据
    if is_return:
        data = util.df_2_timeseries(data, time_col=time_col)
        return data
Exemplo n.º 5
0
def read_stock_data(sec_code,
                    time_col,
                    file_path='drive/My Drive/stock_data_us/',
                    file_format='.csv',
                    source='google_drive',
                    start_date=None,
                    end_date=None,
                    drop_cols=[],
                    drop_na=False,
                    sort_index=True):

    try:
        # 从 Google drive中读取股票数据
        if source == 'google_drive':

            # 构建文件名
            filename = file_path + sec_code + file_format
            if not os.path.exists(filename):
                print(filename, ' not exists')
                data = pd.DataFrame()

            else:
                # 读取数据
                stage = 'reading_from_google_drive'
                if file_format == '.csv':
                    data = pd.read_csv(filename,
                                       encoding='utf8',
                                       engine='python')
                elif file_format == '.xlsx':
                    data = pd.read_excel(filename)

                # 转化为时间序列
                stage = 'transforming_to_timeseries'
                data = util.df_2_timeseries(df=data, time_col=time_col)

                # 处理异常数据
                stage = 'handling_invalid_data'
                # 删除指定列
                data.drop(drop_cols, axis=1, inplace=True)
                # 删除NA列
                if drop_na:
                    data.dropna(axis=1, inplace=True)
                # 重新排序index
                if sort_index:
                    data.sort_index(inplace=True)

        # 从网络上下载股票数据
        elif source == 'web':

            # 下载数据
            stage = 'reading_from_pandas_datareader'
            data = web.DataReader(sec_code,
                                  'yahoo',
                                  start=start_date,
                                  end=end_date)

        else:
            print('source %s not found' % source)
            data = pd.DataFrame()
    except Exception as e:
        print(sec_code, stage, e)
        data = pd.DataFrame()

    return data[start_date:end_date]
Exemplo n.º 6
0
def download_stock_data_from_yahoo(sec_code,
                                   time_col='Date',
                                   start_date=None,
                                   end_date=None,
                                   file_path='drive/My Drive/stock_data_us/',
                                   file_format='.csv',
                                   is_return=False,
                                   is_print=True):

    # 构建股票数据文件名
    filename = file_path + sec_code + file_format

    # 下载开始
    stage = 'downloading_started'
    try:
        # 查看是否已存在下载好的文件, 若有则读取, 若没有则初始化
        stage = 'loading_existed_data'
        data = pd.DataFrame()
        if os.path.exists(filename):
            data = read_stock_data(sec_code,
                                   file_path=file_path,
                                   file_format=file_format,
                                   time_col=time_col)

        # 记录原始数据记录数, 更新下载的起始日期
        init_len = len(data)
        if init_len > 0:
            start_date = util.time_2_string(data.index.max(),
                                            date_format='%Y-%m-%d')

        # 下载更新新下载的数据并保存
        stage = 'appending_new_data'
        tmp_data = web.DataReader(sec_code,
                                  'yahoo',
                                  start=start_date,
                                  end=end_date)
        if len(tmp_data) > 0:
            data = data.append(tmp_data, sort=False)

            # 保存数据
            stage = 'saving_data'
            data = data.reset_index().drop_duplicates(subset=time_col,
                                                      keep='last')
            data.to_csv(filename, index=False)

        # 对比记录数量变化
        if is_print:
            final_len = len(data)
            diff_len = final_len - init_len
            print(
                '%(sec_code)s: %(first_date)s - %(latest_date)s, 新增记录 %(diff_len)s/%(final_len)s, '
                % dict(diff_len=diff_len,
                       final_len=final_len,
                       first_date=data[time_col].min().date(),
                       latest_date=data[time_col].max().date(),
                       sec_code=sec_code))
    except Exception as e:
        print(sec_code, stage, e)

    # 返回数据
    if is_return:
        data = util.df_2_timeseries(data, time_col=time_col)
        return data