예제 #1
0
파일: tests.py 프로젝트: manongbang/snake
class SimTrajectoryTestCase(unittest.TestCase):
    def setUp(self):
        self.days = 30
        self.env = FastTradingEnv(name='000333.SZ', days=self.days)
        self.env.reset()

    def test_run_trajectory(self):
        self.assertTrue(self.env)
        resnet_model = ResnetTradingModel(
            name='test_resnet_model',
            model_dir='./test_models',
            input_shape=(self.days, 5),  # open, high, low, close, volume
        )
        self.assertTrue(resnet_model)
        model_policy = ModelTradingPolicy(
            action_options=self.env.action_options(),
            model=resnet_model,
            debug=False)
        self.assertTrue(model_policy)

        # start trajectory
        t = SimTrajectory(
            env=self.env,
            model_policy=model_policy,
            debug=False,
        )
        t.sim_run(rounds_per_step=23)
        print t.history
        self.assertEqual(len(t.history), self.days)
예제 #2
0
파일: tests.py 프로젝트: manongbang/snake
class ModelPolicyTestCase(unittest.TestCase):
    def setUp(self):
        self.days = 30
        self.env = FastTradingEnv(name='000333.SZ', days=self.days)

    def test_usage_sample(self):
        self.assertTrue(self.env)
        resnet_model = ResnetTradingModel(
            name='test_resnet_model',
            model_dir='test_models',
            input_shape=(self.days, 5),  # open, high, low, close, volume
        )
        self.assertTrue(resnet_model)
        exploit_policy = ModelTradingPolicy(
            action_options=self.env.action_options(),
            model=resnet_model,
            debug=False)
        self.assertTrue(exploit_policy)
        # init env and save snapshot
        self.env.reset()
        snapshot_v0 = self.env.snapshot()
        # init mcts block
        mcts_block = MCTSBuilder(self.env, debug=True)
        mcts_block.clean_up()
        # run batch and get q_table of next step
        root_node = mcts_block.run_batch(policy=exploit_policy,
                                         env_snapshot=snapshot_v0,
                                         batch_size=50)
        self.assertTrue(root_node)
        root_node.show_graph(name='model_policy_tree')
        q_table = root_node.q_table
        self.assertTrue(q_table)
        print q_table
        print root_node.show_final_state()
예제 #3
0
파일: tests.py 프로젝트: manongbang/snake
 def setUp(self):
     self.stock_name = '000333.SZ'
     self.days = 10
     self.env = FastTradingEnv(name=self.stock_name, days=self.days)
     action_options = self.env.action_options()
     self.policy = RandomTradingPolicy(action_options=action_options)
     self.TradingEnvNode = klass_factory(
         'Env_{name}_TradingNode'.format(name=self.stock_name),
         init_args={
             'env': self.env,
             'graph': Tree(),
         },
         base_klass=TradingNode)
예제 #4
0
def sim_run_func(params):
    from policy.resnet_trading_model import ResnetTradingModel
    from policy.model_policy import ModelTradingPolicy
    from trajectory.sim_trajectory import SimTrajectory

    # get input parameters
    stock_name = params['stock_name']
    input_shape = params['input_shape']
    rounds_per_step = params['rounds_per_step']
    model_name = params['model_name']
    model_dir = params['model_dir']
    sim_explore_rate = params['sim_explore_rate']
    specific_model_name = params.get('specific_model_name')
    debug = params.get('debug', False)
    # create env
    _env = FastTradingEnv(name=stock_name,
                          days=input_shape[0],
                          use_adjust_close=False)
    _env.reset()
    logger.debug('created env[{name}:{shape}]'.format(name=stock_name,
                                                      shape=input_shape))
    # load model
    _model = ResnetTradingModel(
        name=model_name,
        model_dir=model_dir,
        load_model=True,
        input_shape=input_shape,
        specific_model_name=specific_model_name,
    )
    logger.debug('loaded model[{d}/{name}]'.format(d=model_dir,
                                                   name=model_name))
    _policy = ModelTradingPolicy(action_options=_env.action_options(),
                                 model=_model,
                                 debug=debug)
    logger.debug('built policy with model[{name}]'.format(name=model_name))
    # start sim trajectory
    _sim = SimTrajectory(env=_env,
                         model_policy=_policy,
                         explore_rate=sim_explore_rate,
                         debug=debug)
    logger.debug('start simulate trajectory, rounds_per_step({r})'.format(
        r=rounds_per_step))
    _sim.sim_run(rounds_per_step=rounds_per_step)
    logger.debug('finished simluate trajectory, history size({s})'.format(
        s=len(_sim.history)))
    # collect data
    return _sim.history
예제 #5
0
파일: tests.py 프로젝트: manongbang/snake
class TradingNodeTestCase(unittest.TestCase):
    def setUp(self):
        self.stock_name = '000333.SZ'
        self.days = 10
        self.env = FastTradingEnv(name=self.stock_name, days=self.days)
        action_options = self.env.action_options()
        self.policy = RandomTradingPolicy(action_options=action_options)
        self.TradingEnvNode = klass_factory(
            'Env_{name}_TradingNode'.format(name=self.stock_name),
            init_args={
                'env': self.env,
                'graph': Tree(),
            },
            base_klass=TradingNode)

    def run_one_episode(self, root_node, debug=False):
        self.assertTrue(self.env and self.policy and root_node)
        self.env.reset()
        current_node = root_node
        while current_node:
            if debug:
                print current_node._state
            current_node = current_node.step(self.policy)
        return root_node

    def test_basic(self):
        self.assertTrue(self.env.name)
        self.assertTrue(self.TradingEnvNode)
        start_node = self.TradingEnvNode(state=None)
        root_node = self.run_one_episode(start_node, debug=True)
        self.assertTrue(root_node)
        self.assertTrue(start_node)
        self.assertEqual(root_node, start_node)
        self.assertEqual(root_node.get_episode_count(), 1)
        root_node.show_graph(name='basic')

    def test_multiple_episode(self):
        self.assertTrue(self.TradingEnvNode)
        count = 100
        root_node = self.TradingEnvNode(state=None)
        for i in range(count):
            root_node = self.run_one_episode(root_node)
        self.assertTrue(root_node)
        self.assertEqual(root_node.get_episode_count(), count)
        # TODO: test edges
        root_node.show_graph(name='multi_episode')
예제 #6
0
파일: tests.py 프로젝트: manongbang/snake
class RandomTradingPolicyTestCase(unittest.TestCase):
    def setUp(self):
        self.env = FastTradingEnv(name='000333.SZ', days=100)

    def test_trading_policy(self):
        action_options = self.env.action_options()
        policy = RandomTradingPolicy(action_options=action_options)
        self.assertTrue(policy)
        state = 'test state'
        action = policy.get_action(state)
        self.assertTrue(action in action_options)
예제 #7
0
    def __init__(self, env, model_policy, explore_rate=1e-01, debug=False):
        assert (env and model_policy)
        self._debug = debug
        self._main_env = env
        self._explore_rate = explore_rate
        self._exploit_policy = model_policy
        self._sim_policy = SimPolicy(
            action_options=self._main_env.action_options())

        # change every step of trajectory
        self._sim_history = []
        self._tmp_env = FastTradingEnv(name=self._main_env.name,
                                       days=self._main_env.days,
                                       use_adjust_close=False)
예제 #8
0
 def evaluate(self, basic_model, evaluate_model, valid_stocks, rounds):
     # load both models
     from policy.resnet_trading_model import ResnetTradingModel
     bm = ResnetTradingModel(
         name=basic_model,
         model_dir=self._model_dir,
         input_shape=self._input_shape,
         load_model=True,
         specific_model_name=basic_model
     )
     em = ResnetTradingModel(
         name=evaluate_model,
         model_dir=self._model_dir,
         input_shape=self._input_shape,
         load_model=True,
         specific_model_name=evaluate_model
     )
     # start evaluation
     _count = 0
     basic_avg_reward, evaluate_avg_reward = 0.0, 0.0
     while _count < rounds:
         stock_name = np.random.choice(valid_stocks, 1)[0]
         try:
             env = FastTradingEnv(
                 name=stock_name, days=self._input_shape[0], use_adjust_close=False
             )
         except Exception as e:
             logger.exception('env init error, {e}'.format(e=e))
             continue
         env_snapshot = env.snapshot()
         basic_evals = self.evaluate_on_env(bm, env)
         basic_avg_reward += basic_evals[-1]['real_reward']
         env.recover(env_snapshot)
         evaluate_evals = self.evaluate_on_env(em, env)
         evaluate_avg_reward += evaluate_evals[-1]['real_reward']
         _count += 1
     return basic_avg_reward / rounds, evaluate_avg_reward / rounds
예제 #9
0
 def setUp(self):
     self.days = 200
     self.env = FastTradingEnv(name='000333.SZ', days=self.days)
예제 #10
0
class FastTradingEnvTestCase(unittest.TestCase):
    def setUp(self):
        self.days = 200
        self.env = FastTradingEnv(name='000333.SZ', days=self.days)

    def test_normal_run(self):
        self.assertTrue(self.env)
        episodes = 10
        for _ in range(episodes):
            self.env.reset()
            done = False
            count = 0
            while not done:
                action = np.random.choice(self.env.action_options(),
                                          1)  # random
                observation, reward, done, info = self.env.step(action)
                self.assertTrue(reward >= -1.0)
                count += 1
        self.assertEqual(count, self.days)

    def test_performance(self):
        setup = """import numpy; from envs.fast_trading_env import FastTradingEnv; env = FastTradingEnv(name='000333.SZ', days=200)"""
        code = """env.reset()
done = False
while not done:
    action = numpy.random.choice(env.action_options(), 1)
    obs, reward, done, info = env.step(action)
"""
        count = 100
        total = timeit.timeit(code, setup=setup, number=count)
        print 'avg: {t} seconds'.format(t=total / count)

    def test_snapshot_recover(self):
        self.assertTrue(self.env)
        snapshot = None

        self.env.reset()
        done = False
        count = 0
        sum_reward = 0.0
        snapshot_sum_reward = 0.0
        while not done:
            action = 1  # fix to long
            observation, reward, done, info = self.env.step(action)
            count += 1
            sum_reward += reward
            if count == self.days / 2:
                # do snapshot
                snapshot = self.env.snapshot()
                snapshot_sum_reward = sum_reward
        self.assertTrue(snapshot)
        self.assertEqual(count, self.days)

        # recover to snapshot
        self.env.recover(snapshot)
        done = False
        count = 0
        recover_sum_reward = snapshot_sum_reward
        while not done:
            action = 1  # fix to long
            observation, reward, done, info = self.env.step(action)
            count += 1
            recover_sum_reward += reward
        self.assertEqual(count, self.days / 2)
        self.assertEqual(sum_reward, recover_sum_reward)

    def test_buy_hold_to_end(self):
        self.env.reset()
        done = False
        count = 0
        while not done:
            action = 1  # fix to long
            observation, reward, done, info = self.env.step(action)
            count += 1
            if not done:
                self.assertAlmostEqual(reward, 0.0)
            else:
                self.assertGreater(reward, 0.0)
예제 #11
0
파일: tests.py 프로젝트: manongbang/snake
 def setUp(self):
     self.days = 30
     self.env = FastTradingEnv(name='000333.SZ', days=self.days)
     self.env.reset()
예제 #12
0
파일: tests.py 프로젝트: manongbang/snake
 def setUp(self):
     self.stock_name = '000333.SZ'
     self.days = 30
     self.env = FastTradingEnv(name=self.stock_name, days=self.days)
예제 #13
0
파일: tests.py 프로젝트: manongbang/snake
class MCTSBuilderTestCase(unittest.TestCase):
    def setUp(self):
        self.stock_name = '000333.SZ'
        self.days = 30
        self.env = FastTradingEnv(name=self.stock_name, days=self.days)

    def test_mcts_batch_debug(self):
        policy = RandomTradingPolicy(action_options=self.env.action_options())
        self.env.reset()
        block = MCTSBuilder(self.env, debug=True)

        snapshot_v0 = self.env.snapshot()
        block.clean_up()
        root_node = block.run_batch(policy,
                                    env_snapshot=snapshot_v0,
                                    batch_size=10)
        self.assertTrue(root_node)
        root_node.show_graph(name='mcts_batch')
        self.assertTrue(root_node.q_table)
        for action, t_reward in enumerate(root_node.q_table):
            self.assertGreaterEqual(t_reward, 0.0)
        print root_node.q_table

    def test_mcts_start_from_snapshot(self):
        # buy and hold policy
        hold_policy = HoldTradingPolicy(
            action_options=self.env.action_options(), action_idx=1)
        self.env.reset()
        block = MCTSBuilder(self.env, debug=True)

        # first run
        snapshot_v0 = self.env.snapshot()
        block.clean_up()
        root_node = block.run_once(hold_policy, env_snapshot=snapshot_v0)
        self.assertTrue(root_node)
        root_node.show_graph(name='mcts_batch_from_snapshot')

        # generate another snapshot
        mid = self.days / 2
        self.env.recover(snapshot_v0)
        fast_moving(self.env, hold_policy, steps=mid)
        snapshot_v1 = self.env.snapshot()

        # second run from snapshot
        block.clean_up()
        self.env.reset()  # call reset to randomly initialize env
        # recover env to snapshot_mid
        root_node = block.run_once(hold_policy, env_snapshot=snapshot_v1)
        self.assertTrue(root_node)
        root_node.show_graph()

    def test_profile(self):
        # buy and hold policy
        policy = HoldTradingPolicy(action_options=self.env.action_options(),
                                   action_idx=1)
        self.env.reset()

        with Profiling(cProfile.Profile()):
            snapshot_v0 = self.env.snapshot()
            block = MCTSBuilder(self.env, debug=False)
            block.clean_up()
            block.run_batch(policy, env_snapshot=snapshot_v0, batch_size=100)