Exemple #1
0
    def create_mtcnn_net(self):
        ''' Create the mtcnn model '''
        pnet, rnet, onet = None, None, None

        if len(self.args.pnet) > 0:
            pnet = PNet(use_cuda=self.use_gpu)
            if self.use_gpu:
                pnet.load_state_dict(torch.load(self.args.pnet))
                pnet = torch.nn.DataParallel(pnet, device_ids=self.gpu_ids)
            else:
                pnet.load_state_dict(
                    torch.load(self.args.pnet,
                               map_location=lambda storage, loc: storage))
            pnet.eval()

        if len(self.args.rnet) > 0:
            rnet = RNet(use_cuda=self.use_gpu)
            if self.use_gpu:
                rnet.load_state_dict(torch.load(self.args.rnet))
                rnet = torch.nn.DataParallel(rnet, device_ids=self.gpu_ids)
            else:
                rnet.load_state_dict(
                    torch.load(self.args.rnet,
                               map_location=lambda storage, loc: storage))
            rnet.eval()

        if len(self.args.onet) > 0:
            onet = ONet(use_cuda=self.use_gpu)
            if self.use_gpu:
                onet.load_state_dict(torch.load(self.args.onet))
                onet = torch.nn.DataParallel(onet, device_ids=self.gpu_ids)
            else:
                onet.load_state_dict(
                    torch.load(self.args.onet,
                               map_location=lambda storage, loc: storage))
            onet.eval()

        self.pnet_detector = pnet
        self.rnet_detector = rnet
        self.onet_detector = onet
def create_mtcnn_net(p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True):
    """
    模型加载,默认使用cpu,正常使用GPU

    """

    pnet, rnet, onet = None, None, None

    if p_model_path is not None:
        pnet = PNet(use_cuda=use_cuda)
        if(use_cuda):
            print('p_model_path:{0}'.format(p_model_path))
            pnet.load_state_dict(torch.load(p_model_path))
            pnet.cuda()
        else:
            pnet.load_state_dict(torch.load(p_model_path, map_location=lambda storage, loc: storage))
        pnet.eval()

    if r_model_path is not None:
        rnet = RNet(use_cuda=use_cuda)
        if (use_cuda):
            print('r_model_path:{0}'.format(r_model_path))
            rnet.load_state_dict(torch.load(r_model_path))
            rnet.cuda()
        else:
            rnet.load_state_dict(torch.load(r_model_path, map_location=lambda storage, loc: storage))
        rnet.eval()

    if o_model_path is not None:
        onet = ONet(use_cuda=use_cuda)
        if (use_cuda):
            print('o_model_path:{0}'.format(o_model_path))
            onet.load_state_dict(torch.load(o_model_path))
            onet.cuda()
        else:
            onet.load_state_dict(torch.load(o_model_path, map_location=lambda storage, loc: storage))
        onet.eval()

    return pnet,rnet,onet