def __init__(self): self.symTable = {} self.iterTable = {} self.constTable = {} self.linkTable = {} self.gradientTable = {} self.dfg = DataFlowGraph() self.funcTypeTable = {} self.parseTree = None
def fromList(self, l): dfg = DataFlowGraph() for nodeDict in l: node = DFGNode() node.fromDict(nodeDict) dfg.nodes.insert(node.id, node) for node in dfg.nodes: children = node.children for index, childId in enumerate(children): children[index] = dfg.get(childId) parents = node.parents for index, parentId in enumerate(parents): parents[index] = dfg.get(parentId) return dfg
def create(self, tree): self.parseTree = tree self.dfg = DataFlowGraph() #self.funcTypeTable = {} # Create source and sink nodes first. source = DFGNode() source.operation = 'source' self.dfg.add(source) sink = DFGNode() sink.operation = 'sink' sink.dist2sink = 0 self.dfg.add(sink) # Create hash tables for Data_Declarations # print("******Before DFG******") self.constTable = self.createConstTable() self.iterTable = self.createIterTable(self.constTable) self.symTable, self.gradientTable = self.createSymbolTable( self.constTable) self.funcTypeTable = self.createFuncTypeTable() # print('const table',self.constTable) # print('iter table',self.iterTable) # print('symTable',self.symTable) # print('funcTypeTable',self.funcTypeTable) # print('gradientTable', self.gradientTable) # print("================================\n\n") # Get statList for all Stats statList = self.parseTree.getChild(1) # Creation of DFG. if statList.children is not None: for stat in statList.children: resultNodes = self.statTraversal(stat) # Append SGD for g in self.linkTable: leftBound = g.find("[") + 1 rightBound = g.rfind("]") iterList = g[leftBound:rightBound].split('][') for i in range(len(iterList)): if not iterList[i].isdigit(): iterList[i] = self.constTable[iterList[i]] if len(iterList) is 0: mult = DFGNode() mult.operation = "*" #mult.dataType = 'gradient' mult.dataType = None self.dfg.add(mult) self.connectNode(self.symTable[g], mult) self.connectNode(symTable["mu"], mult) sub = DFGNode() sub.operation = "-" sub.dataType = 'model' self.dfg.add(sub) self.connectNode(mult, sub) self.connectNode(self.symTable[self.linkTable[g]], sub) self.symTable[self.symTable[self.linkTable[g]]] = sub else: for i in range(iterList[0]): if len(iterList) is 1: gSym = g[0:g.find('[')] + '[' + str(i) + ']' #print(gSym) if gSym in self.symTable: mult = DFGNode() mult.operation = "*" #mult.dataType = 'gradient' mult.dataType = None self.dfg.add(mult) self.connectNode(self.symTable[gSym], mult) self.connectNode(self.symTable["mu"], mult) sub = DFGNode() sub.operation = "-" sub.dataType = 'model' self.dfg.add(sub) wSym = self.linkTable[g] + '[' + str(i) + ']' self.connectNode(mult, sub) self.connectNode(self.symTable[wSym], sub) self.symTable[wSym] = sub else: for j in range(iterList[1]): gSym = g[0:g.find('[')] + '[' + str( i) + '][' + str(j) + ']' if gSym in self.symTable: mult = DFGNode() mult.operation = "*" #mult.dataType = 'gradient' mult.dataType = None self.dfg.add(mult) self.connectNode(self.symTable[gSym], mult) self.connectNode(self.symTable["mu"], mult) sub = DFGNode() sub.operation = "-" sub.dataType = 'model' self.dfg.add(sub) wSym = self.linkTable[g] + '[' + str( i) + '][' + str(j) + ']' self.connectNode(mult, sub) self.connectNode(self.symTable[wSym], sub) self.symTable[wSym] = sub # Needs to connect correct outputs to the sink node for node in self.symTable.values(): # Connect outputs if len(node.children) is 0 and len(node.parents) is not 1: self.connectNode(node, self.dfg.get(1)) # Calculates all the distances to sink self.setDist2sink(sink) # Remove useless nodes # self.dfg.updateId() removedNodes = [] for node in self.dfg.nodes: if node.dist2sink is None: for child in node.children: child.parents.remove(node) for parent in node.parents: parent.children.remove(node) removedNodes.append(node) for node in removedNodes: self.dfg.remove(node) # Print and save self.dfg.updateId() # for val in self.dfg.nodes: # print(val) # print("******After DFG******") # print('const table',self.constTable) # print('iter table',self.iterTable) # print('symTable',self.symTable) # self.dfg.save('./dfg.json') return self.dfg