-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_back.py
117 lines (102 loc) · 3.36 KB
/
extract_back.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
print("Loading model")
import keras;
print(keras.__version__)
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import sys
import keras
print(sys.version)
from PIL import Image
from scipy.misc.pilutil import imresize
import numpy as np
print("Loading model2")
from keras.models import load_model
print("Loading model3")
import tensorflow as tf
import io
import requests
import base64
import cv2
print("Loading model4")
MODEL_URL = 'https://gitlab.com/fast-science/background-removal-server/raw/master/webapp/model/main_model.hdf5'
MODEL_PATH = 'main_model.hdf5'
def download_model():
"""Downloads the model file.
"""
if os.path.exists(MODEL_PATH):
print("Model file is already downloaded.")
return
# Download to a tmp file and move it to final file to avoid inconsistent state
# if download fails or cancelled.
print("Model file is not available. downloading...")
exit_status = os.system("wget {} -O {}.tmp".format(MODEL_URL, MODEL_PATH))
if exit_status == 0:
os.system("mv {}.tmp {}".format(MODEL_PATH, MODEL_PATH))
else:
print("Failed to download the model file", file=sys.stderr)
sys.exit(1)
def ml_predict(image):
with graph.as_default():
# Add a dimension for the batch
prediction = model.predict(image[None, :, :, :])
prediction = prediction.reshape((224,224, -1))
return prediction
THRESHOLD = 0.5
def predict1(image):
print("predict1");
"""Removed the background of given image.
:param image: numpy array
"""
height, width = image.shape[0], image.shape[1]
resized_image = imresize(image, (224, 224)) / 255.0
# Model input shape = (224,224,3)
# [0:3] - Take only the first 3 RGB channels and drop ALPHA 4th channel in case this is a PNG
prediction = ml_predict(resized_image[:, :, 0:3])
print('PREDICTION COUNT', (prediction[:, :, 1]>0.5).sum())
# Resize back to original image size
# [:, :, 1] = Take predicted class 1 - currently in our model = Person class. Class 0 = Background
prediction = imresize(prediction[:, :, 1], (height, width))
prediction[prediction<THRESHOLD*255] = 0
prediction[prediction>=THRESHOLD*255] = 1
#return prediction
res1=prediction*image[0:,:,0]
res2=prediction*image[0:,:,1]
res3=prediction*image[0:,:,2]
img2=np.dstack([res1,res2,res3,res4])
return img2
def read_image(image_spec):
if not isinstance(image_spec, dict):
return None
if 'data' in image_spec:
data = base64.b64decode(image_spec['data'])
elif 'url' in image_spec:
data = requests.get(image_spec['url']).content
else:
return None
return Image.open(fp=io.BytesIO(data))
def write_image(image):
print("write_image");
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
fp = io.BytesIO()
print(fp);
image.save("fp.png", format='png')
return {
"data": base64.b64encode(fp.getvalue()).decode('ascii'),
"content-type": "image/png"
}
def predict():
print("predict");
img=Image.open("r7NEmMOIcjQ.jpg")
print(img);
imgarray = predict1(np.array(img))
return write_image(imgarray)
# Preload our model
# download_model()
print("Loading model5")
model = load_model(MODEL_PATH, compile=False)
print("Loading model6")
graph = tf.get_default_graph()
print("Loading model7")
predict()
print("Loading model8")