コード例 #1
0
ファイル: stock_env_test.py プロジェクト: linbirg/RL
from stock_env import StockEnv


env = StockEnv()

if __name__ == '__main__':
    env.render()
    # print(env.step(1))
    s,r,done = env.step(1)
    print(s)
    print(s.shape)
    print(r)

    print("=====================")
    s,r,done = env.step(0)
    print(s)
    print(r)

    print("====================")
    s,r,done = env.step(2)
    print(s)
    print(r)


    
コード例 #2
0
ファイル: A3C.py プロジェクト: linbirg/RL
class Worker(object):
    def __init__(self, name, globalAC):
        self.env = StockEnv()
        self.name = name
        self.AC = ACNet(name, globalAC)

    def work(self):
        global GLOBAL_RUNNING_R, GLOBAL_EP
        total_step = 1
        buffer_s, buffer_a, buffer_r = [], [], []
        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
            s = self.env.reset()
            ep_r = 0
            for ep_t in range(MAX_EP_STEP):
                if self.name == 'W_0':
                    self.env.render()
                a = self.AC.choose_action(s)
                s_, r, done = self.env.step(a)
                if ep_t == MAX_EP_STEP - 1: done = True
                ep_r += r
                buffer_s.append(s)
                buffer_a.append(a)
                buffer_r.append(r)

                if total_step % UPDATE_GLOBAL_ITER == 0 or done:  # update global and assign to local net
                    if done:
                        v_s_ = 0  # terminal
                    else:
                        v_s_ = SESS.run(self.AC.v,
                                        {self.AC.s: s_[np.newaxis, :]})[0, 0]
                    buffer_v_target = []
                    for r in buffer_r[::-1]:  # reverse buffer r
                        v_s_ = r + GAMMA * v_s_
                        buffer_v_target.append(v_s_)
                    buffer_v_target.reverse()

                    buffer_s, buffer_a, buffer_v_target = np.vstack(
                        buffer_s), np.vstack(buffer_a), np.vstack(
                            buffer_v_target)
                    feed_dict = {
                        self.AC.s: buffer_s,
                        self.AC.a_his: buffer_a,
                        self.AC.v_target: buffer_v_target,
                    }
                    test = self.AC.update_global(feed_dict)
                    buffer_s, buffer_a, buffer_r = [], [], []
                    self.AC.pull_global()

                s = s_
                total_step += 1
                if done:
                    if len(GLOBAL_RUNNING_R
                           ) == 0:  # record running episode reward
                        GLOBAL_RUNNING_R.append(ep_r)
                    else:
                        GLOBAL_RUNNING_R.append(0.9 * GLOBAL_RUNNING_R[-1] +
                                                0.1 * ep_r)
                    print(
                        self.name,
                        "Ep:",
                        GLOBAL_EP,
                        "| Ep_r: %i" % GLOBAL_RUNNING_R[-1],
                        '| Var:',
                        test,
                    )
                    GLOBAL_EP += 1
                    break
コード例 #3
0
    p0 = env.reset()
    agent.update_value(p0)

    for t in range(MAX_T):
        # select the next action
        action = agent.select_action(p0)
        # execute the next action and get next state and reward
        p = env.step()

        for i, a in enumerate(action):
            agent.act(i, a, p[i])

        agent.update_value(p)

        # render the portfolio value graph
        env.render(ax, agent.value)

        # prepare for next iteration
        p0 = p

        if agent.value[-1] >= TERMINAL_VALUE:
            print(
                "Episode %d finished after %f time steps with total value = %f"
                % (episode, t, agent.value[-1]))
            break
        elif agent.value[-1] <= 0:
            print(
                "Episode %d terminated after %f time steps with total value = %f. No more assets."
                % (episode, t, agent.value[-1]))
            break
        elif t >= MAX_T - 1:
コード例 #4
0
            reward.append(r)
        for i, r in enumerate(reward):
            if r == 0.0:
                reward[i] = (agent.value[-1] - agent.value[-2]) / agent.value[-1]

        # update the q table
        agent.update_Q(p0, action, reward, lr)

        # prepare for next iteration
        p0 = p
        agent.update_value(p)
        for i in reward:
            total_reward += i

        # render the portfolio value graph
        env.render(ax, agent.value[2:])
        # env.render(ax2) #ToDo: uncomment this to get a graph of the agent's performance and the stock behavior

        if agent.value[-1] >= TERMINAL_VALUE:
            print("Episode %d finished after %f time steps with total reward = %f. Total Value = %f"
                  % (episode, t, total_reward, agent.value[-1]))
            break
        elif agent.value[-1] <= 20:
            print("Episode %d terminated after %f time steps with total reward = %f. Total Value = %f. No more assets."
                  % (episode, t, total_reward, agent.value[-1]))
            break
        elif t >= MAX_T - 1:
            print("Episode %d terminated after %f time steps with total reward = %f. Total Value = %f"
                  % (episode, t, total_reward, agent.value[-1]))
            break
コード例 #5
0
class Worker(object):
    def __init__(self, name, globalAC):
        self.env = StockEnv()
        self.name = name
        self.AC = ACNet(name,self.env.get_state().shape[0], 4, globalAC)

    def _update_global_reward(self, ep_r):
        global GLOBAL_RUNNING_R, GLOBAL_EP
        if len(GLOBAL_RUNNING_R) == 0:  # record running episode reward
            GLOBAL_RUNNING_R.append(ep_r)
        else:
            GLOBAL_RUNNING_R.append(0.99 * GLOBAL_RUNNING_R[-1] + 0.01 * ep_r)
            logger.debug(
                [self.name,
                "Ep:",
                GLOBAL_EP,
                "| Ep_r: %i" % GLOBAL_RUNNING_R[-1]]
            )
            GLOBAL_EP += 1

    def _update_globa_acnet(self, done, s_, buffer_s, buffer_a, buffer_r):
        if done:
            v_s_ = 0  # terminal
        else:
            v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0]
            buffer_v_target = []
            for r in buffer_r[::-1]:  # reverse buffer r
                v_s_ = r + GAMMA * v_s_
                buffer_v_target.append(v_s_)
            buffer_v_target.reverse()

            buffer_s, buffer_a, buffer_v_target = np.vstack(
                buffer_s), np.array(buffer_a), np.vstack(buffer_v_target)
            feed_dict = {
                self.AC.s: buffer_s,
                self.AC.a_his: buffer_a,
                self.AC.v_target: buffer_v_target,
            }
            self.AC.update_global(feed_dict)

    def work(self):
        total_step = 1
        buffer_s, buffer_a, buffer_r = [], [], []
        self.env.reset()
        if self.name == 'W_0':
            self.env.render()
        while not COORD.should_stop():
            ep_r = 0
            while True:
                s = self.env._get_state()
                a, p = self.AC.choose_action(s)
                s_, r, done = self.env.step(a)
                if done: 
                    r = -0.5
                ep_r += r
                buffer_s.append(s)
                buffer_a.append(a)
                buffer_r.append(r)

                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
                    self._update_globa_acnet(done, s_, buffer_s, buffer_a,
                                             buffer_r)
                    buffer_s, buffer_a, buffer_r = [], [], []
                    self.AC.pull_global()
                # s = s_
                total_step += 1
                if done:
                    self._update_global_reward(ep_r)
                    break
                
                if self.name == 'W_0':
                    logger.debug(["s", s, " a:", a, " p:", p, " r:", r, " total_step:", total_step, 'total', self.env.total])
                    time.sleep(0.5)

    def train(self):
        global GLOBAL_RUNNING_R, GLOBAL_EP
        total_step = 1
        buffer_s, buffer_a, buffer_r = [], [], []
        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
            s = self.env.reset()
            ep_r = 0
            while True:
                # if self.name == 'W_0':
                    # self.env.render()
                a, p = self.AC.choose_action(s)
                s_, r, done = self.env.step(a)
                if done: r = -0.5
                ep_r += r
                buffer_s.append(s)
                buffer_a.append(a)
                buffer_r.append(r)

                if total_step % UPDATE_GLOBAL_ITER == 0 or done:  # update global and assign to local net
                    self._update_globa_acnet(done, s_, buffer_s, buffer_a,
                                             buffer_r)
                    buffer_s, buffer_a, buffer_r = [], [], []
                    self.AC.pull_global()

                if done:
                    self._update_global_reward(ep_r)
                    logger.debug(["s", s, " a:", a, " p:", p, " r:", r, " total_step:", total_step, 'total', self.env.total])
                    break

                s = s_
                total_step += 1
コード例 #6
0
class Worker(object):
    GAMMA = 0.9
    GLOBAL_RUNNING_R = []
    GLOBAL_EP = 0

    def __init__(self, sess, name, N_S, N_A, globalAC):
        self.SESS = sess
        self.N_S = N_S
        self.N_A = N_A
        self.env = StockEnv()
        self.name = name
        self.AC = A3CNet(self.SESS, self.name, self.N_S, self.N_A, globalAC)
        # self.saver = tf.train.Saver()

    def _record_global_reward_and_print(self, global_runing_rs, ep_r,
                                        global_ep, total_step):
        global_runing_rs.append(ep_r)
        try:
            print(self.name, "Ep:", global_ep,
                  "| Ep_r: %i" % global_runing_rs[-1], "| total step:",
                  total_step)
        except Exception as e:
            print(e)

    def train(self):
        buffer_s, buffer_a, buffer_r = [], [], []
        s = self.env.reset()
        ep_r = 0
        total_step = 1

        def reset():
            nonlocal ep_r, total_step
            self.env.reset()
            ep_r = 0
            total_step = 1

        while not COORD.should_stop() and self.GLOBAL_EP < MAX_GLOBAL_EP:
            # s = self.env.reset()
            # ep_r = 0
            # total_step = 1
            reset()
            while total_step < MAX_TOTAL_STEP:
                try:
                    s = self.env.get_state()
                    a, p = self.AC.choose_action(s)
                    s_, r, done = self.env.step(a)
                    if done:
                        r = -2

                    ep_r += r
                    buffer_s.append(s)
                    buffer_a.append(a)
                    buffer_r.append(r)

                    if total_step % UPDATE_GLOBAL_ITER == 0 or done:  # update global and assign to local net
                        self.AC.update(done, s_, buffer_r, buffer_s, buffer_a)
                        buffer_s, buffer_a, buffer_r = [], [], []

                    if done:
                        self._record_global_reward_and_print(
                            self.GLOBAL_RUNNING_R, ep_r, self.GLOBAL_EP,
                            total_step)
                        self.GLOBAL_EP += 1
                        reset()

                    # s = s_
                    total_step += 1
                    if self.name == 'W_0':
                        self.env.render()
                        time.sleep(0.05)
                        logger.debug([
                            "s ", s, " v ",
                            self.AC.get_v(s), " a ", a, " p ", p, " ep_r ",
                            ep_r, " total ", self.env.total, " acct ",
                            self.env.acct
                        ])
                except Exception as e:
                    print(e)

            try:
                print(self.name, " not done,may be donkey!", " total_step:",
                      total_step)
            except Exception as e:
                print(e)