예제 #1
0
def main(args):
    # p_model_path = '/data/guoch_workspace/mtcnn-pytorch-master/model_store/pnet_epoch.pt'
    # r_model_path = '/data/guoch_workspace/mtcnn-pytorch-master/model_store/rnet_epoch.pt'
    # o_model_path = '/data/guoch_workspace/mtcnn-pytorch-master/model_store/onet_epoch.pt'
    # print("the version of torch is {}".format(torch.__version__))
    dummy_input = getInput(args.img_size)  #获得网络的输入
    # 加载模型
    model = PNet()
    #model = RNet()
    #model = ONet()
    model.load_state_dict(torch.load(args.model_path))
    #model_dict =  model.state_dict()
    #model_dict = pnet.load_state_dict(torch.load(p_model_path))
    # if args.model_path:
    #     if os.path.isfile(args.model_path):
    #         print(("=> start loading checkpoint '{}'".format(args.model_path)))
    #         # state_dict = torch.load(args.model_path)
    #         # print("the best acc is {} in epoch:{}".format(
    #         #     state_dict['epoch_acc'], state_dict['epoch']))
    #         # params = state_dict["model_state_dict"]
    #         # # params={k:v for k,v in state_dict.items() if k in  model_dict.keys()}
    #         # # model_dict.update(params)
    #         # # model.load_state_dict(model_dict)
    #         model.load_state_dict(args.model_path)
    #         print("load cls model successfully")
    #     else:
    #         print(("=> no checkpoint found at '{}'".format(args.model_path)))
    #         return
    model.to('cpu')
    model.eval()
    pre = model(dummy_input)
    print("the pre:{}".format(pre))
    #保存onnx模型
    torch2onnx(args, model, dummy_input)
예제 #2
0
def create_mtcnn_net(p_model_path=None,
                     r_model_path=None,
                     o_model_path=None,
                     use_cuda=True):

    pnet, rnet, onet = None, None, None

    if p_model_path is not None:
        pnet = PNet(use_cuda=use_cuda)
        if (use_cuda):
            print('p_model_path:{0}'.format(p_model_path))
            pnet.load_state_dict(torch.load(p_model_path))
            pnet.cuda()
        else:
            # forcing all GPU tensors to be in CPU while loading
            pnet.load_state_dict(
                torch.load(p_model_path,
                           map_location=lambda storage, loc: storage))
        pnet.eval()

    if r_model_path is not None:
        rnet = RNet(use_cuda=use_cuda)
        if (use_cuda):
            print('r_model_path:{0}'.format(r_model_path))
            rnet.load_state_dict(torch.load(r_model_path))
            rnet.cuda()
        else:
            rnet.load_state_dict(
                torch.load(r_model_path,
                           map_location=lambda storage, loc: storage))
        rnet.eval()

    if o_model_path is not None:
        onet = ONet(use_cuda=use_cuda)
        if (use_cuda):
            print('o_model_path:{0}'.format(o_model_path))
            onet.load_state_dict(torch.load(o_model_path))
            onet.cuda()
        else:
            onet.load_state_dict(
                torch.load(o_model_path,
                           map_location=lambda storage, loc: storage))
        onet.eval()

    return pnet, rnet, onet
예제 #3
0
    def create_mtcnn_net(self):
        ''' Create the mtcnn model '''
        pnet, rnet, onet = None, None, None

        if len(self.args.pnet_file) > 0:
            pnet = PNet(use_cuda=self.args.use_cuda)
            if self.args.use_cuda:
                pnet.load_state_dict(torch.load(self.args.pnet_file))
                pnet = torch.nn.DataParallel(
                    pnet, device_ids=self.args.gpu_ids).cuda()
            else:
                pnet.load_state_dict(torch.load(self.args.pnet_file,\
                                                map_location=lambda storage, loc: storage))
            pnet.eval()

        if len(self.args.rnet_file) > 0:
            rnet = RNet(use_cuda=self.args.use_cuda)
            if self.args.use_cuda:
                rnet.load_state_dict(torch.load(self.args.rnet_file))
                rnet = torch.nn.DataParallel(
                    rnet, device_ids=self.args.gpu_ids).cuda()
            else:
                rnet.load_state_dict(torch.load(self.args.rnet_file,\
                                                map_location=lambda storage, loc: storage))
            rnet.eval()

        if len(self.args.onet_file) > 0:
            onet = ONet(use_cuda=self.args.use_cuda)
            if self.args.use_cuda:
                onet.load_state_dict(torch.load(self.args.onet_file))
                onet = torch.nn.DataParallel(
                    onet, device_ids=self.args.gpu_ids).cuda()
            else:
                onet.load_state_dict(torch.load(self.args.onet_file, \
                                                map_location=lambda storage, loc: storage))
            onet.eval()

        self.pnet_detector = pnet
        self.rnet_detector = rnet
        self.onet_detector = onet