Ejemplo n.º 1
0
 def process(expr):
     """Rewrite any retrievals in the given expression. Return a pair
     of the new expression, and a list of new clauses to be added
     for any retrievals not already seen.
     """
     nonlocal seen_map
     new_expr = replacer.process(expr)
     new_field_repls = replacer.field_repls - seen_field_repls
     new_map_repls = replacer.map_repls - seen_map_repls
     new_clauses = []
     
     for repl in new_field_repls:
         obj, field, value = repl
         seen_fields.add(field)
         seen_field_repls.add(repl)
         new_cl = L.Enumerator(L.tuplify((obj, value), lval=True),
                               L.ln(make_frel(field)))
         new_clauses.append(new_cl)
     
     for repl in new_map_repls:
         map, key, value = repl
         seen_map = True
         seen_map_repls.add(repl)
         new_cl = L.Enumerator(L.tuplify((map, key, value), lval=True),
                               L.ln(make_maprel()))
         new_clauses.append(new_cl)
     
     return new_expr, new_clauses
Ejemplo n.º 2
0
 def visit_Aggregate(self, node):
     node = self.generic_visit(node)
     
     if not node.op in ['min', 'max']:
         return node
     func2 = {'min': 'min2', 'max': 'max2'}[node.op]
     
     if not L.is_setunion(node.value):
         return node
     sets = L.get_setunion(node.value)
     if len(sets) == 1:
         # If there's just one set, don't change anything.
         return node
     
     # Wrap each operand in an aggregate query with the same
     # options as the original aggregate. (This ensures that
     # 'impl' is carried over.) Set literals are wrapped in
     # a call to incoq.runtime's min2()/max2() instead of an
     # Aggregate query node.
     terms = []
     for s in sets:
         if isinstance(s, (L.Comp, L.Name)):
             new_term = L.Aggregate(s, node.op, node.options)
         else:
             new_term = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)})
             new_term = new_term._replace(args=s.elts)
         terms.append(new_term)
     
     # The new top-level aggregate is min2()/max2().
     new_node = L.pe('OP(__ARGS)',
                     subst={'OP': L.ln(func2)})
     new_node = new_node._replace(args=tuple(terms))
     return new_node
Ejemplo n.º 3
0
 def test_enumclause_setmatch(self):
     # Make sure we can convert clauses over setmatches.
     cl = EnumClause.from_AST(
         L.Enumerator(L.tuplify(['y'], lval=True),
                      L.SetMatch(L.ln('R'), 'bu', L.ln('x'))), DummyFactory)
     exp_cl = EnumClause(('x', 'y'), 'R')
     self.assertEqual(cl, exp_cl)
Ejemplo n.º 4
0
 def visit_SetUpdate(self, node):
     rel = L.get_name(node.target)
     if rel not in self.at_rels:
         return
     
     # New code gets inserted after the update.
     # This is true even if the update was a removal.
     # It shouldn't matter where we do the U-set update,
     # so long as the invariants are properly maintained
     # at the time.
     
     prefix = self.manager.namegen.next_prefix()
     vars = [prefix + v for v in self.projvars]
     
     if node.op == 'add':
         funcname = L.N.demfunc(self.demname)
     else:
         funcname = L.N.undemfunc(self.demname)
     
     call_func = L.Call(L.ln(funcname),
                           tuple(L.ln(v) for v in vars),
                           (), None, None)
     postcode = L.pc('''
         for S_PROJVARS in DELTA.elements():
             CALL_FUNC
         DELTA.clear()
         ''', subst={'S_PROJVARS': L.tuplify(vars, lval=True),
                     'DELTA': self.delta_name,
                     'CALL_FUNC': call_func})
     
     return self.with_outer_maint(node, funcname, L.ts(node),
                                  (), postcode)
Ejemplo n.º 5
0
def flatten_set_clause(cl, input_rels):
    """Turn a membership clause that is not over a comprehension,
    special relation, or input relation, into a clause over the M-set.
    Return a pair of the (possibly unchanged) clause and a bool
    indicating whether or not the change was done.
    
    This also works on condition clauses that express membership
    constraints. The rewritten clause is still a condition clause.
    """
    def should_trans(rhs):
        return (not isinstance(rhs, L.Comp) and
                not (isinstance(rhs, L.Name) and
                     (is_specialrel(rhs.id) or rhs.id in input_rels)))
    
    # Enumerator case.
    if isinstance(cl, L.Enumerator) and should_trans(cl.iter):
        item = cl.target
        cont = cl.iter
        cont = L.ContextSetter.run(cont, L.Store)
        new_cl = L.Enumerator(L.tuplify((cont, item), lval=True),
                              L.ln(make_mrel()))
        return new_cl, True
    
    # Condition case.
    if isinstance(cl, L.expr) and L.is_cmp(cl):
        item, op, cont = L.get_cmp(cl)
        if isinstance(op, L.In) and should_trans(cont):
            new_cl = L.cmp(L.tuplify((cont, item)),
                           L.In(),
                           L.ln(make_mrel()))
            return new_cl, True
    
    return cl, False
Ejemplo n.º 6
0
    def get_res_code(self):
        """Return code (expression) to lookup the result."""
        params = self.inccomp.comp.params

        if len(params) > 0:
            resexp = self.inccomp.spec.resexp
            assert isinstance(resexp, L.Tuple)
            resexp_arity = len(resexp.elts)
            n_rescomponents = resexp_arity - len(params)

            maskstr = 'b' * len(params) + 'u' * n_rescomponents
            masknode = Mask(maskstr).make_node()
            paramsnode = L.tuplify(params)

            code = L.pe('''
                setmatch(RES, MASK, PARAMS)
                ''',
                        subst={
                            'RES': L.ln(self.inccomp.name),
                            'MASK': masknode,
                            'PARAMS': paramsnode
                        })

        else:
            code = L.ln(self.inccomp.name)

        return code
Ejemplo n.º 7
0
    def visit_Aggregate(self, node):
        node = self.generic_visit(node)

        if not node.op in ['min', 'max']:
            return node
        func2 = {'min': 'min2', 'max': 'max2'}[node.op]

        if not L.is_setunion(node.value):
            return node
        sets = L.get_setunion(node.value)
        if len(sets) == 1:
            # If there's just one set, don't change anything.
            return node

        # Wrap each operand in an aggregate query with the same
        # options as the original aggregate. (This ensures that
        # 'impl' is carried over.) Set literals are wrapped in
        # a call to incoq.runtime's min2()/max2() instead of an
        # Aggregate query node.
        terms = []
        for s in sets:
            if isinstance(s, (L.Comp, L.Name)):
                new_term = L.Aggregate(s, node.op, node.options)
            else:
                new_term = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)})
                new_term = new_term._replace(args=s.elts)
            terms.append(new_term)

        # The new top-level aggregate is min2()/max2().
        new_node = L.pe('OP(__ARGS)', subst={'OP': L.ln(func2)})
        new_node = new_node._replace(args=tuple(terms))
        return new_node
Ejemplo n.º 8
0
def make_vareq_cond(eqs):
    """Given a list of pairs of variables, return a conjunction of
    equalities, one for each pair.
    """
    eqcond = L.BoolOp(L.And(), tuple(L.cmpeq(L.ln(v1), L.ln(v2))
                                     for v1, v2 in eqs))
    return eqcond
Ejemplo n.º 9
0
 def test_enumclause_setmatch(self):
     # Make sure we can convert clauses over setmatches.
     cl = EnumClause.from_AST(
             L.Enumerator(L.tuplify(['y'], lval=True),
                          L.SetMatch(L.ln('R'), 'bu', L.ln('x'))),
             DummyFactory)
     exp_cl = EnumClause(('x', 'y'), 'R')
     self.assertEqual(cl, exp_cl)
Ejemplo n.º 10
0
    def __init__(self, aggrop, rel, relmask, params, oper_demname,
                 oper_demparams):
        assert self.aggrop in ['count', 'sum', 'min', 'max']

        # AST node representation.
        node = L.ln(rel)
        if len(params) > 0:
            node = L.SetMatch(node, relmask.make_node().s, L.tuplify(params))
        if oper_demname is not None:
            node = L.DemQuery(oper_demname, [L.ln(p) for p in oper_demparams],
                              node)
        node = L.Aggregate(node, aggrop, None)
        self.node = node
Ejemplo n.º 11
0
    def visit_Comp(self, node):
        node = self.generic_visit(node)

        if node != self.comp:
            return node

        self.need_func = True

        call = L.pe('QFUN(__ARGS)',
                    subst={'QFUN': L.ln(L.N.queryfunc(self.name))})

        call = call._replace(args=tuple(L.ln(p) for p in self.comp.params))

        return call
Ejemplo n.º 12
0
Archivo: aggr.py Proyecto: IncOQ/incoq
 def __init__(self, aggrop, rel, relmask, params,
              oper_demname, oper_demparams):
     assert self.aggrop in ['count', 'sum', 'min', 'max']
     
     # AST node representation.
     node = L.ln(rel)
     if len(params) > 0:
         node = L.SetMatch(node, relmask.make_node().s,
                           L.tuplify(params))
     if oper_demname is not None:
         node = L.DemQuery(oper_demname,
                           [L.ln(p) for p in oper_demparams], node)
     node = L.Aggregate(node, aggrop, None)
     self.node = node
Ejemplo n.º 13
0
Archivo: aggr.py Proyecto: IncOQ/incoq
 def make_addu_maint(self, prefix):
     """Generate code for after an addition to U."""
     incaggr = self.incaggr
     assert incaggr.has_demand
     spec = incaggr.spec
     mv_var = prefix + 'val'
     elemvar = prefix + 'elem'
     
     # If we're using half-demand, there's no demand to propagate
     # to the operand. All we need to do is add an entry with count
     # 0 if one is not already there.
     if incaggr.half_demand:
         return L.pc('''
             S_MV = A.smdeflookup(MASK, KEY, None)
             if MV is None:
                 A.smassignkey(MASK, KEY, ZERO, PREFIX)
             ''', subst={'A': L.ln(incaggr.name),
                         'S_MV': L.sn(mv_var),
                         'MV': L.ln(mv_var),
                         'MASK': incaggr.aggrmask.make_node(),
                         'KEY': L.tuplify(incaggr.params),
                         'ZERO': self.make_zero_mapval_expr(),
                         'PREFIX': L.Str(prefix)})
     
     update_code = self.make_update_state_code(
         L.sn(mv_var), L.ln(mv_var), 'add',
         L.ln(elemvar), prefix)
     
     # Make operand demand function call, if operand uses demand.
     if spec.has_oper_demand:
         demfunc = L.N.demfunc(spec.oper_demname)
         call_demfunc = L.Call(L.ln(demfunc),
                               tuple(L.ln(v) for v in spec.oper_demparams),
                               (), None, None)
         propagate_code = (L.Expr(call_demfunc),)
     else:
         propagate_code = ()  
     
     code = L.pc('''
         S_MV = ZERO
         for S_ELEM in setmatch(R, RELMASK, PARAMS):
             UPDATE_MAPVAL
         A.smassignkey(AGGRMASK, KEY, MV, PREFIX)
         PROPAGATE_DEMAND
         ''', subst={'S_MV': L.sn(mv_var),
                     'ZERO': self.make_zero_mapval_expr(),
                     'S_ELEM': L.sn(elemvar),
                     'R': spec.rel,
                     'RELMASK': spec.relmask.make_node(),
                     'PARAMS': L.tuplify(spec.params),
                     '<c>UPDATE_MAPVAL': update_code,
                     'A': L.ln(incaggr.name),
                     'AGGRMASK': incaggr.aggrmask.make_node(),
                     'KEY': L.tuplify(incaggr.params),
                     'MV': L.ln(mv_var),
                     'PREFIX': L.Str(prefix),
                     '<c>PROPAGATE_DEMAND': propagate_code})
     
     return code
Ejemplo n.º 14
0
    def make_removeu_maint(self, prefix):
        """Generate code for before a removal from U."""
        incaggr = self.incaggr
        assert incaggr.has_demand
        spec = incaggr.spec
        mv_var = prefix + 'val'

        # If we're using half-demand, there's no demand to propagate
        # to the operand. All we need to do is determine whether to
        # do the removal by checking whether the count is 0.
        if incaggr.half_demand:
            return L.pc('''
                S_MV = A.smlookup(MASK, KEY)
                if COUNT == 0:
                    A.smdelkey(MASK, KEY, PREFIX)
                ''',
                        subst={
                            'S_MV': L.sn(mv_var),
                            'COUNT': self.mapval_proj_count(L.ln(mv_var)),
                            'A': incaggr.name,
                            'MASK': incaggr.aggrmask.make_node(),
                            'KEY': L.tuplify(incaggr.params),
                            'PREFIX': L.Str(prefix)
                        })

        # Generate operand undemand function call, if operand
        # uses demand.
        if spec.has_oper_demand:
            undemfunc = L.N.undemfunc(spec.oper_demname)
            call_undemfunc = L.Call(
                L.ln(undemfunc), tuple(L.ln(v) for v in spec.oper_demparams),
                (), None, None)
            propagate_code = (L.Expr(call_undemfunc), )
        else:
            propagate_code = ()

        code = L.pc('''
            PROPAGATE_DEMAND
            A.smdelkey(MASK, KEY, PREFIX)
            ''',
                    subst={
                        'A': incaggr.name,
                        'MASK': incaggr.aggrmask.make_node(),
                        'KEY': L.tuplify(incaggr.params),
                        'PREFIX': L.Str(prefix),
                        '<c>PROPAGATE_DEMAND': propagate_code
                    })

        return code
Ejemplo n.º 15
0
 def visit_Comp(self, node):
     node = self.generic_visit(node)
     
     if node != self.comp:
         return node
     
     self.need_func = True
     
     call = L.pe('QFUN(__ARGS)',
                 subst={'QFUN': L.ln(L.N.queryfunc(self.name))})
     
     call = call._replace(args=tuple(L.ln(p)
                                     for p in self.comp.params))
     
     return call
Ejemplo n.º 16
0
 def visit_Name(self, node):
     if node.id in self.field_exps:
         obj, field = self.field_exps[node.id]
         new_node = L.Attribute(L.ln(obj), field, node.ctx)
         new_node = self.generic_visit(new_node)
         return new_node
     
     elif node.id in self.map_exps:
         map, key = self.map_exps[node.id]
         new_node = L.Subscript(L.ln(map), L.Index(L.ln(key)), node.ctx)
         new_node = self.generic_visit(new_node)
         return new_node
     
     else:
         return node
Ejemplo n.º 17
0
Archivo: aggr.py Proyecto: IncOQ/incoq
 def make_update_state_code(self, state_snode, state_lnode,
                      op, val_node, prefix):
     add_template = L.trim('''
         S_TREE, _ = STATE
         TREE[VAL] = None
         S_STATE = (TREE, TREE.MINMAX())
         ''')
     
     remove_template = L.trim('''
         S_TREE, _ = STATE
         del TREE[VAL]
         S_STATE = (TREE, TREE.MINMAX())
         ''')
     
     template = {'add': add_template, 'remove': remove_template}[op]
     
     treevar = prefix + 'tree'
     minmax = {'min': '__min__', 'max': '__max__'}[self.kind] 
     code = L.pc(template,
                 subst={'S_TREE': L.sn(treevar),
                        'TREE': L.ln(treevar),
                        '@MINMAX': minmax,
                        'STATE': state_lnode,
                        'S_STATE': state_snode,
                        'VAL': val_node})
     return code
Ejemplo n.º 18
0
 def test_enumclause_basic(self):
     cl = EnumClause(('x', 'y', 'x', '_'), 'R')
     
     # From expression.
     cl2 = EnumClause.from_expr(L.pe('(x, y, x, _) in R'))
     self.assertEqual(cl2, cl)
     
     # AST round-trip.
     clast = cl.to_AST()
     exp_clast = \
         L.Enumerator(L.tuplify(['x', 'y', 'x', '_'], lval=True),
                      L.ln('R'))
     self.assertEqual(clast, exp_clast)
     cl2 = EnumClause.from_AST(exp_clast, DummyFactory)
     self.assertEqual(cl2, cl)
     
     # Attributes.
     
     self.assertFalse(cl.isdelta)
     
     self.assertEqual(cl.enumlhs, ('x', 'y', 'x', '_'))
     self.assertEqual(cl.enumvars, ('x', 'y'))
     self.assertEqual(cl.pat_mask, (True, True, True, True))
     self.assertEqual(cl.enumrel, 'R')
     self.assertTrue(cl.has_wildcards)
     
     self.assertEqual(cl.vars, ('x', 'y'))
     self.assertEqual(cl.eqvars, None)
     
     self.assertTrue(cl.robust)
     self.assertEqual(cl.demname, None)
     self.assertEqual(cl.demparams, ())
Ejemplo n.º 19
0
 def to_AST(self):
     mask = Mask.from_keylen(len(self.lhs) - 1)
     keyvars = self.lhs[:-1]
     var = self.lhs[-1]
     sm = L.SMLookup(L.ln(self.rel),
                     mask.make_node().s, L.tuplify(keyvars), None)
     return L.Enumerator(L.sn(var), L.Set((sm, )))
Ejemplo n.º 20
0
def make_bindmatch(rel, mask, bvars, uvars, body):
    if mask.is_allbound and not mask.has_equalities:
        template = L.trim('''
            if BVARS in REL:
                BODY
            ''')
    
    elif mask.is_allunbound and not mask.has_wildcards:
        template = L.trim('''
            for UVARS in REL:
                BODY
            ''')
    
    else:
        template = L.trim('''
            for UVARS in setmatch(REL, MASK, BVARS):
                BODY
            ''')
    
    code = L.pc(template, subst={'REL': L.ln(rel),
                                 'MASK': mask.make_node(),
                                 'BVARS': L.tuplify(bvars),
                                 'UVARS': L.tuplify(uvars, lval=True),
                                 '<c>BODY': body})
    return code
Ejemplo n.º 21
0
 def visit_DemQuery(self, node):
     self.demwrapped_nodes.add(node.value)
     node = self.generic_visit(node)
     
     if not isinstance(node.value, L.SMLookup):
         return node
     sm = node.value
     assert sm.default is None
     
     v = self.repls.get(node, None)
     if v is not None:
         # Reuse existing entry.
         var = v
     else:
         # Create new entry.
         self.repls[node] = var = next(self.namer)
         # Create accompanying clause. Has form
         #    var in DEMQUERY(..., {smlookup})
         # The clause constructor logic will later rewrite that,
         # or else fail if there's a syntax problem.
         cl_target = L.sn(var)
         cl_iter = node._replace(value=L.Set((sm,)))
         new_cl = L.Enumerator(cl_target, cl_iter)
         self.new_clauses.append(new_cl)
     
     return L.ln(var)
Ejemplo n.º 22
0
    def make_update_state_code(self, state_snode, state_lnode, op, val_node,
                               prefix):
        add_template = L.trim('''
            S_TREE, _ = STATE
            TREE[VAL] = None
            S_STATE = (TREE, TREE.MINMAX())
            ''')

        remove_template = L.trim('''
            S_TREE, _ = STATE
            del TREE[VAL]
            S_STATE = (TREE, TREE.MINMAX())
            ''')

        template = {'add': add_template, 'remove': remove_template}[op]

        treevar = prefix + 'tree'
        minmax = {'min': '__min__', 'max': '__max__'}[self.kind]
        code = L.pc(template,
                    subst={
                        'S_TREE': L.sn(treevar),
                        'TREE': L.ln(treevar),
                        '@MINMAX': minmax,
                        'STATE': state_lnode,
                        'S_STATE': state_snode,
                        'VAL': val_node
                    })
        return code
Ejemplo n.º 23
0
    def visit_SetUpdate(self, node):
        target_ok = LegalUpdateValidator.run(node.target)
        elem_ok = LegalUpdateValidator.run(node.elem)
        if target_ok and elem_ok:
            return node

        code = ()
        if not target_ok:
            targetvar = next(self.namegen)
            code += (L.Assign((L.sn(targetvar), ), node.target), )
            node = node._replace(target=L.ln(targetvar))
        if not elem_ok:
            elemvar = next(self.namegen)
            code += (L.Assign((L.sn(elemvar), ), node.elem), )
            node = node._replace(elem=L.ln(elemvar))
        return code + (node, )
Ejemplo n.º 24
0
 def to_AST(self):
     code = self.cl.to_AST()
     assert isinstance(code, L.Enumerator)
     code = code._replace(
         iter=L.DemQuery(self.demname, tuple(
             L.ln(p) for p in self.demparams), code.iter))
     return code
Ejemplo n.º 25
0
    def make_retrieval_code(self):
        """Make code for retrieving the value of the aggregate result,
        including demanding it.
        """
        incaggr = self.incaggr

        params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load())

        if incaggr.has_demand:
            code = L.pe('''
                DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T))
                ''',
                        subst={
                            'NAME': incaggr.name,
                            'PARAMS_L': params_l,
                            'PARAMS_T': L.tuplify(incaggr.params),
                            'RES': incaggr.name,
                            'AGGRMASK': incaggr.aggrmask.make_node()
                        })

        else:
            code = L.pe('''
                RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO)
                ''',
                        subst={
                            'RES': incaggr.name,
                            'AGGRMASK': incaggr.aggrmask.make_node(),
                            'PARAMS_T': L.tuplify(incaggr.params),
                            'ZERO': self.make_zero_mapval_expr(),
                        })

        code = self.make_proj_mapval_code(code)

        return code
Ejemplo n.º 26
0
 def to_AST(self):
     mask = Mask.from_keylen(len(self.lhs) - 1)
     keyvars = self.lhs[:-1]
     var = self.lhs[-1]
     sm = L.SMLookup(L.ln(self.rel), mask.make_node().s,
                     L.tuplify(keyvars), None)
     return L.Enumerator(L.sn(var), L.Set((sm,)))
Ejemplo n.º 27
0
Archivo: aggr.py Proyecto: IncOQ/incoq
 def make_retrieval_code(self):
     """Make code for retrieving the value of the aggregate result,
     including demanding it.
     """
     incaggr = self.incaggr
     
     params_l = L.List(tuple(L.ln(p) for p in incaggr.params), L.Load())
     
     if incaggr.has_demand:
         code = L.pe('''
             DEMQUERY(NAME, PARAMS_L, RES.smlookup(AGGRMASK, PARAMS_T))
             ''', subst={'NAME': incaggr.name,
                         'PARAMS_L': params_l,
                         'PARAMS_T': L.tuplify(incaggr.params),
                         'RES': incaggr.name,
                         'AGGRMASK': incaggr.aggrmask.make_node()})
     
     else:
         code = L.pe('''
             RES.smdeflookup(AGGRMASK, PARAMS_T, ZERO)
             ''', subst={'RES': incaggr.name,
                         'AGGRMASK': incaggr.aggrmask.make_node(),
                         'PARAMS_T': L.tuplify(incaggr.params),
                         'ZERO': self.make_zero_mapval_expr(),})
     
     code = self.make_proj_mapval_code(code)
     
     return code
Ejemplo n.º 28
0
def structures_to_comps(ds, factory):
    """Convert tags and filters to comprehensions that define them.
    Return pairs of comps and their names, in dependency order.
    Ignore usets.
    """
    tags_by_name = {t.name: t for t in ds.tags}
    result = []
    
    for s in ds.structs:
        if s.kind is KIND_TAG:
            cl = EnumClause(s.lhs, s.rel)
            spec = CompSpec(Join([cl], factory, None), L.ln(s.var), [])
        elif s.kind is KIND_FILTER:
            cls = []
            for tname in s.preds:
                t = tags_by_name[tname]
                cls.append(EnumClause([t.var], t.name))
            # Be sure to replace wildcards with fresh vars.
            lhs = inst_wildcards(s.lhs)
            cls.append(EnumClause(lhs, s.rel))
            spec = CompSpec(Join(cls, factory, None), L.tuplify(lhs), [])
        elif s.kind is KIND_USET:
            continue
        else:
            assert()
        
        result.append((s.name, spec.to_comp({})))
    
    return result
Ejemplo n.º 29
0
    def test_enumclause_basic(self):
        cl = EnumClause(('x', 'y', 'x', '_'), 'R')

        # From expression.
        cl2 = EnumClause.from_expr(L.pe('(x, y, x, _) in R'))
        self.assertEqual(cl2, cl)

        # AST round-trip.
        clast = cl.to_AST()
        exp_clast = \
            L.Enumerator(L.tuplify(['x', 'y', 'x', '_'], lval=True),
                         L.ln('R'))
        self.assertEqual(clast, exp_clast)
        cl2 = EnumClause.from_AST(exp_clast, DummyFactory)
        self.assertEqual(cl2, cl)

        # Attributes.

        self.assertFalse(cl.isdelta)

        self.assertEqual(cl.enumlhs, ('x', 'y', 'x', '_'))
        self.assertEqual(cl.enumvars, ('x', 'y'))
        self.assertEqual(cl.pat_mask, (True, True, True, True))
        self.assertEqual(cl.enumrel, 'R')
        self.assertTrue(cl.has_wildcards)

        self.assertEqual(cl.vars, ('x', 'y'))
        self.assertEqual(cl.eqvars, None)

        self.assertTrue(cl.robust)
        self.assertEqual(cl.demname, None)
        self.assertEqual(cl.demparams, ())
Ejemplo n.º 30
0
 def visit_SetUpdate(self, node):
     target_ok = LegalUpdateValidator.run(node.target)
     elem_ok = LegalUpdateValidator.run(node.elem)
     if target_ok and elem_ok:
         return node
     
     code = ()
     if not target_ok:
         targetvar = next(self.namegen)
         code += (L.Assign((L.sn(targetvar),), node.target),)
         node = node._replace(target=L.ln(targetvar))
     if not elem_ok:
         elemvar = next(self.namegen)
         code += (L.Assign((L.sn(elemvar),), node.elem),)
         node = node._replace(elem=L.ln(elemvar))
     return code + (node,)
Ejemplo n.º 31
0
 def visit_DelKey(self, node):
     target_ok = LegalUpdateValidator.run(node.target)
     key_ok = LegalUpdateValidator.run(node.key)
     if target_ok and key_ok:
         return node
     
     code = ()
     if not target_ok:
         targetvar = next(self.namegen)
         code += (L.Assign((L.sn(targetvar),), node.target),)
         node = node._replace(target=L.ln(targetvar))
     if not key_ok:
         keyvar = next(self.namegen)
         code += (L.Assign((L.sn(keyvar),), node.key),)
         node = node._replace(key=L.ln(keyvar))
     
     return code + (node,)
Ejemplo n.º 32
0
 def visit_DemQuery(self, node):
     # Translate into a call to the demand function. Cost is
     # that plus the result retrieval cost.
     callnode = L.Call(L.ln(L.N.queryfunc(node.demname)), node.args, (),
                       None, None)
     callcost = self.visit(callnode)
     retrievecost = self.visit(node.value)
     return SumCost([callcost, retrievecost])
Ejemplo n.º 33
0
    def visit_DelKey(self, node):
        target_ok = LegalUpdateValidator.run(node.target)
        key_ok = LegalUpdateValidator.run(node.key)
        if target_ok and key_ok:
            return node

        code = ()
        if not target_ok:
            targetvar = next(self.namegen)
            code += (L.Assign((L.sn(targetvar), ), node.target), )
            node = node._replace(target=L.ln(targetvar))
        if not key_ok:
            keyvar = next(self.namegen)
            code += (L.Assign((L.sn(keyvar), ), node.key), )
            node = node._replace(key=L.ln(keyvar))

        return code + (node, )
Ejemplo n.º 34
0
 def helper(self, node, no_update=False):
     fresh = next(self.namegen)
     
     if no_update:
         template = L.trim('''
             S_VAR = Set()
             ''')
     else:
         template = L.trim('''
             S_VAR = Set()
             L_VAR.update(EXPR)
             ''')
     new_code = L.pc(template, subst={'L_VAR': L.ln(fresh),
                                      'S_VAR': L.sn(fresh),
                                      'EXPR': node})
     
     self.pre_stmts.extend(new_code)
     return L.ln(fresh)
Ejemplo n.º 35
0
 def mainttest_helper(self, maskstr):
     spec = AuxmapSpec('R', Mask(maskstr))
     
     # Make the prefix '_' so it's easier to read/type.
     self.manager.namegen.next_prefix = lambda: '_'
     
     code = make_auxmap_maint_code(self.manager, spec, L.ln('e'), 'add')
     
     return code
Ejemplo n.º 36
0
 def visit_Delete(self, node):
     if not self.rewrite_fields:
         return node
     if not L.is_delattr(node):
         return node
     cont, field = L.get_delattr(node)
     return L.pc('''
         CONT.nsdelfield(FIELD)
         ''', subst={'CONT': cont,
                     'FIELD': L.ln(field)})
Ejemplo n.º 37
0
    def make_update_mapval_code(self, mv_snode, mv_lnode, op, val_node,
                                prefix):
        """Produce code to make a new mapval, given an update to
        the corresponding operand. The mapval is read from mv_lnode
        and written to mv_snode.
        """
        # If we don't track counts, the mapvals are the same as
        # the states.
        if not self.incaggr.tracks_counts:
            return self.make_update_state_code(mv_snode, mv_lnode, op,
                                               val_node, prefix)

        statevar = prefix + 'state'
        state_lnode = L.ln(statevar)
        state_snode = L.sn(statevar)
        countvar = prefix + 'count'

        updatestate_code = self.make_update_state_code(state_snode,
                                                       state_lnode, op,
                                                       val_node, prefix)

        if op == 'add':
            template = 'COUNTVAR + 1'
        elif op == 'remove':
            template = 'COUNTVAR - 1'
        else:
            assert ()
        new_count_node = L.pe(template, subst={'COUNTVAR': L.ln(countvar)})

        return L.pc('''
            S_STATE, S_COUNTVAR = MV
            UPDATE_STATE
            S_MV = STATE, NEW_COUNT
            ''',
                    subst={
                        'S_STATE': state_snode,
                        'S_COUNTVAR': L.sn(countvar),
                        'MV': mv_lnode,
                        '<c>UPDATE_STATE': updatestate_code,
                        'STATE': state_lnode,
                        'NEW_COUNT': new_count_node,
                        'S_MV': mv_snode
                    })
Ejemplo n.º 38
0
Archivo: aggr.py Proyecto: IncOQ/incoq
 def make_removeu_maint(self, prefix):
     """Generate code for before a removal from U."""
     incaggr = self.incaggr
     assert incaggr.has_demand
     spec = incaggr.spec
     mv_var = prefix + 'val'
     
     # If we're using half-demand, there's no demand to propagate
     # to the operand. All we need to do is determine whether to
     # do the removal by checking whether the count is 0.
     if incaggr.half_demand:
         return L.pc('''
             S_MV = A.smlookup(MASK, KEY)
             if COUNT == 0:
                 A.smdelkey(MASK, KEY, PREFIX)
             ''', subst={'S_MV': L.sn(mv_var),
                         'COUNT': self.mapval_proj_count(L.ln(mv_var)),
                         'A': incaggr.name,
                         'MASK': incaggr.aggrmask.make_node(),
                         'KEY': L.tuplify(incaggr.params),
                         'PREFIX': L.Str(prefix)})
     
     # Generate operand undemand function call, if operand
     # uses demand.
     if spec.has_oper_demand:
         undemfunc = L.N.undemfunc(spec.oper_demname)
         call_undemfunc = L.Call(L.ln(undemfunc),
                                 tuple(L.ln(v) for v in spec.oper_demparams),
                                 (), None, None)
         propagate_code = (L.Expr(call_undemfunc),)
     else:
         propagate_code = ()  
     
     code = L.pc('''
         PROPAGATE_DEMAND
         A.smdelkey(MASK, KEY, PREFIX)
         ''', subst={'A': incaggr.name,
                     'MASK': incaggr.aggrmask.make_node(),
                     'KEY': L.tuplify(incaggr.params),
                     'PREFIX': L.Str(prefix),
                     '<c>PROPAGATE_DEMAND': propagate_code})
     
     return code
Ejemplo n.º 39
0
 def get_code(self, bindenv, body):
     # Just stick a DemQuery node in before the regular code.
     # TODO: This is a little ugly in that it results in
     # littering the code with "None"s. Maybe make a special
     # case in the translation of DemQuery to avoid this.
     code = self.cl.get_code(bindenv, body)
     new_node = L.Expr(value=L.DemQuery(
         self.demname, tuple(L.ln(p) for p in self.demparams), None))
     code = (new_node, ) + code
     return code
Ejemplo n.º 40
0
def make_auxmap_maint_code(manager, spec, elem, addremove):
    """Construct auxmap maintenance code for a set update."""
    assert addremove in ['add', 'remove']
    
    prefix = manager.namegen.next_prefix()
    mask = spec.mask
    
    # Create fresh variables for the tuple components.
    vars = [prefix + str(i) for i in range(1, len(mask) + 1)]
    bvars, uvars, eqs = mask.split_vars(vars)
    
    vars_node = L.tuplify(vars, lval=True)
    map_node = L.ln(spec.map_name)
    bvars_node = L.tuplify(bvars)
    uvars_node = L.tuplify(uvars)
    
    # If there are equalities, include a conditional check for the
    # constraints being satisfied. If there are wildcards, make the
    # image set manipulation operations reference-counted.
    #
    # Avoid these in cases where we don't have equalities/wildcards,
    # to help reduce constant-factor bloat in code size and running
    # time.
    
    if mask.has_equalities:
        template = '''
            VARS = ELEM
            if EQCOND:
                MAP.IMGOP(BVARS, UVARS)
            '''
        eqcond = make_vareq_cond(eqs)
    else:
        template = '''
            VARS = ELEM
            MAP.IMGOP(BVARS, UVARS)
            '''
        eqcond = None
    
    if mask.has_wildcards:
        imgop = {'add': 'rcimgadd',
                 'remove': 'rcimgremove'}[addremove]
    else:
        imgop = {'add': 'imgadd',
                 'remove': 'imgremove'}[addremove]
    
    code = L.pc(template, subst={
        '@IMGOP': imgop,
        'VARS': vars_node,
        'ELEM': elem,
        'MAP': map_node,
        'BVARS': bvars_node,
        'UVARS': uvars_node,
        'EQCOND': eqcond})
    
    return code
Ejemplo n.º 41
0
    def test(self):
        cl = DemClause(EnumClause(['x', 'y'], 'R'), 'f', ['x'])

        # AST round-trip.
        clast = cl.to_AST()
        exp_clast = \
            L.Enumerator(L.tuplify(['x', 'y'], lval=True),
                         L.DemQuery('f', (L.ln('x'),), L.ln('R')))
        self.assertEqual(clast, exp_clast)
        cl2 = DemClause.from_AST(exp_clast, DemClauseFactory)
        self.assertEqual(cl2, cl)

        # Attributes.
        self.assertEqual(cl.pat_mask, (True, True))
        self.assertEqual(cl.enumvars_tagsin, ('x', ))
        self.assertEqual(cl.enumvars_tagsout, ('y', ))

        # Rewriting.
        cl2 = cl.rewrite_subst({'x': 'z'}, DemClauseFactory)
        exp_cl = DemClause(EnumClause(['z', 'y'], 'R'), 'f', ['z'])
        self.assertEqual(cl2, exp_cl)

        # Fancy rewriting, uses LookupClause.
        cl2 = DemClause(LookupClause(['x', 'y'], 'R'), 'f', ['x'])
        cl2 = cl2.rewrite_subst({'x': 'z'}, DemClauseFactory)
        exp_cl = DemClause(LookupClause(['z', 'y'], 'R'), 'f', ['z'])
        self.assertEqual(cl2, exp_cl)

        # Rating.
        rate = cl.rate(['x'])
        self.assertEqual(rate, Rate.NORMAL)
        rate = cl.rate([])
        self.assertEqual(rate, Rate.UNRUNNABLE)

        # Code generation.
        code = cl.get_code(['x'], L.pc('pass'))
        exp_code = L.pc('''
            DEMQUERY(f, [x], None)
            for y in setmatch(R, 'bu', x):
                pass
            ''')
        self.assertEqual(code, exp_code)
Ejemplo n.º 42
0
 def test(self):
     cl = DemClause(EnumClause(['x', 'y'], 'R'), 'f', ['x'])
     
     # AST round-trip.
     clast = cl.to_AST()
     exp_clast = \
         L.Enumerator(L.tuplify(['x', 'y'], lval=True),
                      L.DemQuery('f', (L.ln('x'),), L.ln('R')))
     self.assertEqual(clast, exp_clast)
     cl2 = DemClause.from_AST(exp_clast, DemClauseFactory)
     self.assertEqual(cl2, cl)
     
     # Attributes.
     self.assertEqual(cl.pat_mask, (True, True))
     self.assertEqual(cl.enumvars_tagsin, ('x',))
     self.assertEqual(cl.enumvars_tagsout, ('y',))
     
     # Rewriting.
     cl2 = cl.rewrite_subst({'x': 'z'}, DemClauseFactory)
     exp_cl = DemClause(EnumClause(['z', 'y'], 'R'), 'f', ['z'])
     self.assertEqual(cl2, exp_cl)
     
     # Fancy rewriting, uses LookupClause.
     cl2 = DemClause(LookupClause(['x', 'y'], 'R'), 'f', ['x'])
     cl2 = cl2.rewrite_subst({'x': 'z'}, DemClauseFactory)
     exp_cl = DemClause(LookupClause(['z', 'y'], 'R'), 'f', ['z'])
     self.assertEqual(cl2, exp_cl)
     
     # Rating.
     rate = cl.rate(['x'])
     self.assertEqual(rate, Rate.NORMAL)
     rate = cl.rate([])
     self.assertEqual(rate, Rate.UNRUNNABLE)
     
     # Code generation.
     code = cl.get_code(['x'], L.pc('pass'))
     exp_code = L.pc('''
         DEMQUERY(f, [x], None)
         for y in setmatch(R, 'bu', x):
             pass
         ''')
     self.assertEqual(code, exp_code)
Ejemplo n.º 43
0
 def visit_Tuple(self, node):
     # No need to recurse, that's taken care of by the caller
     # of this visitor.
     tupvar = self.tupvar_namer.next()
     arity = len(node.elts)
     trel = make_trel(arity)
     elts = (L.sn(tupvar), ) + node.elts
     new_cl = L.Enumerator(L.tuplify(elts, lval=True), L.ln(trel))
     self.new_clauses.append(new_cl)
     self.trels.add(trel)
     return L.sn(tupvar)
Ejemplo n.º 44
0
    def visit_SetUpdate(self, node):
        if not (isinstance(node.target, L.Name)
                and node.target.id == self.rel):
            return

        fresh = next(self.namegen)
        tvar = '_t' + fresh
        ftvar = '_ft' + fresh
        code = make_flattup_code(self.tuptype, node.elem, L.sn(ftvar), tvar)
        update = node._replace(elem=L.ln(ftvar))
        return code + (update, )
Ejemplo n.º 45
0
    def visit_ClassDef(self, node):
        node = self.generic_visit(node)

        assert all(self.valid_baseclass(b) for b in node.bases), \
            'Illegal base class'
        objbase = L.ln('Set')
        if objbase not in node.bases:
            new_bases = node.bases + (objbase, )
            node = node._replace(bases=new_bases)

        return node
Ejemplo n.º 46
0
 def visit_ClassDef(self, node):
     node = self.generic_visit(node)
     
     assert all(self.valid_baseclass(b) for b in node.bases), \
         'Illegal base class'
     objbase = L.ln('Set')
     if objbase not in node.bases:
         new_bases = node.bases + (objbase,)
         node = node._replace(bases=new_bases)
     
     return node
Ejemplo n.º 47
0
 def visit_Assign(self, node):
     if not self.rewrite_fields:
         return node
     if not L.is_attrassign(node):
         return node
     cont, field, value = L.get_attrassign(node)
     return L.pc('''
         CONT.nsassignfield(FIELD, VALUE)
         ''', subst={'CONT': cont,
                     'FIELD': L.ln(field),
                     'VALUE': value})
Ejemplo n.º 48
0
 def visit_DelKey(self, node):
     node = self.generic_visit(node)
     
     if not self.use_mapset:
         return node
     
     code = L.pc('''
         MAPSET.remove((TARGET, KEY, TARGET[KEY]))
         ''', subst={'MAPSET': L.ln(make_maprel()),
                     'TARGET': node.target,
                     'KEY': node.key})
     return code
Ejemplo n.º 49
0
 def visit_SetUpdate(self, node):
     if not (isinstance(node.target, L.Name) and
             node.target.id == self.rel):
         return
     
     fresh = next(self.namegen)
     tvar = '_t' + fresh
     ftvar = '_ft' + fresh
     code = make_flattup_code(self.tuptype, node.elem,
                              L.sn(ftvar), tvar)
     update = node._replace(elem=L.ln(ftvar))
     return code + (update,)
Ejemplo n.º 50
0
 def visit_Tuple(self, node):
     # No need to recurse, that's taken care of by the caller
     # of this visitor.
     tupvar = self.tupvar_namer.next()
     arity = len(node.elts)
     trel = make_trel(arity)
     elts = (L.sn(tupvar),) + node.elts
     new_cl = L.Enumerator(L.tuplify(elts, lval=True),
                           L.ln(trel))
     self.new_clauses.append(new_cl)
     self.trels.add(trel)
     return L.sn(tupvar)