示例#1
0
    def _insert_new_locations(self):
        """Checks for new city/country combinations and appends them to the geographic
        data table in mysql.
        """
        limit = self.test_limit if self.test else None
        with db_session(self.engine) as session:
            existing_location_ids = {
                i[0]
                for i in session.query(Geographic.id).all()
            }
            new_locations = []
            for city, country, key in (session.query(
                    self.city_col, self.country_col,
                    self.location_key_col).distinct(
                        self.location_key_col).limit(limit)):
                if key not in existing_location_ids and key is not None:
                    logging.info(f"new location {city}, {country}")
                    new_locations.append(
                        dict(id=key, city=city, country=country))
                    existing_location_ids.add(key)

        if new_locations:
            logging.warning(
                f"Adding {len(new_locations)} new locations to database")
            insert_data(self.db_config_env, "mysqldb", self.database, Base,
                        Geographic, new_locations)
示例#2
0
    def tests_insert_and_exists(self):
        data = [
            {
                "_id": 10,
                "_another_id": 2,
                "some_field": 20
            },
            {
                "_id": 10,
                "_another_id": 2,
                "some_field": 30
            },  # <--- Dupe pk, so should be ignored
            {
                "_id": 20,
                "_another_id": 2,
                "some_field": 30
            }
        ]
        objs = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base,
                           DummyModel, data)
        self.assertEqual(len(objs), 2)

        objs = insert_data("MYSQLDBCONF", "mysqldb", "production_tests", Base,
                           DummyModel, data)
        self.assertEqual(len(objs), 0)
示例#3
0
 def run(self):
     data = extract_data(limit=1000 if self.test else None)
     logging.info(f'Got {len(data)} rows')
     database = 'dev' if self.test else 'production'
     for chunk in split_batches(data, 10000):
         logging.info(f'Inserting chunk of size {len(chunk)}')
         insert_data('MYSQLDB', 'mysqldb', database,
                     Base, ApplnFamily, chunk, low_memory=True)
     self.output().touch()
示例#4
0
    def run(self):
        db = 'production' if not self.test else 'dev'

        keys = self.get_abstract_file_keys(bucket, key_prefix)
        
        engine = get_mysql_engine(self.db_config_env, 'mysqldb', db)
        with db_session(engine) as session:
            
            if self.test:
                existing_projects = set()
                projects = session.query(Projects.application_id).distinct()
                for p in projects:
                    existing_projects.update(int(p.application_id))
            
            projects_done = set()
            projects_mesh = session.query(ProjectMeshTerms.project_id).distinct()
            for p in projects_mesh:
                projects_done.update(int(p.project_id))
            
            mesh_term_ids = {int(m.id) for m in session.query(MeshTerms.id).all()}

        logging.info('Inserting associations')
        
        for key_count, key in enumerate(keys):
            if self.test and (key_count > 2):
                continue
            # collect mesh results from s3 file and groups by project id
            # each project id has set of mesh terms and corresponding term ids
            df_mesh = retrieve_mesh_terms(bucket, key)
            project_terms = self.format_mesh_terms(df_mesh)
            # go through documents
            for project_count, (project_id, terms) in enumerate(project_terms.items()):
                rows = []
                if self.test and (project_count > 2):
                    continue
                if (project_id in projects_done) or (project_id not in existing_projects):
                    continue

                for term, term_id in zip(terms['terms'], terms['ids']):
                    term_id = int(term_id)
                    # add term to mesh term table if not present
                    if term_id not in mesh_term_ids:
                        objs = insert_data(
                                self.db_config_env, 'mysqldb', db, Base, MeshTerms, 
                                [{'id': term_id, 'term': term}], low_memory=True)
                        mesh_term_ids.update({term_id})
                    # prepare row to be added to project-mesh_term link table
                    rows.append({'project_id': project_id, 'mesh_term_id': term_id})
                # inesrt rows to link table
                insert_data(self.db_config_env, 'mysqldb', db, Base, 
                        ProjectMeshTerms, rows, low_memory=True)
        self.output().touch() # populate project-mesh_term link table
示例#5
0
    def run(self):
        """Collect and process organizations, categories and long descriptions."""

        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        try_until_allowed(Base.metadata.create_all, self.engine)
        limit = 2000 if self.test else None
        batch_size = 30 if self.test else 1000

        with db_session(self.engine) as session:
            all_orgs = session.query(
                Organisation.id, Organisation.addresses).limit(limit).all()
            existing_org_location_ids = session.query(
                OrganisationLocation.id).all()
        logging.info(f"{len(all_orgs)} organisations retrieved from database")
        logging.info(
            f"{len(existing_org_location_ids)} organisations have previously been processed"
        )

        # convert to a list of dictionaries with the nested addresses unpacked
        orgs = get_orgs_to_process(all_orgs, existing_org_location_ids)
        logging.info(f"{len(orgs)} new organisations to geocode")

        total_batches = ceil(len(orgs) / batch_size)
        logging.info(f"{total_batches} batches")
        completed_batches = 0
        for batch in split_batches(orgs, batch_size=batch_size):
            # geocode first to add missing country for UK
            batch = map(geocode_uk_with_postcode, batch)
            batch = map(add_country_details, batch)

            # remove data not in OrganisationLocation columns
            org_location_cols = OrganisationLocation.__table__.columns.keys()
            batch = [{k: v
                      for k, v in org.items() if k in org_location_cols}
                     for org in batch]

            insert_data(self.db_config_env, 'mysqldb', database, Base,
                        OrganisationLocation, batch)
            completed_batches += 1
            logging.info(
                f"Completed {completed_batches} of {total_batches} batches")

            if self.test and completed_batches > 1:
                logging.warning("Breaking after 2 batches in test mode")
                break

        # mark as done
        logging.warning("Finished task")
        self.output().touch()
示例#6
0
文件: run.py 项目: yitzikc/nesta
def run():
    batch_file = os.environ['BATCHPAR_batch_file']
    bucket = os.environ['BATCHPAR_bucket']
    db_name = os.environ['BATCHPAR_db_name']
    db_env = "BATCHPAR_config"
    db_section = "mysqldb"

    # Setup the database connectors
    engine = get_mysql_engine(db_env, db_section, db_name)
    try_until_allowed(Base.metadata.create_all, engine)

    # Retrieve RCNs to iterate over
    s3 = boto3.resource('s3')
    obj = s3.Object(bucket, batch_file)
    all_rcn = json.loads(obj.get()['Body']._raw_stream.read())
    logging.info(f"{len(all_rcn)} project RCNs retrieved from s3")

    # Retrieve all topics
    data = defaultdict(list)
    for i, rcn in enumerate(all_rcn):
        logging.info(i)
        project, orgs, reports, pubs = fetch_data(rcn)
        if project is None:
            continue
        _topics = project.pop('topics')
        _calls = project.pop('proposal_call')
        # NB: Order below matters due to FK constraints!
        data['projects'].append(project)
        data['reports'] += prepare_data(reports, rcn)
        data['publications'] += prepare_data(pubs, rcn)
        data['organisations'] += extract_core_orgs(orgs, rcn)
        data['project_organisations'] += prepare_data(orgs, rcn)
        for topics, project_topics in split_links(_topics, rcn):
            data['topics'].append(topics)
            data['project_topics'].append(project_topics)
        for calls, project_calls in split_links(_calls, rcn):
            data['proposal_calls'].append(calls)
            data['project_proposal_calls'].append(project_calls)

    # Pipe the data to the db
    for table_prefix, rows in data.items():
        table_name = f'cordis_{table_prefix}'
        logging.info(table_name)
        _class = get_class_by_tablename(Base, table_name)
        insert_data(db_env,
                    db_section,
                    db_name,
                    Base,
                    _class,
                    rows,
                    low_memory=True)
示例#7
0
文件: run.py 项目: yitzikc/nesta
def run():
    test = literal_eval(os.environ["BATCHPAR_test"])
    db_name = os.environ["BATCHPAR_db_name"]
    batch_size = int(os.environ["BATCHPAR_batch_size"])  # example parameter
    s3_path = os.environ["BATCHPAR_outinfo"]
    start_string = os.environ["BATCHPAR_start_string"],  # example parameter
    offset = int(os.environ["BATCHPAR_offset"])

    # reduce records in test mode
    if test:
        limit = 50
        logging.info(f"Limiting to {limit} rows in test mode")
    else:
        limit = batch_size

    logging.info(f"Processing {offset} - {offset + limit}")

    # database setup
    logging.info(f"Using {db_name} database")
    engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name)
    try_until_allowed(Base.metadata.create_all, engine)

    with db_session(engine) as session:
        # consider moving this query and the one from the prepare step into a package
        batch_records = (session.query(MyTable.id, MyTable.name).filter(
            MyTable.founded_on > '2007-01-01').offset(offset).limit(limit))

    # process and insert data
    processed_batch = []
    for row in batch_records:
        processed_row = some_func(start_string=start_string, row=row)
        processed_batch.append(processed_row)

    logging.info(f"Inserting {len(processed_batch)} rows")
    insert_data("BATCHPAR_config",
                'mysqldb',
                db_name,
                Base,
                MyOtherTable,
                processed_batch,
                low_memory=True)

    logging.info(f"Marking task as done to {s3_path}")
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")

    logging.info("Batch job complete.")
示例#8
0
    def run(self):

        # Get all UK geographies, and group by country and base
        gss_codes = get_gss_codes()
        country_codes = defaultdict(lambda: defaultdict(list))
        for code in gss_codes:
            country = code[0]
            base = code[0:3]
            # Shortened test mode
            if not self.production and base not in ("S32", "S23"):
                continue
            country_codes[country][base].append(code)

        # Iterate through country and base
        output = []
        for country, base_codes in country_codes.items():
            # Try to find children for each base...
            for base in base_codes.keys():
                for base_, codes in base_codes.items():
                    # ...except for the base of the parent
                    if base == base_:
                        continue
                    output += get_children(base, codes)

        # Write to database
        _class = get_class_by_tablename(Base, "onsOpenGeo_geographic_lookup")
        objs = insert_data(MYSQLDB_ENV, "mysqldb",
                           "production" if self.production else "dev", Base,
                           _class, output)
        self.output().touch()
示例#9
0
文件: run.py 项目: hmessafi/nesta
def run():
    logging.getLogger().setLevel(logging.INFO)

    # Fetch the input parameters
    group_urlnames = literal_eval(os.environ["BATCHPAR_group_urlnames"])
    group_urlnames = [x.decode("utf8") for x in group_urlnames]
    s3_path = os.environ["BATCHPAR_outinfo"]
    db = os.environ["BATCHPAR_db"]

    # Generate the groups for these members
    _output = []
    for urlname in group_urlnames:
        _info = get_group_details(urlname, max_results=200)
        if len(_info) == 0:
            continue
        _output.append(_info)
    logging.info("Processed %s groups", len(_output))

    # Flatten the output
    output = flatten_data(_output,
                          keys=[('category', 'name'), ('category',
                                                       'shortname'),
                                ('category', 'id'), 'created', 'country',
                                'city', 'description', 'id', 'lat', 'lon',
                                'members', 'name', 'topics', 'urlname'])

    objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, Group,
                       output[48:49])

    # Mark the task as done
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")

    return len(objs)
示例#10
0
文件: run.py 项目: hmessafi/nesta
def run():
    logging.getLogger().setLevel(logging.INFO)
    
    # Fetch the input parameters
    member_ids = literal_eval(os.environ["BATCHPAR_member_ids"])
    s3_path = os.environ["BATCHPAR_outinfo"]
    db = os.environ["BATCHPAR_db"]

    # Generate the groups for these members
    output = []
    for member_id in member_ids:
        response = get_member_details(member_id, max_results=200)
        output += get_member_groups(response)
    logging.info("Got %s groups", len(output))
    
    # Load connection to the db, and create the tables
    objs = insert_data("BATCHPAR_config", "mysqldb", db,
                       Base, GroupMember, output)
    logging.info("Inserted %s groups", len(objs))
    
    # Mark the task as done
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")

    return len(objs)
示例#11
0
    def run(self):
        config, geogs_list, dataset_id, date_format = process_config(self.config_name,
                                                                     test=not self.production)
        for igeo, geographies in enumerate(geogs_list):
            if igeo == 0:
                continue
            logging.debug(f"Geography number {igeo}")
            done = False
            record_offset = 0
            while not done:
                logging.debug(f"\tOffset of {record_offset}")
                df, done, record_offset = batch_request(config, dataset_id, geographies,
                                                        date_format, max_api_calls=10,
                                                        record_offset=record_offset)
                data = {self.config_name: df}
                tables = reformat_nomis_columns(data)
                for name, table in tables.items():
                    name = name.split('-sic')[0]  # If sic codes are used in the name
                    logging.debug(f"\t\tInserting {len(table)} into nomis_{name}...")
                    _class = get_class_by_tablename(Base, f"nomis_{name}")
                    objs = insert_data(MYSQLDB_ENV, "mysqldb",
                                       "production" if self.production else "dev",
                                       Base, _class, table, low_memory=True)
                    logging.debug(f"\t\tInserted {len(objs)}")
                    
        #data = get_nomis_data(self.config_name, test=not self.production)   
        #tables = reformat_nomis_columns({self.config_name:data})

        self.output().touch()
示例#12
0
    def test_object_to_dict(self):
        parents = [{
            "_id": 10,
            "_another_id": 2,
            "some_field": 20
        }, {
            "_id": 20,
            "_another_id": 2,
            "some_field": 20
        }]
        _parents = insert_data("MYSQLDBCONF", "mysqldb", "production_tests",
                               Base, DummyModel, parents)
        assert len(parents) == len(_parents)

        children = [{
            "_id": 10,
            "parent_id": 10
        }, {
            "_id": 10,
            "parent_id": 20
        }, {
            "_id": 20,
            "parent_id": 20
        }, {
            "_id": 30,
            "parent_id": 20
        }]
        _children = insert_data("MYSQLDBCONF", "mysqldb", "production_tests",
                                Base, DummyChild, children)
        assert len(children) == len(_children)

        # Re-retrieve parents from the database
        found_children = set()
        engine = get_mysql_engine("MYSQLDBCONF", "mysqldb")
        with db_session(engine) as session:
            for p in session.query(DummyModel).all():
                row = object_to_dict(p)
                assert type(row) is dict
                assert len(row['children']) > 0
                _found_children = set(
                    (c['_id'], c['parent_id']) for c in row['children'])
                found_children = found_children.union(_found_children)
                _row = object_to_dict(p, shallow=True)
                assert 'children' not in _row
                del row['children']
                assert row == _row
            assert len(found_children) == len(children) == len(_children)
示例#13
0
文件: run.py 项目: hmessafi/nesta
def run():
    test = literal_eval(os.environ["BATCHPAR_test"])
    db_name = os.environ["BATCHPAR_db_name"]
    table = os.environ["BATCHPAR_table"]
    batch_size = int(os.environ["BATCHPAR_batch_size"])
    s3_path = os.environ["BATCHPAR_outinfo"]

    logging.warning(f"Processing {table} file")

    # database setup
    engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name)
    try_until_allowed(Base.metadata.create_all, engine)
    table_name = f"crunchbase_{table}"
    table_class = get_class_by_tablename(Base, table_name)

    # collect file
    nrows = 1000 if test else None
    df = get_files_from_tar([table], nrows=nrows)[0]
    logging.warning(f"{len(df)} rows in file")

    # get primary key fields and set of all those already existing in the db
    pk_cols = list(table_class.__table__.primary_key.columns)
    pk_names = [pk.name for pk in pk_cols]
    with db_session(engine) as session:
        existing_rows = set(session.query(*pk_cols).all())

    # process and insert data
    processed_rows = process_non_orgs(df, existing_rows, pk_names)
    for batch in split_batches(processed_rows, batch_size):
        insert_data("BATCHPAR_config",
                    'mysqldb',
                    db_name,
                    Base,
                    table_class,
                    processed_rows,
                    low_memory=True)

    logging.warning(f"Marking task as done to {s3_path}")
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")

    logging.warning("Batch job complete.")
示例#14
0
    def run(self):
        '''Run the data collection'''
        #engine = get_mysql_engine(MYSQLDB_ENV, "mysqldb",
        #                          self.db_config['database'])
        #Base.metadata.create_all(engine)
        #Session = sessionmaker(engine)
        #session = Session()
        wiki_date = find_latest_wikidump()
        ngrams = extract_ngrams(wiki_date)
        if self.test:
            ngrams = list(ngrams)[0:100]
        #for n in ngrams:
        #    ngram = WiktionaryNgram(ngram=n)
        #    session.add(ngram)
        #session.commit()
        #session.close()
        insert_data(MYSQLDB_ENV, "mysqldb", self.db_config['database'], Base,
                    WiktionaryNgram, [dict(ngram=n) for n in ngrams])

        self.output().touch()
示例#15
0
    def _insert_new_locations_no_id(self):
        """Checks for new city/country combinations and appends them to the geographic
        data table in mysql IF NO location_key_col IS PROVIDED.
        """
        limit = self.test_limit if self.test else None
        with db_session(self.engine) as session:
            existing_location_ids = {
                i[0]
                for i in session.query(Geographic.id).all()
            }
            new_locations = []
            all_locations = {(city, country)
                             for city, country in (session.query(
                                 self.city_col, self.country_col).limit(limit))
                             }
            nulls = []
            for city, country in all_locations:
                if self.country_is_iso2:
                    country = country_iso_code_to_name(country, iso2=True)
                if city is None or country is None:
                    nulls.append((city, country))
                    continue
                key = generate_composite_key(city, country)
                if key not in existing_location_ids and key is not None:
                    logging.info(f"new location {city}, {country}")
                    new_locations.append(
                        dict(id=key, city=city, country=country))
                    existing_location_ids.add(key)

        if len(nulls) > 0:
            logging.warning(f"{len(nulls)} locations had a null city or "
                            "country, so won't be processed.")
            logging.warning(nulls)
        if new_locations:
            logging.warning(
                f"Adding {len(new_locations)} new locations to database")
            insert_data(self.db_config_env, "mysqldb", self.database, Base,
                        Geographic, new_locations)
示例#16
0
    def test_db_session_query(self):
        parents = [{
            "_id": i,
            "_another_id": i,
            "some_field": 20
        } for i in range(0, 1000)]
        _parents = insert_data("MYSQLDBCONF", "mysqldb", "production_tests",
                               Base, DummyModel, parents)

        # Re-retrieve parents from the database
        engine = get_mysql_engine("MYSQLDBCONF", "mysqldb")

        # Test for limit = 3
        limit = 3
        old_db = mock.MagicMock()
        old_db.is_active = False
        n_rows = 0
        for db, row in db_session_query(query=DummyModel,
                                        engine=engine,
                                        chunksize=10,
                                        limit=limit):
            assert type(row) is DummyModel
            if old_db != db:
                assert len(old_db.transaction._connections) == 0
                assert len(db.transaction._connections) > 0
            old_db = db
            n_rows += 1
        assert n_rows == limit

        # Test for limit = None
        old_db = mock.MagicMock()
        old_db.is_active = False
        n_rows = 0
        for db, row in db_session_query(query=DummyModel,
                                        engine=engine,
                                        chunksize=100,
                                        limit=None):
            assert type(row) is DummyModel
            if old_db != db:
                assert len(old_db.transaction._connections) == 0
                assert len(db.transaction._connections) > 0
            old_db = db
            n_rows += 1
        assert n_rows == len(parents) == 1000
示例#17
0
文件: run.py 项目: yitzikc/nesta
def run():
    logging.getLogger().setLevel(logging.INFO)

    # Fetch the input parameters
    group_urlname = os.environ["BATCHPAR_group_urlname"]
    group_id = os.environ["BATCHPAR_group_id"]
    s3_path = os.environ["BATCHPAR_outinfo"]
    db = os.environ["BATCHPAR_db"]

    # Collect members
    logging.info("Getting %s", group_urlname)
    output = get_all_members(group_id, group_urlname, max_results=200)
    logging.info("Got %s members", len(output))

    # Load connection to the db, and create the tables
    objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, GroupMember,
                       output)
    # Mainly for testing
    return len(objs)
示例#18
0
文件: run.py 项目: yitzikc/nesta
def run():
    PAGE_SIZE = int(os.environ['BATCHPAR_PAGESIZE'])
    page = int(os.environ['BATCHPAR_page'])
    db = os.environ["BATCHPAR_db"]
    s3_path = os.environ["BATCHPAR_outinfo"]

    data = defaultdict(list)

    # Get all projects on this page
    projects = read_xml_from_url(TOP_URL, p=page, s=PAGE_SIZE)
    for project in projects.getchildren():        
        # Extract the data for the project into 'row'
        # Then recursively extract data from nested rows into the parent 'row'
        _, row = extract_data(project)
        extract_data_recursive(project, row)
        # Flatten out any list data directly into 'data'
        unpack_list_data(row, data)
        # Append the row
        data[row.pop('entity')].append(row)

    # Much of the participant data is repeated so remove overlaps
    if 'participant' in data:
        deduplicate_participants(data)
    # Finally, extract links between entities and the core projects
    extract_link_table(data)
    
    objs = []
    for table_name, rows in data.items():
        _class = get_class_by_tablename(Base, f"gtr_{table_name}")
        # Remove any fields that aren't in the ORM
        cleaned_rows = [{k:v for k, v in row.items() if k in _class.__dict__}
                        for row in rows]
        objs += insert_data("BATCHPAR_config", "mysqldb", db,
                            Base, _class, cleaned_rows)

    # Mark the task as done
    if s3_path != "":
        s3 = boto3.resource('s3')
        s3_obj = s3.Object(*parse_s3_path(s3_path))
        s3_obj.put(Body="")

    return len(objs)
示例#19
0
文件: run.py 项目: hmessafi/nesta
def run():
    logging.getLogger().setLevel(logging.INFO)

    # Fetch the input parameters
    iso2 = os.environ["BATCHPAR_iso2"]
    name = os.environ["BATCHPAR_name"]
    category = os.environ["BATCHPAR_cat"]
    coords = literal_eval(os.environ["BATCHPAR_coords"])
    radius = float(os.environ["BATCHPAR_radius"])
    s3_path = os.environ["BATCHPAR_outinfo"]
    db = os.environ["BATCHPAR_db"]

    # Get the data
    mcg = MeetupCountryGroups(country_code=iso2,
                              category=category,
                              coords=coords,
                              radius=radius)
    mcg.get_groups_recursive()
    output = flatten_data(mcg.groups,
                          country_name=name,
                          country=iso2,
                          timestamp=func.utc_timestamp(),
                          keys=[('category', 'name'), ('category',
                                                       'shortname'),
                                ('category', 'id'), 'description', 'created',
                                'country', 'city', 'id', 'lat', 'lon',
                                'members', 'name', 'topics', 'urlname'])

    # Add the data
    objs = insert_data("BATCHPAR_config", "mysqldb", db, Base, Group, output)

    # Mark the task as done
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")

    # Mainly for testing
    return len(objs)
示例#20
0
    def run(self):
        # Load the input data (note the input contains the path
        # to the output)
        _filename = self.cherry_picked
        if _filename is None:
            _body = self.input().open("rb")
            _filename = _body.read().decode('utf-8')
        obj = s3.S3Target(f"{self.raw_s3_path_prefix}/"
                          f"{_filename}").open('rb')
        data = json.load(obj)

        # Get DB connections and settings
        database = 'dev' if self.test else 'production'
        engine = get_mysql_engine(self.db_conf_env, 'mysqldb', database)
        ArticleTopic.__table__.drop(engine)
        CorExTopic.__table__.drop(engine)

        # Insert the topic names data
        topics = [{
            'id': int(topic_name.split('_')[-1]) + 1,
            'terms': terms
        } for topic_name, terms in data['data']['topic_names'].items()]
        insert_data(self.db_conf_env,
                    'mysqldb',
                    database,
                    Base,
                    CorExTopic,
                    topics,
                    low_memory=True)
        logging.info(f'Inserted {len(topics)} topics')

        # Insert article topic weight data
        topic_articles = []
        done_ids = set()
        for row in data['data']['rows']:
            article_id = row.pop('id')
            if article_id in done_ids:
                continue
            done_ids.add(article_id)
            topic_articles += [{
                'topic_id': int(topic_name.split('_')[-1]) + 1,
                'topic_weight': weight,
                'article_id': article_id
            } for topic_name, weight in row.items()]
            # Flush
            if len(topic_articles) > self.insert_batch_size:
                insert_data(self.db_conf_env,
                            'mysqldb',
                            database,
                            Base,
                            ArticleTopic,
                            topic_articles,
                            low_memory=True)
                topic_articles = []

        # Final flush
        if len(topic_articles) > 0:
            insert_data(self.db_conf_env,
                        'mysqldb',
                        database,
                        Base,
                        ArticleTopic,
                        topic_articles,
                        low_memory=True)

        # Touch the output
        self.output().touch()
    def run(self):
        """Collect and process organizations, categories and long descriptions."""

        # database setup
        database = 'dev' if self.test else 'production'
        logging.warning(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        try_until_allowed(Base.metadata.create_all, self.engine)

        # collect files
        nrows = 200 if self.test else None
        cat_groups, orgs, org_descriptions = get_files_from_tar(
            ['category_groups', 'organizations', 'organization_descriptions'],
            nrows=nrows)
        # process category_groups
        cat_groups = rename_uuid_columns(cat_groups)
        insert_data(self.db_config_env,
                    'mysqldb',
                    database,
                    Base,
                    CategoryGroup,
                    cat_groups.to_dict(orient='records'),
                    low_memory=True)

        # process organizations and categories
        with db_session(self.engine) as session:
            existing_orgs = session.query(Organization.id).all()
        existing_orgs = {org[0] for org in existing_orgs}

        logging.info("Summary of organisation data:")
        logging.info(f"Total number of organisations:\t {len(orgs)}")
        logging.info(
            f"Number of organisations already in the database:\t {len(existing_orgs)}"
        )
        logging.info(f"Number of category groups and text descriptions:\t"
                     f"{len(cat_groups)}, {len(org_descriptions)}")

        processed_orgs, org_cats, missing_cat_groups = process_orgs(
            orgs, existing_orgs, cat_groups, org_descriptions)
        # Insert CatGroups
        insert_data(self.db_config_env, 'mysqldb', database, Base,
                    CategoryGroup, missing_cat_groups)
        # Insert orgs in batches
        n_batches = round(len(processed_orgs) / self.insert_batch_size)
        logging.info(
            f"Inserting {n_batches} batches of size {self.insert_batch_size}")
        for i, batch in enumerate(
                split_batches(processed_orgs, self.insert_batch_size)):
            if i % 100 == 0:
                logging.info(f"Inserting batch {i} of {n_batches}")
            insert_data(self.db_config_env,
                        'mysqldb',
                        database,
                        Base,
                        Organization,
                        batch,
                        low_memory=True)

        # link table needs to be inserted via non-bulk method to enforce relationship
        logging.info("Filtering duplicates...")
        org_cats, existing_org_cats, failed_org_cats = filter_out_duplicates(
            self.db_config_env,
            'mysqldb',
            database,
            Base,
            OrganizationCategory,
            org_cats,
            low_memory=True)
        logging.info(
            f"Inserting {len(org_cats)} org categories "
            f"({len(existing_org_cats)} already existed and {len(failed_org_cats)} failed)"
        )
        #org_cats = [OrganizationCategory(**org_cat) for org_cat in org_cats]
        with db_session(self.engine) as session:
            session.add_all(org_cats)

        # mark as done
        self.output().touch()
示例#22
0
文件: run.py 项目: yitzikc/nesta
def run():
    db_name = os.environ["BATCHPAR_db_name"]
    s3_path = os.environ["BATCHPAR_outinfo"]
    start_cursor = int(os.environ["BATCHPAR_start_cursor"])
    end_cursor = int(os.environ["BATCHPAR_end_cursor"])
    batch_size = end_cursor - start_cursor
    logging.warning(f"Retrieving {batch_size} articles between {start_cursor - 1}:{end_cursor - 1}")

    # Setup the database connectors
    engine = get_mysql_engine("BATCHPAR_config", "mysqldb", db_name)
    try_until_allowed(Base.metadata.create_all, engine)

    # load arxiv subject categories to database
    bucket = 'innovation-mapping-general'
    cat_file = 'arxiv_classification/arxiv_subject_classifications.csv'
    load_arxiv_categories("BATCHPAR_config", db_name, bucket, cat_file)

    # process data
    articles = []
    article_cats = []
    resumption_token = request_token()
    for row in retrieve_arxiv_batch_rows(start_cursor, end_cursor, resumption_token):
        with db_session(engine) as session:
            categories = row.pop('categories', [])
            articles.append(row)
            for cat in categories:
                # TODO:this is inefficient and should be queried once to a set. see
                # iterative proceess.
                try:
                    session.query(Category).filter(Category.id == cat).one()
                except NoResultFound:
                    logging.warning(f"missing category: '{cat}' for article {row['id']}.  Adding to Category table")
                    session.add(Category(id=cat))
                article_cats.append(dict(article_id=row['id'], category_id=cat))

    inserted_articles, existing_articles, failed_articles = insert_data(
                                                "BATCHPAR_config", "mysqldb", db_name,
                                                Base, Article, articles,
                                                return_non_inserted=True)
    logging.warning(f"total article categories: {len(article_cats)}")
    inserted_article_cats, existing_article_cats, failed_article_cats = insert_data(
                                                "BATCHPAR_config", "mysqldb", db_name,
                                                Base, ArticleCategory, article_cats,
                                                return_non_inserted=True)

    # sanity checks before the batch is marked as done
    logging.warning((f'inserted articles: {len(inserted_articles)} ',
                     f'existing articles: {len(existing_articles)} ',
                     f'failed articles: {len(failed_articles)}'))
    logging.warning((f'inserted article categories: {len(inserted_article_cats)} ',
                     f'existing article categories: {len(existing_article_cats)} ',
                     f'failed article categories: {len(failed_article_cats)}'))
    if len(inserted_articles) + len(existing_articles) + len(failed_articles) != batch_size:
        raise ValueError(f'Inserted articles do not match original data.')
    if len(inserted_article_cats) + len(existing_article_cats) + len(failed_article_cats) != len(article_cats):
        raise ValueError(f'Inserted article categories do not match original data.')

    # Mark the task as done
    s3 = boto3.resource('s3')
    s3_obj = s3.Object(*parse_s3_path(s3_path))
    s3_obj.put(Body="")
示例#23
0
    def run(self):
        # s3 setup
        s3 = boto3.resource('s3')
        intermediate_file = s3.Object(BUCKET, INTERMEDIATE_FILE)

        # database setup
        database = 'dev' if self.test else 'production'
        logging.info(f"Using {database} database")
        self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database)
        Base.metadata.create_all(self.engine)

        eu = get_eu_countries()
        logging.info(f"Retrieved {len(eu)} EU countries")

        with db_session(self.engine) as session:
            all_fos_ids = {f.id for f in (session
                                          .query(FieldOfStudy.id)
                                          .all())}
            logging.info(f"{len(all_fos_ids):,} fields of study in database")

            eu_grid_ids = {i.id for i in (session
                                          .query(Institute.id)
                                          .filter(Institute.country.in_(eu))
                                          .all())}
            logging.info(f"{len(eu_grid_ids):,} EU institutes in GRID")

        try:
            processed_grid_ids = set(json.loads(intermediate_file
                                                .get()['Body']
                                                ._raw_stream.read()))

            logging.info(f"{len(processed_grid_ids)} previously processed institutes")
            eu_grid_ids = eu_grid_ids - processed_grid_ids
            logging.info(f"{len(eu_grid_ids):,} institutes to process")
        except ClientError:
            logging.info("Unable to load file of processed institutes, starting from scratch")
            processed_grid_ids = set()

        if self.test:
            self.batch_size = 500
            batch_limit = 1
        else:
            batch_limit = None
        testing_finished = False

        row_count = 0
        for institute_count, grid_id in enumerate(eu_grid_ids):
            paper_ids, author_ids = set(), set()
            data = {Paper: [],
                    Author: [],
                    PaperAuthor: set(),
                    PaperFieldsOfStudy: set(),
                    PaperLanguage: set()}

            if not institute_count % 50:
                logging.info(f"{institute_count:,} of {len(eu_grid_ids):,} institutes processed")

            if not check_institute_exists(grid_id):
                logging.debug(f"{grid_id} not found in MAG")
                continue

            # these tables have data stored in sets for deduping so the fieldnames will
            # need to be added when converting to a list of dicts for loading to the db
            field_names_to_add = {PaperAuthor: ('paper_id', 'author_id'),
                                  PaperFieldsOfStudy: ('paper_id', 'field_of_study_id'),
                                  PaperLanguage: ('paper_id', 'language')}

            logging.info(f"Querying MAG for {grid_id}")
            for row in query_by_grid_id(grid_id,
                                        from_date=self.from_date,
                                        min_citations=self.min_citations,
                                        batch_size=self.batch_size,
                                        batch_limit=batch_limit):

                fos_id = row['fieldOfStudyId']
                if fos_id not in all_fos_ids:
                    logging.info(f"Getting missing field of study {fos_id} from MAG")
                    update_field_of_study_ids_sparql(self.engine, fos_ids=[fos_id])
                    all_fos_ids.add(fos_id)

                # the returned data is normalised and therefore contains many duplicates
                paper_id = row['paperId']
                if paper_id not in paper_ids:
                    data[Paper].append({'id': paper_id,
                                        'title': row['paperTitle'],
                                        'citation_count': row['paperCitationCount'],
                                        'created_date': row['paperCreatedDate'],
                                        'doi': row.get('paperDoi'),
                                        'book_title': row.get('bookTitle')})
                    paper_ids.add(paper_id)

                author_id = row['authorId']
                if author_id not in author_ids:
                    data[Author].append({'id': author_id,
                                         'name': row['authorName'],
                                         'grid_id': grid_id})
                    author_ids.add(author_id)

                data[PaperAuthor].add((row['paperId'], row['authorId']))

                data[PaperFieldsOfStudy].add((row['paperId'], row['fieldOfStudyId']))

                try:
                    data[PaperLanguage].add((row['paperId'], row['paperLanguage']))
                except KeyError:
                    # language is an optional field
                    pass

                row_count += 1
                if self.test and row_count >= 1000:
                    logging.warning("Breaking after 1000 rows in test mode")
                    testing_finished = True
                    break

            # write out to SQL
            for table, rows in data.items():
                if table in field_names_to_add:
                    rows = [{k: v for k, v in zip(field_names_to_add[table], row)}
                            for row in rows]
                logging.debug(f"Writing {len(rows):,} rows to {table.__table__.name}")

                for batch in split_batches(rows, self.insert_batch_size):
                    insert_data('MYSQLDB', 'mysqldb', database, Base, table, batch)

            # flag institute as completed on S3
            processed_grid_ids.add(grid_id)
            intermediate_file.put(Body=json.dumps(list(processed_grid_ids)))

            if testing_finished:
                break


        # mark as done
        logging.info("Task complete")
        self.output().touch()