-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
67 lines (58 loc) · 2.45 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals
from pymongo import MongoClient, ASCENDING
from datetime import timedelta, datetime, date
import pytz
from collections import defaultdict
import numpy as np
class TwseDailyDataset(object):
def __init__(self, mongo_url='localhost'):
self.db = MongoClient(host=mongo_url)['twse_daily']
self.db.authenticate(name='crawler', password='crawler')
self.collections = [self.db[c] for c in ['BFI82U', 'FMTQIK', 'MI_INDEX', 'MI_MARGN', 'STOCK_DAY']]
def date_to_datetime(self, obj_or_year, month=None, day=None):
"""將date物件轉換成datetime物件,時間設為台灣時間的下午2點(收盤時間)。
pymongo寫入datetime物件時,會自動轉換成utc timezone。
"""
if month is None:
year, month, day = obj_or_year.year, obj_or_year.month, obj_or_year.day
else:
year = obj_or_year
dt = datetime(year=year, month=month, day=day, hour=14)
return pytz.timezone('Asia/Taipei').localize(dt)
def __getitem__(self, day):
data = dict()
for c in self.collections:
result = c.find_one({'_id': self.date_to_datetime(day)})
if result:
data.update(result)
return data
def gen_features(self, feature_names, only_trading_day=True):
start_day = date(2004, 4, 1)
end_day = date.today() if datetime.now().hour > 15 else date.today() - timedelta(days=1)
dates = []
day = start_day
while day <= end_day:
dates.append(day)
day += timedelta(days=1)
masks = np.zeros(len(dates), dtype=bool)
features = np.zeros((len(dates), len(feature_names)), dtype=np.float32)
cond = {n: 1 for n in feature_names}
cond.update({'_id': 0})
for i, day in enumerate(dates):
feat = {}
for c in self.collections:
result = c.find_one({'_id': self.date_to_datetime(day)}, cond)
if result:
feat.update(result)
if feat:
masks[i] = True
for j, name in enumerate(feature_names):
if name in feat:
features[i, j] = feat[name]
dates = np.array(dates)
if only_trading_day:
features = features[masks]
dates = dates[masks]
return dates, features
dailydata = TwseDailyDataset()