예제 #1
0
    def __init__(self, pnet_param, rnet_param, onet_param, isCuda=True):
        self.isCuda = isCuda
        self.pnet = Nets.PNet()
        self.rnet = Nets.RNet()
        self.onet = Nets.ONet()

        if self.isCuda:
            self.pnet.cuda()
            self.rnet.cuda()
            self.onet.cuda()

        # 加载网络参数
        self.pnet.load_state_dict(torch.load(pnet_param))
        self.rnet.load_state_dict(torch.load(rnet_param))
        self.onet.load_state_dict(torch.load(onet_param))

        # 网络是测试
        self.pnet.eval()
        self.rnet.eval()
        self.onet.eval()
        # 定义transform为ToTensor
        self.__image_transform = transforms.Compose([transforms.ToTensor()])
예제 #2
0
    def __init__(self, pnet_param, rnet_param, onet_param, isCuda=False):

        self.isCuda = isCuda
        # 实例化网络
        self.pnet = Nets.PNet()
        self.rnet = Nets.RNet()
        self.onet = Nets.ONet()
        # CUDA加速网络
        if self.isCuda:
            self.pnet().cuda()
            self.rnet().cuda()
            self.onet().cuda()
        # 装载网络训练结果
        self.pnet.load_state_dict(torch.load(pnet_param))
        self.rnet.load_state_dict(torch.load(rnet_param))
        self.onet.load_state_dict(torch.load(onet_param))

        self.pnet.eval()
        self.rnet.eval()
        self.onet.eval()
        # 将图片数据转换成NCHW
        self.__image_transform = transforms.Compose(transforms.ToTensor())
예제 #3
0
import Nets
import train_nets

if __name__ == '__main__':
    net = Nets.PNet()
    trainer = train_nets.Trainer(net, './param/pnet1.pth', r'E:\data\data\12')
    trainer.train()


예제 #4
0
파일: mtcnn.py 프로젝트: cay125/mtcnn
if __name__ == '__main__':
    pnet = PNet(test=True)
    pnet.load_state_dict(
        torch.load('models/pnet_2019_09_04', map_location='cpu')['weights'])
    rnet = RNet(test=True)
    rnet.load_state_dict(
        torch.load('models/rnet_2019_09_10', map_location='cpu')['weights'])
    onet = ONet(test=True)
    onet.load_state_dict(
        torch.load('models/onet_2019_09_11', map_location='cpu')['weights'])
    pnet2 = PNet(test=True)
    pnet2.load_state_dict(
        torch.load('models/pnet_2019_09_08_18_07',
                   map_location='cpu')['weights'])
    pnet3 = Nets.PNet(test=True)
    pnet3.load_state_dict(
        torch.load('../mtcnn-pytorch/scripts/models/pnet_20190908_final.pkl',
                   map_location='cpu')['weights'])
    detector = mtcnn(pnet, rnet, onet, mini_face_size=20, name="1", show=True)
    detector2 = mtcnn(pnet2,
                      rnet,
                      onet,
                      mini_face_size=20,
                      name="2",
                      show=True)
    detector3 = mtcnn(pnet3,
                      rnet,
                      onet,
                      mini_face_size=20,
                      name="3",