Exemplo n.º 1
0
    def SingleRuleSql(self,
                      rule,
                      allocator=None,
                      external_vocabulary=None,
                      is_combine=False):
        """Producing SQL for a given rule in the program."""
        allocator = allocator or self.NewNamesAllocator()
        r = rule
        if (is_combine):
            r = self.execution.dialect.DecorateCombineRule(
                r, allocator.AllocateVar())
        s = rule_translate.ExtractRuleStructure(r, allocator,
                                                external_vocabulary)

        s.ElliminateInternalVariables(assert_full_ellimination=False)

        self.RunInjections(s, allocator)
        s.ElliminateInternalVariables(assert_full_ellimination=True)
        s.UnificationsToConstraints()
        try:
            sql = s.AsSql(self.MakeSubqueryTranslator(allocator),
                          self.flag_values)
        except RuntimeError as runtime_error:
            if (str(runtime_error).startswith('maximum recursion')):
                raise rule_translate.RuleCompileException(
                    RecursionError(), s.full_rule_text)
            else:
                raise runtime_error
        if 'nil' in s.tables.values():
            # Mark rule for deletion.
            sql = '/* nil */' + sql
        return sql
Exemplo n.º 2
0
    def RunInjections(self, s, allocator):
        iterations = 0
        while True:
            iterations += 1
            if iterations > sys.getrecursionlimit():
                raise rule_translate.RuleCompileException(
                    RecursionError(), s.full_rule_text)

            new_tables = collections.OrderedDict()
            for table_name_rsql, table_predicate_rsql in s.tables.items():
                rules = list(self.GetPredicateRules(table_predicate_rsql))
                if (len(rules) == 1 and ('distinct_denoted' not in rules[0])
                        and
                        self.annotations.OkInjection(table_predicate_rsql)):
                    [r] = rules
                    rs = rule_translate.ExtractRuleStructure(
                        r, allocator, None)
                    rs.ElliminateInternalVariables(
                        assert_full_ellimination=False)
                    new_tables.update(rs.tables)
                    InjectStructure(s, rs)

                    new_vars_map = {}
                    new_inv_vars_map = {}
                    for (table_name,
                         table_var), clause_var in s.vars_map.items():
                        if table_name != table_name_rsql:
                            new_vars_map[table_name, table_var] = clause_var
                            new_inv_vars_map[clause_var] = (table_name,
                                                            table_var)
                        else:
                            if table_var not in rs.select:
                                if '*' in rs.select:
                                    subscript = {
                                        'literal': {
                                            'the_symbol': {
                                                'symbol': table_var
                                            }
                                        }
                                    }
                                    s.vars_unification.append({
                                        'left': {
                                            'variable': {
                                                'var_name': clause_var
                                            }
                                        },
                                        'right': {
                                            'subscript': {
                                                'subscript': subscript,
                                                'record': rs.select['*']
                                            }
                                        }
                                    })
                                else:
                                    extra_hint = '' if table_var != '*' else (
                                        ' Are you using ..<rest of> for injectible predicate? '
                                        'Please list the fields that you extract explicitly. '
                                        'Tracking bug: b/131759583.')
                                    raise rule_translate.RuleCompileException(
                                        color.Format(
                                            'Predicate {warning}{table_predicate_rsql}{end} '
                                            'does not have an argument '
                                            '{warning}{table_var}{end}, but '
                                            'this rule tries to access it. {extra_hint}',
                                            dict(table_predicate_rsql=
                                                 table_predicate_rsql,
                                                 table_var=table_var,
                                                 extra_hint=extra_hint)),
                                        s.full_rule_text)
                            else:
                                s.vars_unification.append({
                                    'left': {
                                        'variable': {
                                            'var_name': clause_var
                                        }
                                    },
                                    'right':
                                    rs.select[table_var]
                                })
                    s.vars_map = new_vars_map
                    s.inv_vars_map = new_inv_vars_map
                else:
                    new_tables[table_name_rsql] = table_predicate_rsql
            if s.tables == new_tables:
                break
            s.tables = new_tables
Exemplo n.º 3
0
    def FunctionSql(self, name, allocator=None, internal_mode=False):
        """Print formatted SQL function creation statement."""
        # TODO: Refactor this into FunctionSqlInternal and FunctionSql.
        if not allocator:
            allocator = self.NewNamesAllocator()

        rules = list(self.GetPredicateRules(name))

        # Check that the predicate is defined via a single rule.
        if not rules:
            raise rule_translate.RuleCompileException(
                color.Format(
                    'No rules are defining {warning}{name}{end}, but compilation '
                    'was requested.', dict(name=name)), r'        ¯\_(ツ)_/¯')
        elif len(rules) > 1:
            raise rule_translate.RuleCompileException(
                color.Format(
                    'Predicate {warning}{name}{end} is defined by more than 1 rule '
                    'and can not be compiled into a function.',
                    dict(name=name)),
                '\n\n'.join(r['full_text'] for r in rules))
        [rule] = rules

        # Extract structure and assert that it is isomorphic to a function.
        s = rule_translate.ExtractRuleStructure(rule,
                                                external_vocabulary=None,
                                                names_allocator=allocator)

        udf_variables = [
            v if isinstance(v, str) else 'col%d' % v for v in s.select
            if v != 'logica_value'
        ]
        s.select = self.TurnPositionalIntoNamed(s.select)

        variables = [v for v in s.select if v != 'logica_value']
        if 0 in variables:
            raise rule_translate.RuleCompileException(
                color.Format(
                    'Predicate {warning}{name}{end} must have all aruments named for '
                    'compilation as a function.', dict(name=name)),
                rule['full_text'])
        for v in variables:
            if ('variable' not in s.select[v]
                    or s.select[v]['variable']['var_name'] != v):
                raise rule_translate.RuleCompileException(
                    color.Format(
                        'Predicate {warning}{name}{end} must not rename arguments '
                        'for compilation as a function.', dict(name=name)),
                    rule['full_text'])

        vocabulary = {v: v for v in variables}
        s.external_vocabulary = vocabulary
        self.RunInjections(s, allocator)
        s.ElliminateInternalVariables(assert_full_ellimination=True)
        s.UnificationsToConstraints()
        sql = s.AsSql(subquery_encoder=self.MakeSubqueryTranslator(allocator))
        if s.constraints or s.unnestings or s.tables:
            raise rule_translate.RuleCompileException(
                color.Format(
                    'Predicate {warning}{name}{end} is not a simple function, but '
                    'compilation as function was requested. Full SQL:\n{sql}',
                    dict(name=name, sql=sql)), rule['full_text'])
        if 'logica_value' not in s.select:
            raise rule_translate.RuleCompileException(
                color.Format(
                    'Predicate {warning}{name}{end} does not have a value, but '
                    'compilation as function was requested. Full SQL:\n%s' %
                    sql), rule['full_text'])

        # pylint: disable=g-long-lambda
        # Compile the function!
        ql = expr_translate.QL(
            vocabulary,
            self.MakeSubqueryTranslator(allocator),
            lambda message: rule_translate.RuleCompileException(
                message, rule['full_text']),
            self.flag_values,
            custom_udfs=self.custom_udfs,
            dialect=self.execution.dialect)
        value_sql = ql.ConvertToSql(s.select['logica_value'])

        sql = 'CREATE TEMP FUNCTION {name}({signature}) AS ({value})'.format(
            name=name,
            signature=', '.join('%s ANY TYPE' % v for v in variables),
            value=value_sql)

        sql = FormatSql(sql)

        if internal_mode:
            return ('%s(%s)' % (name, ', '.join('{%s}' % v
                                                for v in udf_variables)), sql)

        return sql