예제 #1
0
    def __init__(self, param):
        self.logger = logging.getLogger(__name__)

        # memento for state machine
        self.memento = Memento(param['memento'])
        self.memento.load()

        # prepare database
        self.spider = param['spider']
        self.assembler = param['assembler']
        self.feed_builder = param['feed_builder']
        self.database = Database()

        # all states are listed here
        self.state_map = {
            'InitialState': InitialState(self),
            'LoadState': LoadState(self),
            'SpiderState': SpiderState(self),
            'AssemblerState': AssemblerState(self),
            'DatabaseState': DatabaseState(self),
            'FinalState': FinalState(self),
        }

        # set transition table
        self.transition_table = param['transition_table']

        # set current state
        self.curr_state = 'InitialState'
예제 #2
0
class DatabaseTest(unittest.TestCase):
    def setUp(self):
        self.database = Database()

    def tearDown(self):
        self.database = None
    
    def test_get_stock_symbol_list(self):
        stock_symbol_list = self.database.get_stock_symbol_list()
        self.assertEqual(stock_symbol_list[0], ('1101', datetime.date(1962, 2, 9)))
        self.assertEqual(stock_symbol_list[1], ('1102', datetime.date(1962, 6, 8)))
예제 #3
0
class DatabaseTest(unittest.TestCase):
    def setUp(self):
        self.database = Database()

    def tearDown(self):
        self.database = None

    def test_get_stock_symbol_list(self):
        stock_symbol_list = self.database.get_stock_symbol_list()
        self.assertEqual(stock_symbol_list[0],
                         ('1101', datetime.date(1962, 2, 9)))
        self.assertEqual(stock_symbol_list[1],
                         ('1102', datetime.date(1962, 6, 8)))
예제 #4
0
class StockSymbolAnalyzer():
    def __init__(self):
        self.database = Database()
        self.market_category = {
            '\xe4\xb8\x8a\xe5\xb8\x82' : 'TW',
            '\xe4\xb8\x8a\xe6\xab\x83' : 'TWO',
        }

    def get_stock_symbol_list(self):
        result = []
        records = self.database.get('StockSymbolList', None)
        for record in records:
            entry = { 
                'stock_symbol' : record[0], 
                'listing_date' : record[1], 
                'market_category' : self.market_category[record[2]] 
            }
            result.append(entry)
        return result
예제 #5
0
 def __init__(self):
     self.database = Database()
     self.market_category = {
         '\xe4\xb8\x8a\xe5\xb8\x82' : 'TW',
         '\xe4\xb8\x8a\xe6\xab\x83' : 'TWO',
     }
예제 #6
0
 def __init_value(self, param):
     output = {}
     database = Database()
     for account in param['account_list']:
         output[account] = TimeSeries.create(database.get(account, param))
     return output
예제 #7
0
 def setUp(self):
     self.database = Database()
예제 #8
0
 def __init_value(self, param):
     output = {}
     database = Database()
     for account in param["account_list"]:
         output[account] = TimeSeries.create(database.get(account, param))
     return output
예제 #9
0
 def setUp(self):
     self.database = Database()