コード例 #1
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(
         Message(
             'vault_request', 'add_stock',
             ManagedStock(data['acronym'],
                          last_price=float(data['price']),
                          shares=int(data['shares']))))
コード例 #2
0
    def setUp(self) -> None:
        self.uut = TestQSM('test')
        self.uut.start()

        self.jcontr = JSONController('test_json', output_directory, 1)
        self.jcontr.start()

        self.client = MessageHandler('client')
コード例 #3
0
    def __init__(self, name: str, msg_list: list = None):
        super().__init__()
        self.is_exitting = False
        self.name = name

        self.mappings = {}
        self.setup_states()

        self.msg_list = msg_list if msg_list is not None else []
        if 'all' not in self.msg_list:
            self.msg_list.append('all')

        self.setup_msg_mappings(self.msg_list)

        self.handler = MessageHandler(self.name, self.msg_list)  # Because otherwise we can't join from 'final'
        self.states = Queue()
        self.append_state('init')
コード例 #4
0
class SQLTestCase(unittest.TestCase):
    def setUp(self) -> None:
        StockVault.setup_instance(db_directory='/tmp/robin_test_db')
        self.client = MessageHandler('client_handler')

    def tearDown(self) -> None:
        StockVault.sjoin()
        self.client.join()
        with os.scandir('/tmp/robin_test_db/') as it:
            for entry in it:
                os.remove(entry.path)
        os.rmdir('/tmp/robin_test_db')

    def test_file_creation(self):
        sql.setup_db('/tmp/robin_test_db')
        self.assertTrue(path.exists(path.join('/tmp/robin_test_db', db_name)),
                        'Can create test files without error')

    def test_add_stock(self):
        self.client.send(Message('vault_request', 'add_monitor',
                                 Stock('AAPL')))
        time.sleep(0.5)
        names = StockVault.get_stock_names()
        print(names)
        self.assertTrue('AAPL' in names,
                        'An inserted stock should appear in the database')

    def test_remove_stock(self):
        self.client.send(Message('vault_request', 'add_monitor',
                                 Stock('AAPL')))
        self.client.send(
            Message('vault_request', 'remove_monitor', Stock('AAPL')))
        time.sleep(0.5)
        names = StockVault.get_stock_names()
        self.assertTrue('AAPL' not in names,
                        'A removed stock should not appear in the database')

    def test_get_info(self):
        self.client.send(Message('vault_request', 'add_monitor',
                                 Stock('AAPL')))
        time.sleep(0.1)
        info = StockVault.get_info('AAPL')
        self.assertTrue(
            info.acronym == 'AAPL',
            'Inserted database element should be able to be found by info grab'
        )
コード例 #5
0
 def setUp(self) -> None:
     self.handler = MessageHandler('test_handler', ['trade_request', 'monitor_config', 'vault_request'])
コード例 #6
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     up = BalanceUpdate(data['amount'])
     up.date = dt.datetime.strptime(data['date'], date_format)
     handler.send(Message('vault_request', 'add_balance_update', up))
コード例 #7
0
 def setUp(self) -> None:
     StockVault.setup_instance(db_directory='/tmp/robin_test_db')
     self.client = MessageHandler('client_handler')
コード例 #8
0
        time.sleep(self.interval)


class MarketTimer(Timer):
    """A Normal Timer class that automatically pauses during after hours."""
    def idle_state(self):
        dow = dt.datetime.now().isoweekday()
        hour = dt.datetime.now().hour
        minute = dt.datetime.now().minute
        # print('The date is {} @ {}:{}'.format(dow, hour, minute))
        if not self.paused and ((dow == 6 or dow == 7) or
                                ((hour >= 16 or hour < 9) or (hour == 9 and minute < 30))):
            print('Pausing Timer')
            self.paused = True
        elif self.paused and ((dow != 6 and dow != 7) and ((9 < hour < 16) or (hour == 9 and minute >= 30))):
            print('Unpausing Timer')
            self.paused = False
        super().idle_state()


if __name__ == "__main__":
    from tradebot.messaging.message import MessageHandler

    t = Timer('timer1')
    h = MessageHandler('timer_rx', ['timer'])
    t.start()

    while True:
        if h.receive() is not None:
            print('Timer triggered')
コード例 #9
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(Message('monitor_config', 'limit', data['limit']))
コード例 #10
0
class StateTestCase(unittest.TestCase):
    def setUp(self) -> None:
        self.uut = TestQSM('test')
        self.uut.start()

        self.jcontr = JSONController('test_json', output_directory, 1)
        self.jcontr.start()

        self.client = MessageHandler('client')

    def tearDown(self) -> None:
        self.uut.join()
        self.jcontr.join()
        self.client.join()

    def test_save(self):
        self.client.send(Message('json_request', 'save'))
        time.sleep(1)
        self.assertTrue(
            path.exists(path.join(output_directory, state_save_file)),
            'Saving generates a state file')
        self.assertTrue(
            path.getsize(path.join(output_directory, state_save_file)) > 0,
            'Saving should populate file')

    def test_load(self):
        self.client.send(Message('json_request', 'save'))
        time.sleep(1)
        data = JSONController.load(output_directory)
        self.assertTrue(
            len(data) == 2, "Loading should've created two controllers")

        found = False
        for d in data:
            if d.name == 'test':
                found = True

        self.assertTrue(found,
                        "TestQSM should've been loaded by the controller")

    def test_config_filename(self):
        self.client.send(
            Message('json_config', 'filename', '/tmp/robin_test_2/'))
        self.client.send(Message('json_request', 'save'))
        time.sleep(1)
        self.assertTrue(
            path.exists(path.join('/tmp/robin_test_2/', state_save_file)),
            'Modifying filename, configures to save in new location')

    def test_config_count(self):
        self.client.send(Message('json_config', 'count', 2))
        second_uut = TestQSM('test2')
        second_uut.start()
        self.client.send(Message('json_request', 'save'))
        time.sleep(1)
        data = JSONController.load(output_directory)
        self.assertTrue(
            len(data) == 3,
            "Modifying the count value should've generated three controllers, not {}"
            .format(len(data)))
コード例 #11
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     if data['title'] == 'force':
         handler.send(Message('monitor_config', 'update'))
     else:
         super().handle(handler, data)
コード例 #12
0
class QSM(Process, Dictable):
    def __init__(self, name: str, msg_list: list = None):
        super().__init__()
        self.is_exitting = False
        self.name = name

        self.mappings = {}
        self.setup_states()

        self.msg_list = msg_list if msg_list is not None else []
        if 'all' not in self.msg_list:
            self.msg_list.append('all')

        self.setup_msg_mappings(self.msg_list)

        self.handler = MessageHandler(self.name, self.msg_list)  # Because otherwise we can't join from 'final'
        self.states = Queue()
        self.append_state('init')

    @property
    def dict(self) -> dict:
        print('Converting qsm to dict')
        d = {'name': self.name,
             'msg_list': self.msg_list,
             'handler': self.handler.dict}
        return d

    @staticmethod
    def from_dict(d: dict) -> object:
        sm = QSM(d['name'], d['msg_list'])
        sm.handler = MessageHandler.from_dict(d['handler'])
        sm.append_state('init')
        return sm

    def join(self, timeout: Optional[float] = -1) -> None:
        self.append_state('exit')
        super().join(timeout=timeout)
        self.handler.join()

    def run(self) -> None:
        # sys.stdin = self.cclient
        while not self.is_exitting:
            try:
                s = self.states.get_nowait()
                # print('Going to {}'.format(s['state']))
                if s['payload'] is not None:
                    self.mappings[s['state']](s['payload'])
                else:
                    self.mappings[s['state']]()
            except Empty as _:
                self.mappings['idle']()

    def setup_msg_mappings(self, msg_list: list):
        """Sets up the self.msg_map dictionary, mapping message titles to state names,
        by default, appends '_msg' to msg_list entries"""
        for m in msg_list:
            # print('Setting up message state for {} to {}'.format(m, eval('self.' + m + '_msg')))
            self.mappings[m] = eval('self.' + m + '_msg')

    def all_msg(self, msg: Message):
        if msg.msg == 'save':
            self.handler.send(Message('json_update', self.name, {'dtype': str(type(self).__name__),
                                                                 'package': inspect.getmodulename(
                                                                     inspect.getfile(self.__class__)),
                                                                 'path': str(inspect.getfile(self.__class__)),
                                                                 'data': self.dict}))

    def setup_states(self):
        """Sets up the self.mappings dictionary, mapping state names to state methods,
        by default 'init', 'idle', 'exit', and 'final' states are set up."""
        self.mappings['init'] = self.initial_state
        self.mappings['idle'] = self.idle_state
        self.mappings['exit'] = self.exit_state
        self.mappings['final'] = self.final_state

    def append_state(self, state: str, payload: object = None):
        try:
            # print('Appending {}'.format(state))
            self.states.put_nowait({'state': state, 'payload': payload})
        except Full as _:
            print('{} state queue is full, skipping {}'.format(self.name, state))

    def append_states(self, states: list, payloads: list = None):
        for i, s in enumerate(states):
            p = payloads[i] if payloads is not None else None
            self.append_state(s, p)

    def initial_state(self):
        """This is the first state to execute, always"""
        pass

    def idle_state(self):
        """This is the state the is triggered when the state machine has nowhere to go to."""
        m = self.handler.receive(block=False)
        if m is not None:
            print('Received message {}'.format(m))
            self.append_state(m.title, m)

    def exit_state(self):
        """This stateis triggered when the qsm should exit, enqueues the 'final' state"""
        # print('Exitting')
        self.append_state('final')

    def final_state(self):
        """This is the final state to execute, always"""
        self.is_exitting = True
コード例 #13
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(Message('trade_request', 'sell', data))
コード例 #14
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(Message('all', 'load'))
コード例 #15
0
def __setup_handler(name: str = 'cli'):
    global cli_handler

    if cli_handler is None:
        cli_handler = MessageHandler(name)
コード例 #16
0
if __name__ == '__main__':
    print('Welcome to iggy\'s Robinhood Trading bot')
    print('Reading in stock info')
    stocks = read_stocks('./stocks.txt')

    print('Configuring limit info')
    limit_dict = {}
    for s in stocks:
        print('configuring {}'.format(s[0]))
        limit_dict[s[0].acronym] = s[1]

    print('Creating modules')
    t = timer.MarketTimer('update_trigger', interval=10)
    relay = TimerRelay('timer_relay', Message('monitor_config', 'update'))
    rx = MessageHandler('receiver', ['monitor_config'])

    p = PyrhAdapter()
    p.login()

    dm = monitor.StockMonitor('monitor', [Stock(t[0].acronym) for t in stocks],
                              limit_dict, p)
    dc = datacontroller.DataController('data_controller')
    # tc = tradecontroller.TradeController('trade_controller', 0)

    print('Starting modules')
    relay.start()
    p.start()
    dm.start()
    dc.start()
    # tc.start()
コード例 #17
0
from tradebot.messaging.qsm import QSM
from tradebot.messaging.message import Message


class TimerRelay(QSM):
    def __init__(self, name: str, target_msg: Message):
        super().__init__(name, ['timer'])
        self.target = target_msg

    def timer_msg(self, msg: Message):
        print('Relay {} triggered'.format(self.name))
        self.handler.send(self.target)


if __name__ == '__main__':
    from tradebot.controllers.timer import Timer
    from tradebot.messaging.message import MessageHandler

    t = Timer('timer1', 1)
    tr = TimerRelay('relay', Message('something'))
    rx = MessageHandler('receiver', ['something'])

    tr.start()
    t.start()

    while True:
        if rx.receive() is not None:
            print('Relay fired')
コード例 #18
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(Message('monitor_config', 'add', Stock(data['acronym'])))
コード例 #19
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(
         Message(
             'vault_request', 'update_monitor',
             Stock(data['acronym'], data['ask_price'], data['bid_price'],
                   data['ask_size'], data['bid_size'])))
コード例 #20
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     trans = StockTransaction(data['id'], '', True, data['price'],
                              data['shares'])
     trans.date = dt.datetime.strptime(data['date'], date_format)
     handler.send(Message('vault_request', 'add_transaction', trans))
コード例 #21
0
 def from_dict(d: dict) -> object:
     sm = QSM(d['name'], d['msg_list'])
     sm.handler = MessageHandler.from_dict(d['handler'])
     sm.append_state('init')
     return sm
コード例 #22
0
        second_uut.start()
        self.client.send(Message('json_request', 'save'))
        time.sleep(1)
        data = JSONController.load(output_directory)
        self.assertTrue(
            len(data) == 3,
            "Modifying the count value should've generated three controllers, not {}"
            .format(len(data)))


if __name__ == '__main__':
    # unittest.main()
    uut = TestQSM()
    uut.start()

    jcontr = JSONController('test_json', output_directory, 1)
    jcontr.start()

    client = MessageHandler('client')

    client.send(Message('json_request', 'save'))
    time.sleep(1)
    data = JSONController.load(output_directory)
    assert len(data) == 2

    found = False
    for d in data:
        if isinstance(d, TestQSM):
            found = True
    assert found
コード例 #23
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(Message('vault_request', 'remove_stock', ManagedStock('None', data['id'], data['shares'])))
コード例 #24
0
 def handle(self, handler: MessageHandler, data: dict) -> None:
     handler.send(
         Message(
             'vault_request', 'update_stock',
             ManagedStock(data['acronym'], data['id'], data['shares'],
                          data['price'])))
コード例 #25
0
class CLITest(unittest.TestCase):
    def setUp(self) -> None:
        self.handler = MessageHandler('test_handler', ['trade_request', 'monitor_config', 'vault_request'])

    def tearDown(self) -> None:
        self.handler.join()
        cli.cli_handler = None

    def test_add(self):
        parse_command('add stock AAPL shares=10; add stock AAPL 1 3.0; add stock AAPL shares=10 price=4.50')
        msg = self.handler.receive()

        self.assertEqual(msg.title, 'vault_request', 'Add command should produce vault_request message')
        self.assertEqual(msg.msg, 'add_stock', 'Add command should produce an add_stock message')
        self.assertTrue(msg.payload == ManagedStock('AAPL', shares=10, last_price=-1))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == ManagedStock('AAPL', shares=1, last_price=3))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == ManagedStock('AAPL', shares=10, last_price=4.5))

    def test_remove(self):
        parse_command('remove stock 123; remove stock 123 10')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'vault_request', 'Remove command should produce vault_request message')
        self.assertEqual(msg.msg, 'remove_stock', 'Remove command should produce a remove_stock message')

        self.assertTrue(msg.payload == ManagedStock('None', table_id=123, shares=-1))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == ManagedStock('None', table_id=123, shares=10))

    def test_list(self):
        parse_command('list; list acronym AAPL')

        # msg = self.handler.receive()
        # self.assertEqual(msg.title, 'vault_request', 'List command should produce vault_request message')
        # self.assertEqual(msg.msg, 'get_stock_names', 'List command should produce a get_stock_names message')

    def test_limit(self):
        parse_command('add limit 123; add limit 123 % 1.05 0.95; add limit 123 0.95 0.85')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'monitor_config', 'List command should produce monitor_config message')
        self.assertEqual(msg.msg, 'limit', 'List command should produce a limit message')

        self.assertTrue(msg.payload == LimitDescriptor(123))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == LimitDescriptor(123))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == LimitDescriptor(123, upper=0.95, lower=0.85))

    def test_buy(self):
        parse_command('buy AAPL; buy AAPL 10; buy AAPL 10 3.75')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'trade_request', 'Buy command should produce a trade_request message')
        self.assertEqual(msg.msg, 'buy', 'Buy command should produce a buy message')

        self.assertTrue(msg.payload == ManagedStock('AAPL', last_price=-1))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == ManagedStock('AAPL', shares=10, last_price=-1))

        msg = self.handler.receive()
        self.assertTrue(msg.payload == ManagedStock('AAPL', shares=10, last_price=3.75))

    def test_sell(self):
        parse_command('sell 123; sell 123 10')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'trade_request', 'Sell command should produce a trade_request message')
        self.assertEqual(msg.msg, 'sell', 'Sell command should produce a sell message')

        self.assertEqual(msg.payload['id'], 123)
        self.assertEqual(msg.payload['shares'], -1)

        msg = self.handler.receive()
        self.assertEqual(msg.payload['id'], 123)
        self.assertEqual(msg.payload['shares'], 10)

    def test_update(self):
        parse_command('update')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'monitor_config', 'Update command should produce a monitor_config message')
        self.assertEqual(msg.msg, 'update', 'Update command should produce an update message')

    def test_export(self):
        parse_command('export')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'all', 'Export command should produce a global message')
        self.assertEqual(msg.msg, 'save', 'Export command should produce an save message')

    def test_import(self):
        parse_command('import')

        msg = self.handler.receive()
        self.assertEqual(msg.title, 'all', 'Import command should produce a global message')
        self.assertEqual(msg.msg, 'load', 'Import command should produce an load message')