def load_model():
    if args.backbone == 'MobileNet':
        model = MobileNet_GDConv(136)
        model = torch.nn.DataParallel(model)
        # download model from https://drive.google.com/file/d/1Le5UdpMkKOTRr1sTp4lwkw8263sbgdSe/view?usp=sharing
        checkpoint = torch.load(
            'checkpoint/mobilenet_224_model_best_gdconv_external.pth.tar',
            map_location=map_location)
        print('Use MobileNet as backbone')
    elif args.backbone == 'PFLD':
        model = PFLDInference()
        # download from https://drive.google.com/file/d/1zgQdcVuuHS73jiNmqPToOPDS9PjCl9cy/view?usp=sharing
        checkpoint = torch.load('checkpoint/pfld_model_best.pth.tar',
                                map_location=map_location)
        print('Use PFLD as backbone')
        # download from https://drive.google.com/file/d/1_tWbsAnnfmlKddsrxX85WMiY9H-BuxL-/view?usp=sharing
    elif args.backbone == 'MobileFaceNet':
        model = MobileFaceNet([112, 112], 136)
        checkpoint = torch.load('checkpoint/mobilefacenet_model_best.pth.tar',
                                map_location=map_location)
        print('Use MobileFaceNet as backbone')
    else:
        print('Error: not suppored backbone')
    model.load_state_dict(checkpoint['state_dict'])
    return model
def load_model():
    if args.backbone == 'MobileNet':
        model = MobileNet_GDConv(136)
        model = torch.nn.DataParallel(model)
        # download model from https://drive.google.com/file/d/1Le5UdpMkKOTRr1sTp4lwkw8263sbgdSe/view?usp=sharing
        checkpoint = torch.load(
            'checkpoint/mobilenet_224_model_best_gdconv_external.pth.tar',
            map_location=map_location)
        print('Use MobileNet as backbone')
    elif args.backbone == 'PFLD':
        model = PFLDInference()
        # download from https://drive.google.com/file/d/1gjgtm6qaBQJ_EY7lQfQj3EuMJCVg9lVu/view?usp=sharing
        checkpoint = torch.load('checkpoint/pfld_model_best.pth.tar',
                                map_location=map_location)
        print('Use PFLD as backbone')
        # download from https://drive.google.com/file/d/1T8J73UTcB25BEJ_ObAJczCkyGKW5VaeY/view?usp=sharing
    elif args.backbone == 'MobileFaceNet':
        model = MobileFaceNet([112, 112], 136)
        checkpoint = torch.load('checkpoint/mobilefacenet_model_best.pth.tar',
                                map_location=map_location)
        print('Use MobileFaceNet as backbone')
    else:
        print('Error: not suppored backbone')
    model.load_state_dict(checkpoint['state_dict'])
    return model
Пример #3
0
	def __init__(self, model = 'PLFD', checkpoint = os.path.join(BASE_DIR, 'thirdparty', \
			'pytorch_face_landmark','checkpoint','pfld_model_best.pth.tar'), \
			detector = 'MTCNN'):
		import torch
		sys.path.append(os.path.join(BASE_DIR,'thirdparty','pytorch_face_landmark'))
		if model == 'MobileNet':
			from models.basenet import MobileNet_GDConv
			self.model = MobileNet_GDConv(136)
			self.model = torch.nn.DataParallel(self.model)
			self.model.load_state_dict(torch.load(checkpoint)['state_dict'])
			self.model.eval()
			self.size = 224
		elif model == 'MobileFaceNet':
			from models.mobilefacenet import MobileFaceNet
			self.model = MobileFaceNet([112, 112], 136)
			self.model.load_state_dict(torch.load(checkpoint)['state_dict'])
			self.model.eval()
			self.size = 112
		elif model == 'PLFD':
			from models.pfld_compressed import PFLDInference
			self.model = PFLDInference()
			self.model.load_state_dict(torch.load(checkpoint)['state_dict'])
			self.model.eval()
			self.size = 112
		if detector == 'MTCNN':
			from MTCNN import detect_faces
			self.detect_fun = lambda x: detect_faces(x[:,:,::-1])
		elif detector == 'FaceBoxes':
			from FaceBoxes import FaceBoxes
			self.detector = FaceBoxes()
			self.detect_fun = lambda x: self.detector.face_boxex(x)
		elif detector == 'Retinaface':
			from Retinaface import Retinaface
			self.detector = Retinaface.Retinaface()
			self.detect_fun = lambda x: self.detector(x)
		else:
			import dlib
			self.detector = dlib.get_frontal_face_detector()
			self.detect_fun = lambda x: self.detector(cv2.cvtColor(x,cv2.COLOR_BGR2GRAY))
def load_model():
    model = MobileNet_GDConv(136)
    #model = torch.nn.DataParallel(model)
    checkpoint = torch.load(args.checkpoint, map_location=map_location)
    model.load_state_dict(checkpoint['state_dict'])
    return model