コード例 #1
0
def restore(from_path=PATH + '/data/backup', to_series='stock'):
    """
    restore(from_path=PATH+'/data/backup', to_series='stock'):

    Restore data from csv

    Input:
        from_from: (string): path of csv backup file

        to_series: (string): series of database to recover

    Return:
        None
    """
    l = _getlist(from_path, [])
    cont6 = get_connection(series=to_series, stock_pool='6')
    cont3 = get_connection(series=to_series, stock_pool='3')
    cont0 = get_connection(series=to_series, stock_pool='0')
    cont = {'0': cont0, '3': cont3, '6': cont6}

    bar = ProgressBar(total=len(l))
    for i in l:
        bar.log(code)
        tmp = i.split('/')[-1]
        code = tmp.split('.')[0]
        df = pd.read_csv(i)
        df['date'] = df['date'].apply(lambda date: pd.Timestamp(date))
        df.to_sql(name=code,
                  con=cont[code[2]],
                  if_exists='replace',
                  index=False)
        bar.move()
コード例 #2
0
def backup_csv(from_series='stock', to_path=PATH + '/data/backup'):
    """
    backup_csv(from_series='stock', to_path=PATH+'/data/backup'):

    Backup data periodly

    Input:
        from_series: (string): series of database backup from

        to_path: (string): path to backup to

    Return:
        None
    """
    l = []
    for stock_pool in ['0', '3', '6']:
        test_list = get_stock_list(series=from_series, stock_pool=stock_pool)
        l.extend(test_list)

    conf6 = get_connection(series=from_series, stock_pool='6')
    conf3 = get_connection(series=from_series, stock_pool='3')
    conf0 = get_connection(series=from_series, stock_pool='0')
    conf = {'0': conf0, '3': conf3, '6': conf6}

    bar = ProgressBar(total=len(l))
    for code in l:
        bar.log(code)
        sql = 'select distinct * from ' + code + ';'
        df = pd.read_sql(sql, conf[code[2]])
        df.to_csv(path_or_buf=to_path + '/' + code + '.csv', index=False)
        bar.move()
コード例 #3
0
ファイル: trade_system.py プロジェクト: jiahaoxing/quant
    def backtest(self, testlist='all', start='2014-01-01', end=None, savepath=PATH+'/data/backtest_records.csv'):
        """
        backtest(self, testlist='all', start='20110101', end=None):

        回测系统,统计成功率和平均收益率和持有天数
        # TODO: 对学习系统还应该有准确率和召回率

        Input:
            testlist: ('all' or list of code): 要进行回测的股票列表, 'all'为全部股票

            start: (string of date): 回测数据的起始时间(含start)

            end: (string of date): 回测数据的结束时间(含end)

            savepath: (string): 生成的记录报告的保存路径,为None则不保存
        Return:
            None
        """
        if testlist is 'all':
            testlist = os.listdir(FS_PATH)
            testlist = [filename.split('.')[0] for filename in testlist]
        #records, records_tmp = pd.DataFrame(), pd.DataFrame()
        records_tmp = [None for _ in testlist]
        cnt = 0
        bar = ProgressBar(total=len(testlist))
        for i, code in enumerate(testlist):
            bar.log(code)
            df = get_stock_data(code, start, end)
            df['date'] = df['date'].apply(lambda x: Timestamp(x))
            buy_record = self.buy(df)
            buy_and_sell_record = self.sell(df, buy_record)
            if buy_and_sell_record is not None and len(buy_and_sell_record) > 0:
                buy_and_sell_record = buy_and_sell_record.apply(lambda record: self.integrate(df, record), axis=1)
            buy_and_sell_record.insert(0,'code',[code for _ in range(len(buy_and_sell_record))])
            records_tmp[i] = buy_and_sell_record
            bar.move()
        records = pd.concat(records_tmp)
        if len(records) > 0:
            self.avggainrate = round(records['gainrate'].mean(), 4) - self.gainbias
            self.successrate = round(len(records[records['gainrate']>self.gainbias]) / len(records), 4)
            self.keepdays = round(records['keepdays'].mean(), 2)
            if savepath is not None:
                records.to_csv(savepath)
                print('records is saved at '+savepath)
        else:
            print('No records')
コード例 #4
0
ファイル: download_data.py プロジェクト: fswzb/quant-1
def download_data_append_hfq(start_date, end_date=None, from_code=None, \
                        update_list=False, to_series='stock'):
    conn6 = get_connection(series=to_series, stock_pool='6')
    conn3 = get_connection(series=to_series, stock_pool='3')
    conn0 = get_connection(series=to_series, stock_pool='0')
    if update_list:
        codelist = ts.get_stock_basics()
        try:
            os.system('rm '+PATH+'/data/code_list.npy')
        except Exception as err:
            pass
        # os.system('rm /home/wyn/stock/code_list_time.npy')
        np.save(PATH+'/data/code_list.npy', codelist.index)
        # np.save('/home/wyn/stock/code_list_time.npy', codelist['timeToMarket'].values)
    log_dir = PATH+'/data/log/'+str(date.today())
    record = open(log_dir + '.txt', 'a')

    stock_code_list = list(np.load(PATH+'/data/code_list.npy'))

    start_idx=0
    if from_code is not None:
        for stock_code in stock_code_list:
            start_idx += 1
            if stock_code == from_code:
                break

    #start_idx += 1
    turnover_list = []
    cnt = start_idx
    bar = ProgressBar(count=start_idx, total=len(stock_code_list[start_idx:]))

    for stock_code in stock_code_list[start_idx:]:
        stock_code_sql = ('sh' if stock_code[0] == '6' else 'sz') + stock_code
        bar.log(stock_code_sql)

        data = ts.get_h_data(stock_code, start=start_date, end=end_date,\
                            retry_count=5, pause=0.1, autype='hfq', drop_factor=False)
        data_t = ts.get_hist_data(stock_code, start=start_date, end=end_date,retry_count=5, pause=0.1)

        if data is None or len(data) == 0:
            record.writelines('\nERROR:::data of '+stock_code_sql \
                        +' may be missed. \n')
        else:
            if data_t is None or len(data_t) == 0:
                data['turnover'] = np.NAN
                print('No turnover data')
                turnover_list.append(stock_code)
            else:
                data['turnover'] = list(data_t['turnover'] / 100)

            data = data.reindex(index=data.index[::-1])
            if stock_code[0] == '6':
                data.to_sql(name=stock_code_sql, con=conn6, if_exists='append')
            elif stock_code[0] == '3':
                data.to_sql(name=stock_code_sql, con=conn3, if_exists='append')
            else:
                data.to_sql(name=stock_code_sql, con=conn0, if_exists='append')

        bar.move()
        cnt += 1

    if len(turnover_list) > 0:
        print('缺少换手率数据的有:')
        record.writelines('\n缺少换手率数据的有:\n')
        for code in turnover_list:
            print(code)
            record.writelines(code + ',')
        print('共', len(turnover_list), '只')

    record.close()