Ejemplo n.º 1
0
    def __init__(self, name, proto=None, weightsFile=None, filePath=None):
        Node.__init__(self, name, allowAddInput=True, allowAddOutput=True, allowRemove=True)

        solverFD = SolverProto.DESCRIPTOR

        self.param = LParameter.create(repeated=False, fieldDescriptor=solverFD, expanded=True)
        for param in self.param.children():
            param.setToDefault()
        t = ptree.ParameterTree()
        t.addParameters(self.param, depth=1, showTop=False)
        t.setMinimumHeight(325)
        t.setVerticalScrollBarPolicy(pg.QtCore.Qt.ScrollBarAlwaysOff)
        self.ui = t
        self.proto = SolverProto()
        self.trainLoss, self.testLoss, self.testAcc = None, None, None
        self.niter, self.testInterval = None, None
        if proto is not None:
            self.setProto(proto)
        # if a filePath is given, we use it regardless of the directory of the solver prototxt, and will write a
        # temporary file in that directory for training purposes. otherwise we give the path of the solver file,
        # and hope the train file is in that directory, or else crash!
        if filePath is None:
            if proto is not None:
                filePath = os.path.dirname(proto)
            else:
                filePath = os.curdir
        self.filePath = filePath
        self.weights = weightsFile
Ejemplo n.º 2
0
    def updateSpecificParam(self, layerType):
        layerSpec = self.proto
        try:
            specifcParamName = _param_names[layerType] + '_param'
            specificFieldDescriptor = layerSpec.DESCRIPTOR.fields_by_name[specifcParamName]
            child = LParameter.create(fieldDescriptor=specificFieldDescriptor, expanded=True)
            # add specific parameter type to top level param
            if self.specificParam is not None:
                self.param.removeChild(self.specificParam)
            self.specificParam = self.param.insertChild(4, child)
            self.specificParam.setToDefault()
        except KeyError:
            pass
        if self.baseParam is not None:
            self.param.removeChild(self.baseParam)

        if layerType == 'Data':
            additionalParamName = 'transform_param'
            self.addAdditionalParam(additionalParamName)

        if 'Loss' in layerType:
            additionalParamName = 'loss_param'
            self.addAdditionalParam(additionalParamName)

        self.nodeName = layerType
Ejemplo n.º 3
0
    def __init__(self, name):
        Node.__init__(self, name, allowAddInput=True, allowAddOutput=True, allowRemove=True)
        # self.paramList = self.getParamList()
        LayerFD = NetParameter.DESCRIPTOR.fields_by_name['layer']
        self.param = LParameter.create(repeated=False, fieldDescriptor=LayerFD, expanded=True)
        for param in self.param.children():
            param.setToDefault()
        t = ptree.ParameterTree()
        t.addParameters(self.param, depth=1, showTop=False)
        self.proto = LayerProto()
        self.specificParam = None
        self.baseParam = None
        self.sigRenamed.connect(self.nameChanged)
        self.sigTerminalRenamed.connect(self.updateBlobs)
        self.sigTerminalAdded.connect(self.updateBlobs)
        self.sigTerminalRemoved.connect(self.updateBlobs)
        t.setMinimumHeight(325)
        t.setVerticalScrollBarPolicy(pg.QtCore.Qt.ScrollBarAlwaysOff)
        self.ui = t

        self.bottoms = self.param.child('bottom')
        self.tops = self.param.child('top')
        # self.bottoms.sigChildAdded.connect(self.updateTerminals)
        # self.tops.sigChildAdded.connect(self.updateTerminals)

        if self.nodeName != "Layer":
            proto = LayerProto()
            proto.type = self.nodeName
            self.configFromLayerSpec(proto)
            # self.bottoms.addNew()
            # self.tops.addNew()
            # self.rename(str(self.nodeName).lower())
            nodeName = str(self.nodeName).lower()
            with signalsBlocked(self.param):
                self.bottoms.addNew("bottom1")
                self.tops.addNew(nodeName.lower())
            self.updateTerminals()
Ejemplo n.º 4
0
 def addAdditionalParam(self, name):
     additionalFieldDescriptor = self.proto.DESCRIPTOR.fields_by_name[name]
     child = LParameter.create(fieldDescriptor=additionalFieldDescriptor)
     self.baseParam = self.param.insertChild(5, child)
     self.baseParam.setToDefault()