class F3Net(nn.Module): def __init__(self, num_classes=1, img_width=299, img_height=299, LFS_window_size=10, LFS_stride=2, LFS_M=6, mode='FAD', device=None): super(F3Net, self).__init__() assert img_width == img_height img_size = img_width self.num_classes = num_classes self.mode = mode self.window_size = LFS_window_size self._LFS_M = LFS_M # init branches if mode == 'FAD' or mode == 'Both': self.FAD_head = FAD_Head(img_size) self.init_xcep_FAD() if mode == 'LFS' or mode == 'Both': self.LFS_head = LFS_Head(img_size, LFS_window_size, LFS_M) self.init_xcep_LFS() if mode == 'Original': self.init_xcep() # classifier self.relu = nn.ReLU(inplace=True) self.fc = nn.Linear( 4096 if self.mode == 'Both' or self.mode == 'Mix' else 2048, num_classes) self.dp = nn.Dropout(p=0.2) def init_xcep_FAD(self): self.FAD_xcep = Xception(self.num_classes) # To get a good performance, using ImageNet-pretrained Xception model is recommended state_dict = get_xcep_state_dict() conv1_data = state_dict['conv1.weight'].data self.FAD_xcep.load_state_dict(state_dict, False) # copy on conv1 # let new conv1 use old param to balance the network self.FAD_xcep.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False) for i in range(4): self.FAD_xcep.conv1.weight.data[:, i * 3:(i + 1) * 3, :, :] = conv1_data / 4.0 def init_xcep_LFS(self): self.LFS_xcep = Xception(self.num_classes) # To get a good performance, using ImageNet-pretrained Xception model is recommended state_dict = get_xcep_state_dict() conv1_data = state_dict['conv1.weight'].data self.LFS_xcep.load_state_dict(state_dict, False) # copy on conv1 # let new conv1 use old param to balance the network self.LFS_xcep.conv1 = nn.Conv2d(self._LFS_M, 32, 3, 1, 0, bias=False) for i in range(int(self._LFS_M / 3)): self.LFS_xcep.conv1.weight.data[:, i * 3:(i + 1) * 3, :, :] = conv1_data / float( self._LFS_M / 3.0) def init_xcep(self): self.xcep = Xception(self.num_classes) # To get a good performance, using ImageNet-pretrained Xception model is recommended state_dict = get_xcep_state_dict() self.xcep.load_state_dict(state_dict, False) def forward(self, x): if self.mode == 'FAD': fea_FAD = self.FAD_head(x) fea_FAD = self.FAD_xcep.features(fea_FAD) fea_FAD = self._norm_fea(fea_FAD) y = fea_FAD if self.mode == 'LFS': fea_LFS = self.LFS_head(x) fea_LFS = self.LFS_xcep.features(fea_LFS) fea_LFS = self._norm_fea(fea_LFS) y = fea_LFS if self.mode == 'Original': fea = self.xcep.features(x) fea = self._norm_fea(fea) y = fea if self.mode == 'Both': fea_FAD = self.FAD_head(x) fea_FAD = self.FAD_xcep.features(fea_FAD) fea_FAD = self._norm_fea(fea_FAD) fea_LFS = self.LFS_head(x) fea_LFS = self.LFS_xcep.features(fea_LFS) fea_LFS = self._norm_fea(fea_LFS) y = torch.cat((fea_FAD, fea_LFS), dim=1) f = self.dp(y) f = self.fc(f) return y, f def _norm_fea(self, fea): f = self.relu(fea) f = F.adaptive_avg_pool2d(f, (1, 1)) f = f.view(f.size(0), -1) return f
import torch import base64 from io import BufferedReader, BytesIO from skimage import io, color import numpy as np from PIL import Image from flask import Flask, jsonify, request from werkzeug.utils import secure_filename from xception import Xception app = Flask(__name__) device = torch.device('cpu') model = Xception() ckpt_dir = 'log_path/Xception_trained_model.pth' checkpoint = torch.load(ckpt_dir, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() message = '' def crop_image(file): detector_ori = dlib.get_frontal_face_detector() # open the image file try: img = io.imread(file) except Exception as e: message = "While processing, " + str(e) return message # If the resolution is less than 128x128 then skip img_height = img.shape[0]