Esempio n. 1
0
 def __init__(self):
     self.symTable = {}
     self.iterTable = {}
     self.constTable = {}
     self.linkTable = {}
     self.gradientTable = {}
     self.dfg = DataFlowGraph()
     self.funcTypeTable = {}
     self.parseTree = None
Esempio n. 2
0
    def fromList(self, l):
        dfg = DataFlowGraph()
        for nodeDict in l:
            node = DFGNode()
            node.fromDict(nodeDict)
            dfg.nodes.insert(node.id, node)

        for node in dfg.nodes:
            children = node.children
            for index, childId in enumerate(children):
                children[index] = dfg.get(childId)

            parents = node.parents
            for index, parentId in enumerate(parents):
                parents[index] = dfg.get(parentId)

        return dfg
Esempio n. 3
0
    def create(self, tree):
        self.parseTree = tree
        self.dfg = DataFlowGraph()
        #self.funcTypeTable = {}

        # Create source and sink nodes first.
        source = DFGNode()
        source.operation = 'source'
        self.dfg.add(source)

        sink = DFGNode()
        sink.operation = 'sink'
        sink.dist2sink = 0
        self.dfg.add(sink)

        # Create hash tables for Data_Declarations
        # print("******Before DFG******")
        self.constTable = self.createConstTable()
        self.iterTable = self.createIterTable(self.constTable)
        self.symTable, self.gradientTable = self.createSymbolTable(
            self.constTable)
        self.funcTypeTable = self.createFuncTypeTable()
        # print('const table',self.constTable)
        # print('iter table',self.iterTable)
        # print('symTable',self.symTable)
        # print('funcTypeTable',self.funcTypeTable)
        # print('gradientTable', self.gradientTable)
        # print("================================\n\n")

        # Get statList for all Stats
        statList = self.parseTree.getChild(1)

        # Creation of DFG.
        if statList.children is not None:
            for stat in statList.children:
                resultNodes = self.statTraversal(stat)

        # Append SGD
        for g in self.linkTable:
            leftBound = g.find("[") + 1
            rightBound = g.rfind("]")
            iterList = g[leftBound:rightBound].split('][')
            for i in range(len(iterList)):
                if not iterList[i].isdigit():
                    iterList[i] = self.constTable[iterList[i]]
            if len(iterList) is 0:
                mult = DFGNode()
                mult.operation = "*"
                #mult.dataType = 'gradient'
                mult.dataType = None
                self.dfg.add(mult)
                self.connectNode(self.symTable[g], mult)
                self.connectNode(symTable["mu"], mult)
                sub = DFGNode()
                sub.operation = "-"
                sub.dataType = 'model'
                self.dfg.add(sub)
                self.connectNode(mult, sub)
                self.connectNode(self.symTable[self.linkTable[g]], sub)
                self.symTable[self.symTable[self.linkTable[g]]] = sub
            else:
                for i in range(iterList[0]):
                    if len(iterList) is 1:
                        gSym = g[0:g.find('[')] + '[' + str(i) + ']'
                        #print(gSym)
                        if gSym in self.symTable:
                            mult = DFGNode()
                            mult.operation = "*"
                            #mult.dataType = 'gradient'
                            mult.dataType = None
                            self.dfg.add(mult)
                            self.connectNode(self.symTable[gSym], mult)
                            self.connectNode(self.symTable["mu"], mult)
                            sub = DFGNode()
                            sub.operation = "-"
                            sub.dataType = 'model'
                            self.dfg.add(sub)
                            wSym = self.linkTable[g] + '[' + str(i) + ']'
                            self.connectNode(mult, sub)
                            self.connectNode(self.symTable[wSym], sub)
                            self.symTable[wSym] = sub
                    else:
                        for j in range(iterList[1]):
                            gSym = g[0:g.find('[')] + '[' + str(
                                i) + '][' + str(j) + ']'
                            if gSym in self.symTable:
                                mult = DFGNode()
                                mult.operation = "*"
                                #mult.dataType = 'gradient'
                                mult.dataType = None
                                self.dfg.add(mult)
                                self.connectNode(self.symTable[gSym], mult)
                                self.connectNode(self.symTable["mu"], mult)
                                sub = DFGNode()
                                sub.operation = "-"
                                sub.dataType = 'model'
                                self.dfg.add(sub)
                                wSym = self.linkTable[g] + '[' + str(
                                    i) + '][' + str(j) + ']'
                                self.connectNode(mult, sub)
                                self.connectNode(self.symTable[wSym], sub)
                                self.symTable[wSym] = sub

        # Needs to connect correct outputs to the sink node
        for node in self.symTable.values():  # Connect outputs
            if len(node.children) is 0 and len(node.parents) is not 1:
                self.connectNode(node, self.dfg.get(1))

        # Calculates all the distances to sink
        self.setDist2sink(sink)

        # Remove useless nodes
        # self.dfg.updateId()
        removedNodes = []
        for node in self.dfg.nodes:
            if node.dist2sink is None:
                for child in node.children:
                    child.parents.remove(node)
                for parent in node.parents:
                    parent.children.remove(node)
                removedNodes.append(node)
        for node in removedNodes:
            self.dfg.remove(node)

        # Print and save
        self.dfg.updateId()
        # for val in self.dfg.nodes:
        # print(val)
        # print("******After DFG******")
        # print('const table',self.constTable)
        # print('iter table',self.iterTable)
        # print('symTable',self.symTable)

        # self.dfg.save('./dfg.json')

        return self.dfg