def __init__(self, eventbrite_org, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = Database(database=database) self.eventbrite = Eventbrite() self.eventbrite_org = eventbrite_org
def initialize(drop_views=False): """ Initializes the tables for the dashboard """ database = Database() LOGGER.info('Initializing the database ...') database.initialize(drop_views=drop_views) LOGGER.info('Loading member ids into participant match table ..') name_resolver = NameResolver(database=database) name_resolver.load_member_ids()
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = Database() if not database else database self.name_resolver = NameResolver(database=self.database) self.participants = Participants(database=self.database) self.avg_event_age = {}
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.path = os.path.dirname(os.path.realpath(__file__)) self.database = Database() if not database else database self.lookup = self._read_names_file() self.average_age = None
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = Database(database=database) self.path = os.path.dirname(os.path.realpath(__file__)) self.url = 'https://www.zip-codes.com/cache/kml-zip/' self.search = SearchEngine(simple_zipcode=True) self.zip_code_cache = {}
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) # Load column mapping configs self.path = os.path.dirname(os.path.realpath(__file__)) filename = self.path + '/member_columns.yml' with open(filename, 'r') as f: self.column_mapping = yaml.safe_load(f) self.database = Database() if not database else database self.fake_news = FakeNews(database=self.database)
def __init__(self): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = Database() self.events_manager = Events() self.network = None self.metrics = { 'node_connectivity': self.node_connectivity, 'edge_connectivity': self.edge_connectivity, 'density': self.density, 'membership_scaled_density': self.membership_scaled_density }
def log_request(request, user, authorized, database=None): """Logs the API request to the database for security monitoring and analyzing user metrics Parameters ---------- request : request The flask request object for the API call user : str The user who made the API call. Pulled from the JWT. authorized : boolean Indicates whether of not the user was authorized to access the end point. database : shir_connect.database.database A Shir Connect database object. Primarily used for testing Returns ------- Logs information about the request to the Postgres database. """ # Don't write logs to the table during unit tests or development if conf.SHIR_CONNECT_ENV in ['DEV', 'TEST']: return None else: if not database: database = Database(database='postgres', schema='application_logs') # By default, the remote_addr attribute on the Flask request object # if the IP address of the referrer, which in our case is NGINX. We # configure NGINX for put the real remote_addr in the header so we're # able to track it. remote_addr = request.environ.get('HTTP_X_FORWARDED_FOR', request.remote_addr) item = { 'id': uuid.uuid4().hex, 'application_user': user, 'authorized': authorized, 'base_url': request.base_url, 'endpoint': request.endpoint, 'host': request.host, 'host_url': request.host_url, 'query_string': request.query_string.decode('utf-8'), 'referrer': request.referrer, 'remote_addr': remote_addr, 'scheme': request.scheme, 'url': request.url, 'url_root': request.url_root, 'user_agent': str(request.user_agent), 'load_datetime': datetime.datetime.now() } database.load_item(item, 'shir_connect_logs')
def map_authorize(): """ Checks whether the users is authorized to view the map """ database = Database() jwt_user = get_jwt_identity() user = database.get_item('users', jwt_user) authorized = conf.MAP_GROUP in user['modules'] log_request(request, jwt_user, authorized) if not authorized: response = {'message': '%s does not have access the map' % (jwt_user)} return jsonify(response), 403 else: response = {'message': '%s is authorized to view the map' % (jwt_user)} return jsonify(response), 200
def count_bad_login_attempts(user, domain, reset_date): """Counts the number of bad login attempts the user has made on the specified domain. This is used to put a lock on the acccount if they have made too many bad authentication requests. Paramters --------- user: string, the domain of the user domain: string, the prefix for the host url. For https:/dev.shirconnect.com, the domain would be 'dev' reset_date: string, the date when the user last set their password. The bad login count should return to zero after a password reset Returns ------- count: int, the number of bad login attempts """ database = Database(database='postgres', schema='application_logs') sql = """ SELECT COUNT(id) AS bad_login_attempts FROM application_logs.shir_connect_logs WHERE application_user = '******' AND host = '{domain}.shirconnect.com' AND authorized = FALSE AND load_datetime > NOW() - INTERVAL '1 DAY' AND load_datetime > '{reset_date}' """.format(user=user, domain=domain, reset_date=reset_date) df = pd.read_sql(sql, database.connection) bad_login_attempts = df.loc[0]['bad_login_attempts'] return int(bad_login_attempts)
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = database if database else Database() self.access_groups = conf.ACCESS_GROUPS self.user_roles = conf.USER_ROLES
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = database if database else Database() self.member_loader = MemberLoader(database=database) self.allowed_extensions = conf.ALLOWED_EXTENSIONS
def member_authorize(): """ Checks to see if the user is authorized to see members """ database = Database() jwt_user = get_jwt_identity() user = database.get_item('users', jwt_user) authorized = conf.TRENDS_GROUP in user['modules'] log_request(request, jwt_user, authorized) if not authorized: response = { 'message': '%s does not have access to trends' % (jwt_user) } return jsonify(response), 403 else: del user['password'] del user['temporary_password'] return jsonify(user), 200
def test_read_table(): database = Database() df = database.read_table('event_aggregates', limit=10) assert len(df) == 10 df = database.read_table( 'event_aggregates', query=('name', 'Rodef 2100'), where=[('start_datetime',{'>=': "'2018-01-01'"})], limit=10 ) assert len(df) > 0 for i in df.index: row = df.loc[i] assert str(row['start_datetime']) >= '2018-01-01' assert '2100' in row['name'] or 'rodef' in row['name'].lower() count = database.count_rows('event_aggregates', query=('name', 'Rodef 2100')) assert count > 0
def export_event_aggregates(): """ Exports the event aggregates as a csv """ database = Database() # Make sure the user has access to the module jwt_user = get_jwt_identity() user = database.get_item('users', jwt_user) authorized = conf.EVENT_GROUP in user['modules'] log_request(request, jwt_user, authorized) if not authorized: response = {'message': '%s does not have acccess to events'%(jwt_user)} return jsonify(response), 403 q = request.args.get('q') query = ('name', q) if q else None database = Database() df = database.read_table('event_aggregates', query=query) today = str(datetime.datetime.now())[:10] buffer = StringIO() df.to_csv(buffer, encoding='utf-8', index=False) output = make_response(buffer.getvalue()) output.headers["Content-Disposition"] = "attachment; filename=export.csv" output.headers["Content-type"] = "text/csv" return output
def initialize_log_table(): """Builds the table in the postgres database that is used for storing application logs.""" database = Database(database='postgres') LOGGER.info('Creating the application_logs schema ...') schema_sql = "CREATE SCHEMA IF NOT EXISTS application_logs" database.run_query(schema_sql) table_sql = """ CREATE TABLE IF NOT EXISTS application_logs.shir_connect_logs ( id text, application_user text, authorized boolean, base_url text, endpoint text, host text, host_url text, query_string text, referrer text, remote_addr text, scheme text, url text, url_root text, user_agent text, load_datetime timestamp ) """ LOGGER.info('Creating the shir_connect_logs table ...') database.run_query(table_sql)
def get_new_members(): """Pulls in a list of the most recent members of the Congregation.""" database = Database() jwt_user = get_jwt_identity() authorized = utils.check_access(jwt_user, conf.REPORT_GROUP, database) utils.log_request(request, jwt_user, authorized) if not authorized: response = { 'message': '{} does not have access to reports.'.format(jwt_user) } return jsonify(response), 403 limit_param = request.args.get('limit') limit = limit_param if limit_param else 25 new_members = database.read_table('members_view', limit=limit, order='desc', sort='membership_date') response = database.to_json(new_members) return jsonify(response)
def build_email_content(): """Builds the email with the weekly usage statistics.""" database = Database(database='postgres', schema='application_logs') all_time_sql = """ SELECT COUNT(*) as total, application_user, remote_addr, host, user_agent FROM application_logs.shir_connect_logs WHERE load_datetime >= '2018-08-01' AND application_user <> 'true' GROUP BY application_user, host, remote_addr, user_agent ORDER BY application_user ASC """ all_time_stats = pd.read_sql(all_time_sql, database.connection) all_time_html = all_time_stats.to_html() weekly_sql = """ SELECT COUNT(*) as total, application_user, remote_addr, host, user_agent FROM application_logs.shir_connect_logs WHERE load_datetime >= NOW() - INTERVAL '7 DAYS' GROUP BY application_user, host, remote_addr, user_agent ORDER BY application_user ASC """ weekly_stats = pd.read_sql(weekly_sql, database.connection) weekly_html = weekly_stats.to_html() html = """ <h3>Usage Statistics</h3> <p>Greetings Fiddlers! Here are the latest usage statistics for Shir Connect.</p> <h4>Overall Usage</h4> {all_time_html} <h4>Weekly Usage</h4> {weekly_html} """.format(all_time_html=all_time_html, weekly_html=weekly_html) return html
def add_fake_names(): """Adds fake names that can be used in demo mode.""" database = Database(database='trs') fake_news = FakeNews(database=database) fake_news.build_fake_data()
def match_participants(): """Runs the fuzzy matching algorithm to match up attendees and members.""" database = Database(database='trs') participant_matcher = ParticipantMatcher(database=database) participant_matcher.run() participant_matcher.estimate_unknown_ages()
def refresh_materialized_views(): """Refreshes the materialized views for Shir Connect.""" database = Database(database='trs') database.refresh_views() print('Materialized views have been refreshed!')
def test_initialize(): database = Database() database.initialize() assert database.connection.status == 1
class ParticipantMatcher: def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = Database() if not database else database self.name_resolver = NameResolver(database=self.database) self.participants = Participants(database=self.database) self.avg_event_age = {} def run(self, limit=1000, iters=None): """Adds attendees that have not been matched up to a participant to the look up table. If there are no matches for an attendee, a new participant id is created.""" n = 0 while True: missing_attendees = self._get_missing_attendees(limit=limit) count = len(missing_attendees) if (iters and n >= iters) or count == 0: msg = 'Participant matcher has finished processing.' self.logger.info(msg) break msg = 'Iteration {} | Processing {} missing attendees' self.logger.info(msg.format(n + 1, count)) for i in range(count): attendee = dict(missing_attendees.loc[i]) if attendee['first_name'] and attendee['last_name']: self._process_attendee(attendee) n += 1 def estimate_unknown_ages(self): """Finds estimated ages for any participant whose age is unknown or who has an estimated age.""" unknowns = self._get_unknown_ages() for i, unknown in enumerate(unknowns): if i % 1000 == 0: msg = 'Estimated age for {} participants.'.format(i) self.logger.info(msg) estimated_age = self._estimate_participant_age(unknown['id']) if not estimated_age: continue now = datetime.datetime.now() estimated_birth_date = now - datetime.timedelta( estimated_age * 365) estimated_birth_date = "'{}'".format( str(estimated_birth_date)[:10]) self.database.update_column(table='participant_match', item_id=unknown['id'], column='birth_date', value=estimated_birth_date) self.database.update_column(table='participant_match', item_id=unknown['id'], column='is_birth_date_estimated', value=True) def _process_attendee(self, attendee): """Adds a link to attendee_to_participant if the attendee has a match. Otherwise a new participant id is created for the attendee.""" # Cache the average age for the event so it # doesn't have to pull it from the database each time event_id = attendee['event_id'] if event_id not in self.avg_event_age: age = self._get_avg_event_age(event_id) self.avg_event_age[event_id] = age else: age = self.avg_event_age[event_id] match = self.name_resolver.find_best_match( first_name=attendee['first_name'], last_name=attendee['last_name'], email=attendee['email'], age=age) if match: participant_id = match['id'] else: # If there is no participant match, a new participant # is created and added to the database participant_id = uuid.uuid4().hex participant = { 'id': participant_id, 'first_name': attendee['first_name'], 'last_name': attendee['last_name'], 'email': attendee['email'] } self.database.load_item(participant, 'participant_match') # Insert the attendee to participant match to the database item = {'id': attendee['id'], 'participant_id': participant_id} self.database.load_item(item, 'attendee_to_participant') def _get_missing_attendees(self, limit=1000): """Pulls a list of attendees that have not yet been matched to a participant.""" sql = """ SELECT id, event_id, first_name, last_name, email FROM {schema}.attendees WHERE id NOT IN (SELECT id FROM {schema}.attendee_to_participant) AND first_name IS NOT NULL AND last_name IS NOT NULL ORDER BY event_id ASC LIMIT {limit} """.format(schema=self.database.schema, limit=limit) df = pd.read_sql(sql, self.database.connection) return df def _get_avg_event_age(self, event_id): """Computes the average age of the attendees of an event.""" if not isinstance(event_id, list): event_id = [str(event_id)] else: event_id = [str(x) for x in event_id] sql = """ SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY avg_age) as avg_age FROM( SELECT event_id, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY z.age) as avg_age FROM {schema}.attendees x INNER JOIN {schema}.attendee_to_participant y ON x.id = y.id INNER JOIN {schema}.participants z ON y.participant_id = z.participant_id WHERE event_id = ANY(%(event_id)s) AND z.age IS NOT NULL GROUP BY event_id ) a """.format(schema=self.database.schema, event_id=event_id) df = self.database.fetch_df(sql, params={'event_id': event_id}) avg_age = None if len(df) > 0: avg_age = df.loc[0]['avg_age'] return avg_age def _estimate_participant_age(self, participant_id): """Estimates a participants age based on who they've attended events with.""" events = self.participants.get_participant_events(participant_id) if len(events) == 0: return None else: event_id = [x['event_id'] for x in events] age = self._get_avg_event_age(event_id) return age def _get_unknown_ages(self): """Pulls all participant ids that have a null date or and estimated date.""" sql = """ SELECT id FROM {schema}.participant_match WHERE is_birth_date_estimated = TRUE OR birth_date IS NULL """.format(schema=self.database.schema) df = pd.read_sql(sql, self.database.connection) results = self.database.to_json(df) return results
def test_load_items(): database = Database() database.delete_item('members', 'testid1') database.delete_item('members', 'testid2') columns = database.get_columns('members') item1 = {x: None for x in columns} item1['id'] = 'testid1' item2 = {x: None for x in columns} item2['id'] = 'testid2' items = [item1, item2] database.load_items(items, 'members') item1_ = database.get_item('members', 'testid1') assert item1_['id'] == 'testid1' item2_ = database.get_item('members', 'testid2') assert item2_['id'] == 'testid2' database.delete_item('members', 'testid1') item1_ = database.get_item('members', 'testid1') assert item1_ == None database.delete_item('members', 'testid2') item2_ = database.get_item('members', 'testid2') assert item2_ == None
class EventbriteLoader(object): """Loads data from Eventbrite into Postgres """ def __init__(self, eventbrite_org, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = Database(database=database) self.eventbrite = Eventbrite() self.eventbrite_org = eventbrite_org def run(self, test=False): """ Runs the data load process """ last_load_date = self.database.last_event_load_date() if last_load_date: look_back = datetime.datetime.now() - datetime.timedelta(days=1) first_event = min(look_back, last_load_date) start = str(first_event)[:10] self.logger.info('Loading events starting at %s' % (start)) else: self.logger.info('Loading events from the first available event') start = None events = self.get_events(start=start, page=1) num_events = events['pagination']['object_count'] if num_events > 0: self.logger.info('There are %s events to process' % (num_events)) else: self.logger.info('There are not next events. Exiting') return more_events = True while more_events: for event in events['events']: if not event: continue msg = "Loading information for %s" % (event['name']['text']) self.logger.info(msg) # Load the event into the database. Delete the current # entry in order to maintain the unique index event_id = event['id'] if not test: self.database.delete_item('events', event_id) self.load_event(event) # Load the venue, if it does not already # appear in the database venue_id = event['venue_id'] venue_ = self.database.get_item('venues', venue_id) if venue_id and not venue_: venue = self.get_venue(venue_id) if not test: self.load_venue(venue) attendees = self.get_attendees(event_id, page=1) more_attendees = True while more_attendees: if not attendees: break for attendee in attendees['attendees']: if not attendee: continue if not test: self.database.delete_item('attendees', attendee['id'], {'event_id': event_id}) self.load_attendee(attendee) if test or not attendees['pagination']['has_more_items']: more_attendees = False break else: page = attendees['pagination']['page_number'] + 1 attendees = self.get_attendees(event_id, page) # Sleep to avoid the Eventbrite rate limit if test: return else: time.sleep(60) if not events['pagination']['has_more_items']: more_events = False break else: page = events['pagination']['page_number'] + 1 msg = 'Pulling events on page %s' % (page) self.logger.info(msg) events = self.get_events(start, page) def get_events(self, start, page=1): """ Pulls events from eventbrite and sleeps if the rate limit has been exceeded """ org_id = self.eventbrite_org events = self.eventbrite.get_events(org_id=org_id, start=start, page=page) if not events: # Sleep until eventbrite resets self.logger.info('Rate limit exceed. Sleeping 30 mins') time.sleep(3600) events = self.eventbrite.get_events(start=start, page=page) return events def get_attendees(self, event_id, page=1): """ Pulls attendees from eventbrite and sleeps if the rate limit has been exceeded """ attendees = self.eventbrite.get_attendees(event_id, page) if not attendees: # If events comes back as none, sleep until the # Eventbrite rate limit resets self.logger.info('Rate limit exceed. Sleeping 30 mins') time.sleep(3600) attendees = self.eventbrite.get_attendees(event_id, page) return attendees def get_venue(self, venue_id, page=1): """ Pull a venue and sleeps if the rate limit has been exceeded """ venue = self.eventbrite.get_venue(venue_id, page) if not venue: self.logger.info('Rate limit exceed. Sleeping 30 mins') time.sleep(3600) venue = self.eventbrite.get_venue(event_id, page) return venue def load_event(self, event): """ Loads an event into the database """ event_ = deepcopy(event) start = arrow.get(event_['start']['utc']).datetime event_['start_datetime'] = start end = arrow.get(event_['end']['utc']).datetime event_['end_datetime'] = end description = event_['description']['text'] event_['description'] = description name = event_['name']['text'] event_['name'] = name event_['load_datetime'] = datetime.datetime.utcnow() self.database.load_item(event_, 'events') def load_attendee(self, attendee): """ Loads an attendee into the database """ attendee_ = deepcopy(attendee) profile = attendee_['profile'] if 'name' in profile: attendee_['name'] = profile['name'] if 'first_name' in profile: attendee_['first_name'] = profile['first_name'] if 'last_name' in profile: attendee_['last_name'] = profile['last_name'] if 'email' in profile: attendee_['email'] = profile['email'] cost = attendee_['costs']['gross']['major_value'] attendee_['cost'] = float(cost) attendee_['load_datetime'] = datetime.datetime.utcnow() self.database.load_item(attendee_, 'attendees') def load_order(self, order): """ Loads an order into the database """ order_ = deepcopy(order) cost = order_['costs']['gross']['major_value'] order_['cost'] = float(cost) order_['load_datetime'] = datetime.datetime.utcnow() self.database.load_item(order_, 'orders') def load_venue(self, venue): """ Loads a venue into the database """ venue_ = deepcopy(venue) for key in venue_['address']: val = venue_['address'][key] venue_[key] = val venue_['latitude'] = float(venue_['latitude']) venue_['longitude'] = float(venue_['longitude']) self.database.load_item(venue_, 'venues')
def test_get_columns(): database = Database() columns = database.get_columns('events') assert len(columns) > 0
class MM2000: def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) # Load column mapping configs self.path = os.path.dirname(os.path.realpath(__file__)) filename = self.path + '/member_columns.yml' with open(filename, 'r') as f: self.column_mapping = yaml.safe_load(f) self.database = Database() if not database else database self.fake_news = FakeNews(database=self.database) ##################################### # Methods for loading MM2000 members ##################################### def load(self, df): """ Loads the data in to the member database """ self.logger.info('Parsing MM2000 data.') items = self.parse_mm2000(df) self.logger.info('Backing up current member table.') self.database.backup_table('members') self.logger.info('Truncating current member table.') self.database.truncate_table('members') self.logger.info('Loading updated member data.') for item in items: self.database.load_item(item, 'members') self.logger.info('Checking updated columns.') good_columns = self.check_columns() if good_columns: self.logger.info('Generating demo data') self.fake_news.fake_names() self.logger.info('Refreshing materialized views.') self.database.refresh_view('members_view') self.database.refresh_view('participants') else: self.logger.warning('Column mismatch in upload') self.database.revert_table('members') return False return True def parse_mm2000(self, df): """ Converts the MM2000 export into a list of rows """ column_mapping = self.column_mapping['MM2000'] items = [] for group in column_mapping: column_map = column_mapping[group]['columns'] df_group = _group_mm2000(df, column_map) if 'id_extension' in column_mapping[group]: id_extension = column_mapping[group]['id_extension'] else: id_extension = None for i in df_group.index: item = dict(df_group.loc[i]) item = _parse_postal_code(item) item = _check_mm2000_active(item) # ID extension for children and spouses # since a family shares the same id item['household_id'] = item['id'] if id_extension: item['id'] += id_extension # Remove invalid birthdates item = _parse_mm2000_date(item, 'birth_date') item = _parse_mm2000_date(item, 'membership_date') # Skip if the member is under the minimum age # that we keep in the database too_young = utils.check_age(item['birth_date'], min_age=18) if too_young: continue # Children only have a full name, not separate # first names and last name if 'first_name' not in item and item['full_name']: item['first_name'] = item['full_name'].split()[0] if 'last_name' not in item and item['full_name']: item['last_name'] = item['full_name'].split()[0] if not item['first_name'] or not item['last_name']: continue else: items.append(item) return items def check_columns(self): """ Checks to make sure the columns are the same in the new table """ new_columns = self.database.get_columns('members') old_columns = self.database.get_columns('members_backup') for column in new_columns: if column not in old_columns: return False return True ########################################### # Methods for handling MM2000 resignations ########################################### def load_resignations(self, df): """Loads MM2000 resignation data into the database.""" _validate_resignation_data(df) # Map the file column names to the databse column names df = df.rename(columns=self.column_mapping['MM2000 Resignations']) # Drop any rows where the resignation date is null df = df.dropna(axis=0, how='any', subset=['resignation_date']) for i in df.index: member = dict(df.loc[i]) member = _parse_mm2000_date(member, 'resignation_date') resignation_date = str(member['resignation_date'])[:10] # TODO: This logic is specific to TRS because that's how they # track people who rejoined the congregation. We may have to # update this if another client uses MM2000 if 'Comment1' in member: if 'rejoin' in str(member['Comment1']).lower(): resignation_date = None if 'Comment2' in member: if 'rejoin' in str(member['Comment2']).lower(): resignation_date = None if resignation_date: resignation_date = "'{}'".format(resignation_date) sql = """ UPDATE {schema}.members SET resignation_date = {resignation_date} WHERE (household_id = '{member_id}' OR id = '{member_id}') """.format(schema=self.database.schema, resignation_date=resignation_date, member_id=member['id']) self.database.run_query(sql) reason = _find_resignation_reason(member['resignation_reason']) sql = """ UPDATE {schema}.members SET resignation_reason = '{reason}' WHERE (household_id = '{member_id}' OR id = '{member_id}') """.format(schema=self.database.schema, reason=reason, member_id=member['id']) self.database.run_query(sql) self.database.refresh_views()
def refresh_views(): """Refreshes the materialized views for the dashboard """ database = Database() LOGGER.info('Refreshing materialized views ...') database.refresh_views()
def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.database = database if database else Database()
class NameResolver(): """Resolves the names of participants using participant characteristics.""" def __init__(self, database=None): daiquiri.setup(level=logging.INFO) self.logger = daiquiri.getLogger(__name__) self.path = os.path.dirname(os.path.realpath(__file__)) self.database = Database() if not database else database self.lookup = self._read_names_file() self.average_age = None def load_member_ids(self): """Loads member information into the participant match table. Only loads names that have already been loaded into the database. """ sql = """ INSERT INTO {schema}.participant_match (id, member_id, first_name, last_name, nickname, email, birth_date, is_birth_date_estimated) SELECT uuid_generate_v4(), id as member_id, first_name, last_name, nickname, email, birth_date, false FROM {schema}.members WHERE id NOT IN (SELECT member_id FROM {schema}.participant_match) """.format(schema=self.database.schema) self.database.run_query(sql) def get_fuzzy_matches(self, first_name, last_name, tolerance=1): """Returns all names from the participants table that are within edit distance tolerance of the first name and last name.""" # Add PostgreSQL escape characters first_name = first_name.replace("'", "''") last_name = last_name.replace("'", "''") select, conditions = self._first_name_sql(first_name, tolerance) sql = """ SELECT id, member_id, first_name, last_name, nickname, email, birth_date, is_birth_date_estimated FROM( SELECT *, {select} FROM {schema}.participant_match ) x WHERE ( ({conditions}) AND last_name = '{last_name}') """.format(select=select, conditions=conditions, schema=self.database.schema, first_name=first_name, last_name=last_name, tol=tolerance) df = pd.read_sql(sql, self.database.connection) results = self.database.to_json(df) return results def find_best_match(self, first_name, last_name, email=None, age=None): """Finds the best, given the criteria that is provide. If there are not matches, None will be returned.""" matches = self.get_fuzzy_matches(first_name, last_name) if not self.average_age: self.average_age = self._get_average_age() if not matches: return None else: for match in matches: if not match['birth_date'] or match['birth_date'] < 0: match['age'] = self.average_age else: match['age'] = compute_age(match['birth_date']) match_score = compute_match_score(match, first_name=first_name, email=email, age=age) match['match_score'] = match_score sorted_matches = sorted(matches, key=lambda k: k['match_score'], reverse=True) return sorted_matches[0] def _get_average_age(self): """Pulls the average participant age. Is used if there is an observation that does not have an age recorded.""" sql = """ SELECT AVG(age) as avg_age FROM( SELECT DATE_PART('year', AGE(now(), birth_date)) as age FROM {schema}.participant_match WHERE birth_date is not null ) x """.format(schema=self.database.schema) df = pd.read_sql(sql, self.database.connection) avg_age = None if len(df) > 0: avg_age = df.loc[0]['avg_age'] return avg_age def _read_names_file(self): """Reads the names.csv, which contains mappings of names to nicknames.""" filename = os.path.join(self.path, 'names.csv') lookup = collections.defaultdict(list) with open(filename) as f: reader = csv.reader(f) for line in reader: matches = set(line) for match in matches: lookup[match].append(matches) return lookup def _lookup_name(self, name): """Generates a sets of equivalent nicknames.""" name = name.lower() if name not in self.lookup: return { name } names = functools.reduce(operator.or_, self.lookup[name]) names.add(name) return names def _first_name_sql(self, first_name, tolerance=1): """Generates the select and where statments for the name fuzzy match.""" nicknames = self._lookup_name(first_name) first_name_selects = [] first_name_conditions = [] for i, name in enumerate(nicknames): col_name = "match_first_name_{}".format(i) select = " lower('{}') as {} ".format(name, col_name) first_name_selects.append(select) edit_distance = """ (levenshtein(lower(first_name), {col}) <= {tolerance} OR levenshtein(lower(nickname), {col}) <= {tolerance}) """.format(col=col_name, tolerance=tolerance) first_name_conditions.append(edit_distance) name_select = ", ".join(first_name_selects) name_conditions = " OR ".join(first_name_conditions) return name_select, name_conditions