예제 #1
0
def setUp():
    drop_all_tables()  # ensure everything is empty
    create_db()
    session = get_session(env='TST', echo=True)
    # load to satisfy fk constraints
    load_regions_and_stations(session)
    session.close()
예제 #2
0
 def test_update_captures_username(self):
     ''' When a db record is updated,
         modified_by should be changed to show who it was'''
     # run a load (as bikeshare_tst)
     session = get_session(env='TST', echo=True)
     etl(Station_Information, session)
     conn, cur = get_connection_and_cursor()
     cur.execute('SELECT station_id FROM station_information;')
     row_id = cur.fetchone()
     # connect as different user
     conn, cur = get_connection_and_cursor(user='******',
                                           pw=environ['POSTGRES_PW_TST'])
     cur.execute(
         '''UPDATE station_information
                    SET station_name = 'not an actual station'
                    WHERE station_id = (%s);''', row_id)
     conn.commit()
     cur.execute('SELECT station_id, modified_by FROM station_information')
     db_data = cur.fetchall()
     correct = True
     i = 0
     while i < len(db_data) and correct:
         row = db_data[i]
         if row[0] == row_id[0]:
             user = '******'
         else:
             user = environ['POSTGRES_USER_TST']
         correct = (row[1] == user)
         if not correct:
             print(f'user should be {user} but is actually {row[1]}')
         i += 1
     self.assertTrue(correct)
     empty_table_station_information(session)
     session.close()
예제 #3
0
 def test_load_data_station_status(self):
     ''' Load data, then pull it from db and ensure match '''
     session = get_session(env='TST', echo=False)
     data = get_data_from_api()
     print(f'rows pulled: {len(data["data"]["stations"])}')
     # tstmp needed for creating Station_Status obj
     last_updated = data['last_updated']
     out = []
     comp = {}
     for row in data['data']['stations']:
         row['last_updated'] = last_updated
         record = Station_Status(row)
         out.append(record)
         comp[record.station_id] = record
     load_db(out, session)
     # connect and get all data from db
     conn, cur = get_connection_and_cursor()
     cur.execute(select_all_stmt())
     db_data = cur.fetchall()
     all_match = True
     i = 0
     if len(db_data) != len(comp):
         all_match = False
         print('lens dont match.')
         print(f'db: {len(db_data)}')
         print(f'comp: {len(comp)}')
     while i < len(db_data) and all_match:
         row = db_data[i]
         if row != comp[row[2]].to_tuple():
             all_match = False
         i += 1
     self.assertTrue(all_match)
     empty_station_status_table(session)
     session.close()
예제 #4
0
 def test_get_data_returns_dict(self):
     ''' ensure we get a dict of Station_Information back '''
     session = get_session(env='TST', echo=True)
     metadata = create_metadata(Station_Information, session)
     data = get_data(Station_Information, metadata)
     self.assertIsInstance(data, dict)
     session.close()
예제 #5
0
 def test_load_on_empty(self):
     ''' all records pulled down should be loaded into db '''
     session = get_session(env='TST', echo=True)
     metadata = create_metadata(System_Region, session)
     data = get_data(System_Region, metadata)
     compare_data(data, System_Region, metadata, session)
     session.commit()
     # get data from db
     conn, cur = get_connection_and_cursor()
     cur.execute('''SELECT region_id,
                    region_name,
                    row_modified_tstmp,
                    load_id,
                    transtype,
                    modified_by
                    FROM system_regions''')
     db_data = cur.fetchall()
     correct = True
     # compare each val of each row from db to each row pulled from web
     i = 0
     while i < len(db_data) and correct:
         row = db_data[i]
         orig = data[row[0]].to_tuple()
         correct = (row == orig)
         if not correct:
             print(f'row {row} didnt match orig {orig}')
         i += 1
     self.assertTrue(correct)
     empty_db(session)
     session.close()
예제 #6
0
def setUp():
    drop_all_tables()  # ensure everything is empty
    create_db()
    session = get_session(env='TST', echo=True)
    load_regions(session)
    # load one dummy region with region_id 9999
    load_single_region(create_dummy_region(), session)
    session.close()
예제 #7
0
 def test_get_data_dict_key_is_id(self):
     ''' ensure dict key is row.id'''
     session = get_session(env='TST', echo=True)
     metadata = create_metadata(Station_Information, session)
     data = get_data(Station_Information, metadata)
     correct = True
     for row in data:
         if row != data[row].id:
             correct = False
     self.assertTrue(correct)
     session.close()
예제 #8
0
 def test_get_latest_data(self):
     ''' Ensure get_latest_data_from_db brings back actual latest '''
     session = get_session(env='TST', echo=True)
     # create a dummy record and load it
     d1 = create_dummy_data(time())
     load_single_status_row(d1, session)
     # create another dummy with a later tstmp
     d2 = create_dummy_data(time() + 10)
     load_single_status_row(d2, session)
     # get latest (dict with station_id as key)
     latest = get_latest_from_db(session)
     self.assertEqual(d2, latest[d2.station_id])
     empty_station_status_table(session)
     session.close()
예제 #9
0
 def test_comparison(self):
     ''' Ensure only changed data is loaded.'''
     session = get_session(env='TST', echo=True)
     # load first dummy file on empty db
     data_d1 = get_json_from_file('tests/dummy1.json')
     last_updated = data_d1['last_updated']
     d1_objs = []
     for row in data_d1['data']['stations']:
         row['last_updated'] = last_updated
         session.add(Station_Status(row))
     session.commit()
     # compare and load file 2
     latest = get_latest_from_db(session)
     data_d2 = get_json_from_file('tests/dummy2.json')
     out, latest = get_changed_data(data_d2, latest)
     load_db(out, session)
     # compare and load file 3
     data_d3 = get_json_from_file('tests/dummy3.json')
     out, latest = get_changed_data(data_d3, latest)
     load_db(out, session)
     # pull all data from db.
     # should match exactly data in desired file
     conn, cur = get_connection_and_cursor()
     cur.execute(select_all_stmt())
     db_data = cur.fetchall()
     desired = get_json_from_file('tests/desired.json')
     all_match = True
     if len(db_data) != len(desired['stations']):
         all_match = False
         print(f"{db_data} should be same len as {desired['stations']}")
         print(f"{len(db_data)} != {len(desired['stations'])}")
     i = 0
     while i < len(db_data) and all_match:
         row = db_data[i]
         should_be = Station_Status(desired['stations'][i]).to_tuple()
         # not worried about first and last vals (auto id and username)
         print(row)
         print(should_be)
         all_match = row[1:-1] == should_be[1:-1]
         if not all_match:
             print(f'row {row} should match row {should_be}')
             print('except first val which is auto incremented id')
         i += 1
     self.assertTrue(all_match)
     empty_station_status_table(session)
     session.close()
예제 #10
0
 def test_load_on_empty(self):
     ''' all records pulled down should be loaded into db '''
     session = get_session(env='TST', echo=True)
     metadata = create_metadata(Station_Information, session)
     data = get_data(Station_Information, metadata)
     compare_data(data, Station_Information, metadata, session)
     session.commit()
     # get data from db
     conn, cur = get_connection_and_cursor()
     cur.execute('''SELECT station_id,
                    short_name,
                    station_name,
                    lat,
                    lon,
                    capacity,
                    region_id,
                    eightd_has_key_dispenser,
                    rental_method_key,
                    rental_method_creditcard,
                    rental_method_paypass,
                    rental_method_applepay,
                    rental_method_androidpay,
                    rental_method_transitcard,
                    rental_method_accountnumber,
                    rental_method_phone,
                    row_modified_tstmp,
                    load_id,
                    transtype,
                    modified_by
                    FROM station_information;''')
     db_data = cur.fetchall()
     correct = True
     # compare each val of each row from db to each row pulled from web
     i = 0
     while i < len(db_data) and correct:
         row = db_data[i]
         orig = data[row[0]].to_tuple()
         correct = (row == orig)
         if not correct:
             print(f'row {row} didnt match orig {orig}')
         i += 1
     self.assertTrue(correct)
     empty_table_station_information(session)
     session.close()
def test_get_session():
    """Test the function get_session."""
    session = get_session()
    assert session is not None
예제 #12
0
    """
    from src import sqla
    import csv

    f = utils.open_data_file("bug_summary.csv")
    f.readline()
    reader = csv.reader(f)
    for row in reader:
        br = sqla.BugRow(row)

        bug = session.query(Bug).filter_by(bzid=br.bug).scalar()
        bug.reported = br.reported

    session.commit()

session = utils.get_session()

##### A bunch of silly little convenience methods intended to be run interactively (ideally only once) ######
def pop_bs():
    from src.bugstate import populate_bugstates
    populate_bugstates(session, False, 15)

def add_ids():
    from src.bugs import Bug
    Bug.add_assignee_ids(session)

def add_comments(interval=0):
    from src.bug_events import BugEvent
    BugEvent.scrape_comment_events(session, interval)

예제 #13
0
 def test_get_session_returns_Session(self):
     self.assertIsInstance(get_session(env='TST'), Session)
예제 #14
0
 def __init__(self, cls):
     self.cls = cls
     self.session = get_session()
예제 #15
0
 def test_update_only_updates_that_record(self):
     ''' load, then load an update. ensure only that record was updated'''
     # first get data and load
     session = get_session(env='TST', echo=True)
     metadata = create_metadata(Station_Information, session)
     data = get_data(Station_Information, metadata)
     compare_data(data, Station_Information, metadata, session)
     session.commit()
     # get tuple copies of each record that was loaded
     originals = {}
     for row in data:
         originals[row] = data[row].to_tuple()
     # now get data again
     m2 = create_metadata(Station_Information, session)
     d2 = get_data(Station_Information, m2)
     # get one record from data and make a change
     u_record = d2[list(d2.keys())[0]]
     u_record.station_name = 'phoney balogna'
     u_data = {u_record.id: u_record}
     # load new record (should be update)
     compare_data(u_data, Station_Information, metadata, session)
     session.commit()
     # get all current data from db
     conn, cur = get_connection_and_cursor()
     cur.execute('''SELECT station_id,
                    short_name,
                    station_name,
                    lat,
                    lon,
                    capacity,
                    region_id,
                    eightd_has_key_dispenser,
                    rental_method_key,
                    rental_method_creditcard,
                    rental_method_paypass,
                    rental_method_applepay,
                    rental_method_androidpay,
                    rental_method_transitcard,
                    rental_method_accountnumber,
                    rental_method_phone,
                    row_modified_tstmp,
                    load_id,
                    transtype,
                    modified_by
                    FROM station_information;''')
     db_data = cur.fetchall()
     row_updated = True
     rows_match = True
     correct_trans = True
     # iterate through db data. break if any test fails
     i = 0
     while i < len(db_data) and\
             row_updated and\
             rows_match and\
             correct_trans:
         row = db_data[i]
         orig = originals[row[0]]
         if row[0] == u_record.id:
             trans = 'U'
             # update orig should not match row and ensure name was updated
             row_updated = (orig != row) and (row[2] == 'phoney balogna')
         else:
             # all other records row should match orig
             trans = 'I'
             rows_match = (orig == row)
         correct_trans = (row[-2] == trans)
         if not row_updated:
             print(f'row {row} shouldnt match orig {orig}')
             print('also, region_name in row should be phoney balogna')
         if not rows_match:
             print(f'row {row} didnt match orig {orig}')
         if not correct_trans:
             print(f'incorrect trans on row {row} -- should be {trans}')
         i += 1
     self.assertTrue(row_updated and rows_match and correct_trans)
     empty_table_station_information(session)
     session.close()
예제 #16
0
def setUp():
    drop_all_tables()  # ensure everything is empty
    create_db()
    session = get_session(env='TST', echo=True)
    session.close()