Beispiel #1
0
def train():
  """Run RL training loop"""
  ch_q = zero_q(A.CH_MOVE)
  us_q = zero_q(A.US_MOVE)

  for cnt in tqdm(range(NITER)):
    if cnt % NRESTART == 0:
      state = first_state()
    action = zero_action()
    action[A.index.US_MOVE] = random(A.US_MOVE)
    action[A.index.CH_MOVE] = random(A.CH_MOVE)

    rand_choice = cnt < NRANDOM or np.random.random(1)[0] < EPSILON

    if not rand_choice:
      action[A.index.US_MOVE] = np.argmax(us_q[tuple(state)])
    next_state, reward = m.trans(state, action, A.US_MOVE)
    us_q = q_update(state, next_state, us_q,
      next_state[S.index.LAST_US_MOVE], reward[R.index.US_PROFIT])
    ch_q = q_update(state, next_state, ch_q,
      next_state[S.index.LAST_CH_MOVE], reward[R.index.CH_PROFIT])
    state = next_state

    if not rand_choice:
      action[A.index.CH_MOVE] = np.argmax(ch_q[tuple(state)])
    next_state, reward = m.trans(state, action, A.CH_MOVE)
    us_q = q_update(state, next_state, us_q,
      next_state[S.index.LAST_US_MOVE], reward[R.index.US_PROFIT])
    ch_q = q_update(state, next_state, ch_q,
      next_state[S.index.LAST_CH_MOVE], reward[R.index.CH_PROFIT])
    state = next_state

  np.save('us_q', us_q)
  np.save('ch_q', ch_q)
  return us_q, ch_q
Beispiel #2
0
def train():
    """Run RL training loop"""
    q = zero_q(A.TRADE)

    for cnt in tqdm(range(NITER)):
        if cnt % NRESTART == 0:
            state = first_state()
            state = pm.trans_price(state)
        if cnt % NRESTART == NRESTART - 1:
            end = True
        else:
            end = False
        action = zero_action()
        action[A.index.TRADE] = random(A.TRADE)

        rand_choice = cnt < NRANDOM or np.random.random(1)[0] < EPSILON

        if not rand_choice:
            action[A.index.TRADE] = np.argmax(q[tuple(state)])
        next_state, reward = pm.trans_holding(state, action, end=end)
        next_state = tr.trans(next_state)
        next_state = pm.trans_price(next_state)
        q = q_update(state, next_state, q, action[A.index.TRADE],
                     reward[R.index.TRADER_PROFIT])
        state = next_state

    np.save('pm_q', q)
    return q
Beispiel #3
0
def lock_extras(state):
  """Lock extra states to zero (which is what we trained on)"""
  first = first_state()
  new = np.copy(state)
  valid = {I.LAST_US_MOVE, I.LAST_CH_MOVE, I.USEC_GROWTH, I.CHEC_GROWTH}
  extras = list(set(range(S.N)) - valid)
  new[extras] = first[extras]
  return new
Beispiel #4
0
def test_state_trans():
  state = first_state()
  show_s(state)
  state, r1 = trans(
    state, [A.US_MOVE.DEESCALATE, A.CH_MOVE.DEESCALATE], A.US_MOVE)
  show_s(state)
  print(r1)
  state, r2 = trans(
    state, [A.US_MOVE.DEESCALATE, A.CH_MOVE.DEESCALATE], A.CH_MOVE)
  show_s(state)
  print(r2)
Beispiel #5
0
def test_train():
  train()
  us_q, ch_q = load_q()
  print('US Q')
  show_q(us_q, A.US_MOVE)
  print('CH Q')
  show_q(ch_q, A.CH_MOVE)
  state = first_state()
  show_s(state)
  for _ in range(5):
    state = trans(state)
    show_s(state)
Beispiel #6
0
def test_trans():
    state = first_state()
    action = zero_action()
    action[A.index.TRADE] = A.TRADE.BUY
    for _ in range(3):
        state = tr.trans(state)
        state = trans_price(state)
        state, reward = trans_holding(state, action)
        show_s(state)
        print(reward)
    state = tr.trans(state)
    state = trans_price(state)
    state, reward = trans_holding(state, action, end=True)
    show_s(state)
    print(reward)
Beispiel #7
0
def plot_usec_growth():
  from matplotlib import pyplot as plt, rc
  fig, ax = plt.subplots(1, 1)
  fig.set_size_inches(*FIGSIZE)

  n = 4
  state = first_state()
  show_s(state)
  history = [state[I.USEC_GROWTH]]
  for _ in range(n-1):
    state = trans(state)
    show_s(state)
    history += [state[I.USEC_GROWTH]]

  plt.plot(range(n), history, 'o-')
  plt.yticks(range(S.USEC_GROWTH.N), [{'MT3': 'HIGH', 'B23': 'MEDIUM', 'LT2': 'LOW'}.get(a, a) for a in S.USEC_GROWTH.values])
  plt.ylabel('us econ growth')
  plt.xlabel('round')
  plt.title('us econ growth projection')
  plt.savefig('usec.png')
Beispiel #8
0
def plot_price():
    from matplotlib import pyplot as plt, rc
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(*FIGSIZE)

    for _ in range(6):
        history = []
        state = first_state()
        action = zero_action()
        state = trans_price(state)
        history = [price(state)]
        n = 4
        for _ in range(n - 1):
            state = tr.trans(state)
            state = trans_price(state)
            history.append(price(state))
        plt.plot(range(n), history, '-')

    plt.ylabel('price')
    plt.xlabel('round')
    plt.title('price projection')
    plt.savefig('price.png')
Beispiel #9
0
def test_show_s():
  show_s(first_state())