class Select(RelExpr): def __init__(self, cond, input): assert isinstance(cond, ValExpr) assert isinstance(input, RelExpr) super(Select, self).__init__([input]) self.cond = cond def __str__(self): return literal(sym.SELECT) + literal(sym.ARG_L) + str(self.cond) + literal(sym.ARG_R) +\ ' ' + str(paren(self.inputs[0])) def validateSubtree(self, context: StatementContext): self.inputs[0].validateSubtree(context) self.cond.validateSubtree(context, self) if self.cond.type != ValType.BOOLEAN: raise self.validationError('selection condition {} has type {}; boolean expected'\ .format(self.cond, self.cond.type.value)) self.type = RelType(context.new_tmp(), self.inputs[0].type.attrs) def info(self): yield '{} => {}'.format(symbolic(sym.SELECT), self.type) yield '|' + self.cond.info() for i, line in enumerate(self.inputs[0].info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): for block in self.inputs[0].sql(): yield block yield '{}({}) AS (SELECT * FROM {} WHERE {})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), self.inputs[0].type.sql_rel(), self.cond.sql(self))
class Cross(RelExpr): def __init__(self, left, right): assert isinstance(left, RelExpr) assert isinstance(right, RelExpr) super(Cross, self).__init__([left, right]) def __str__(self): return '{} {} {}'.format(paren(self.inputs[0]), literal(sym.CROSS), paren(self.inputs[1])) def validateSubtree(self, context: StatementContext): self.inputs[0].validateSubtree(context) self.inputs[1].validateSubtree(context) for a0, a1 in itertools.product(self.inputs[0].type.attrs, self.inputs[1].type.attrs): if a0.can_be_confused_with(a1): logger.warning('{}: attributes {} from the left input and {} from the right' ' become confused in the cross product output'\ .format(self, a0.str_ref_only(), a1.str_ref_only())) self.type = RelType(context.new_tmp(), self.inputs[0].type.attrs + self.inputs[1].type.attrs) def info(self): yield '{} => {}'.format(symbolic(sym.CROSS), self.type) for i, line in enumerate(self.inputs[0].info()): yield ('\\_' if i == 0 else '| ') + line for i, line in enumerate(self.inputs[1].info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): for block in self.inputs[0].sql(): yield block for block in self.inputs[1].sql(): yield block yield '{}({}) AS (SELECT * FROM {}, {})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), self.inputs[0].type.sql_rel(), self.inputs[1].type.sql_rel())
class Project(RelExpr): def __init__(self, attrs, input): assert isinstance(attrs, list) assert all(isinstance(attr, ValExpr) for attr in attrs) assert isinstance(input, RelExpr) super(Project, self).__init__([input]) self.attrs = attrs def __str__(self): return literal(sym.PROJECT) + literal(sym.ARG_L) +\ ', '.join(str(attr) for attr in self.attrs) +\ literal(sym.ARG_R) + ' ' + str(paren(self.inputs[0])) def validateSubtree(self, context: StatementContext): self.inputs[0].validateSubtree(context) output_attrspecs = list() for attr in self.attrs: attr.validateSubtree(context, self) if isinstance(attr, AttrRef): _, aidx = attr.internal_ref output_attrspecs.append(self.inputs[0].type.attrs[aidx]) else: output_attrspecs.append(AttrSpec(None, None, attr.type)) self.type = RelType(context.new_tmp(), output_attrspecs) def info(self): yield '{} => {}'.format(symbolic(sym.PROJECT), self.type) for attr in self.attrs: yield '|' + attr.info() for i, line in enumerate(self.inputs[0].info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): for block in self.inputs[0].sql(): yield block yield '{}({}) AS (SELECT DISTINCT {} FROM {})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), ', '.join(attr.sql(self) for attr in self.attrs), self.inputs[0].type.sql_rel())
class Aggr(RelExpr): def __init__(self, groupbys, aggrs, input): assert isinstance(groupbys, list) assert all(isinstance(attr, ValExpr) for attr in groupbys) assert isinstance(aggrs, list) assert all(isinstance(attr, ValExpr) for attr in aggrs) assert isinstance(input, RelExpr) super(Aggr, self).__init__([input]) self.groupbys = groupbys self.aggrs = aggrs def __str__(self): return literal(sym.AGGR) + literal(sym.ARG_L) +\ ', '.join(str(attr) for attr in self.groupbys) +\ (': ' if len(self.groupbys) > 0 else '') +\ ', '.join(str(attr) for attr in self.aggrs) +\ literal(sym.ARG_R) + ' ' + str(paren(self.inputs[0])) def validateSubtree(self, context: StatementContext): self.inputs[0].validateSubtree(context) output_attrspecs = list() for groupby in self.groupbys: groupby.validateSubtree(context, self) if isinstance(groupby, AttrRef): _, aidx = groupby.internal_ref output_attrspecs.append(self.inputs[0].type.attrs[aidx]) else: output_attrspecs.append(AttrSpec(None, None, groupby.type)) for aggr in self.aggrs: aggr.validateSubtree(context, self, allow_aggr=True) aggr.checkSubtreeGroupInvariant(self) if isinstance(aggr, AttrRef): # rare, but techincally possible _, aidx = groupby.internal_ref output_attrspecs.append(self.inputs[0].type.attrs[aidx]) else: output_attrspecs.append(AttrSpec(None, None, aggr.type)) self.type = RelType(context.new_tmp(), output_attrspecs) def info(self): yield '{} => {}'.format(symbolic(sym.AGGR), self.type) for groupby in self.groupbys: yield '|group by: ' + groupby.info() for aggr in self.aggrs: yield '|' + aggr.info() for i, line in enumerate(self.inputs[0].info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): for block in self.inputs[0].sql(): yield block yield '{}({}) AS (SELECT {}{}{} FROM {}{}{})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), ', '.join(attr.sql(self) for attr in self.groupbys), (', ' if len(self.groupbys) > 0 else ''), ', '.join(attr.sql(self) for attr in self.aggrs), self.inputs[0].type.sql_rel(), (' GROUP BY ' if len(self.groupbys) > 0 else ''), ', '.join(attr.sql(self) for attr in self.groupbys))
class RelRef(RelExpr): def __init__(self, rel): super(RelRef, self).__init__() assert isinstance(rel, str) self.rel = rel def __str__(self): return self.rel def validateSubtree(self, context: StatementContext): # first check if this is a table in dbms: if context.db.table_exists(self.rel): attrspecs = [AttrSpec(self.rel, attr, type)\ for attr, type in context.db.describe(self.rel)] self.type = RelType(context.new_tmp(), attrspecs) self.view = None return # then check if this is a view defined in this session: view_def = context.views.raw_def(self.rel) if view_def is not None: self.view = RelExpr.from_view_def(view_def) self.view.validateSubtree(context) attrspecs = [AttrSpec(self.rel, attrspec.name, attrspec.type)\ for attrspec in self.view.type.attrs] self.type = RelType(self.view.type.tmp, attrspecs) # don't create a new temp relation return # reference is not found: raise self.validationError('relation {} does not exist'.format( self.rel)) def info(self): if self.view is None: yield 'RELATION {} => {}'.format(self, self.type) else: yield 'VIEW {} => {}'.format(self, self.type) for i, line in enumerate(self.view.info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): if self.view is None: yield '{}({}) AS (SELECT * FROM {})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), self.rel) else: for block in self.view.sql(): yield block
class SetOp(RelExpr): def __init__(self, left, right): assert isinstance(left, RelExpr) assert isinstance(right, RelExpr) super(SetOp, self).__init__([left, right]) def op(self): if isinstance(self, Union): return sym.UNION elif isinstance(self, Diff): return sym.DIFF elif isinstance(self, Intersect): return sym.INTERSECT else: assert False def sql_op(self): if isinstance(self, Union): return 'UNION' elif isinstance(self, Diff): return 'EXCEPT' elif isinstance(self, Intersect): return 'INTERSECT' else: assert False def __str__(self): return '{} {} {}'.format(paren(self.inputs[0]), literal(self.op()), paren(self.inputs[1])) def validateSubtree(self, context: StatementContext): self.inputs[0].validateSubtree(context) self.inputs[1].validateSubtree(context) if len(self.inputs[0].type.attrs) != len(self.inputs[1].type.attrs): raise self.validationError( 'input relations to a set operation' ' do not have the same number of attributes') for i, (a0, a1) in enumerate( zip(self.inputs[0].type.attrs, self.inputs[1].type.attrs)): if a0.type != a1.type: raise self.validationError( 'input attributes at position {} have' ' different types: {} vs. {}'.format(i, a0, a1)) if a0.name != a1.name: logger.warning('{}: input attributes at position {} have' ' different names: {} vs. {}'.format( self, i, a0.str_ref_only(), a1.str_ref_only())) self.type = RelType(context.new_tmp(), self.inputs[0].type.attrs) def info(self): yield '{} => {}'.format(symbolic(self.op()), self.type) for i, line in enumerate(self.inputs[0].info()): yield ('\\_' if i == 0 else '| ') + line for i, line in enumerate(self.inputs[1].info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): for block in self.inputs[0].sql(): yield block for block in self.inputs[1].sql(): yield block # interestingly, in the SQL below, parentheses around the two # SELECT subqueries are not needed, and SQLite in fact would # not like them: yield '{}({}) AS (SELECT * FROM {} {} SELECT * FROM {})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), self.inputs[0].type.sql_rel(), self.sql_op(), self.inputs[1].type.sql_rel())
class Join(RelExpr): def __init__(self, left, cond, right): assert isinstance(left, RelExpr) assert cond is None or isinstance(cond, ValExpr) assert isinstance(right, RelExpr) super(Join, self).__init__([left, right]) self.cond = cond def __str__(self): cond = '' if self.cond is None else\ literal(sym.ARG_L) + str(self.cond) + literal(sym.ARG_R) return '{} {}{} {}'.format(paren(self.inputs[0]), literal(sym.JOIN), cond, paren(self.inputs[1])) def validateSubtree(self, context: StatementContext): self.inputs[0].validateSubtree(context) self.inputs[1].validateSubtree(context) if self.cond is None: self.pairs = list() for i0, a0 in enumerate(self.inputs[0].type.attrs): if a0.name is None: continue matches = [(i1, a1) for i1, a1 in enumerate(self.inputs[1].type.attrs)\ if a0.name == a1.name] if len(matches) > 1: raise self.validationError('ambiguity in natural join: {} from the left input' ' matches multiple attributes on the right'\ .format(a0.str_ref_only())) elif len(matches) == 1: i1, a1 = matches[0] try: context.check.function_call(symbolic(sym.EQ), [a0.type, a1.type]) except TypeSysError as e: raise self.validationError('natural join cannot equate {} and {}'\ .format(a0, a1)) self.pairs.append((i0, i1)) if len(self.pairs) == 0: logger.warning( '{}: no attributes with matching names found;' ' natural join degnerates into cross product'.format(self)) for i0, i1 in self.pairs: if any((j0, j1) for j0, j1 in self.pairs if j0 != i0 and j1 == i1): raise self.validationError('ambiguity in natural join: {} from the right input' ' matches multiple attributes on the left'\ .format(self.inputs[1].type.attrs[i1]\ .str_ref_only())) attrspecs = list() for a0 in self.inputs[0].type.attrs: attrspecs.append(a0) for i1, a1 in enumerate(self.inputs[1].type.attrs): if not any(i1 == i for _, i in self.pairs): attrspecs.append(a1) if any( a1.can_be_confused_with(a0) for a0 in self.inputs[0].type.attrs): # this shouldn't happen under natural join rule, but oh well: logger.warning('{}: attribute {} becomes confused with others' ' in the join output'\ .format(self, a1.str_ref_only())) self.type = RelType(context.new_tmp(), attrspecs) else: self.cond.validateSubtree(context, self) if self.cond.type != ValType.BOOLEAN: raise self.validationError('join condition {} has type {}; boolean expected'\ .format(self.cond, self.cond.type.value)) for a0, a1 in itertools.product(self.inputs[0].type.attrs, self.inputs[1].type.attrs): if a0.can_be_confused_with(a1): logger.warning('{}: attributes {} from the left input and {} from the right' ' become confused in the join output'\ .format(self, a0.str_ref_only(), a1.str_ref_only())) self.type = RelType( context.new_tmp(), self.inputs[0].type.attrs + self.inputs[1].type.attrs) def info(self): yield '{} => {}'.format(symbolic(sym.JOIN), self.type) if self.cond is not None: yield '|' + self.cond.info() elif len(self.pairs) == 0: yield '|inferred: cross product with no join condition' else: for i0, i1 in self.pairs: yield '|inferred: {}[0.{}] = {}[1.{}]'\ .format(self.inputs[0].type.attrs[i0].str_ref_only(), i0, self.inputs[1].type.attrs[i1].str_ref_only(), i1) for i, line in enumerate(self.inputs[0].info()): yield ('\\_' if i == 0 else '| ') + line for i, line in enumerate(self.inputs[1].info()): yield ('\\_' if i == 0 else ' ') + line def sql(self): for block in self.inputs[0].sql(): yield block for block in self.inputs[1].sql(): yield block select = '*' where = '' if self.cond is None: if len(self.pairs) > 0: attrs = ['{}.{}'.format(self.inputs[0].type.sql_rel(), attr)\ for attr in self.inputs[0].type.sql_attrs()] attrs += ['{}.{}'.format(self.inputs[1].type.sql_rel(), attr)\ for i, attr in enumerate(self.inputs[1].type.sql_attrs())\ if all(i != i1 for _, i1 in self.pairs)] select = ', '.join(attrs) eqs = ['{}.{} = {}.{}'.format(self.inputs[0].type.sql_rel(), self.inputs[0].type.sql_attr(i0), self.inputs[1].type.sql_rel(), self.inputs[1].type.sql_attr(i1))\ for i0, i1 in self.pairs] where = ' WHERE {}'.format(' AND '.join(eqs)) else: where = ' WHERE {}'.format(self.cond.sql(self)) yield '{}({}) AS (SELECT {} FROM {}, {}{})'\ .format(self.type.sql_rel(), ', '.join(self.type.sql_attrs()), select, self.inputs[0].type.sql_rel(), self.inputs[1].type.sql_rel(), where)