-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet_cifar.py
130 lines (108 loc) · 4.33 KB
/
resnet_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding:utf-8 -*-
"""
A reproduce of ResNet with Identity Mappings
author: Tianz
"""
from keras.models import Model
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import GlobalAveragePooling2D
from keras.layers import Input,add
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
def residual_block(x, nb_filters,strides=(1,1),
dropout_rate=0., weight_decay=1E-4):
x = Conv2D(nb_filters, (3,3),
kernel_initializer='he_normal',
padding="same",
strides=strides,
use_bias=False,
kernel_regularizer=l2(weight_decay))(x)
if dropout_rate:
x = Dropout(dropout_rate)(x)
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(x)
x = Activation('relu')(x)
x = Conv2D(nb_filters, (3, 3),
kernel_initializer='he_normal',
padding="same",
use_bias=False,
kernel_regularizer=l2(weight_decay))(x)
return x
def create_ResNet(nb_classes, img_dim, nb_blocks=[4,4,4],k=1,
weight_decay=1E-4,droprate=0.):
"""
:param nb_classes: the number of your dataset classes,
for cifar-10, nb_classes should be 10
:param img_dim: the input shape of the model input
:param nb_blocks: the number of blocks in each stage
:param k: the widen fatcor, k=1 indicates that the model
is original ResNet, when k>1 the model is a wide ResNet
:param weight_decay: weight decay for L2 regularization
:param droprate: the dropout between two convolutons of
each block and the default drop rate is set to 0.0
:return: ResNet model or WRN model
"""
model_input = Input(shape=img_dim)
stack = [16*k, 32*k, 64*k]
nb_filter = 16
# Initial convolution
y = Conv2D(nb_filter, (3, 3),
kernel_initializer="he_normal",
padding="same",
use_bias=False,
kernel_regularizer=l2(weight_decay))(model_input)
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
x = Activation('relu')(x)
# stage 1
x = residual_block(x,stack[0],dropout_rate=droprate)
if stack[0] != 16:
y = Conv2D(stack[0], (1, 1),
kernel_initializer="he_normal",
padding="same",
use_bias=False,
kernel_regularizer=l2(weight_decay))(y)
y = add([x,y])
for j in range(nb_blocks[0]-1):
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
x = Activation('relu')(x)
x = residual_block(x, stack[0],dropout_rate=droprate)
y = add([x,y])
# stage 2
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
y = Activation('relu')(x)
x = residual_block(y,stack[1],strides=(2,2),dropout_rate=droprate)
y = Conv2D(stack[1], (1, 1),strides=(2,2),
kernel_initializer="he_normal",
padding="valid",
use_bias=False,
kernel_regularizer=l2(weight_decay))(y)
y = add([x, y])
for j in range(nb_blocks[1]-1):
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
x = Activation('relu')(x)
x = residual_block(x, stack[1],dropout_rate=droprate)
y = add([x, y])
# stage 3
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
y = Activation('relu')(x)
x = residual_block(y, stack[2],strides=(2,2),dropout_rate=droprate)
y = Conv2D(stack[2], (1, 1),strides=(2,2),
kernel_initializer="he_normal",
padding="valid",
use_bias=False,
kernel_regularizer=l2(weight_decay))(y)
y = add([x, y])
for j in range(nb_blocks[2]-1):
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
x = Activation('relu')(x)
x = residual_block(x, stack[2],dropout_rate=droprate)
y = add([x, y])
x = BatchNormalization(axis=-1, epsilon=1.1e-5)(y)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(nb_classes,
activation='softmax',
kernel_regularizer=l2(weight_decay),
bias_regularizer=l2(weight_decay))(x)
model = Model(input=[model_input], output=[x])
return model