forked from DingXiaoH/Centripetal-SGD
/
tfm_builder_densenet.py
121 lines (96 loc) · 4.52 KB
/
tfm_builder_densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from tensorflow.python.layers.convolutional import conv2d
import numpy as np
from tensorflow.python.layers.core import dense
from tensorflow.python.layers.core import dropout
from tensorflow.python.layers.pooling import max_pooling2d, average_pooling2d
from tensorflow.contrib.layers import flatten, batch_norm
import tensorflow as tf
class ModelBuilder(object):
def __init__(self, training):
self.training = training
def build(self, img_input):
pass
def _relu(self, bottom):
return tf.nn.relu(bottom)
def _flatten(self, bottom):
return flatten(bottom)
def _dropout(self, bottom, drop_rate):
return dropout(bottom, rate=drop_rate, training=self.training)
def _batch_norm_default(self, bottom, scope, eps=1e-3, center=True, scale=True):
if hasattr(self, 'bn_decay'):
# print('bn decay factor: ', self.bn_decay)
decay = self.bn_decay
else:
decay = 0.9
if hasattr(self, 'need_gamma'):
need_gamma = self.need_gamma
else:
need_gamma = scale
if hasattr(self, 'need_beta'):
need_beta = self.need_beta
else:
need_beta = center
return batch_norm(inputs=bottom, decay=decay, center=need_beta, scale=need_gamma, activation_fn=None,
is_training=self.training, scope=scope, epsilon=eps)
def _maxpool(self, bottom, stride):
return max_pooling2d(bottom, [stride,stride], [stride,stride])
def _avgpool(self, bottom, stride):
return average_pooling2d(bottom, [stride,stride], [stride,stride])
def _gap(self, bottom):
height = bottom.get_shape()[1]
width = bottom.get_shape()[2]
return average_pooling2d(bottom, [height,width], [height,width])
def _xavier_initializer(self):
return tf.contrib.layers.xavier_initializer()
class DC40Builder(ModelBuilder):
def __init__(self, training, deps):
super(DC40Builder, self).__init__(training)
self.deps = deps
self.N = 12 # 12 layers per stage
self.num_classes = 10
def build(self, img_input):
def conv(input, filters, stride, name):
return conv2d(input, filters, [3,3], strides=[stride, stride], name=name,
padding='same', activation=None, use_bias=False, kernel_initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / filters)))
def add_layer(name, input, cur_layer_idx):
# shape = input.get_shape().as_list()
# in_channel = shape[3]
with tf.variable_scope(name) as scope:
c = self._batch_norm_default(input, name)
c = tf.nn.relu(c)
c = conv(c, self.deps[cur_layer_idx], 1, 'conv1')
result = tf.concat([input, c], 3)
return result
def add_transition(name, input, nb_filters):
# shape = input.get_shape().as_list()
# in_channel = shape[3]
with tf.variable_scope(name) as scope:
l = self._batch_norm_default(input, name)
l = tf.nn.relu(l)
l = conv2d(l, nb_filters, [1,1], strides=[1,1], name='conv1',
padding='same', activation=None, use_bias=False)
l = tf.nn.relu(l)
l = self._avgpool(l, 2)
return l
# tf.summary.image('input-image', img_input)
l = conv(img_input, self.deps[0], 1, 'conv0')
with tf.variable_scope('stage1') as scope:
for i in range(self.N):
l = add_layer('block{}'.format(i), l, self.N * 0 + 1 + i)
l = add_transition('transition1', l, nb_filters=self.deps[self.N + 1])
with tf.variable_scope('stage2') as scope:
for i in range(self.N):
l = add_layer('block{}'.format(i), l, self.N * 1 + 2 + i)
l = add_transition('transition2', l, nb_filters=self.deps[self.N * 2 + 2])
with tf.variable_scope('stage3') as scope:
for i in range(self.N):
l = add_layer('block{}'.format(i), l, self.N * 2 + 3 + i)
l = self._batch_norm_default(l, scope='bnlast')
l = tf.nn.relu(l)
l = self._gap(l)
l = self._flatten(l)
logits = dense(l, self.num_classes, activation=None, use_bias=True,
kernel_initializer=self._xavier_initializer(), name='fc10')
return logits