/
roialign.py
60 lines (46 loc) · 2.24 KB
/
roialign.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import backend as K
import utils
class RoiAlignLayer(layers.Layer):
def __init__(self,pool_shape,image_shape,**kwargs):
super(RoiAlignLayer,self).__init__(**kwargs)
self.pool_shape = tuple(pool_shape)
self.image_shape = tuple(image_shape)
def call(self,input):
rois = input[0]
features = input[1:]
#assign feature to each rois
y1,x1,y2,x2 = tf.split(rois,4,axis=2)
h = y2 - y1
w = x2 - x1
area = tf.cast(self.image_shape[0] * self.image_shape[1], tf.float32)
spec = utils.log2(tf.sqrt(h*w) / (224.0 / tf.sqrt(area)))
roi_level = tf.minimum(5, tf.maximum(2, 4 + tf.cast(tf.round(spec), tf.int32)))
roi_level = tf.squeeze(roi_level,2)
pooled=[]
roi_to_level=[]
for i,level in enumerate(range(2,6)):
ix = tf.where(tf.equal(roi_level, level))
roi_with_level = tf.gather_nd(rois,ix) #list roi with "level" feature
roi_ids = tf.cast(ix[:,0], tf.int32)
roi_to_level.append(ix) #map level with roi list
#stop gradient propogation
roi_with_level = tf.stop_gradient(roi_with_level)
roi_ids = tf.stop_gradient(roi_ids)
pooled.append(tf.image.crop_and_resize(features[i], roi_with_level,
roi_ids, self.pool_shape, method = "bilinear"))
pooled = tf.concat(pooled, axis=0)
roi_to_level = tf.concat(roi_to_level, axis=0)
roi_range = tf.expand_dims(tf.range(tf.shape(roi_to_level)[0]),1)
roi_to_level = tf.concat([tf.cast(roi_to_level, tf.int32), roi_range], axis=1)
## Rearrange pooled features to match the order of the original boxes
# Sort roi_to_level by batch then box index
sorting_ts = roi_to_level[:,0] * 100000 + roi_to_level[:,1]
ix = tf.nn.top_k(sorting_ts, k = tf.shape(roi_to_level)[0]).indices[::-1]
ix = tf.gather(roi_to_level[:,2],ix)
pooled = tf.gather(pooled,ix)
pooled = tf.expand_dims(pooled,0)
return pooled
def compute_output_shape(self, input_shape):
return input_shape[0][:2] + self.pool_shape + (input_shape[1][-1], )