예제 #1
0
    def load_model(self, model, modelname):
        preTrainDict = torch.load(modelname)
        model_dict = model.state_dict()
        # print 'preTrainDict:',preTrainDict.keys()
        # print 'modelDict:',model_dict.keys()
        preTrainDictTemp = {k:v for k,v in preTrainDict.items() if k in model_dict}

        if( 0 == len(preTrainDictTemp) ):
            self.logger.info("Does not find any module to load. Try DataParallel version.")
            for k, v in preTrainDict.items():
                kk = k[7:]

                if ( kk in model_dict ):
                    preTrainDictTemp[kk] = v

            preTrainDict = preTrainDictTemp

        if ( 0 == len(preTrainDict) ):
            raise WorkFlow.WFException("Could not load model from %s." % (modelname), "load_model")

        # for item in preTrainDict:
        #     self.logger.info("Load pretrained layer:{}".format(item) )
        model_dict.update(preTrainDict)
        model.load_state_dict(model_dict)
        return model
예제 #2
0
    def load_modules(self, fn):
        if ( False == self.isInitialized ):
            raise WorkFlow.WFException("Cannot load module before initialization.", "load_modules")

        self.csn.load_state_dict( torch.load( fn ) )