コード例 #1
0
def create_neural_net(input_file, batch_size=50):
    net = caffe.NetSpec()
    net.data, net.label = L.Data(batch_size=batch_size,
                                 source=input_file,
                                 backend=caffe.params.Data.LMDB,
                                 ntop=2,
                                 include=dict(phase=caffe.TEST),
                                 name='juniward04')

    ## pre-process
    net.conv1 = L.Convolution(net.data,
                              num_output=16,
                              kernel_size=4,
                              stride=1,
                              pad=1,
                              weight_filler=dict(type='dct4'),
                              param=[{
                                  'lr_mult': 0,
                                  'decay_mult': 0
                              }],
                              bias_term=False)
    TRUNCABS = caffe_pb2.QuantTruncAbsParameter.TRUNCABS
    net.quanttruncabs = L.QuantTruncAbs(net.conv1,
                                        process=TRUNCABS,
                                        threshold=8,
                                        in_place=True)

    ## block 1 16
    [
        net.conv1_proj, net.bn2, net.scale2, net.conv512_1, net.bn2_1,
        net.scale2_1, net.relu512_1, net.conv512_to_256, net.bn2_2,
        net.scale2_2, net.res512_to_256, net.relu512_to_256
    ] = add_downsampling_block_1(net.quanttruncabs, 12)
    ## block 2 13
    [
        net.conv256_1, net.bn2_4, net.scale2_4, net.relu256_1, net.conv256_2,
        net.bn2_5, net.scale2_5, net.relu256_2, net.conv256_3, net.bn2_6,
        net.scale2_6, net.res256_3, net.relu256_3
    ] = add_skip_block(net.res512_to_256, 24)
    ## block 3 16
    [
        net.res256_3_proj, net.bn2_7, net.scale2_7, net.conv256_4, net.bn2_8,
        net.scale2_8, net.relu256_4, net.conv256_5, net.bn2_9, net.scale2_9,
        net.relu256_5, net.conv256_to_128, net.bn2_10, net.scale2_10,
        net.res256_to_128, net.relu256_to_128
    ] = add_downsampling_block(net.res256_3, 24)
    ## block 4 13
    [
        net.conv128_1, net.bn2_11, net.scale2_11, net.relu128_1, net.conv128_2,
        net.bn2_12, net.scale2_12, net.relu128_2, net.conv128_3, net.bn2_13,
        net.scale2_13, net.res128_3, net.relu128_3
    ] = add_skip_block(net.res256_to_128, 48)
    ## block 5 16
    [
        net.res128_3_proj, net.bn2_14, net.scale2_14, net.conv128_4,
        net.bn2_15, net.scale2_15, net.relu128_4, net.conv128_5, net.bn2_16,
        net.scale2_16, net.relu128_5, net.conv128_to_64, net.bn2_17,
        net.scale2_17, net.res128_to_64, net.relu128_to_64
    ] = add_downsampling_block(net.res128_3, 48)
    ##    ## block 6 13
    ##    [net.conv64_1, net.bn2_18, net.scale2_18, net.relu64_1, net.conv64_2, net.bn2_19,
    ##     net.scale2_19, net.relu64_2, net.conv64_3, net.bn2_20, net.scale2_20, net.res64_3,
    ##     net.relu64_3] = add_skip_block(net.res128_to_64, 96)
    ## block 7 16
    [
        net.res64_3_proj, net.bn2_21, net.scale2_21, net.conv64_4, net.bn2_22,
        net.scale2_22, net.relu64_4, net.con64_5, net.bn2_23, net.scale2_23,
        net.relu64_5, net.conv64_to_32, net.bn2_24, net.scale2_24,
        net.res64_to_32, net.relu64_to_32
    ] = add_downsampling_block(net.res128_to_64, 96)
    ## block 8 13
    [
        net.conv32_1, net.bn2_25, net.scale2_25, net.relu32_1, net.conv32_2,
        net.bn2_26, net.scale2_26, net.relu32_2, net.conv32_3, net.bn2_27,
        net.scale2_27, net.res32_3, net.relu32_3
    ] = add_skip_block(net.res64_to_32, 192)
    ## block 9 16
    [
        net.res32_3_proj, net.bn2_28, net.scale2_28, net.conv32_4, net.bn2_29,
        net.scale2_29, net.relu32_4, net.con32_5, net.bn2_30, net.scale2_30,
        net.relu32_5, net.conv32_to_16, net.bn2_31, net.scale2_31,
        net.res32_to_16, net.relu32_to_16
    ] = add_downsampling_block(net.res32_3, 192)
    ## block 10 13
    [
        net.conv16_1, net.bn2_32, net.scale2_32, net.relu16_1, net.conv16_2,
        net.bn2_33, net.scale2_33, net.relu16_2, net.conv16_3, net.bn2_34,
        net.scale2_34, net.res16_3, net.relu16_3
    ] = add_skip_block(net.res32_to_16, 384)

    ## global pool
    AVE = caffe_pb2.PoolingParameter.AVE
    net.global_pool = L.Pooling(net.res16_3, pool=AVE, kernel_size=8, stride=1)

    ## full connecting
    net.fc = L.InnerProduct(net.global_pool,
                            param=[{
                                'lr_mult': 1
                            }, {
                                'lr_mult': 2
                            }],
                            num_output=2,
                            weight_filler=dict(type='xavier'),
                            bias_filler=dict(type='constant'))
    ## accuracy
    net.accuracy = L.Accuracy(net.fc,
                              net.label,
                              include=dict(phase=caffe.TEST))
    ## loss
    net.loss = L.SoftmaxWithLoss(net.fc, net.label)

    return net.to_proto()
コード例 #2
0
def create_neural_net(input_file, batch_size=50):
    net = caffe.NetSpec()
    net.data, net.label = L.Data(batch_size=batch_size,
                                 source=input_file,
                                 backend=caffe.params.Data.LMDB,
                                 ntop=2,
                                 include=dict(phase=caffe.TRAIN),
                                 name='juniward04')

    ## pre-process
    net.conv1 = L.Convolution(net.data,
                              num_output=16,
                              kernel_size=4,
                              stride=1,
                              pad=1,
                              weight_filler=dict(type='dct4'),
                              param=[{
                                  'lr_mult': 0,
                                  'decay_mult': 0
                              }],
                              bias_term=False)
    TRUNCABS = caffe_pb2.QuantTruncAbsParameter.TRUNCABS
    net.quanttruncabs = L.QuantTruncAbs(net.conv1,
                                        process=TRUNCABS,
                                        threshold=8,
                                        in_place=True)

    ## block 1
    [
        net.conv1_proj, net.bn2, net.scale2, net.conv512_1, net.bn2_1,
        net.scale2_1, net.relu512_1, net.conv512_to_256, net.bn2_2,
        net.scale2_2, net.res512_to_256, net.relu512_to_256
    ] = add_downsampling_block(net.quanttruncabs, 12)
    ## block 2
    [
        net.conv256_1, net.bn2_3, net.scale2_3, net.relu256_1, net.conv256_2,
        net.bn2_4, net.scale2_4, net.res256_2, net.relu256_2
    ] = add_skip_block(net.res512_to_256, 24)
    ## block 2_1
    [
        net.conv256_4, net.bn3_1, net.scale3_1, net.relu256_4, net.conv256_5,
        net.bn3_2, net.scale3_2, net.res256_5, net.relu256_5
    ] = add_skip_block(net.res256_2, 24)
    ## block 2_2
    [
        net.conv256_6, net.bn4_1, net.scale4_1, net.relu256_6, net.conv256_7,
        net.bn4_2, net.scale4_2, net.res256_7, net.relu256_7
    ] = add_skip_block(net.res256_5, 24)
    ## block 2_3
    [
        net.conv256_8, net.bn5_1, net.scale5_1, net.relu256_8, net.conv256_9,
        net.bn5_2, net.scale5_2, net.res256_9, net.relu256_9
    ] = add_skip_block(net.res256_7, 24)
    ## block 2_4
    [
        net.conv256_10, net.bn6_1, net.scale6_1, net.relu256_10,
        net.conv256_11, net.bn6_2, net.scale6_2, net.res256_11, net.relu256_11
    ] = add_skip_block(net.res256_9, 24)
    ## block 2_5
    [
        net.conv256_12, net.bn7_1, net.scale7_1, net.relu256_12,
        net.conv256_13, net.bn7_2, net.scale7_2, net.res256_13, net.relu256_13
    ] = add_skip_block(net.res256_11, 24)
    ## block 2_6
    [
        net.conv256_14, net.bn8_1, net.scale8_1, net.relu256_14,
        net.conv256_15, net.bn8_2, net.scale8_2, net.res256_15, net.relu256_15
    ] = add_skip_block(net.res256_13, 24)
    ## block 2_7
    [
        net.conv256_16, net.bn9_1, net.scale9_1, net.relu256_16,
        net.conv256_17, net.bn9_2, net.scale9_2, net.res256_17, net.relu256_17
    ] = add_skip_block(net.res256_15, 24)
    ## block 2_8
    [
        net.conv256_18, net.bn10_1, net.scale10_1, net.relu256_18,
        net.conv256_19, net.bn10_2, net.scale10_2, net.res256_19,
        net.relu256_19
    ] = add_skip_block(net.res256_17, 24)
    ## block 3
    [
        net.res256_2_proj, net.bn2_5, net.scale2_5, net.conv256_3, net.bn2_6,
        net.scale2_6, net.relu256_3, net.conv256_to_128, net.bn2_7,
        net.scale2_7, net.res256_to_128, net.relu256_to_128
    ] = add_downsampling_block(net.res256_19, 24)
    ## block 4
    [
        net.conv128_1, net.bn2_8, net.scale2_8, net.relu128_1, net.conv128_2,
        net.bn2_9, net.scale2_9, net.res128_2, net.relu128_2
    ] = add_skip_block(net.res256_to_128, 48)
    ## block 4_1
    [
        net.conv128_4, net.bn3_3, net.scale3_3, net.relu128_4, net.conv128_5,
        net.bn3_4, net.scale3_4, net.res128_5, net.relu128_5
    ] = add_skip_block(net.res128_2, 48)
    ## block 4_2
    [
        net.conv128_6, net.bn4_3, net.scale4_3, net.relu128_6, net.conv128_7,
        net.bn4_4, net.scale4_4, net.res128_7, net.relu128_7
    ] = add_skip_block(net.res128_5, 48)
    ## block 4_3
    [
        net.conv128_8, net.bn5_3, net.scale5_3, net.relu128_8, net.conv128_9,
        net.bn5_4, net.scale5_4, net.res128_9, net.relu128_9
    ] = add_skip_block(net.res128_7, 48)
    ## block 4_4
    [
        net.conv128_10, net.bn6_3, net.scale6_3, net.relu128_10,
        net.conv128_11, net.bn6_4, net.scale6_4, net.res128_11, net.relu128_11
    ] = add_skip_block(net.res128_9, 48)
    ## block 4_5
    [
        net.conv128_12, net.bn7_3, net.scale7_3, net.relu128_12,
        net.conv128_13, net.bn7_4, net.scale7_4, net.res128_13, net.relu128_13
    ] = add_skip_block(net.res128_11, 48)
    ## block 4_6
    [
        net.conv128_14, net.bn8_3, net.scale8_3, net.relu128_14,
        net.conv128_15, net.bn8_4, net.scale8_4, net.res128_15, net.relu128_15
    ] = add_skip_block(net.res128_13, 48)
    ## block 4_7
    [
        net.conv128_16, net.bn9_3, net.scale9_3, net.relu128_16,
        net.conv128_17, net.bn9_4, net.scale9_4, net.res128_17, net.relu128_17
    ] = add_skip_block(net.res128_15, 48)
    ## block 4_8
    [
        net.conv128_18, net.bn10_3, net.scale10_3, net.relu128_18,
        net.conv128_19, net.bn10_4, net.scale10_4, net.res128_19,
        net.relu128_19
    ] = add_skip_block(net.res128_17, 48)
    ## block 5
    [
        net.res128_2_proj, net.bn2_10, net.scale2_10, net.conv128_3,
        net.bn2_11, net.scale2_11, net.relu128_3, net.conv128_to_64,
        net.bn2_12, net.scale2_12, net.res128_to_64, net.relu128_to_64
    ] = add_downsampling_block(net.res128_19, 48)
    ## block 6
    [
        net.conv64_1, net.bn2_13, net.scale2_13, net.relu64_1, net.conv64_2,
        net.bn2_14, net.scale2_14, net.res64_2, net.relu64_2
    ] = add_skip_block(net.res128_to_64, 96)
    ## block 6_1
    [
        net.conv64_4, net.bn3_5, net.scale3_5, net.relu64_4, net.conv64_5,
        net.bn3_6, net.scale3_6, net.res64_5, net.relu64_5
    ] = add_skip_block(net.res64_2, 96)
    ## block 6_2
    [
        net.conv64_6, net.bn4_5, net.scale4_5, net.relu64_6, net.conv64_7,
        net.bn4_6, net.scale4_6, net.res64_7, net.relu64_7
    ] = add_skip_block(net.res64_5, 96)
    ## block 6_3
    [
        net.conv64_8, net.bn5_5, net.scale5_5, net.relu64_8, net.conv64_9,
        net.bn5_6, net.scale5_6, net.res64_9, net.relu64_9
    ] = add_skip_block(net.res64_7, 96)
    ## block 6_4
    [
        net.conv64_10, net.bn6_5, net.scale6_5, net.relu64_10, net.conv64_11,
        net.bn6_6, net.scale6_6, net.res64_11, net.relu64_11
    ] = add_skip_block(net.res64_9, 96)
    ## block 6_5
    [
        net.conv64_12, net.bn7_5, net.scale7_5, net.relu64_12, net.conv64_13,
        net.bn7_6, net.scale7_6, net.res64_13, net.relu64_13
    ] = add_skip_block(net.res64_11, 96)
    ## block 6_6
    [
        net.conv64_14, net.bn8_5, net.scale8_5, net.relu64_14, net.conv64_15,
        net.bn8_6, net.scale8_6, net.res64_15, net.relu64_15
    ] = add_skip_block(net.res64_13, 96)
    ## block 6_7
    [
        net.conv64_16, net.bn9_5, net.scale9_5, net.relu64_16, net.conv64_17,
        net.bn9_6, net.scale9_6, net.res64_17, net.relu64_17
    ] = add_skip_block(net.res64_15, 96)
    ## block 6_8
    [
        net.conv64_18, net.bn10_5, net.scale10_5, net.relu64_18, net.conv64_19,
        net.bn10_6, net.scale10_6, net.res64_19, net.relu64_19
    ] = add_skip_block(net.res64_17, 96)
    ## block 7
    [
        net.res64_2_proj, net.bn2_15, net.scale2_15, net.conv64_3, net.bn2_16,
        net.scale2_16, net.relu64_3, net.conv64_to_32, net.bn2_17,
        net.scale2_17, net.res64_to_32, net.relu64_to_32
    ] = add_downsampling_block(net.res64_19, 96)
    ## block 8
    [
        net.conv32_1, net.bn2_18, net.scale2_18, net.relu32_1, net.conv32_2,
        net.bn2_19, net.scale2_19, net.res32_2, net.relu32_2
    ] = add_skip_block(net.res64_to_32, 192)
    ## block 8_1
    [
        net.conv32_4, net.bn3_7, net.scale3_7, net.relu32_4, net.conv32_5,
        net.bn3_8, net.scale3_8, net.res32_5, net.relu32_5
    ] = add_skip_block(net.res32_2, 192)
    ## block 8_2
    [
        net.conv32_6, net.bn4_7, net.scale4_7, net.relu32_6, net.conv32_7,
        net.bn4_8, net.scale4_8, net.res32_7, net.relu32_7
    ] = add_skip_block(net.res32_5, 192)
    ## block 8_3
    [
        net.conv32_8, net.bn5_7, net.scale5_7, net.relu32_8, net.conv32_9,
        net.bn5_8, net.scale5_8, net.res32_9, net.relu32_9
    ] = add_skip_block(net.res32_7, 192)
    ## block 8_4
    [
        net.conv32_10, net.bn6_7, net.scale6_7, net.relu32_10, net.conv32_11,
        net.bn6_8, net.scale6_8, net.res32_11, net.relu32_11
    ] = add_skip_block(net.res32_9, 192)
    ## block 8_5
    [
        net.conv32_12, net.bn7_7, net.scale7_7, net.relu32_12, net.conv32_13,
        net.bn7_8, net.scale7_8, net.res32_13, net.relu32_13
    ] = add_skip_block(net.res32_11, 192)
    ## block 8_6
    [
        net.conv32_14, net.bn8_7, net.scale8_7, net.relu32_14, net.conv32_15,
        net.bn8_8, net.scale8_8, net.res32_15, net.relu32_15
    ] = add_skip_block(net.res32_13, 192)
    ## block 8_7
    [
        net.conv32_16, net.bn9_7, net.scale9_7, net.relu32_16, net.conv32_17,
        net.bn9_8, net.scale9_8, net.res32_17, net.relu32_17
    ] = add_skip_block(net.res32_15, 192)
    ## block 8_8
    [
        net.conv32_18, net.bn10_7, net.scale10_7, net.relu32_18, net.conv32_19,
        net.bn10_8, net.scale10_8, net.res32_19, net.relu32_19
    ] = add_skip_block(net.res32_17, 192)
    ## block 9
    [
        net.res32_2_proj, net.bn2_20, net.scale2_20, net.conv32_3, net.bn2_21,
        net.scale2_21, net.relu32_3, net.conv32_to_16, net.bn2_22,
        net.scale2_22, net.res32_to_16, net.relu32_to_16
    ] = add_downsampling_block(net.res32_19, 192)
    ## block 10
    [
        net.conv16_1, net.bn2_23, net.scale2_23, net.relu16_1, net.conv16_2,
        net.bn2_24, net.scale2_24, net.res16_2, net.relu16_2
    ] = add_skip_block(net.res32_to_16, 384)
    ## block 10_1
    [
        net.conv16_3, net.bn3_9, net.scale3_9, net.relu16_3, net.conv16_4,
        net.bn3_10, net.scale3_10, net.res16_4, net.relu16_4
    ] = add_skip_block(net.res16_2, 384)
    ## block 10_2
    [
        net.conv16_5, net.bn4_9, net.scale4_9, net.relu16_5, net.conv16_6,
        net.bn4_10, net.scale4_10, net.res16_6, net.relu16_6
    ] = add_skip_block(net.res16_4, 384)
    ## block 10_3
    [
        net.conv16_7, net.bn5_9, net.scale5_9, net.relu16_7, net.conv16_8,
        net.bn5_10, net.scale5_10, net.res16_8, net.relu16_8
    ] = add_skip_block(net.res16_6, 384)
    ## block 10_4
    [
        net.conv16_9, net.bn6_9, net.scale6_9, net.relu16_9, net.conv16_10,
        net.bn6_10, net.scale6_10, net.res16_10, net.relu16_10
    ] = add_skip_block(net.res16_8, 384)
    ## block 10_5
    [
        net.conv16_11, net.bn7_9, net.scale7_9, net.relu16_11, net.conv16_12,
        net.bn7_10, net.scale7_10, net.res16_12, net.relu16_12
    ] = add_skip_block(net.res16_10, 384)
    ## block 10_6
    [
        net.conv16_13, net.bn8_9, net.scale8_9, net.relu16_13, net.conv16_14,
        net.bn8_10, net.scale8_10, net.res16_14, net.relu16_14
    ] = add_skip_block(net.res16_12, 384)
    ## block 10_7
    [
        net.conv16_15, net.bn9_9, net.scale9_9, net.relu16_15, net.conv16_16,
        net.bn9_10, net.scale9_10, net.res16_16, net.relu16_16
    ] = add_skip_block(net.res16_14, 384)
    ## block 10_8
    [
        net.conv16_17, net.bn10_9, net.scale10_9, net.relu16_17, net.conv16_18,
        net.bn10_10, net.scale10_10, net.res16_18, net.relu16_18
    ] = add_skip_block(net.res16_16, 384)

    ## global pool
    AVE = caffe_pb2.PoolingParameter.AVE
    net.global_pool = L.Pooling(net.res16_18,
                                pool=AVE,
                                kernel_size=8,
                                stride=1)

    ## full connecting
    net.fc = L.InnerProduct(net.global_pool,
                            param=[{
                                'lr_mult': 1
                            }, {
                                'lr_mult': 2
                            }],
                            num_output=2,
                            weight_filler=dict(type='xavier'),
                            bias_filler=dict(type='constant'))
    ## accuracy
    net.accuracy = L.Accuracy(net.fc,
                              net.label,
                              include=dict(phase=caffe.TEST))
    ## loss
    net.loss = L.SoftmaxWithLoss(net.fc, net.label)

    return net.to_proto()