def create_preact_resnet(depth=200): '''Resnet with the batchnorm and relu moved to before the conv layer for each block''' net = ffnet.FeedForwardNet() net.add( Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224))) net.add(BatchNormalization('input-bn')) net.add(Activation('input_relu')) net.add(MaxPooling2D('input_pool', 3, 2, pad=1)) conf = cfg[depth] if depth > 34: stage(0, net, conf[0], 64, 64, 256, 1, bottleneck, preact=True) stage(1, net, conf[1], 256, 128, 512, 2, bottleneck, preact=True) stage(2, net, conf[2], 512, 256, 1024, 2, bottleneck, preact=True) stage(3, net, conf[3], 1024, 512, 2048, 2, bottleneck, preact=True) else: stage(0, net, conf[0], 64, 64, 64, 1, basicblock, preact=True) stage(1, net, conf[1], 64, 128, 128, 2, basicblock, preact=True) stage(2, net, conf[2], 128, 256, 256, 2, basicblock, preact=True) stage(3, net, conf[3], 256, 512, 512, 2, basicblock, preact=True) net.add(BatchNormalization('final-bn')) net.add(Activation('final-relu')) net.add(AvgPooling2D('avg', 7, 1, pad=0)) net.add(Flatten('flat')) net.add(Dense('dense', 1000)) return net
def add_dense_connected_layers(name, net, growth_rate): net.add(BatchNormalization('%s/bn1' % name)) net.add(Activation('%s/relu1' % name)) net.add( Conv2D('%s/conv1' % name, 4 * growth_rate, 1, 1, pad=0, use_bias=conv_bias)) net.add(BatchNormalization('%s/bn2' % name)) net.add(Activation('%s/relu2' % name)) return net.add( Conv2D('%s/conv2' % name, growth_rate, 3, 1, pad=1, use_bias=conv_bias))
def create_resnet(depth=18): '''Original resnet, where the there is a relue after the addition layer''' net = ffnet.FeedForwardNet() net.add( Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224))) net.add(BatchNormalization('input-bn')) net.add(Activation('input_relu')) net.add(MaxPooling2D('input_pool', 3, 2, pad=1)) conf = cfg[depth] if depth > 34: stage(0, net, conf[0], 64, 64, 256, 1, bottleneck) stage(1, net, conf[1], 256, 128, 512, 2, bottleneck) stage(2, net, conf[2], 512, 256, 1024, 2, bottleneck) stage(3, net, conf[3], 1024, 512, 2048, 2, bottleneck) else: stage(0, net, conf[0], 64, 64, 64, 1, basicblock) stage(1, net, conf[1], 64, 128, 128, 2, basicblock) stage(2, net, conf[2], 128, 256, 256, 2, basicblock) stage(3, net, conf[3], 256, 512, 512, 2, basicblock) net.add(AvgPooling2D('avg', 7, 1, pad=0)) net.add(Flatten('flat')) net.add(Dense('dense', 1000)) return net
def conv2d(net, name, nb_filter, k, s=1, border_mode='SAME', src=None): if type(k) is list: k = (k[0], k[1]) net.add( Conv2D(name, nb_filter, k, s, border_mode=border_mode, use_bias=False), src) net.add(BatchNormalization('%s/BatchNorm' % name)) return net.add(Activation(name + '/relu'))
def bottleneck(name, net, inplane, midplane, outplane, stride=1, preact=False, add_bn=False): '''Add three conv layers, with a>=b<=c filters. The default structure is input -split - conv1-bn1-relu1-conv2-bn2-relu2-conv3-bn3 - conv-bn or dummy -add -relu Args: inplane, num of feature maps of the input midplane, num of featue maps of the middle layer outplane, num of feature maps of the output preact, if true, move the bn3 and relu before conv1, i.e., pre-activation ref identity mapping paper add_bn, if true, move the last bn after the addition layer (for resnet-50) ''' assert not ( preact and add_bn ), 'preact and batchnorm after addition cannot be true at the same time' split = net.add(Split(name + '-split', 2)) if preact: net.add(BatchNormalization(name + '-preact-bn')) net.add(Activation(name + '-preact-relu')) conv(net, name + '-0', midplane, 1, 1, 0, True, True) conv(net, name + '-1', midplane, 3, stride, 1, True, True) br0 = conv(net, name + '-2', outplane, 1, 1, 0, not (preact or add_bn), False) br1 = shortcut(net, name, inplane, outplane, stride, split, not add_bn) ret = net.add(Merge(name + '-add'), [br0, br1]) if add_bn: ret = net.add(BatchNormalization(name + '-add-bn')) if not preact: ret = net.add(Activation(name + '-add-relu')) return ret
def densenet_base(depth, growth_rate=32, reduction=0.5): ''' rewrite according to pytorch models special case of densenet 161 ''' if depth == 121: stages = [6, 12, 24, 16] elif depth == 169: stages = [6, 12, 32, 32] elif depth == 201: stages = [6, 12, 48, 32] elif depth == 161: stages = [6, 12, 36, 24] else: print('unknown depth: %d' % depth) sys.exit(-1) net = ffnet.FeedForwardNet() growth_rate = 48 if depth == 161 else 32 n_channels = 2 * growth_rate net.add( Conv2D('input/conv', n_channels, 7, 2, pad=3, use_bias=conv_bias, input_sample_shape=(3, 224, 224))) net.add(BatchNormalization('input/bn')) net.add(Activation('input/relu')) net.add(MaxPooling2D('input/pool', 3, 2, pad=1)) # Dense-Block 1 and transition (56x56) n_channels = add_block('block1', net, n_channels, stages[0], growth_rate) add_transition('trans1', net, int(math.floor(n_channels * reduction))) n_channels = math.floor(n_channels * reduction) # Dense-Block 2 and transition (28x28) n_channels = add_block('block2', net, n_channels, stages[1], growth_rate) add_transition('trans2', net, int(math.floor(n_channels * reduction))) n_channels = math.floor(n_channels * reduction) # Dense-Block 3 and transition (14x14) n_channels = add_block('block3', net, n_channels, stages[2], growth_rate) add_transition('trans3', net, int(math.floor(n_channels * reduction))) n_channels = math.floor(n_channels * reduction) # Dense-Block 4 and transition (7x7) n_channels = add_block('block4', net, n_channels, stages[3], growth_rate) add_transition('trans4', net, n_channels, True) return net
def add_transition(name, net, n_channels, last=False): net.add(BatchNormalization('%s/norm' % name)) lyr = net.add(Activation('%s/relu' % name)) if last: net.add( AvgPooling2D('%s/pool' % name, lyr.get_output_sample_shape()[1:3], pad=0)) net.add(Flatten('flat')) else: net.add( Conv2D('%s/conv' % name, n_channels, 1, 1, pad=0, use_bias=conv_bias)) net.add(AvgPooling2D('%s/pool' % name, 2, 2, pad=0))
def conv(net, prefix, n, ksize, stride=1, pad=0, bn=True, relu=True, src=None): '''Add a convolution layer and optionally a batchnorm and relu layer. Args: prefix, a string for the prefix of the layer name n, num of filters for the conv layer bn, if true add batchnorm relu, if true add relu Returns: the last added layer ''' ret = net.add( Conv2D(prefix + '-conv', n, ksize, stride, pad=pad, use_bias=conv_bias), src) if bn: ret = net.add(BatchNormalization(prefix + '-bn')) if relu: ret = net.add(Activation(prefix + '-relu')) return ret
def create_layers(net, cfg, sample_shape, batch_norm=False): lid = 0 for idx, v in enumerate(cfg): if v == 'M': net.add(MaxPooling2D('pool/features.%d' % lid, 2, 2, pad=0)) lid += 1 else: net.add( Conv2D('conv/features.%d' % lid, v, 3, pad=1, input_sample_shape=sample_shape)) lid += 1 if batch_norm: net.add(BatchNormalization('bn/features.%d' % lid)) lid += 1 net.add(Activation('act/features.%d' % lid)) lid += 1 sample_shape = None return net
def create_wide_resnet(depth=50): '''Similar original resnet except that a<=b<=c for the bottleneck block''' net = ffnet.FeedForwardNet() net.add( Conv2D('input-conv', 64, 7, 2, pad=3, use_bias=False, input_sample_shape=(3, 224, 224))) net.add(BatchNormalization('input-bn')) net.add(Activation('input_relu')) net.add(MaxPooling2D('input_pool', 3, 2, pad=1)) stage(0, net, 3, 64, 128, 256, 1, bottleneck) stage(1, net, 4, 256, 256, 512, 2, bottleneck) stage(2, net, 6, 512, 512, 1024, 2, bottleneck) stage(3, net, 3, 1024, 1024, 2048, 2, bottleneck) net.add(AvgPooling2D('avg_pool', 7, 1, pad=0)) net.add(Flatten('flag')) net.add(Dense('dense', 1000)) return net
def inception_v3_base(name, sample_shape, final_endpoint, aux_endpoint, depth_multiplier=1, min_depth=16): """Creates the Inception V3 network up to the given final endpoint. Args: sample_shape: input image sample shape, 3d tuple final_endpoint: specifies the endpoint to construct the network up to. aux_endpoint: for aux loss. Returns: logits: the logits outputs of the model. end_points: the set of end_points from the inception model. Raises: ValueError: if final_endpoint is not set to one of the predefined values """ V3 = 'InceptionV3' end_points = {} net = ffnet.FeedForwardNet() def final_aux_check(block_name): if block_name == final_endpoint: return True if block_name == aux_endpoint: aux = aux_endpoint + '-aux' end_points[aux] = net.add(Split(aux, 2)) return False def depth(d): return max(int(d * depth_multiplier), min_depth) blk = V3 + '/Conv2d_1a_3x3' # 299 x 299 x 3 net.add( Conv2D(blk, depth(32), 3, 2, border_mode='VALID', use_bias=False, input_sample_shape=sample_shape)) net.add(BatchNormalization(blk + '/BatchNorm')) end_points[blk] = net.add(Activation(blk + '/relu')) if final_aux_check(blk): return net, end_points # 149 x 149 x 32 conv2d(net, '%s/Conv2d_2a_3x3' % V3, depth(32), 3, border_mode='VALID') # 147 x 147 x 32 conv2d(net, '%s/Conv2d_2b_3x3' % V3, depth(64), 3) # 147 x 147 x 64 net.add(MaxPooling2D('%s/MaxPool_3a_3x3' % V3, 3, 2, border_mode='VALID')) # 73 x 73 x 64 conv2d(net, '%s/Conv2d_3b_1x1' % V3, depth(80), 1, border_mode='VALID') # 73 x 73 x 80. conv2d(net, '%s/Conv2d_4a_3x3' % V3, depth(192), 3, border_mode='VALID') # 71 x 71 x 192. net.add(MaxPooling2D('%s/MaxPool_5a_3x3' % V3, 3, 2, border_mode='VALID')) # 35 x 35 x 192. blk = V3 + '/Mixed_5b' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(48), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % blk, depth(64), 5) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3) net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(32), 1) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_1: 35 x 35 x 288. blk = V3 + '/Mixed_5c' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x1' % blk, depth(48), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv_1_0c_5x5' % blk, depth(64), 5) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(64), 1) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_2: 35 x 35 x 288. blk = V3 + '/Mixed_5d' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(48), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % blk, depth(64), 5) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(64), 1) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_3: 17 x 17 x 768. blk = V3 + '/Mixed_6a' s = net.add(Split('%s/Split' % blk, 3)) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_1x1' % blk, depth(384), 3, 2, border_mode='VALID', src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(64), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % blk, depth(96), 3) br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_1x1' % blk, depth(96), 3, 2, border_mode='VALID') br2 = net.add( MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2, border_mode='VALID'), s) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2]) if final_aux_check(blk): return net, end_points # mixed4: 17 x 17 x 768. blk = V3 + '/Mixed_6b' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(128), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(128), [1, 7]) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(128), [1, 1], src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(128), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(128), [1, 7]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(128), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_5: 17 x 17 x 768. blk = V3 + '/Mixed_6c' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(160), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(160), [1, 7]) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(160), [1, 1], src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(160), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(160), [1, 7]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(160), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_6: 17 x 17 x 768. blk = V3 + '/Mixed_6d' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(160), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(160), [1, 7]) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(160), [1, 1], src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(160), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(160), [1, 7]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(160), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points blk = V3 + '/Mixed_6e' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(192), [1, 7]) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(192), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(192), [1, 7]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(192), [7, 1]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7]) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_8: 8 x 8 x 1280. blk = V3 + '/Mixed_7a' s = net.add(Split('%s/Split' % blk, 3)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, depth(320), [3, 3], 2, border_mode='VALID') br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(192), [1, 1], src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(192), [1, 7]) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1]) br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, depth(192), [3, 3], 2, border_mode='VALID') br2 = net.add( MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2, border_mode='VALID'), s) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2]) if final_aux_check(blk): return net, end_points # mixed_9: 8 x 8 x 2048. blk = V3 + '/Mixed_7b' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(320), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(384), 1, src=s) s1 = net.add(Split('%s/Branch_1/Split1' % blk, 2)) br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % blk, depth(384), [1, 3], src=s1) br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % blk, depth(384), [3, 1], src=s1) br1 = net.add(Concat('%s/Branch_1/Concat1' % blk, 1), [br11, br12]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(448), 1, src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(384), 3) s2 = net.add(Split('%s/Branch_2/Split2' % blk, 2)) br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, depth(384), [1, 3], src=s2) br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % blk, depth(384), [3, 1], src=s2) br2 = net.add(Concat('%s/Branch_2/Concat2' % blk, 1), [br21, br22]) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) if final_aux_check(blk): return net, end_points # mixed_10: 8 x 8 x 2048. blk = V3 + '/Mixed_7c' s = net.add(Split('%s/Split' % blk, 4)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(320), 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(384), 1, src=s) s1 = net.add(Split('%s/Branch_1/Split1' % blk, 2)) br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % blk, depth(384), [1, 3], src=s1) br12 = conv2d(net, '%s/Branch_1/Conv2d_0c_3x1' % blk, depth(384), [3, 1], src=s1) br1 = net.add(Concat('%s/Branch_1/Concat1' % blk, 1), [br11, br12]) br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(448), [1, 1], src=s) br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(384), [3, 3]) s2 = net.add(Split('%s/Branch_2/Split2' % blk, 2)) br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, depth(384), [1, 3], src=s2) br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % blk, depth(384), [3, 1], src=s2) br2 = net.add(Concat('%s/Branch_2/Concat2' % blk, 1), [br21, br22]) br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s) br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1]) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3]) assert final_endpoint == blk, \ 'final_enpoint = %s is not in the net' % final_endpoint return net, end_points
def inception_v4_base(sample_shape, final_endpoint='Inception/Mixed_7d', aux_endpoint='Inception/Mixed_6e'): """Creates the Inception V4 network up to the given final endpoint. Endpoint name list: 'InceptionV4/' + ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'] Args: sample_shape: input image sample shape, 3d tuple final_endpoint: specifies the endpoint to construct the network up to. aux_endpoint: for aux loss. Returns: the neural net the set of end_points from the inception model. """ name = 'InceptionV4' end_points = {} net = ffnet.FeedForwardNet() def final_aux_check(block_name): if block_name == final_endpoint: return True if block_name == aux_endpoint: aux = aux_endpoint + '-aux' end_points[aux] = net.add(Split(aux, 2)) return False # 299 x 299 x 3 blk = name + '/Conv2d_1a_3x3' net.add( Conv2D(blk, 32, 3, 2, border_mode='VALID', use_bias=False, input_sample_shape=sample_shape)) net.add(BatchNormalization('%s/BatchNorm' % blk)) end_points[blk] = net.add(Activation('%s/relu' % blk)) if final_aux_check(blk): return net, end_points # 149 x 149 x 32 blk = name + '/Conv2d_2a_3x3' end_points[blk] = conv2d(net, blk, 32, 3, border_mode='VALID') if final_aux_check(blk): return net, end_points # 147 x 147 x 32 blk = name + '/Conv2d_2b_3x3' end_points[blk] = conv2d(net, blk, 64, 3) if final_aux_check(blk): return net, end_points # 147 x 147 x 64 blk = name + '/Mixed_3a' s = net.add(Split('%s/Split' % blk, 2)) br0 = net.add( MaxPooling2D('%s/Branch_0/MaxPool_0a_3x3' % blk, 3, 2, border_mode='VALID'), s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_3x3' % blk, 96, 3, 2, border_mode='VALID', src=s) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1]) if final_aux_check(blk): return net, end_points # 73 x 73 x 160 blk = name + '/Mixed_4a' s = net.add(Split('%s/Split' % blk, 2)) br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, 64, 1, src=s) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 96, 3, border_mode='VALID') br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 64, 1, src=s) br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, 64, (1, 7)) br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, 64, (7, 1)) br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 96, 3, border_mode='VALID') end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1]) if final_aux_check(blk): return net, end_points # 71 x 71 x 192 blk = name + '/Mixed_5a' s = net.add(Split('%s/Split' % blk, 2)) br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 192, 3, 2, border_mode='VALID', src=s) br1 = net.add( MaxPooling2D('%s/Branch_1/MaxPool_1a_3x3' % blk, 3, 2, border_mode='VALID'), s) end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1]) if final_aux_check(blk): return net, end_points # 35 x 35 x 384 # 4 x Inception-A blocks for idx in range(4): blk = name + '/Mixed_5' + chr(ord('b') + idx) end_points[blk] = block_inception_a(blk, net) if final_aux_check(blk): return net, end_points # 35 x 35 x 384 # Reduction-A block blk = name + '/Mixed_6a' end_points[blk] = block_reduction_a(blk, net) if final_aux_check(blk): return net, end_points[blk], end_points # 17 x 17 x 1024 # 7 x Inception-B blocks for idx in range(7): blk = name + '/Mixed_6' + chr(ord('b') + idx) end_points[blk] = block_inception_b(blk, net) if final_aux_check(blk): return net, end_points # 17 x 17 x 1024 # Reduction-B block blk = name + '/Mixed_7a' end_points[blk] = block_reduction_b(blk, net) if final_aux_check(blk): return net, end_points # 8 x 8 x 1536 # 3 x Inception-C blocks for idx in range(3): blk = name + '/Mixed_7' + chr(ord('b') + idx) end_points[blk] = block_inception_c(blk, net) if final_aux_check(blk): return net, end_points assert final_endpoint == blk, \ 'final_enpoint = %s is not in the net' % final_endpoint