def check_crosstable_dotprod(self, other_table, col1, join1, col2, join2=None, constraint={}): """ Check that col1 is the sum of the product of the values in the columns of col2 over rows of other_table with self.table.join1 = other_table.join2. There are some peculiarities of this method, resulting from its application to mf_subspaces. col1 is allowed to be a pair, in which case the difference col1[0] - col1[1] will be compared. col2 does not take value col1 as a default, since they are playing different roles. """ if isinstance(col1, list): if len(col1) != 2: raise ValueError("col1 must have length 2") col1 = SQL("t1.{0} - t1.{1}").format(Identifier(col1[0]), Identifier(col1[1])) dotprod = SQL("SUM({0})").format( SQL(" * ").join( SQL("t2.{0}").format(Identifier(col)) for col in col2)) return self._run_crosstable(dotprod, other_table, col1, join1, join2, constraint)
def check_string_concatenation(self, label_col, other_columns, constraint={}, sep='.', convert_to_base26={}): """ Check that the label_column is the concatenation of the other columns with the given separator Input: - ``label_col`` -- the label_column - ``other_columns`` -- the other columns from which we can deduce the label - ``constraint`` -- a dictionary, as passed to the search method - ``sep`` -- the separator for the join - ``convert_to_base26`` -- a dictionary where the keys are columns that we need to convert to base26, and the values is that the shift that we need to apply """ oc_converted = [ SQL('to_base26({0} + {1})').format( Identifier(col), Literal(int(convert_to_base26[col]))) if col in convert_to_base26 else Identifier(col) for col in other_columns ] #intertwine the separator oc = [ oc_converted[i // 2] if i % 2 == 0 else Literal(sep) for i in range(2 * len(oc_converted) - 1) ] return self._run_query( SQL(" != ").join([SQL(" || ").join(oc), Identifier(label_col)]), constraint)
def check_crosstable_aggregate(self, other_table, col1, join1, col2=None, join2=None, sort=None, truncate=None, constraint={}): """ Check that col1 is the sorted array of values in col2 where join1 = join2 Here col2 and join2 default to col1 and join1, and join1 and join2 are allowed to be lists of columns sort defaults to col2, but can be a list of columns in other_table """ if col2 is None: col2 = col1 if truncate is not None: col1 = SQL("t1.{0}[:%s]" % (int(truncate))).format( Identifier(col1)) if sort is None: sort = SQL(" ORDER BY t2.{0}").format(Identifier(col2)) else: sort = SQL(" ORDER BY {0}").format( SQL(", ").join( SQL("t2.{0}").format(Identifier(col)) for col in sort)) return self._run_crosstable(col2, other_table, col1, join1, join2, constraint, subselect_wrapper="ARRAY", extra=sort)
def check_array_product(self, array_column, value_column, constraint={}): """ Checks that prod(array_column) == value_column """ return self._run_query( SQL("(SELECT PROD(s) FROM UNNEST({0}) s) != {1}").format( Identifier(array_column), Identifier(value_column)), constraint)
def check_letter_code(self, index_column, letter_code_column, constraint={}): return self._run_query( SQL("{0} != to_base26({1} - 1)").format( Identifier(letter_code_column), Identifier(index_column)), constraint)
def _run_query(self, condition=None, constraint={}, values=None, table=None, query=None, ratio=1): """ Run a query to check a condition. The number of returned failures will be limited by the ``_cur_limit`` attribute of this ``TableChecker``. If ``_cur_label`` is set, only that label will be checked. INPUT: - ``condition`` -- an SQL object giving a condition on the search table - ``constraint`` -- a dictionary, as passed to the search method, or an SQL object - ``values`` -- a list of values to fill in for ``%s`` in the condition. - ``table`` -- an SQL object or string giving the table to execute this query on. Defaults to the table for this TableChecker. - ``query`` -- an SQL object giving the whole query, leaving out only the ``_cur_label`` and ``_cur_limit`` parts. Note that ``condition``, ``constraint``, ``table`` and ``ratio`` will be ignored if query is provided. - ``ratio`` -- the ratio of rows in the table to run this query on. """ if values is None: values = [] label_col = Identifier(self.label_col) if query is None: if table is None: table = self.table.search_table if isinstance(table, string_types): if ratio == 1: table = Identifier(table) else: table = SQL("{0} TABLESAMPLE SYSTEM({1})").format( Identifier(table), Literal(ratio)) # WARNING: the following is not safe from SQL injection, so be careful if you copy this code query = SQL("SELECT {0} FROM {1} WHERE {2}").format( label_col, table, condition) if not isinstance(constraint, Composable): constraint, cvalues = self.table._parse_dict(constraint) if constraint is not None: values = values + cvalues if constraint is not None: query = SQL("{0} AND {1}").format(query, constraint) if self._cur_label is not None: query = SQL("{0} AND {1} = %s").format(query, label_col) values += [self._cur_label] query = SQL("{0} LIMIT %s").format(query) cur = db._execute(query, values + [self._cur_limit]) return [rec[0] for rec in cur]
def check_array_len_col(self, array_column, len_column, constraint={}, shift=0, array_dim=1): """ Length of array_column matches len_column """ return self._run_query( SQL("array_length({0}, {3}) != {1} + {2}").format( Identifier(array_column), Identifier(len_column), Literal(int(shift)), Literal(array_dim), ), constraint)
def check_array_len_gte_constant(self, column, limit, constraint={}): """ Length of array greater than or equal to limit """ return self._run_query( SQL("array_length({0}, 1) < %s").format(Identifier(column)), constraint, [limit])
def check_an_length(self): """ check that an_normalized is a list of pairs of doubles of length at least 1000 """ # TIME > 3600s return self._run_query(SQL("array_length({0}, 1) < 1000 OR array_length({0}, 2) != 2").format( Identifier("an_normalized")))
def _make_sql(self, s, tablename=None): """ Create an SQL Composable object out of s. INPUT: - ``s`` -- a string, integer or Composable object - ``tablename`` -- a tablename prepended to the resulting object if ``s`` is a string """ if isinstance(s, integer_types): return Literal(s) elif isinstance(s, Composable): return s elif tablename is None: return Identifier(s) else: return SQL(tablename + ".") + Identifier(s)
def _run_crosstable(self, quantity, other_table, col, join1, join2=None, constraint={}, values=[], subselect_wrapper="", extra=None): """ Checks that `quantity` matches col INPUT: - ``quantity`` -- a column name or an SQL object giving some quantity from the ``other_table`` - ``other_table`` -- the name of the other table - ``col`` -- an integer or the name of column to check against ``quantity`` - ``join1`` -- a column or list of columns on self on which we will join the two tables - ``join2`` -- a column or list of columns (default: `None`) on ``other_table`` on which we will join the two tables. If `None`, we take ``join2`` = ``join1``, see `_make_join` - ``constraint`` -- a dictionary, as passed to the search method - ``subselect_wrapper`` -- a string, e.g., "ARRAY" to convert the inner select query - ``extra`` -- SQL object to append to the subquery. This can hold additional constraints or set the sort order for the inner select query """ # WARNING: since it uses _run_query, this whole function is not safe against SQL injection, # so should only be run locally in data validation join = self._make_join(join1, join2) col = self._make_sql(col, "t1") if isinstance(quantity, basestring): quantity = SQL("t2.{0}").format(Identifier(quantity)) # This is unsafe subselect_wrapper = SQL(subselect_wrapper) if extra is None: extra = SQL("") condition = SQL( "{0} != {1}(SELECT {2} FROM {3} t2 WHERE {4}{5})").format( col, subselect_wrapper, quantity, Identifier(other_table), join, extra) return self._run_query(condition, constraint, values, table=SQL("{0} t1").format( Identifier(self.table.search_table)))
def check_array_bound(self, array_column, bound, constraint={}, upper=True): """ Check that all entries in the array are <= bound (or >= if upper is False) """ op = '>=' if upper else '<=' return self._run_query(SQL("NOT ({0} %s ALL({1}))" % op).format( Literal(bound), Identifier(array_column)), constraint=constraint)
def check_array_len_eq_constant(self, column, limit, constraint={}, array_dim=1): """ Length of array equal to constant """ return self._run_query( SQL("array_length({0}, {1}) != {2}").format( Identifier(column), Literal(int(array_dim)), Literal(int(limit))), constraint)
def check_crosstable_sum(self, other_table, col1, join1, col2=None, join2=None, constraint={}): """ Check that col1 is the sum of the values in col2 where join1 = join2 Here col2 and join2 default to col1 and join1, and join1 and join2 are allowed to be lists of columns """ if col2 is None: col2 = col1 sum2 = SQL("SUM(t2.{0})").format(Identifier(col2)) return self._run_crosstable(sum2, other_table, col1, join1, join2, constraint)
def check_roots(self): """ check that embedding_root_real, and embedding_root_image present in mf_hecke_cc whenever field_poly is present """ # TIME > 240s # I didn't manage to write a generic one for this one join = self._make_join('hecke_orbit_code', None) query = SQL( "SELECT t1.{0} FROM {1} t1, {2} t2 WHERE {3} AND t2.{4} is NULL AND t2.{5} is NULL AND t1.{6} IS NOT NULL" ).format(Identifier(self.table._label_col), Identifier(self.table.search_table), Identifier('mf_hecke_cc'), join, Identifier("embedding_root_real"), Identifier("embedding_root_imag"), Identifier("field_poly")) return self._run_query(query=query)
def check_sub_mul_positive(self): """ sub_mult is positive """ return self._run_query(SQL("{0} <= 0").format(Identifier('sub_mult')))
def check_label_hoc(self): """ check that label is consistent with hecke_orbit_code """ return self._run_query(SQL("{0} != from_newform_label_to_hecke_orbit_code({1})").format(Identifier('hecke_orbit_code'), Identifier('label')))
def check_label_conrey(self): """ check that label is consistent with conrey_lebel, embedding_index """ # TIME about 230s return self._run_query(SQL("(string_to_array({0},'.'))[5:6] != array[{1}::text,{2}::text]").format(Identifier('label'), Identifier('conrey_index'), Identifier('embedding_index')))
def check_eq(self, col1, col2, constraint={}): return self._run_query( SQL("{0} != {1}").format(Identifier(col1), Identifier(col2)), constraint)
def check_sorted(self, column): return self._run_query( SQL("{0} != sort({0})").format(Identifier(column)))
def check_string_startswith(self, col, head, constraint={}): value = head.replace('_', r'\_').replace('%', r'\%') + '%' return self._run_query(SQL("NOT ({0} LIKE %s)").format( Identifier(col)), constraint=constraint, values=[value])