Exemplo n.º 1
0
class F3Net(nn.Module):
    def __init__(self,
                 num_classes=1,
                 img_width=299,
                 img_height=299,
                 LFS_window_size=10,
                 LFS_stride=2,
                 LFS_M=6,
                 mode='FAD',
                 device=None):
        super(F3Net, self).__init__()
        assert img_width == img_height
        img_size = img_width
        self.num_classes = num_classes
        self.mode = mode
        self.window_size = LFS_window_size
        self._LFS_M = LFS_M

        # init branches
        if mode == 'FAD' or mode == 'Both':
            self.FAD_head = FAD_Head(img_size)
            self.init_xcep_FAD()

        if mode == 'LFS' or mode == 'Both':
            self.LFS_head = LFS_Head(img_size, LFS_window_size, LFS_M)
            self.init_xcep_LFS()

        if mode == 'Original':
            self.init_xcep()

        # classifier
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(
            4096 if self.mode == 'Both' or self.mode == 'Mix' else 2048,
            num_classes)
        self.dp = nn.Dropout(p=0.2)

    def init_xcep_FAD(self):
        self.FAD_xcep = Xception(self.num_classes)

        # To get a good performance, using ImageNet-pretrained Xception model is recommended
        state_dict = get_xcep_state_dict()
        conv1_data = state_dict['conv1.weight'].data

        self.FAD_xcep.load_state_dict(state_dict, False)

        # copy on conv1
        # let new conv1 use old param to balance the network
        self.FAD_xcep.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False)
        for i in range(4):
            self.FAD_xcep.conv1.weight.data[:, i * 3:(i + 1) *
                                            3, :, :] = conv1_data / 4.0

    def init_xcep_LFS(self):
        self.LFS_xcep = Xception(self.num_classes)

        # To get a good performance, using ImageNet-pretrained Xception model is recommended
        state_dict = get_xcep_state_dict()
        conv1_data = state_dict['conv1.weight'].data

        self.LFS_xcep.load_state_dict(state_dict, False)

        # copy on conv1
        # let new conv1 use old param to balance the network
        self.LFS_xcep.conv1 = nn.Conv2d(self._LFS_M, 32, 3, 1, 0, bias=False)
        for i in range(int(self._LFS_M / 3)):
            self.LFS_xcep.conv1.weight.data[:, i * 3:(i + 1) *
                                            3, :, :] = conv1_data / float(
                                                self._LFS_M / 3.0)

    def init_xcep(self):
        self.xcep = Xception(self.num_classes)

        # To get a good performance, using ImageNet-pretrained Xception model is recommended
        state_dict = get_xcep_state_dict()
        self.xcep.load_state_dict(state_dict, False)

    def forward(self, x):
        if self.mode == 'FAD':
            fea_FAD = self.FAD_head(x)
            fea_FAD = self.FAD_xcep.features(fea_FAD)
            fea_FAD = self._norm_fea(fea_FAD)
            y = fea_FAD

        if self.mode == 'LFS':
            fea_LFS = self.LFS_head(x)
            fea_LFS = self.LFS_xcep.features(fea_LFS)
            fea_LFS = self._norm_fea(fea_LFS)
            y = fea_LFS

        if self.mode == 'Original':
            fea = self.xcep.features(x)
            fea = self._norm_fea(fea)
            y = fea

        if self.mode == 'Both':
            fea_FAD = self.FAD_head(x)
            fea_FAD = self.FAD_xcep.features(fea_FAD)
            fea_FAD = self._norm_fea(fea_FAD)
            fea_LFS = self.LFS_head(x)
            fea_LFS = self.LFS_xcep.features(fea_LFS)
            fea_LFS = self._norm_fea(fea_LFS)
            y = torch.cat((fea_FAD, fea_LFS), dim=1)

        f = self.dp(y)
        f = self.fc(f)
        return y, f

    def _norm_fea(self, fea):
        f = self.relu(fea)
        f = F.adaptive_avg_pool2d(f, (1, 1))
        f = f.view(f.size(0), -1)
        return f
Exemplo n.º 2
0
import torch
import base64
from io import BufferedReader, BytesIO
from skimage import io, color
import numpy as np
from PIL import Image
from flask import Flask, jsonify, request
from werkzeug.utils import secure_filename
from xception import Xception

app = Flask(__name__)
device = torch.device('cpu')
model = Xception()
ckpt_dir = 'log_path/Xception_trained_model.pth'
checkpoint = torch.load(ckpt_dir, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
message = ''


def crop_image(file):
    detector_ori = dlib.get_frontal_face_detector()
    # open the image file
    try:
        img = io.imread(file)
    except Exception as e:
        message = "While processing, " + str(e)
        return message

    # If the resolution is less than 128x128 then skip
    img_height = img.shape[0]