コード例 #1
0
def results(id):
    select_stmt = "SELECT q.name, d.name as department, q.description, q.query " \
                  "FROM query q JOIN departments d ON q.department_id = d.id " \
                  "WHERE q.id=%s;"
    with RealDictConnection(dsn=local) as conn:
        with conn.cursor() as cursor:
            cursor.execute(select_stmt, (str(id), ))
            res = cursor.fetchone()
    if res:
        with RealDictConnection(dsn=local) as conn:
            with conn.cursor() as cursor:
                cursor.execute(res['query'])
                result = cursor.fetchall()
                header = result[0].keys()
        if request.args.get('download', '').strip():
            si = StringIO()
            f = csv.writer(si)
            f.writerow(header)
            f.writerows([row.values() for row in result])
            output = make_response(si.getvalue())
            output.headers["Content-Disposition"] = "attachment; filename=%s.csv" \
                                                    % str(res['name'])
            output.headers["Content-type"] = "text/csv"
            return output
        else:
            return render_template('results.html',
                                   details=res,
                                   rows=result[0:5],
                                   id=id,
                                   header=header)
    else:
        return 'Query with id %s does not exist!' % str(id)
コード例 #2
0
def update_course_sections(conn: RealDictConnection, semester_id: str, course_sections: List[CourseSection]):
    c = conn.cursor()

    c.execute(
        "DELETE FROM course_section_periods WHERE semester_id=%s", (semester_id,))
    c.execute("DELETE FROM course_sections WHERE semester_id=%s", (semester_id,))

    print(f"Adding {len(course_sections)} sections...", flush=True)
    for course_section in course_sections:
        record = course_section.to_record()

        # Add new record
        q = Query \
            .into(course_sections_t) \
            .columns(*record.keys()) \
            .insert(*record.values())
        c.execute(str(q))

        # Add course sections
        if len(course_section.periods) > 0:
            q = Query \
                .into(periods_t) \
                .columns(*course_section.periods[0].dict().keys())

            for period in course_section.periods:
                q = q.insert(*period.to_record().values())

            c.execute(str(q))
    if len(course_sections) > 0:
        conn.commit()
        print("Done!", flush=True)
    else:
        print("No course sections found... rolling back any changes")
コード例 #3
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def update_course_sections(conn: RealDictConnection, semester_id: str,
                           course_sections: List[CourseSection]):
    c = conn.cursor()

    c.execute("DELETE FROM course_section_periods WHERE semester_id=%s",
              (semester_id, ))
    c.execute("DELETE FROM course_sections WHERE semester_id=%s",
              (semester_id, ))

    print(f"Adding {len(course_sections)} sections")
    for course_section in course_sections:
        record = course_section.to_record()

        # Add new record
        q = Query \
            .into(course_sections_t) \
            .columns(*record.keys()) \
            .insert(*record.values())
        print(str(q))
        c.execute(str(q))

        # Add course sections
        if len(course_section.periods) > 0:
            q = Query \
                .into(periods_t) \
                .columns(*course_section.periods[0].dict().keys())

            for period in course_section.periods:
                q = q.insert(*period.to_record().values())

            # https://github.com/kayak/pypika/issues/527
            workaround = str(q).replace("ARRAY[]", "'{}'")
            print(workaround)
            c.execute(workaround)
    conn.commit()
コード例 #4
0
 def setUp(self):
     log = logging.getLogger("setUp")
     log.debug("installing schema and test data")
     _install_schema_and_test_data()
     log.debug("creating database connection")
     self._connection = RealDictConnection(
         get_node_database_dsn(_node_name, _database_password,
                               _database_host, _database_port))
     log.debug("setup done")
コード例 #5
0
def edit(id):
    select_stmt = "SELECT q.name, d.name as department, q.query, q.description " \
                  "FROM query q JOIN departments d ON q.department_id = d.id WHERE q.id=%s;"
    update_stmt = "UPDATE query SET description=%s, query=%s WHERE id=%s;"
    delete_stmt = "DELETE FROM query WHERE id=%s;"

    if request.args.get('save', '').strip():
        description = request.args.get('description', '').strip()
        query = request.args.get('query', '').strip()

        with psycopg2.connect(dsn=local) as conn:
            with conn.cursor() as cursor:
                cursor.execute(update_stmt, (description, query, str(id)))
                conn.commit()
        return redirect(url_for('home'))

    elif request.args.get('delete', '').strip():
        with psycopg2.connect(dsn=local) as conn:
            with conn.cursor() as cursor:
                cursor.execute(delete_stmt, (str(id), ))
                conn.commit()
        return redirect(url_for('home'))

    else:
        with RealDictConnection(dsn=local) as conn:
            with conn.cursor() as cursor:
                cursor.execute(select_stmt, (str(id), ))
                res = cursor.fetchone()
        return render_template('edit.html', details=res)
コード例 #6
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def search_course_sections(conn: RealDictConnection, semester_id: str,
                           limit: int, offset: int, **search):
    c = conn.cursor()

    q: QueryBuilder = (course_sections_q.select("*").where(
        course_sections_t.semester_id == semester_id).limit(limit).offset(
            offset))

    # Values that require exact matches
    for col in ["course_number", "course_subject_prefix"]:
        if search[col]:
            q = q.where(course_sections_t[col] == search[col])

    # Values that require wildcards
    for col in ["course_title"]:
        if search[col]:
            q = q.where(course_sections_t[col].ilike(f"%{search[col]}%"))

    # Special values that require complex checks
    if search["has_seats"] == False:
        q = q.where(
            course_sections_t.enrollments >= course_sections_t.max_enrollments)

    if search["has_seats"] == True:
        q = q.where(
            course_sections_t.enrollments < course_sections_t.max_enrollments)

    c.execute(q.get_sql())
    records = c.fetchall()

    return records_to_sections(conn, semester_id, records)
コード例 #7
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def fetch_course_sections(conn: RealDictConnection, semester_id: str,
                          crns: List[str]) -> CourseSection:
    c = conn.cursor()

    # Create query to fetch course sections
    q: QueryBuilder = (course_sections_q.select("*").where(
        course_sections_t.semester_id == semester_id).where(
            course_sections_t.crn.isin(crns)))

    c.execute(q.get_sql())
    course_section_records = c.fetchall()

    # BIG BRAIN MOVE:
    # Instead of making a separate query for each section's periods, fetch them all first and them associate them with their section
    q: QueryBuilder = periods_q.where(
        periods_t.semester_id == semester_id).where(periods_t.crn.isin(crns))

    c.execute(q.get_sql())
    period_records = c.fetchall()

    # Match the periods fetched to their course section records!
    sections = []
    for record in course_section_records:
        # Find period records for this course section
        section_period_records = filter(lambda pr: pr["crn"] == record["crn"],
                                        period_records)
        # Turn those period records into CourseSectionPeriods
        periods = list(
            map(CourseSectionPeriod.from_record, section_period_records))
        # Add created CourseSection
        sections.append(CourseSection.from_record(record, periods))

    return sections
コード例 #8
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def fetch_course_subject_prefixes(conn: RealDictConnection) -> List[str]:
    cursor = conn.cursor()
    q: QueryBuilder = (Query.from_(course_sections_t).select(
        course_sections_t.course_subject_prefix).groupby(
            course_sections_t.course_subject_prefix).orderby(
                course_sections_t.course_subject_prefix))

    cursor.execute(q.get_sql())
    return list(
        map(lambda record: record["course_subject_prefix"], cursor.fetchall()))
コード例 #9
0
 def setUp(self):
     log = logging.getLogger("setUp")
     log.debug("installing schema and test data")
     _install_schema_and_test_data()
     log.debug("creating database connection")
     self._connection = RealDictConnection(
         get_node_database_dsn(_node_name, 
                               _database_password, 
                               _database_host, 
                               _database_port))
     log.debug("setup done")
コード例 #10
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def fetch_course_section_periods(conn: RealDictConnection, semester_id: str,
                                 crn: str) -> List[CourseSectionPeriod]:
    c = conn.cursor()
    c.execute(
        "SELECT * FROM course_section_periods WHERE semester_id=%s and crn=%s",
        (semester_id, crn),
    )
    course_section_periods_raw = c.fetchall()

    return list(
        map(CourseSectionPeriod.from_record, course_section_periods_raw))
コード例 #11
0
ファイル: get_node_ids.py プロジェクト: HackLinux/nimbus.io
def get_node_ids(node_name):
    psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
    psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
    connection = RealDictConnection(get_central_database_dsn())
    cursor = connection.cursor()
    query = """select id, name from nimbusio_central.node 
               where cluster_id = 
                   (select cluster_id from nimbusio_central.node
                    where name = %s)"""

    cursor.execute(query, [node_name, ])

    # we assume node-name will never be the same as node-id
    node_dict = dict()
    for entry in cursor.fetchall():
        node_dict[entry["id"]] = entry["name"]
        node_dict[entry["name"]] = entry["id"]

    cursor.close()
    connection.close()

    return node_dict
コード例 #12
0
def get_node_ids(node_name):
    psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
    psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
    connection = RealDictConnection(get_central_database_dsn())
    cursor = connection.cursor()
    query = """select id, name from nimbusio_central.node 
               where cluster_id = 
                   (select cluster_id from nimbusio_central.node
                    where name = %s)"""

    cursor.execute(query, [
        node_name,
    ])

    # we assume node-name will never be the same as node-id
    node_dict = dict()
    for entry in cursor.fetchall():
        node_dict[entry["id"]] = entry["name"]
        node_dict[entry["name"]] = entry["id"]

    cursor.close()
    connection.close()

    return node_dict
コード例 #13
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def populate_course_periods(conn: RealDictConnection, semester_id: str,
                            courses: List[Course], include_periods: bool):
    cursor = conn.cursor()

    for course in courses:
        q: QueryBuilder = (course_sections_q.select("*").where(
            course_sections_t.semester_id == semester_id).where(
                course_sections_t.course_subject_prefix ==
                course.subject_prefix).where(
                    course_sections_t.course_number == course.number).where(
                        course_sections_t.course_title ==
                        course.title).orderby(course_sections_t.section_id))

        cursor.execute(q.get_sql())
        records = cursor.fetchall()
        # TODO: check include_periods
        course.sections = list(map(CourseSection.from_record, records))
コード例 #14
0
def get_rs(queries):

    rs = dict()
    connection = RealDictConnection(config.conn_string)
    connection.set_client_encoding('latin1')
    cursor = connection.cursor()
    cursor.execute('begin transaction isolation level serializable;')
    for query_nome, query in queries.items():
        cursor.execute(query['query'], query.get('args', {}))
        rs[query_nome] = cursor.fetchall()
    cursor.execute('commit;')
    cursor.close()
    connection.close()

    return rs
コード例 #15
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def fetch_courses_without_sections(conn: RealDictConnection, semester_id: str,
                                   limit: int, offset: int,
                                   **search) -> List[Course]:
    c = conn.cursor()

    q: QueryBuilder = (course_sections_q.select(
        course_sections_t.semester_id).select(
            course_sections_t.course_subject_prefix.as_("subject_prefix")
        ).select(course_sections_t.course_number.as_("number")).select(
            course_sections_t.course_title.as_("title")).where(
                course_sections_t.semester_id == semester_id).groupby(
                    course_sections_t.semester_id).groupby(
                        course_sections_t.course_subject_prefix).groupby(
                            course_sections_t.course_number).groupby(
                                course_sections_t.course_title).limit(
                                    limit).offset(offset))

    c.execute(q.get_sql())
    return list(map(lambda r: Course(**r), c.fetchall()))
コード例 #16
0
def get_node_databases():
    """
    return a dict of database connections keyed by node name
    """
    psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
    psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)

    node_databases = dict()
    node_database_list = zip(_node_names, _node_database_hosts,
                             _node_database_ports, _node_database_passwords)

    for name, host, port, password in node_database_list:

        # XXX: temporary expedient until boostrap is fixed
        host = 'localhost'
        dsn = get_node_database_dsn(name, password, host, port)
        connection = RealDictConnection(dsn)
        node_databases[name] = connection

    return node_databases
コード例 #17
0
ファイル: proto.py プロジェクト: jMyles/records
class Warehouse(object):
    def __init__(self, conn_str):
        self.db = RealDictConnection(conn_str)
        register_hstore(self.db)

    def query(self, q, params=None, fetchall=False):
        c = self.db.cursor()
        c.execute(q, params)

        gen = (r for r in c)

        if fetchall:
            return list(gen)
        else:
            return gen

    def query_file(self, path, params=None, fetchall=False):
        with open(path) as f:
            query = f.read()

        return self.query(query, params=params, fetchall=fetchall)
コード例 #18
0
class TestSegmentVisibility(unittest.TestCase):
    """
    test segment visibility subsystem
    """
    def setUp(self):
        log = logging.getLogger("setUp")
        log.debug("installing schema and test data")
        _install_schema_and_test_data()
        log.debug("creating database connection")
        self._connection = RealDictConnection(
            get_node_database_dsn(_node_name, _database_password,
                                  _database_host, _database_port))
        log.debug("setup done")

    def tearDown(self):
        log = logging.getLogger("tearDown")
        log.debug("teardown starts")
        if hasattr(self, "_connection"):
            self._connection.close()
            delattr(self, "_connection")
        log.debug("teardown done")

    def _retrieve_collectables(self, versioned):
        """
        check that none of these rows appear in any other result.
        check that the rows from other results are not included here.
        """
        sql_text = collectable_archive(_test_collection_id,
                                       versioned=versioned)

        args = {
            "collection_id": _test_collection_id,
        }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        rows = cursor.fetchall()
        cursor.close()

        # there should always be some garbage. If there's not, something's
        # wrong.
        self.assertGreater(len(rows), 0)

        return set([(
            r["key"],
            r["unified_id"],
        ) for r in rows])

    #@unittest.skip("isolate test")
    def test_no_such_collectable(self):
        """
        test that there are no collectable rows for a bogus unified_id
        """
        log = logging.getLogger("test_no_such_collectable")

        sql_text = collectable_archive(_test_collection_id,
                                       versioned=True,
                                       key=_test_key,
                                       unified_id=_test_no_such_unified_id)

        args = {
            "collection_id": _test_collection_id,
            "key": _test_key,
            "unified_id": _test_no_such_unified_id
        }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        rows = cursor.fetchall()
        cursor.close()
        self.assertEqual(len(rows), 0, rows)

    #@unittest.skip("isolate test")
    def test_list(self):
        """
        test listing keys and versions of keys
        """
        log = logging.getLogger("test_list")

        versioned = False
        sql_text = list_versions(_test_collection_id,
                                 versioned=versioned,
                                 prefix=_test_prefix,
                                 limit=None)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        if _write_debug_sql:
            with open("/tmp/debug_unversioned_rows.sql",
                      "w") as debug_sql_file:
                debug_sql_file.write(mogrify(sql_text, args))

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        unversioned_rows = cursor.fetchall()
        cursor.close()

        collectable_set = self._retrieve_collectables(versioned)
        test_set = set([(
            r["key"],
            r["unified_id"],
        ) for r in unversioned_rows])
        collectable_intersection = test_set & collectable_set
        self.assertEqual(len(collectable_intersection), 0,
                         collectable_intersection)

        # check that there's no more than one row per key for a non-versioned
        # collection
        # check that every row begins with prefix
        unversioned_key_counts = Counter()
        for row in unversioned_rows:
            unversioned_key_counts[row["key"]] += 1
            self.assertTrue(row["key"].startswith(_test_prefix))
        for key, value in unversioned_key_counts.items():
            self.assertEqual(value, 1, (key, value))

        versioned = True
        sql_text = list_versions(_test_collection_id,
                                 versioned=versioned,
                                 prefix=_test_prefix)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        if _write_debug_sql:
            with open("/tmp/debug_versioned_rows.sql", "w") as debug_sql_file:
                debug_sql_file.write(mogrify(sql_text, args))

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        versioned_rows = cursor.fetchall()
        cursor.close()

        collectable_set = self._retrieve_collectables(versioned)
        test_set = set([(
            r["key"],
            r["unified_id"],
        ) for r in versioned_rows])
        collectable_intersection = test_set & collectable_set
        self.assertEqual(len(collectable_intersection), 0,
                         collectable_intersection)

        latest_versioned_rows = OrderedDict()
        for row in versioned_rows[::-1]:
            latest_versioned_rows.setdefault(row["key"], row)
        latest_versioned_rows = latest_versioned_rows.values()
        latest_versioned_rows.reverse()
        assert len(latest_versioned_rows) <= len(versioned_rows)

        versioned_key_counts = Counter()
        for row in versioned_rows:
            versioned_key_counts[row["key"]] += 1
            self.assertTrue(row["key"].startswith(_test_prefix))

        # check that there's >= as many rows now as above.
        for key, value in versioned_key_counts.items():
            self.assertTrue(value >= unversioned_key_counts[key], (key, value))

        # check that the list keys result is consistent with list_versions in
        # above (although there could be extra columns.)  Note that
        # list_keys(versioned=True) may have records that
        # list_versions(versioned=False) does not have, because there are more
        # ways for a segment to become eligible for garbage collection in an
        # unversioned collection.

        for versioned in [
                False,
                True,
        ]:
            sql_text = list_keys(_test_collection_id,
                                 versioned=versioned,
                                 prefix=_test_prefix,
                                 limit=None)

            args = {
                "collection_id": _test_collection_id,
                "prefix": _test_prefix,
            }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            key_rows = cursor.fetchall()
            cursor.close()

            if _write_debug_sql:
                debug_filename = "/tmp/debug_key_rows_versioned_%r.sql" % (
                    versioned, )
                with open(debug_filename, "w") as debug_sql_file:
                    debug_sql_file.write(mogrify(sql_text, args))

            collectable_set = self._retrieve_collectables(versioned)
            test_set = set([(
                r["key"],
                r["unified_id"],
            ) for r in key_rows])
            collectable_intersection = test_set & collectable_set
            self.assertEqual(len(collectable_intersection), 0,
                             collectable_intersection)

            if versioned:
                # a list of keys with versioning on may have keys that don't
                # show up in the list of unversioned rows.  That's because in
                # an unversioned collection, keys end when another key is
                # added.  So it's possible for that plus a tombstone to cause a
                # situation where an archive is not eligible for garbage
                # collection in a versioned collection, but it is eligible for
                # garbage collection in an unversioned collection.
                self.assertGreaterEqual(len(key_rows), len(unversioned_rows), (
                    len(key_rows),
                    len(unversioned_rows),
                    versioned,
                ))
            else:
                self.assertEqual(len(key_rows), len(unversioned_rows), (
                    len(key_rows),
                    len(unversioned_rows),
                    versioned,
                ))

            key_counts = Counter()
            for row in key_rows:
                key_counts[row["key"]] += 1
                self.assertTrue(row["key"].startswith(_test_prefix))
            for key, value in key_counts.items():
                self.assertEqual(value, 1, (key, value))

            if versioned:
                for key_row, version_row in zip(key_rows,
                                                latest_versioned_rows):
                    self.assertEqual(key_row["key"], version_row["key"])
                    self.assertEqual(key_row["unified_id"],
                                     version_row["unified_id"])
            else:
                for key_row, version_row in zip(key_rows, unversioned_rows):
                    self.assertEqual(key_row["key"], version_row["key"])
                    self.assertEqual(key_row["unified_id"],
                                     version_row["unified_id"])

    #@unittest.skip("isolate test")
    def test_limits_and_markers(self):
        """
        check that the limits and markers work correctly. 
        perhaps take the result with limit=None, and run a series of queries 
        with limit=1 for each of those rows, checking results.
        """
        log = logging.getLogger("test_limits_and_markers")

        for versioned in [True, False]:
            sql_text = list_keys(_test_collection_id,
                                 versioned=versioned,
                                 prefix=_test_prefix)

            args = {
                "collection_id": _test_collection_id,
                "prefix": _test_prefix,
            }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            baseline_rows = cursor.fetchall()
            cursor.close()

            key_marker = None
            for row in baseline_rows:
                sql_text = list_keys(_test_collection_id,
                                     versioned=versioned,
                                     prefix=_test_prefix,
                                     key_marker=key_marker,
                                     limit=1)

                args = {
                    "collection_id": _test_collection_id,
                    "prefix": _test_prefix,
                    "key_marker": key_marker,
                    "limit": 1
                }

                cursor = self._connection.cursor()
                cursor.execute(sql_text, args)
                test_row = cursor.fetchone()
                cursor.close()

                self.assertEqual(test_row["key"], row["key"],
                                 (test_row["key"], row["key"]))
                self.assertEqual(test_row["unified_id"], row["unified_id"],
                                 (test_row["unified_id"], row["unified_id"]))

                key_marker = test_row["key"]

        for versioned in [True, False]:
            sql_text = list_versions(_test_collection_id,
                                     versioned=versioned,
                                     prefix=_test_prefix,
                                     limit=None)

            args = {
                "collection_id": _test_collection_id,
                "prefix": _test_prefix,
            }

            if _write_debug_sql:
                with open("/tmp/debug_all.sql", "w") as debug_sql_file:
                    debug_sql_file.write(mogrify(sql_text, args))

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            baseline_rows = cursor.fetchall()
            cursor.close()
            baseline_set = set([(
                r["key"],
                r["unified_id"],
            ) for r in baseline_rows])
            key_marker = None
            version_marker = None
            for row_idx, row in enumerate(baseline_rows):
                sql_text = list_versions(_test_collection_id,
                                         versioned=versioned,
                                         prefix=_test_prefix,
                                         key_marker=key_marker,
                                         version_marker=version_marker,
                                         limit=1)

                args = {
                    "collection_id": _test_collection_id,
                    "prefix": _test_prefix,
                    "limit": 1
                }

                if key_marker is not None:
                    args["key_marker"] = key_marker
                if version_marker is not None:
                    args["version_marker"] = version_marker

                if _write_debug_sql:
                    debug_filename = "/tmp/debug_%s.sql" % (row_idx, )
                    with open(debug_filename, "w") as debug_sql_file:
                        debug_sql_file.write(mogrify(sql_text, args))

                # this result should always be stable. is it?
                last_time = None
                for _ in range(5):
                    cursor = self._connection.cursor()
                    cursor.execute(sql_text, args)
                    test_row = cursor.fetchone()
                    cursor.close()
                    if last_time is not None:
                        assert test_row == last_time
                    last_time = test_row

                # make sure it's in the result somewhere. below we test if it's
                # in the right order.
                self.assertEqual((test_row["key"], test_row["unified_id"])
                                 in baseline_set, True)

                log.info("{0}, {1}".format(test_row["key"], row["key"]))
                log.debug(sql_text)

                self.assertEqual(
                    test_row["key"], row["key"],
                    (row_idx, versioned, test_row["key"], row["key"]))
                self.assertEqual(test_row["unified_id"], row["unified_id"],
                                 (row_idx, versioned, test_row["unified_id"],
                                  row["unified_id"]))

                key_marker = test_row["key"]
                version_marker = test_row["unified_id"]

    #@unittest.skip("isolate test")
    def test_version_for_key(self):
        """
        version_for_key 
        """
        log = logging.getLogger("test_version_for_key")

        # check that for every row in list_keys, calling version_for_key with
        # unified_id=None should return the same row, regardless of it being
        # versioned or not.
        for versioned in [True, False]:
            sql_text = list_keys(_test_collection_id,
                                 versioned=versioned,
                                 prefix=_test_prefix)

            args = {
                "collection_id": _test_collection_id,
                "prefix": _test_prefix,
            }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            baseline_rows = cursor.fetchall()
            cursor.close()

            for row in baseline_rows:
                sql_text = version_for_key(_test_collection_id,
                                           versioned=versioned,
                                           key=row["key"])

                args = {
                    "collection_id": _test_collection_id,
                    "key": row["key"]
                }

                cursor = self._connection.cursor()
                if _write_debug_sql:
                    with open("/tmp/debug.sql", "w") as debug_sql_file:
                        debug_sql_file.write(mogrify(sql_text, args))
                cursor.execute(sql_text, args)
                test_rows = cursor.fetchall()
                cursor.close()

                # 2012-12-20 dougfort -- list_keys and list_versions only
                # retrieve one conjoined part, but version_for_key retrieves
                # all conjoined parts. So we may have more than one row here.
                self.assertTrue(len(test_rows) > 0)
                for test_row in test_rows:
                    self.assertEqual(test_row["key"], row["key"],
                                     (test_row["key"], row["key"]))
                    self.assertEqual(
                        test_row["unified_id"], row["unified_id"],
                        (test_row["unified_id"], row["unified_id"]))

        # check that these return empty
        for versioned in [True, False]:
            sql_text = version_for_key(_test_collection_id,
                                       versioned=versioned,
                                       key=_test_key,
                                       unified_id=_test_no_such_unified_id)

            args = {
                "collection_id": _test_collection_id,
                "key": row["key"],
                "unified_id": _test_no_such_unified_id
            }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            test_rows = cursor.fetchall()
            cursor.close()
            self.assertEqual(len(test_rows), 0, test_rows)

    #@unittest.skip("isolate test")
    def test_version_for_key_find_all_same_rows(self):
        """
        check that this can find all the same rows list_keys returns
        """
        # XXX: this looks like it tests the same stuff as the previous
        # entry?
        log = logging.getLogger("test_version_for_key_find_all_same_rows")

        sql_text = list_keys(_test_collection_id,
                             versioned=False,
                             prefix=_test_prefix)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_keys_rows = cursor.fetchall()
        cursor.close()

        for list_keys_row in list_keys_rows:
            sql_text = version_for_key(_test_collection_id,
                                       versioned=False,
                                       key=list_keys_row["key"])

            args = {
                "collection_id": _test_collection_id,
                "key": list_keys_row["key"],
            }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            version_for_key_rows = cursor.fetchall()
            cursor.close()

            self.assertTrue(len(version_for_key_rows) > 0)
            for version_for_key_row in version_for_key_rows:
                self.assertEqual(version_for_key_row["key"],
                                 list_keys_row["key"])
                self.assertEqual(version_for_key_row["unified_id"],
                                 list_keys_row["unified_id"])

    #@unittest.skip("isolate test")
    def test_list_versions_same_rows(self):
        """
        check that this can find all the same rows list_versions returns in the
        versioned case above
        """
        log = logging.getLogger("test_list_versions_same_rows")

        sql_text = list_versions(_test_collection_id,
                                 versioned=True,
                                 prefix=_test_prefix)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        with open("/tmp/debug.sql", "w") as debug_sql_file:
            debug_sql_file.write(mogrify(sql_text, args))

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_versions_rows = cursor.fetchall()
        cursor.close()

        for list_versions_row in list_versions_rows:
            sql_text = version_for_key(
                _test_collection_id,
                versioned=True,
                key=list_versions_row["key"],
                unified_id=list_versions_row["unified_id"])

            args = {
                "collection_id": _test_collection_id,
                "key": list_versions_row["key"],
                "unified_id": list_versions_row["unified_id"]
            }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            version_for_key_rows = cursor.fetchall()
            cursor.close()

            self.assertTrue(
                len(version_for_key_rows) > 0,
                "{0} {1}".format(args, list_versions_row))
            for version_for_key_row in version_for_key_rows:
                self.assertEqual(version_for_key_row["key"],
                                 list_versions_row["key"])
                self.assertEqual(version_for_key_row["unified_id"],
                                 list_versions_row["unified_id"],
                                 list_versions_row)

    #@unittest.skip("isolate test")
    def test_list_keys_vs_list_versions(self):
        """ 
        check that this can ONLY find the same rows list_versions returns 
        above IF they are also in the result that list_keys returns
        (i.e. some of them should be findable, some not.)
        """
        # Background: list_keys returns the newest version of every key.
        # list_versions returns every version of every key.
        # If a collection is unversioned, output from list_keys and list_versions
        # should find the same rows
        # (although the output from list_keys has an extra column.)
        # In other words, in a versioned collection, any version of a key
        # that isn't the newest version should be unreachable.
        # So, I was imagining the test to do this:
        # 1. get the full output from list_versions(test_colelction_id, versioned=True) and save it.
        # 2. get the full output from list_keys(test_collection_id, versioned=True) and save it
        # 3. compare the results to determine which keys are older versions
        # 4. For each row, call version_for_key with specific unified_id and versioned arguments
        #    and verify finding (or correctly not finding) the result.
        #
        # Sothe rows that are in the output of list_versions but are NOT in
        # the output of list_keys should be rows that are older versions.
        # (You may have to discard that extra column from list_keys before
        # comparing results.) That's probably worth an assert or two to verify
        # that assumption once you have the lists.
        # Then, if we call version_for_key on those rows that are only in
        # list_versions with versioned=False and specify their unified_id when
        # calling version_for_key, they should not be reachable.
        # With versioned=True they should be reachable.
        # The rows that were in both list_versions output and list_keys output
        # should be reachable either with versioned=True or versioned=False.

        # 1. get the full output from list_versions(test_colelction_id, versioned=True) and save it.
        sql_text = list_versions(_test_collection_id,
                                 versioned=True,
                                 prefix=_test_prefix)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_versions_rows = cursor.fetchall()
        cursor.close()

        list_versions_set = set([(r["key"], r["unified_id"], ) \
                                 for r in list_versions_rows])

        # 2. get the full output from list_keys(test_collection_id, versioned=True) and save it
        sql_text = list_keys(_test_collection_id,
                             versioned=True,
                             prefix=_test_prefix)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_keys_rows = cursor.fetchall()
        cursor.close()

        list_keys_set = set([(r["key"], r["unified_id"], ) \
                                 for r in list_keys_rows])

        # find keys that are only reachable by list_keys when versioned=True
        # we need this below.
        sql_text = list_keys(_test_collection_id,
                             versioned=False,
                             prefix=_test_prefix)

        args = {
            "collection_id": _test_collection_id,
            "prefix": _test_prefix,
        }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        unversioned_list_keys_rows = cursor.fetchall()
        cursor.close()

        unversioned_list_keys_set = set([(r["key"], r["unified_id"], ) \
                                         for r in unversioned_list_keys_rows])

        versioned_only_reachable_set = \
            list_keys_set - unversioned_list_keys_set

        # 3. compare the results to determine which keys are older versions
        older_version_set = list_versions_set - list_keys_set

        # Sothe rows that are in the output of list_versions but are NOT in
        # the output of list_keys should be rows that are older versions.
        # (You may have to discard that extra column from list_keys before
        # comparing results.) That's probably worth an assert or two to verify
        # that assumption once you have the lists.
        for list_versions_row in list_versions_rows:
            test_tuple = (
                list_versions_row["key"],
                list_versions_row["unified_id"],
            )
            self.assertIn(test_tuple, list_versions_set)
            if test_tuple in list_keys_set:
                self.assertNotIn(test_tuple, older_version_set)
            else:
                self.assertIn(test_tuple, older_version_set)

        # 4. For each row, call version_for_key with specific unified_id and versioned arguments
        #    and verify finding (or correctly not finding) the result.

        # Then, if we call version_for_key on those rows that are only in
        # list_versions with versioned=False and specify their unified_id when
        # calling version_for_key, they should not be reachable.
        # With versioned=True they should be reachable.
        for key, unified_id in older_version_set:
            for versioned in [
                    False,
                    True,
            ]:
                sql_text = version_for_key(_test_collection_id,
                                           versioned=versioned,
                                           key=key,
                                           unified_id=unified_id)

                args = {
                    "collection_id": _test_collection_id,
                    "key": key,
                    "unified_id": unified_id
                }

                cursor = self._connection.cursor()
                cursor.execute(sql_text, args)
                test_rows = cursor.fetchall()
                cursor.close()

                if not versioned:
                    self.assertEqual(len(test_rows), 0)
                else:
                    self.assertTrue(len(test_rows) > 0)

        # The rows that were in both list_versions output and list_keys output
        # should be reachable either with versioned=True, but only reachable
        # with versioned=False if they are not in versioned_only_reachable_set.
        for key, unified_id in list_keys_set:
            for versioned in [
                    False,
                    True,
            ]:
                sql_text = version_for_key(_test_collection_id,
                                           versioned=versioned,
                                           key=key,
                                           unified_id=unified_id)

                args = {
                    "collection_id": _test_collection_id,
                    "key": key,
                    "unified_id": unified_id
                }

                cursor = self._connection.cursor()
                cursor.execute(sql_text, args)
                test_rows = cursor.fetchall()
                cursor.close()

                if (versioned is False and (
                        key,
                        unified_id,
                ) in versioned_only_reachable_set):
                    self.assertTrue(
                        len(test_rows) == 0,
                        "versioned={0} {1}".format(versioned, args))
                else:
                    self.assertTrue(
                        len(test_rows) > 0,
                        "versioned={0} {1}".format(versioned, args))
コード例 #19
0
import psycopg2.extensions
from psycopg2.extras import RealDictConnection

psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)

conn_string = "host=%s port=%s dbname=%s user=%s password=%s" % (
    "localhost4", "5432", "fahstats", "kakaostats", "aroeira")

connection = RealDictConnection(conn_string)
connection.set_client_encoding('latin1')
cursor = connection.cursor()

query = """\
    set session work_mem = 2097152
    ;
    begin transaction isolation level serializable
    ;
    create table donor_work_old (like donors_old including constraints including defaults)
    ;
    alter table donors_old rename to donors_old_back
;"""

cursor.execute(query)

query = """\
    select distinct data
    from donors_old_back
;"""

cursor.execute(query)
rs = cursor.fetchall()
コード例 #20
0
ファイル: db.py プロジェクト: 123joshuawu/orca
def fetch_semesters(conn: RealDictConnection) -> List[Semester]:
    c = conn.cursor()
    c.execute("SELECT * FROM semesters ORDER BY semester_id")
    records = c.fetchall()
    return list(map(Semester.from_record, records))
コード例 #21
0
class TestSegmentVisibility(unittest.TestCase):
    """
    test segment visibility subsystem
    """
    def setUp(self):
        log = logging.getLogger("setUp")
        log.debug("installing schema and test data")
        _install_schema_and_test_data()
        log.debug("creating database connection")
        self._connection = RealDictConnection(
            get_node_database_dsn(_node_name, 
                                  _database_password, 
                                  _database_host, 
                                  _database_port))
        log.debug("setup done")

    def tearDown(self):
        log = logging.getLogger("tearDown")        
        log.debug("teardown starts")
        if hasattr(self, "_connection"):
            self._connection.close()
            delattr(self, "_connection")
        log.debug("teardown done")

    def _retrieve_collectables(self, versioned):
        """
        check that none of these rows appear in any other result.
        check that the rows from other results are not included here.
        """
        sql_text = collectable_archive(_test_collection_id, 
                                       versioned=versioned)

        args = {"collection_id" : _test_collection_id,
                }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        rows = cursor.fetchall()
        cursor.close()

        # there should always be some garbage. If there's not, something's
        # wrong.
        self.assertGreater(len(rows), 0)

        return set([(r["key"], r["unified_id"], ) for r in rows])

    #@unittest.skip("isolate test")
    def test_no_such_collectable(self):
        """
        test that there are no collectable rows for a bogus unified_id
        """
        log = logging.getLogger("test_no_such_collectable")

        sql_text = collectable_archive(_test_collection_id, 
                                       versioned=True, 
                                       key=_test_key, 
                                       unified_id=_test_no_such_unified_id)

        args = {"collection_id" : _test_collection_id,
                "key"           : _test_key,
                "unified_id"    : _test_no_such_unified_id}

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        rows = cursor.fetchall()
        cursor.close()
        self.assertEqual(len(rows), 0, rows)

    #@unittest.skip("isolate test")
    def test_list(self):
        """
        test listing keys and versions of keys
        """
        log = logging.getLogger("test_list")

        versioned = False
        sql_text = list_versions(_test_collection_id, 
                                 versioned=versioned, 
                                 prefix=_test_prefix, 
                                 limit=None)

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        if _write_debug_sql:
            with open("/tmp/debug_unversioned_rows.sql", "w") as debug_sql_file:
                debug_sql_file.write(mogrify(sql_text, args))

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        unversioned_rows = cursor.fetchall()
        cursor.close()

        collectable_set = self._retrieve_collectables(versioned)
        test_set = set([(r["key"], r["unified_id"], ) for r in unversioned_rows])
        collectable_intersection = test_set & collectable_set
        self.assertEqual(len(collectable_intersection), 0, 
                         collectable_intersection)

        # check that there's no more than one row per key for a non-versioned 
        # collection
        # check that every row begins with prefix
        unversioned_key_counts = Counter()
        for row in unversioned_rows:
            unversioned_key_counts[row["key"]] += 1
            self.assertTrue(row["key"].startswith(_test_prefix))
        for key, value in unversioned_key_counts.items():
            self.assertEqual(value, 1, (key, value))

        versioned = True
        sql_text = list_versions(_test_collection_id, 
                                 versioned=versioned, 
                                 prefix=_test_prefix)

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        if _write_debug_sql:
            with open("/tmp/debug_versioned_rows.sql", "w") as debug_sql_file:
                debug_sql_file.write(mogrify(sql_text, args))

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        versioned_rows = cursor.fetchall()
        cursor.close()

        collectable_set = self._retrieve_collectables(versioned)
        test_set = set([(r["key"], r["unified_id"], ) for r in versioned_rows])
        collectable_intersection = test_set & collectable_set
        self.assertEqual(len(collectable_intersection), 0, 
                         collectable_intersection)
        
        latest_versioned_rows = OrderedDict()
        for row in versioned_rows[::-1]:
            latest_versioned_rows.setdefault(row["key"], row)
        latest_versioned_rows = latest_versioned_rows.values()
        latest_versioned_rows.reverse()
        assert len(latest_versioned_rows) <= len(versioned_rows)

        versioned_key_counts = Counter()
        for row in versioned_rows:
            versioned_key_counts[row["key"]] += 1
            self.assertTrue(row["key"].startswith(_test_prefix))

        # check that there's >= as many rows now as above.
        for key, value in versioned_key_counts.items():
            self.assertTrue(value >= unversioned_key_counts[key], (key, value))

        # check that the list keys result is consistent with list_versions in
        # above (although there could be extra columns.)  Note that
        # list_keys(versioned=True) may have records that
        # list_versions(versioned=False) does not have, because there are more
        # ways for a segment to become eligible for garbage collection in an
        # unversioned collection.

        for versioned in [False, True, ]:
            sql_text = list_keys(_test_collection_id, 
                                 versioned=versioned, 
                                 prefix=_test_prefix,
                                 limit=None)

            args = {"collection_id" : _test_collection_id,
                    "prefix"        : _test_prefix, }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            key_rows = cursor.fetchall()
            cursor.close()

            if _write_debug_sql:
                debug_filename = "/tmp/debug_key_rows_versioned_%r.sql" % ( versioned, )
                with open(debug_filename, "w") as debug_sql_file:
                    debug_sql_file.write(mogrify(sql_text, args))

            collectable_set = self._retrieve_collectables(versioned)
            test_set = set([(r["key"], r["unified_id"], ) for r in key_rows])
            collectable_intersection = test_set & collectable_set
            self.assertEqual(len(collectable_intersection), 0, 
                             collectable_intersection)

            if versioned:
                # a list of keys with versioning on may have keys that don't
                # show up in the list of unversioned rows.  That's because in
                # an unversioned collection, keys end when another key is
                # added.  So it's possible for that plus a tombstone to cause a
                # situation where an archive is not eligible for garbage
                # collection in a versioned collection, but it is eligible for
                # garbage collection in an unversioned collection.
                self.assertGreaterEqual(len(key_rows), len(unversioned_rows), 
                    (len(key_rows), len(unversioned_rows), versioned, ))
            else:
                self.assertEqual(len(key_rows), len(unversioned_rows), 
                    (len(key_rows), len(unversioned_rows), versioned, ))

            key_counts = Counter()
            for row in key_rows:
                key_counts[row["key"]] += 1
                self.assertTrue(row["key"].startswith(_test_prefix))
            for key, value in key_counts.items():
                self.assertEqual(value, 1, (key, value))

            if versioned:
                for key_row, version_row in zip(key_rows, latest_versioned_rows):
                    self.assertEqual(key_row["key"], version_row["key"])
                    self.assertEqual(key_row["unified_id"], version_row["unified_id"])
            else:
                for key_row, version_row in zip(key_rows, unversioned_rows):
                    self.assertEqual(key_row["key"], version_row["key"])
                    self.assertEqual(key_row["unified_id"], version_row["unified_id"])

    #@unittest.skip("isolate test")
    def test_limits_and_markers(self):
        """
        check that the limits and markers work correctly. 
        perhaps take the result with limit=None, and run a series of queries 
        with limit=1 for each of those rows, checking results.
        """
        log = logging.getLogger("test_limits_and_markers")

        for versioned in [True, False]:
            sql_text = list_keys(_test_collection_id, 
                                 versioned=versioned, 
                                 prefix=_test_prefix)

            args = {"collection_id" : _test_collection_id,
                    "prefix"        : _test_prefix, }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            baseline_rows = cursor.fetchall()
            cursor.close()

            key_marker = None
            for row in baseline_rows:
                sql_text = list_keys(_test_collection_id, 
                                     versioned=versioned, 
                                     prefix=_test_prefix,
                                     key_marker=key_marker,
                                     limit=1)

                args = {"collection_id" : _test_collection_id,
                        "prefix"        : _test_prefix, 
                        "key_marker"    : key_marker,
                        "limit"         : 1}

                cursor = self._connection.cursor()
                cursor.execute(sql_text, args)
                test_row = cursor.fetchone()
                cursor.close()
                
                self.assertEqual(test_row["key"], row["key"], 
                                 (test_row["key"], row["key"]))
                self.assertEqual(test_row["unified_id"], row["unified_id"], 
                                 (test_row["unified_id"], row["unified_id"]))

                key_marker = test_row["key"]

        for versioned in [True, False]:
            sql_text = list_versions(_test_collection_id, 
                                 versioned=versioned, 
                                 prefix=_test_prefix,
                                 limit=None)

            args = {"collection_id" : _test_collection_id,
                    "prefix"        : _test_prefix, }

            if _write_debug_sql:
                with open("/tmp/debug_all.sql", "w") as debug_sql_file:
                    debug_sql_file.write(mogrify(sql_text, args))

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            baseline_rows = cursor.fetchall()
            cursor.close()
            baseline_set = set([(r["key"], r["unified_id"], ) 
                                for r in baseline_rows])
            key_marker = None
            version_marker = None
            for row_idx, row in enumerate(baseline_rows):
                sql_text = list_versions(_test_collection_id, 
                                     versioned=versioned, 
                                     prefix=_test_prefix,
                                     key_marker=key_marker,
                                     version_marker=version_marker,
                                     limit=1)

                args = {"collection_id" : _test_collection_id,
                        "prefix"        : _test_prefix, 
                        "limit"         : 1}

                if key_marker is not None:
                    args["key_marker"] = key_marker
                if version_marker is not None:
                    args["version_marker"] = version_marker

                if _write_debug_sql:
                    debug_filename = "/tmp/debug_%s.sql" % (row_idx, )
                    with open(debug_filename, "w") as debug_sql_file:
                        debug_sql_file.write(mogrify(sql_text, args))

                # this result should always be stable. is it?
                last_time = None
                for _ in range(5):
                    cursor = self._connection.cursor()
                    cursor.execute(sql_text, args)
                    test_row = cursor.fetchone()
                    cursor.close()
                    if last_time is not None:
                        assert test_row == last_time
                    last_time = test_row

                # make sure it's in the result somewhere. below we test if it's
                # in the right order.
                self.assertEqual(
                    (test_row["key"], test_row["unified_id"]) in baseline_set,
                    True)
                
                log.info("{0}, {1}".format(test_row["key"], row["key"]))
                log.debug(sql_text)

                self.assertEqual(test_row["key"], row["key"], 
                                 (row_idx, versioned, test_row["key"], row["key"]))
                self.assertEqual(test_row["unified_id"], row["unified_id"], 
                                 (row_idx, versioned, test_row["unified_id"], row["unified_id"]))

                key_marker = test_row["key"]
                version_marker = test_row["unified_id"]

    #@unittest.skip("isolate test")
    def test_version_for_key(self):
        """
        version_for_key 
        """
        log = logging.getLogger("test_version_for_key")

        # check that for every row in list_keys, calling version_for_key with
        # unified_id=None should return the same row, regardless of it being 
        # versioned or not.
        for versioned in [True, False]:
            sql_text = list_keys(_test_collection_id, 
                                 versioned=versioned, 
                                 prefix=_test_prefix)

            args = {"collection_id" : _test_collection_id,
                    "prefix"        : _test_prefix, }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            baseline_rows = cursor.fetchall()
            cursor.close()

            for row in baseline_rows:
                sql_text = version_for_key(_test_collection_id, 
                                           versioned=versioned, 
                                           key=row["key"])

                args = {"collection_id" : _test_collection_id,
                        "key"           : row["key"]} 

                cursor = self._connection.cursor()
                if _write_debug_sql:
                    with open("/tmp/debug.sql", "w") as debug_sql_file:
                        debug_sql_file.write(mogrify(sql_text, args))
                cursor.execute(sql_text, args)
                test_rows = cursor.fetchall()
                cursor.close()

                # 2012-12-20 dougfort -- list_keys and list_versions only
                # retrieve one conjoined part, but version_for_key retrieves
                # all conjoined parts. So we may have more than one row here.
                self.assertTrue(len(test_rows) > 0) 
                for test_row in test_rows:
                    self.assertEqual(test_row["key"], row["key"], 
                                     (test_row["key"], row["key"]))
                    self.assertEqual(test_row["unified_id"], row["unified_id"], 
                                     (test_row["unified_id"], row["unified_id"]))

        # check that these return empty
        for versioned in [True, False]:
            sql_text = version_for_key(_test_collection_id, 
                                       versioned=versioned, 
                                       key=_test_key,
                                       unified_id=_test_no_such_unified_id)

            args = {"collection_id" : _test_collection_id,
                    "key"           : row["key"],
                    "unified_id"    : _test_no_such_unified_id} 

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            test_rows = cursor.fetchall()
            cursor.close()
            self.assertEqual(len(test_rows), 0, test_rows)

    #@unittest.skip("isolate test")
    def test_version_for_key_find_all_same_rows(self):
        """
        check that this can find all the same rows list_keys returns
        """
        # XXX: this looks like it tests the same stuff as the previous
        # entry?
        log = logging.getLogger("test_version_for_key_find_all_same_rows")

        sql_text = list_keys(_test_collection_id, 
                             versioned=False, 
                             prefix=_test_prefix)

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_keys_rows = cursor.fetchall()
        cursor.close()

        for list_keys_row in list_keys_rows:
            sql_text = version_for_key(_test_collection_id, 
                                       versioned=False, 
                                       key=list_keys_row["key"])

            args = {"collection_id" : _test_collection_id,
                    "key"           : list_keys_row["key"], }

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            version_for_key_rows = cursor.fetchall()
            cursor.close()

            self.assertTrue(len(version_for_key_rows) > 0)
            for version_for_key_row in version_for_key_rows:
                self.assertEqual(version_for_key_row["key"],
                                 list_keys_row["key"])
                self.assertEqual(version_for_key_row["unified_id"],
                                 list_keys_row["unified_id"])

    #@unittest.skip("isolate test")
    def test_list_versions_same_rows(self):
        """
        check that this can find all the same rows list_versions returns in the
        versioned case above
        """
        log = logging.getLogger("test_list_versions_same_rows")

        sql_text = list_versions(_test_collection_id, 
                                 versioned=True, 
                                 prefix=_test_prefix) 

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        with open("/tmp/debug.sql", "w") as debug_sql_file:
            debug_sql_file.write(mogrify(sql_text, args))

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_versions_rows = cursor.fetchall()
        cursor.close()

        for list_versions_row in list_versions_rows:
            sql_text = version_for_key(_test_collection_id, 
                                       versioned=True,
                                       key=list_versions_row["key"], 
                                       unified_id=list_versions_row["unified_id"]) 

            args = {"collection_id" : _test_collection_id,
                    "key"           : list_versions_row["key"], 
                    "unified_id"    : list_versions_row["unified_id"]}

            cursor = self._connection.cursor()
            cursor.execute(sql_text, args)
            version_for_key_rows = cursor.fetchall()
            cursor.close()

            self.assertTrue(len(version_for_key_rows) > 0, 
                            "{0} {1}".format(args, list_versions_row))
            for version_for_key_row in version_for_key_rows:
                self.assertEqual(version_for_key_row["key"],
                                 list_versions_row["key"])
                self.assertEqual(version_for_key_row["unified_id"],
                                 list_versions_row["unified_id"],
                                 list_versions_row)

    #@unittest.skip("isolate test")
    def test_list_keys_vs_list_versions(self):
        """ 
        check that this can ONLY find the same rows list_versions returns 
        above IF they are also in the result that list_keys returns
        (i.e. some of them should be findable, some not.)
        """
        # Background: list_keys returns the newest version of every key. 
        # list_versions returns every version of every key. 
        # If a collection is unversioned, output from list_keys and list_versions 
        # should find the same rows 
        # (although the output from list_keys has an extra column.) 
        # In other words, in a versioned collection, any version of a key 
        # that isn't the newest version should be unreachable.
        # So, I was imagining the test to do this:
        # 1. get the full output from list_versions(test_colelction_id, versioned=True) and save it.
        # 2. get the full output from list_keys(test_collection_id, versioned=True) and save it
        # 3. compare the results to determine which keys are older versions
        # 4. For each row, call version_for_key with specific unified_id and versioned arguments 
        #    and verify finding (or correctly not finding) the result.
        #
        # Sothe rows that are in the output of list_versions but are NOT in 
        # the output of list_keys should be rows that are older versions. 
        # (You may have to discard that extra column from list_keys before 
        # comparing results.) That's probably worth an assert or two to verify 
        # that assumption once you have the lists.
        # Then, if we call version_for_key on those rows that are only in 
        # list_versions with versioned=False and specify their unified_id when 
        # calling version_for_key, they should not be reachable. 
        # With versioned=True they should be reachable.
        # The rows that were in both list_versions output and list_keys output 
        # should be reachable either with versioned=True or versioned=False.

        # 1. get the full output from list_versions(test_colelction_id, versioned=True) and save it.
        sql_text = list_versions(_test_collection_id, 
                                 versioned=True, 
                                 prefix=_test_prefix)

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_versions_rows = cursor.fetchall()
        cursor.close()

        list_versions_set = set([(r["key"], r["unified_id"], ) \
                                 for r in list_versions_rows])

        # 2. get the full output from list_keys(test_collection_id, versioned=True) and save it
        sql_text = list_keys(_test_collection_id, 
                             versioned=True, 
                             prefix=_test_prefix)

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        list_keys_rows = cursor.fetchall()
        cursor.close()

        list_keys_set = set([(r["key"], r["unified_id"], ) \
                                 for r in list_keys_rows])

        # find keys that are only reachable by list_keys when versioned=True
        # we need this below.
        sql_text = list_keys(_test_collection_id, 
                             versioned=False, 
                             prefix=_test_prefix)

        args = {"collection_id" : _test_collection_id,
                "prefix"        : _test_prefix, }

        cursor = self._connection.cursor()
        cursor.execute(sql_text, args)
        unversioned_list_keys_rows = cursor.fetchall()
        cursor.close()

        unversioned_list_keys_set = set([(r["key"], r["unified_id"], ) \
                                         for r in unversioned_list_keys_rows])

        versioned_only_reachable_set = \
            list_keys_set - unversioned_list_keys_set


        # 3. compare the results to determine which keys are older versions
        older_version_set = list_versions_set - list_keys_set

        # Sothe rows that are in the output of list_versions but are NOT in 
        # the output of list_keys should be rows that are older versions. 
        # (You may have to discard that extra column from list_keys before 
        # comparing results.) That's probably worth an assert or two to verify 
        # that assumption once you have the lists.
        for list_versions_row in list_versions_rows:
            test_tuple = (list_versions_row["key"], 
                          list_versions_row["unified_id"], )
            self.assertIn(test_tuple, list_versions_set)
            if test_tuple in list_keys_set:
                self.assertNotIn(test_tuple, older_version_set)
            else:
                self.assertIn(test_tuple, older_version_set)

        # 4. For each row, call version_for_key with specific unified_id and versioned arguments 
        #    and verify finding (or correctly not finding) the result.

        # Then, if we call version_for_key on those rows that are only in 
        # list_versions with versioned=False and specify their unified_id when 
        # calling version_for_key, they should not be reachable. 
        # With versioned=True they should be reachable.
        for key, unified_id in older_version_set:
            for versioned in [False, True, ]:
                sql_text = version_for_key(_test_collection_id, 
                                           versioned=versioned, 
                                           key=key,
                                           unified_id=unified_id)

                args = {"collection_id" : _test_collection_id,
                        "key"           : key,
                        "unified_id"    : unified_id} 

                cursor = self._connection.cursor()
                cursor.execute(sql_text, args)
                test_rows = cursor.fetchall()
                cursor.close()

                if not versioned:
                    self.assertEqual(len(test_rows), 0)
                else:
                    self.assertTrue(len(test_rows) > 0)

        # The rows that were in both list_versions output and list_keys output
        # should be reachable either with versioned=True, but only reachable
        # with versioned=False if they are not in versioned_only_reachable_set.
        for key, unified_id in list_keys_set:
            for versioned in [False, True, ]:
                sql_text = version_for_key(_test_collection_id, 
                                           versioned=versioned, 
                                           key=key,
                                           unified_id=unified_id)

                args = {"collection_id" : _test_collection_id,
                        "key"           : key,
                        "unified_id"    : unified_id} 

                cursor = self._connection.cursor()
                cursor.execute(sql_text, args)
                test_rows = cursor.fetchall()
                cursor.close()

                if (versioned is False
                    and (key, unified_id, ) in versioned_only_reachable_set
                ):
                    self.assertTrue(len(test_rows) == 0,
                                "versioned={0} {1}".format(versioned, args))
                else:
                    self.assertTrue(len(test_rows) > 0, 
                                "versioned={0} {1}".format(versioned, args))
コード例 #22
0
ファイル: proto.py プロジェクト: jMyles/records
 def __init__(self, conn_str):
     self.db = RealDictConnection(conn_str)
     register_hstore(self.db)
コード例 #23
0
ファイル: update_feeds.py プロジェクト: bchandos/rss_feed
from datetime import datetime, timedelta
import re
from urllib.request import urlopen, Request
from psycopg2.extras import RealDictConnection
import os
import time
import xml.etree.ElementTree as ET

from dateutil.parser import parse

db = RealDictConnection(os.environ['DATABASE_URL'])

if os.environ['FLASK_ENV'] == 'development':
    WAIT_MINUTES = int(os.environ.get('WAIT_MINUTES', 2))
    WAIT_SECONDS = 6
else:
    WAIT_MINUTES = int(os.environ.get('WAIT_MINUTES', 15))
    WAIT_SECONDS = 60


def download_feed(feed_url):
    try:
        req = Request(feed_url, headers={'User-Agent': 'Mozilla/5.0'})
        f = urlopen(req)
    except:
        return None

    with f:
        if f.getcode() == 200 and 'xml' in f.getheader('Content-Type'):
            return ET.fromstring(f.read())
コード例 #24
0
import sys
sys.path.append("/fahstats/scripts/python")
from setup import connStr
from psycopg2.extras import RealDictConnection


def fetchsome(cursor, arraysize=50000):
    while True:
        results = cursor.fetchmany(arraysize)
        if not results: break
        for result in results:
            yield result


db = RealDictConnection(connStr["backend"])
cursor = db.cursor()

query = """
select d1_serial as d0, d0_serial as d1
from select_two_serial_dates_x_days_last_batch(50) as tsd
inner join datas as d on d.data_serial = tsd.d1_serial
where d.data >= (
  select max(data) from datas where have_data)
  - '8 weeks'::interval
;"""
cursor.execute(query)
serial_dates = cursor.fetchall()

query = """
select distinct serial_date
from team_active_members_history