コード例 #1
0
 def RegisteBackward(self):
     if self.OutputNeedGrad():
         if DEBUG():
             print(type(self).__name__ + self.DEBUGID + "->RegisteBackward")
         GOperatorManager.EnqueBackward(self)
         if DEBUG():
             print("Num of Backward: " +
                   str(len(GOperatorManager.BackwardOperators)))
コード例 #2
0
def ClearDataNode():
    if DEBUG():
        print("=======Clear DataNode!========")
        print("Num Of Data Befor Clear: " +
              str(len(GDataNodeManager.DataNodes)))
    GDataNodeManager.Clear()
    if DEBUG():
        print("Num Of Data After Clear: " +
              str(len(GDataNodeManager.DataNodes)))
コード例 #3
0
def ClearOperator():
    if DEBUG():
        print("=======Clear Operator!========")
        print("Num Of Data Befor Clear: " +
              str(len(GOperatorManager.ForwardOperators)) + ":" +
              str(len(GOperatorManager.BackwardOperators)))
    GOperatorManager.Clear()
    if DEBUG():
        print("Num Of Data Befor Clear: " +
              str(len(GOperatorManager.ForwardOperators)) + ":" +
              str(len(GOperatorManager.BackwardOperators)))
コード例 #4
0
 def Backward(self):
     self.BackwardOperators[-1].Output.OnesGrad()
     if DEBUG():
         print("Last Output Grad Oneing finished!")
     self.BackwardOperators.reverse()
     for Operator in self.BackwardOperators:
         Operator.Backward()
コード例 #5
0
 def __call__(self, TensorInput1, TensorInput2):
     self.Inputs.append(TensorInput1)
     self.Inputs.append(TensorInput2)
     self.Output = DataNode()
     if DEBUG():
         print(type(self).__name__ + self.DEBUGID + "->Builded")
     return self.Output
コード例 #6
0
 def RegisteForward(self):
     GOperatorManager.EnqueForward(self)
     if DEBUG():
         print("-----------------------")
         print(type(self).__name__ + self.DEBUGID + "->RegisterForward")
         print("Num of Backward: " +
               str(len(GOperatorManager.ForwardOperators)))
     return self
コード例 #7
0
 def DEBUGself(self):
     if DEBUG():
         if self.CanBeClear==True:
             if self.NeedGrad==True:
                 print("==>IntemediateGrad being Created!")
             else:
                 print("==>InteMediateConstant being Created!")
         else:
             if self.NeedGrad==True:
                 print("==>Paremeter being Created!")
             else:
                 print("==>Input being Created")
コード例 #8
0
 def Backward(self):
     if DEBUG():
         print(type(self).__name__ + self.DEBUGID + "->Backward Calculate")
     for DataNode in self.Inputs:
         if not DataNode.NeedGrad:
             continue
         else:
             DownStreamGrad = self.Output.Grad
             if type(DataNode.Grad) == type(None):
                 DataNode.AddGrad()
             DataNode.Grad = DataNode.Grad + self.LocalGrad(
                 DataNode, DownStreamGrad)
コード例 #9
0
 def OutputNeedGrad(self):
     NeedGrad = False
     for DataNode in self.Inputs:
         if DataNode.NeedGrad == True:
             NeedGrad = True
             self.Output.NeedGrad = True
             break
     if DEBUG():
         if NeedGrad:
             print(type(self).__name__ + self.DEBUGID + "->Need Backward!")
         else:
             print(
                 type(self).__name__ + self.DEBUGID +
                 "->Do not Need Backward")
     return NeedGrad
コード例 #10
0
 def Registe(self,DataNode):
     self.DataNodes.append(DataNode)
     if DEBUG():
         print("##############")
         print("#####Add One DataNode to GLOBAL!")
         print("#####Num of DataNode: "+str(len(self.DataNodes)))
コード例 #11
0
 def Forward(self):
     if DEBUG():
         print("----------------------")
         print(type(self).__name__ + self.DEBUGID + "->Forward Calculate")
     self.Calculate()
     self.RegisteBackward()
コード例 #12
0
def Backward():
    if DEBUG():
        print("=======Backward Begining!=======")
    GOperatorManager.Backward()
コード例 #13
0
def Forward():
    if DEBUG():
        print("=======Forward Begining!=======")
    GOperatorManager.Forward()