forked from AustinNeverPee/Q-LearningTradingStrategy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RL_strategy.py
149 lines (115 loc) · 4.29 KB
/
RL_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Q-Learning Stock Trading Algorithm
State: TP Matrix
Action: {buy, sell, hold}
Reward: difference between current and previous portfolio value
Approximate function: Convolutional Neural Network
Train agent for only one stock
Using epsilon-greedy algorithm to train
Author: YANG, Austin Liu
"""
import pytz
from datetime import datetime
from zipline.algorithm import TradingAlgorithm
from zipline.utils.factory import load_bars_from_yahoo
import os
import global_values as gv
import train
import test
import numpy as np
from tensorflow.contrib import learn
from tensorflow.contrib.learn.python import SKCompat
import pdb
# tf.logging.set_verbosity(tf.logging.INFO)
# Training steps
training_steps = 100
# Number of agent training
Q_training_iters = 5
def initialize_log():
"""Initialize log module"""
# Create log directory
os.makedirs('log/' + gv.directory_log)
# Add file handle to mylogger
gv.mylogger.addFileHandler(gv.directory_log)
def load_data():
"""Load stock data
Both training and testing data
"""
# Load data manually from Yahoo! finance
# Training data
start_train_data = datetime(2010, 1, 1, 0, 0, 0, 0, pytz.utc)
end_train_data = datetime(2014, 1, 1, 0, 0, 0, 0, pytz.utc)
data_train = load_bars_from_yahoo(stocks=['AAPL'],
start=start_train_data,
end=end_train_data)
# Testing data
start_test_data = datetime(2014, 1, 1, 0, 0, 0, 0, pytz.utc)
end_test_data = datetime(2016, 1, 1, 0, 0, 0, 0, pytz.utc)
data_test = load_bars_from_yahoo(stocks=['AAPL'],
start=start_test_data,
end=end_test_data)
return [data_train, data_test]
def Q_update():
"""Update weights of three models:
"sell" model, "buy" model and "hold" model
"""
for action in gv.action_set:
gv.mylogger.logger.info("Update " + action + " model")
# # Configure a ValidationMonitor with training data
# validation_monitor = learn.monitors.ValidationMonitor(
# np.float32(Q_data[action]),
# np.float32(Q_labels[action]),
# every_n_steps=20)
# Create the estimator
Q_estimator = learn.Estimator(
model_fn=gv.cnn_model_fn,
model_dir=gv.model_dirs[action])
# Train the model
SKCompat(Q_estimator).fit(
x=train.Q_data[action].astype(np.float32),
y=train.Q_labels[action].astype(np.float32),
steps=training_steps)
# Evaluate the model and print results
eval_results = Q_estimator.evaluate(
x=train.Q_data[action].astype(np.float32),
y=train.Q_labels[action].astype(np.float32))
gv.mylogger.logger.info(eval_results)
def agent_train(data_train):
"""Train the agent
Learn from the environment
"""
for iter in range(0, Q_training_iters):
gv.mylogger.logger.info("Agent Iteration :" + str(iter + 1))
# Create algorithm object passing in initialize,
# handle_data functions and so on
algo = TradingAlgorithm(initialize=train.initialize,
handle_data=train.handle_data,
data_frequency='daily',
capital_base=gv.capital_base)
# Run algorithm
perf = algo.run(data_train)
# Train neural network with produced training set
Q_update()
# Update epsilon
gv.epsilon = pow(gv.epsilon, iter + 2)
def agent_test(data_test):
"""Test the agent
Check out the result of learning
"""
# Create algorithm object passing in initialize and
# handle_data functions
algo_obj = TradingAlgorithm(initialize=test.initialize,
handle_data=test.handle_data,
analyze=test.analyze,
data_frequency='daily',
capital_base=gv.capital_base)
# Run algorithm
perf = algo_obj.run(data_test)
if __name__ == '__main__':
# Initialize log module
initialize_log()
# Load stock data
[data_train, data_test] = load_data()
# Train the agent
agent_train(data_train)
# Test the agent
agent_test(data_test)