def serializeUniformNode(compiler, translator, self, holdingSlot, refs, root): check = symbols.SymbolRewriter(compiler.extractor, typeCheckTemplate) types = sorted(set([ref.xtype.obj.pythonType() for ref in refs])) # The tree transform should guarantee there's only one object per type typeLUT = {} for ref in refs: t = ref.xtype.obj.pythonType() assert t not in typeLUT typeLUT[t] = ref assert len(types) > 0 if len(types) == 1: t = types[0] return handleUniformType(compiler, translator, self, holdingSlot, typeLUT[t], root, t) else: switches = [] for t in types: cond = check.rewrite(root=root, type=t) body = handleUniformType(compiler, translator, self, holdingSlot, typeLUT[t], root, t) switches.append((cond, ast.Suite(body))) current = ast.Suite([ast.Assert(ast.Existing(compiler.extractor.getObject(False)), None)]) for cond, suite in reversed(switches): current = ast.Switch(ast.Condition(ast.Suite([]), cond), suite, ast.Suite([current])) return [current]
def visitSwitch(self, node): t = self.getSwitchExit(node, 'true') f = self.getSwitchExit(node, 'false', ignoreRegion=True) if node.region is not t.region or node.region is not f.region: return ok, exit = self.getCommonExit(t, f) if not ok: return ok, error = self.getError(t, f) if not ok: return condition = ast.Condition(ast.Suite([]), node.condition) switch = ast.Switch(condition, ast.Suite(t.ops), ast.Suite(f.ops)) result = graph.Suite(node.region) result.ops.append(switch) # Reconnect the graph node.redirectEntries(result) t.destroy() f.destroy() if exit: result.setExit('normal', exit) if isinstance(exit, graph.Merge): exit.simplify() if error: result.setExit('error', error) self.simplifySuite(result)
def bindUniforms(compiler, translator, uniformSlot): code = symbols.SymbolRewriter(compiler.extractor, uniformCodeTemplate) self = ast.Local('self') shader = ast.Local('shader') if uniformSlot.annotation.references: uniformRefs = uniformSlot.annotation.references.merged body = ast.Suite(serializeUniformNode(compiler, translator, self, uniformSlot, uniformRefs, shader)) else: # No uniforms are used. body = ast.Suite([]) return code.rewrite(args=[self, shader], body=body)
def evaluate(compiler, g): def forward(node): return node.normalForward() def bind(node, djnode): node.data = djnode dom.evaluate([g.entryTerminal], forward, bind) djroot = g.entryTerminal.data findLoops(djroot) s = Search() s.process(g.entryTerminal) order = s.order c = Compactor(compiler, g, s.loops) for node in order: c(node) entry = g.entryTerminal body = entry.getExit('entry') assert body.getExit('normal') is None, "could not reduce?" g.code.ast = ast.Suite(body.ops)
def bindStreams(compiler, translator, context): code = symbols.SymbolRewriter(compiler.extractor, streamCodeTemplate) bind = symbols.SymbolRewriter(compiler.extractor, streamBindTemplate) originalParams = context.originalParams currentParams = context.code.codeparameters self = ast.Local('self') streams = [] statements = [] for original in originalParams.params[2:]: root = ast.Local(original.name) streams.append(root) ioname = context.shaderdesc.fields[original] if ioname in translator.liveInputs: refs = original.annotation.references.merged assert len(refs) == 1 obj = refs[0] assert intrinsics.isIntrinsicObject(obj) t = obj.xtype.obj.pythonType() attr = "bind_stream_" + t.__name__ structInfo = translator.ioRefInfo.get(ioname) shaderName = structInfo.lut.subpools[t].name statements.append(bind.rewrite(self=self, attr=attr, shaderName=shaderName, name=root)) # for original, current in zip(originalParams.params, currentParams.params)[2:]: # root = ast.Local(original.name) # streams.append(root) # # if current.isDoNotCare(): continue # # refs = current.annotation.references.merged # assert len(refs) == 1 # obj = refs[0] # assert intrinsics.isIntrinsicObject(obj) # t = obj.xtype.obj.pythonType() # attr = "bind_stream_" + t.__name__ # # structInfo = translator.serializationInfo(current) # assert structInfo is not None # # shaderName = structInfo.intrinsics[t].decl.name # # statements.append(bind.rewrite(self=self, attr=attr, shaderName=shaderName, name=root)) body = ast.Suite(statements) args = [self] args.extend(streams) names = [arg.name for arg in args] return code.rewrite(args=args, argnames=names, body=body)
def __init__(self, compiler, graph, opPathLength, clone): self.decompileTime = 0 self.console = compiler.console self.extractor = compiler.extractor self.clone = clone # Should we copy the code before annotating it? # Has the context been constructed? self.liveContexts = set() self.liveCode = set() # Constraint information, for debugging self.constraints = [] # The worklist self.dirty = collections.deque() self.canonical = graph.canonical self._canonicalContext = util.canonical.CanonicalCache( base.AnalysisContext) # Controls how many previous ops are remembered by a context. # TODO remember prior CPA signatures? self.opPathLength = opPathLength self.cache = {} # Information for contextual operations. self.opAllocates = collections.defaultdict(set) self.opReads = collections.defaultdict(set) self.opModifies = collections.defaultdict(set) self.opInvokes = collections.defaultdict(set) self.codeContexts = collections.defaultdict(set) self.storeGraph = graph # Setup the "external" context, used for creaing bogus slots. self.externalOp = util.canonical.Sentinel('<externalOp>') self.externalFunction = ast.Code( 'external', ast.CodeParameters(None, [], [], [], None, None, [ast.Local('internal_return')]), ast.Suite([])) externalSignature = self._signature(self.externalFunction, None, ()) opPath = self.initialOpPath() self.externalFunctionContext = self._canonicalContext( externalSignature, opPath, self.storeGraph) self.codeContexts[self.externalFunction].add( self.externalFunctionContext) # For vargs self.tupleClass = self.extractor.getObject(tuple) self.ensureLoaded(self.tupleClass) # For kargs self.dictionaryClass = self.extractor.getObject(dict) self.ensureLoaded(self.dictionaryClass) self.entryPointOp = {}
def methodCallToTypeSwitch(self, node, arg, pos, targets): # TODO if mutable types are allowed, we should be looking at the LowLevel type slot? # TODO localy rebuild read/modify/allocate information using filtered invokes. # TODO should the return value be SSAed? This might interfere with nessled type switches. # If so, retarget the return value and fix up return types groups = self.groupTypes(node, arg, pos) if groups is None or len(groups) <= 1: return None # Don't create trivial type switches cases = [] for group in groups: # Create a filtered version of the argument. name = arg.name if isinstance(arg, ast.Local) else None expr = ast.Local(name) expr.annotation = self.filterReferenceAnnotationByType( arg.annotation, group) # Create the new op opannotation = node.annotation # Kill contexts where the filtered expression has no references. # (In these contexts, the new op will never be evaluated.) mask = self.makeRemapMask(expr.annotation) if -1 in mask: opannotation = opannotation.contextSubset(mask) # Filter out invocations that don't have the right type for the given parameter. opannotation = self.filterOpAnnotationByType( opannotation, group, pos) # Rewrite the op to use expr instead of the original arg. newop = rewrite.rewriteTerm(node, {arg: expr}) assert newop is not node newop.annotation = opannotation # Try to reduce it to a direct call newop = self(newop) # Create the suite for this case stmts = [] if targets is None: stmts.append(ast.Discard(newop)) else: # HACK should SSA it? stmts.append(ast.Assign(newop, list(targets))) suite = ast.Suite(stmts) case = ast.TypeSwitchCase([self.existingFromObj(t) for t in group], expr, suite) cases.append(case) ts = ast.TypeSwitch(arg, cases) return ts
def visitTryExceptFinally(self, node): if node.finally_: self.handlers.breaks.new() self.handlers.continues.new() self.handlers.returns.new() self.handlers.raises.new() r, m, fr, fm = self.annotate[node] body = self(node.body) normal = self.locals merges = [] # Raise / may raise handlers = [] for handler in node.handlers: self.locals = ef = LocalFrame(ExceptionMerge(self.exceptLocals)) h = self(handler) handlers.append(h) merges.append((TailInserter(h.body), self.locals)) self.locals = ef = LocalFrame(ExceptionMerge(self.exceptLocals)) default = self(node.defaultHandler) merges.append((TailInserter(default), self.locals)) # Normal self.locals = normal else_ = self(node.else_) if not else_: else_ = ast.Suite([]) merges.append((TailInserter(else_), self.locals)) self.locals = mergeFrames(merges) if node.finally_: breaks = self.handlers.breaks.pop() continues = self.handlers.continues.pop() returns = self.handlers.returns.pop() raises = self.handlers.raises.pop() # All paths finally_ = self(node.finally_) return node.reconstruct(body, handlers, default, else_, finally_)
def visitUnpackSequence(self, node): expr = self(node.expr) targets = [self.locals.writeLocal(target) for target in node.targets] out = node.reconstruct(expr, targets) if self.hasExceptionHandling: out = ast.Suite([out]) for oldtgt, newtgt in zip(node.targets, targets): asgn = ast.Assign(newtgt, self.exceptLocal(oldtgt)) asgn.markMerge() out.append(asgn) return out
def visitAssign(self, node): assert self.locals if any([len(self.localuses[lcl]) > 0 for lcl in node.lcls]): expr = self(node.expr) assert self.locals, node.expr if isinstance(node.expr, ast.Local): # Assign local to local. Nullop for SSA. assert len(node.lcls) == 1 expr = self.reach(expr) self.locals.redefineLocal(node.lcls[0], expr) # Create a merge for exception handling. if self.hasExceptionHandling: el = self.exceptLocal(node.lcls[0]) easgn = ast.Assign(expr, el) easgn.markMerge() return easgn else: return None else: renames = [self.locals.writeLocal(lcl) for lcl in node.lcls] for rename in renames: self.defns[rename] = expr asgn = ast.Assign(expr, renames) if self.hasExceptionHandling: # Create a merge for exception handling. output = [asgn] for lcl, rename in zip(node.lcls, renames): el = self.exceptLocal(lcl) easgn = ast.Assign(rename, el) easgn.markMerge() output.append(easgn) asgn = ast.Suite(output) return asgn elif not node.expr.isPure(): return ast.Discard(self(node.expr)) else: return None
def postProcessCode(self, code): self.rewrites = {} self.remap = {} self.fields = {} self.header = [] self.processParameters(code) for name, group in self.groups.iteritems(): self.processGroup(code, name, group) rewrite.rewrite(self.compiler, code, self.rewrites) code.ast = ast.Suite([ast.InputBlock(self.header), code.ast]) if self.ssaBroken: ssatransform.evaluateCode(self.compiler, code) simplify.evaluateCode(self.compiler, self.prgm, code)
def visitTypeSwitch(self, node): exits = [ self.getSwitchExit(node, i) for i in range(len(node.original.cases)) ] for e in exits: assert node.region is e.region ok, exit = self.getCommonExit(*exits) assert ok ok, error = self.getError(*exits) assert ok cases = [ ast.TypeSwitchCase(case.types, case.expr, ast.Suite(e.ops)) for case, e in zip(node.original.cases, exits) ]
def enterExcept(self, r): self.exceptLevel += 1 rn = {} if self.exceptLevel > 1: rn.update(self.exceptRename[-1]) merges = ast.Suite([]) for lcl in r: if not lcl in rn: old = self(lcl) merge = lcl.clone() rn[lcl] = merge asgn = ast.Assign(old, merge) asgn.markMerge() merges.append(asgn) self.exceptRename.append(rn) return merges
def visitMerge(self, node): if node not in self.loops: node.simplify() else: assert node.numPrev() == 1 preamble = node.getExit('normal') if isinstance(preamble, graph.Switch): assert False else: assert isinstance(preamble, graph.Suite) switch = preamble.getExit('normal') # print # print preamble.ops # print switch if switch is None: # Degenerate loop body = preamble preamble = graph.Suite(body.region) switch = graph.Switch( body.region, ast.Existing(self.compiler.extractor.getObject(True))) else_ = graph.Suite(body.region) else: assert isinstance(switch, graph.Switch) body = self.getSwitchExit(switch, 'true') else_ = self.getSwitchExit(switch, 'false', ignoreRegion=True) switch.killExit('true') switch.killExit('false') # print # print "pre", preamble.ops # print "cond", switch.condition # print "body", body.ops # print "else", else_.ops if node in self.breaks: b = self.breaks[node] ee = else_.getExit('normal') assert ee is None or ee is b else_.killExit('normal') else: b = else_ else_ = graph.Suite(else_.region) # print # print "pre", preamble.ops # print "cond", switch.condition # print "body", body.ops # print "else", else_.ops # print "break", b.ops bodyast = ast.Suite(body.ops) bodyast = KillContinues()(bodyast) loop = ast.While( ast.Condition(ast.Suite(preamble.ops), switch.condition), bodyast, ast.Suite(else_.ops)) result = graph.Suite(node.region) result.ops.append(loop) node.killExit('normal') node.setExit('normal', result) result.setExit('normal', b) if isinstance(b, graph.Merge): b.simplify() node.simplify() #print list(result.forward()) self.simplifySuite(result) #print list(result.forward()) preamble.destroy() switch.destroy() body.destroy() else_.destroy()
def visitSuite(self, node): if len(node.blocks) > 1: return ast.Suite(node.blocks[:-1] + [self(node.blocks[-1])]) else: return node
[], None, None ) ) def bindUniform(compiler, self, name, t, value): bind = symbols.SymbolRewriter(compiler.extractor, uniformBindTemplate) methodName = "bind_uniform_" + t.__name__ return bind.rewrite(self=self, methodName=methodName, name=name, value=value) typeCheckTemplate = ast.Call( ast.GetGlobal(existingConstant('isinstance')), [symbols.Symbol('root'), existingSymbol('type')], [], None, None ) typeBindTemplate = ast.Suite([ ]) # self -> the serializing class # root -> the current uniform local def serializeUniformNode(compiler, translator, self, holdingSlot, refs, root): check = symbols.SymbolRewriter(compiler.extractor, typeCheckTemplate) types = sorted(set([ref.xtype.obj.pythonType() for ref in refs])) # The tree transform should guarantee there's only one object per type typeLUT = {} for ref in refs: t = ref.xtype.obj.pythonType() assert t not in typeLUT typeLUT[t] = ref
def makeLoopMerge(self, lcls, read, hot, modify): entrySuite = ast.Suite([]) # Holds the partial merges (entry -> condition) loopMerge = LoopMerge(HeadInserter(entrySuite), lcls, read, hot, modify) lcls = LocalFrame(loopMerge) return entrySuite, loopMerge, lcls