class MpkDbTests(TestCase):
    def setUp(self):
        super().setUp()
        self.environ_dict = {'TRAM_ROOT': 'mock_value'}
        self.os_environ_mock = patch.dict('db.mpk.os.environ',
                                          self.environ_dict).start()
        self.sqlite_mock = patch('db.mpk.sqlite3').start()
        self.mock_db_connection = MagicMock()
        self.mock_db_cursor = MagicMock()
        self.mock_db_connection.cursor.return_value = self.mock_db_cursor
        self.sqlite_mock.connect.return_value = self.mock_db_connection
        self.mpk = MpkDb()

    def tearDown(self):
        super().tearDown()

    def test_constructor_default(self):
        mock_db_file = self.environ_dict['TRAM_ROOT'] + '/data/baza.ready.zip'
        self.assertEqual(self.mpk.db_file, mock_db_file)
        self.sqlite_mock.connect.assert_called_once_with(mock_db_file)
        self.mock_db_connection.cursor.assert_called_once_with()
        self.assertEqual(self.mpk.cursor, self.mock_db_cursor)

    def test_get_lines(self):
        self.mock_db_cursor.fetchall.return_value = [(18, ), (19, ), (20, )]
        res = self.mpk.get_lines()
        self.assertEqual(res, [18, 19, 20])
        self.assertTrue(self.mock_db_cursor.execute.called)
        self.assertEqual(self.mock_db_cursor.execute.call_count, 1)

    def test_get_line_points(self):
        mock_line = 18
        mock_ret_dict = {963234: [11948, 9801, 9799], 963235: [7888, 9565]}
        self.mock_db_cursor.fetchall.return_value = [(963234, 11948),
                                                     (963234, 9801),
                                                     (963234, 9799),
                                                     (963235, 7888),
                                                     (963235, 9565)]
        ret = self.mpk.get_line_points(mock_line)
        self.assertEqual(ret, mock_ret_dict)
        self.assertTrue(self.mock_db_cursor.execute.called)
        self.assertEqual(self.mock_db_cursor.execute.call_count, 1)
Exemple #2
0
class TimetableWorker(YieldPeriodicCallback):
    def __init__(self):
        self.number = 1
        self.last_db_update = None
        self.db_file = os.environ['TRAM_ROOT'] + '/data/'
        self.config = Config()
        self.force_update = False
        self.status = [('not running', str(datetime.datetime.now()))]

        self.db = MpkDb()
        self.przystanki_db = PrzystankiDb()
        self.mpk_link = self.config['get_db_link']
        self.mpk_point_data = self.config['get_point_data_link']
        self.headers = self.config['mpk_headers']
        self.httpclient = AsyncHTTPClient()
        YieldPeriodicCallback.__init__(self,
                                       self.run,
                                       self.config['ttworker_refresh_period'] *
                                       60000,
                                       faststart=True)
        self.update_status('TTworker initialised')

    def update_status(self, message):
        self.status.insert(0, (message, str(datetime.datetime.now())))

    def get_status(self, number=1):
        self.update_status('get_status requested')
        return self.status[0:number]

    @coroutine
    def get_new_db(self, res):
        self.update_status('fetching new db version')
        if (self.force_update or
           self.last_db_update is None or
           datetime.datetime.now() - self.last_db_update > datetime.timedelta(hours=23)) and \
           self.config['refresh_przystanki_db']:
            self.force_update = False
            self.last_db_update = datetime.datetime.now()
            zwrotka = json.loads(res.body.decode('utf-8'))
            logging.info("got %s from mpk", zwrotka)
            logging.info("downloading new db")
            self.update_status('downloading db')
            baza = yield self.httpclient.fetch(zwrotka['d'],
                                               request_timeout=600)
            self.update_status('saving db')
            with open(self.db_file + 'baza.zip', 'wb') as f:
                f.write(baza.body)
            logging.info("new db saved")
            self.update_status('db saved')
            os.rename(self.db_file + 'baza.zip',
                      self.db_file + 'baza.ready.zip')
        else:
            logging.info("24 hours didnt pass")

    @coroutine
    def push_to_przystanki(self, body, res):
        print(res)
        data = json.loads(res.body.decode('utf-8'))['d']
        to_push = {
            'pointId': body['pointId'],
            'pointName': data['StopName'],
            'variantId': body['variantId'],
            'lineName': data['LineName'],
            'pointTime': json.dumps(data['PointTime']),
            'route': data['Route']
        }
        self.przystanki_db.insert(to_push)

    @coroutine
    def fill_przystanki_db(self, lines):
        if self.config['refresh_przystanki_db'] is False:
            logging.info('przystanki.db untouched due to config')
            return
        self.update_status('filling przystanki db cache')
        self.przystanki_db.clear_table()
        for line in lines:
            logging.info('fetching line %s', line)
            self.update_status('fetching line %s' % line)
            line_points = self.db.get_line_points(line)
            for variant, points in line_points.items():
                for point in points:
                    body = {
                        "variantId": variant,
                        "pointId": point,
                        "lineName": line
                    }
                    cb = functools.partial(self.push_to_przystanki, body)
                    yield self.httpclient.fetch(self.mpk_point_data,
                                                cb,
                                                method='POST',
                                                body=json.dumps(body),
                                                headers=self.headers)

    @coroutine
    def get_line_position(self, line, time=time.strftime('%-H:%M')):
        pass

    @coroutine
    def run(self):
        logging.info("running for %d time" % self.number)
        self.update_status("running for %d time" % self.number)
        yield self.httpclient.fetch(self.mpk_link,
                                    self.get_new_db,
                                    method='POST',
                                    body=urllib.parse.urlencode({}),
                                    headers=self.headers)
        yield self.fill_przystanki_db(self.config.get('lines'))
        self.number += 1