def conv_block(input_tensor,
               kernel_size,
               filters,
               stage,
               block,
               use_bias=True,
               pretrain=False,
               strides=2,
               data_dict=None):
    """The identity_block is the block that has no conv layer at shortcut
    # Arguments
        input_tensor: input tensor
        kernel_size: defualt 3, the kernel size of middle conv layer at main path
        filters: list of integers, the nb_filters of 3 conv layer at main path
        stage: integer, current stage label, used for generating layer names
        block: 'a','b'..., current block label, used for generating layer names
    """
    nb_filter1, nb_filter2, nb_filter3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = setool.conv_op(input_op=input_tensor, name=conv_name_base + '2a',\
                       kh=1, kw=1,  dh=strides, dw=strides, n_out=nb_filter1,data_dict=data_dict,
                       pretrain=False)
    x = setool.batch_norm_liqi(x=x, name=bn_name_base + '2a')  #//???????
    x = tf.nn.relu(x)

    x = setool.conv_op(input_op=x, name=conv_name_base + '2b',\
                       kh=kernel_size, kw=kernel_size,  n_out=nb_filter2,data_dict=data_dict,
                       pretrain=False)
    x = setool.batch_norm_liqi(x=x, name=bn_name_base + '2b')  #//???????
    x = tf.nn.relu(x)

    x = setool.conv_op(input_op=x, name=conv_name_base + '2c',\
                       kh=1, kw=1,  n_out=nb_filter3,data_dict=data_dict,
                       pretrain=False)
    x = setool.batch_norm_liqi(x=x, name=bn_name_base + '2c')  #//???????

    shortcut = setool.conv_op(input_op=input_tensor, name=conv_name_base + '1',\
                       kh=1, kw=1,  dh=strides, dw =strides,
                       n_out=nb_filter3,data_dict=data_dict,
                       pretrain=False)
    shortcut = setool.batch_norm_liqi(x=shortcut,
                                      name=bn_name_base + '1')  #//???????
    x = x + shortcut
    x = tf.nn.relu(x, name='res' + str(stage) + block + '_out')
    return x
    def build_rpn_model(self):
        shared = setool.conv_op(input_op=self.in_feature, name='rpn_conv_shared', \
                                dh=self.strides, dw=self.strides, n_out=512)
        shared = tf.nn.relu(shared)

        #classfication
        x = setool.conv_op(input_op=shared, name='rpn_class_raw', n_out=2*self.anchors_per_location,\
                            kh=1, kw=1, padding='VALID')
        rpn_class_logits = tf.reshape(x, [GLOBAL_BATCH_SIZE, -1, 2])
        rpn_probs = tf.nn.softmax(rpn_class_logits, name="rpn_class_xxx")

        #regression
        x = setool.conv_op(input_op=shared, name='rpn_bbox_pred', n_out=4*self.anchors_per_location, \
                            kh=1, kw=1, padding='VALID')
        rpn_bbox = tf.reshape(x, [GLOBAL_BATCH_SIZE, -1, 4])

        return [rpn_class_logits, rpn_probs, rpn_bbox]
    def build(self, mode, config, images):

        assert mode in ['training', 'inference']
        # Image size must be dividable by 2 multiple times
        # h, w = config.IMAGE_SHAPE[:2]
        # if h / 2**6 != int(h / 2**6) or w / 2**6 != int(w / 2**6):
        #     raise Exception("Image size must be dividable by 2 at least 6 times "
        #                     "to avoid fractions when downscaling and upscaling."
        #                     "For example, use 256, 320, 384, 448, 512, ... etc. ")
        # input_image = tf.placeholder(shape=config.IMAGE_SHAPE.tolist(), name="input_image")

        C2, C3, C4, C5 = resnet_graph(images, "resnet50", stage5=True)

        #128*4*4*256
        P5 = setool.conv_op(input_op=C5,
                            name='fpn_c5p5',
                            kh=1,
                            kw=1,
                            n_out=256)
        P4 = setool.conv_op(input_op=C4, name='fpn_c4p4',kh=1, kw=1,  n_out=256) + \
                     tf.image.resize_images(P5, [64,64])
        P3= setool.conv_op(input_op=C3, name='fpn_c3p3',kh=1, kw=1,  n_out=256) + \
                     tf.image.resize_images(P4, [128, 128])
        P2= setool.conv_op(input_op=C2, name='fpn_c2p2',kh=1, kw=1,  n_out=256) + \
                     tf.image.resize_images(P3, [256, 256])

        P2 = setool.conv_op(input_op=P2, name='fpn_p2', n_out=256)
        P3 = setool.conv_op(input_op=P3, name='fpn_p3', n_out=256)
        P4 = setool.conv_op(input_op=P4, name='fpn_p4', n_out=256)
        P5 = setool.conv_op(input_op=P5, name='fpn_p5', n_out=256)
        P6 = setool.mpool_op(input_tensor=P5, k=1, s=2, name="fpn_p6")

        rpn_feature_maps = [P2, P3, P4, P5, P6]
        mrcnn_feature_maps = [P2, P3, P4, P5]
        # Generate Anchors
        self.anchors = utils.generate_pyramid_anchors(
            self.config.RPN_ANCHOR_SCALES, self.config.RPN_ANCHOR_RATIOS,
            self.config.BACKBONE_SHAPES, self.config.BACKBONE_STRIDES,
            self.config.RPN_ANCHOR_STRIDE)
        #(32, 64, 128, 256, 512)  3, [256,128,64,32,16], [4, 8, 16, 32, 64], 1
        rpn_P6 = RPN_net(
            P6, anchor_stride=self.config.RPN_ANCHOR_STRIDE).build_rpn_model()
        rpn_P5 = RPN_net(
            P5, anchor_stride=self.config.RPN_ANCHOR_STRIDE).build_rpn_model()
        rpn_P4 = RPN_net(
            P4, anchor_stride=self.config.RPN_ANCHOR_STRIDE).build_rpn_model()
        rpn_P3 = RPN_net(
            P3, anchor_stride=self.config.RPN_ANCHOR_STRIDE).build_rpn_model()
        rpn_P2 = RPN_net(
            P2, anchor_stride=self.config.RPN_ANCHOR_STRIDE).build_rpn_model()

        rpn_class_logits = tf.concat(
            [rpn_P2[0], rpn_P3[0], rpn_P4[0], rpn_P5[0], rpn_P6[0]], 1)
        rpn_class = tf.concat(
            [rpn_P2[1], rpn_P3[1], rpn_P4[1], rpn_P5[1], rpn_P6[1]], 1)
        rpn_bbox = tf.concat(
            [rpn_P2[2], rpn_P3[2], rpn_P4[2], rpn_P5[2], rpn_P6[2]], 1)
        # print(rpn_class_logits.shape)
        # print(rpn_class.shape)
        # print(rpn_bbox.shape)

        return rpn_class_logits, rpn_bbox, rpn_class
def resnet_graph(input_image,
                 architecture,
                 stage5=False,
                 data_dict=None,
                 pretrain=False):
    assert architecture in ["resnet50", "resnet101"]
    # Stage 1
    x = setool.conv_op(input_op=input_image, name='conv1',\
                       kh=7, kw=7,  dh=2, dw=2, n_out=64,data_dict=data_dict,#@
                       pretrain=False)
    x = setool.batch_norm_liqi(x=x, name='bn_conv1')  #//???????
    x = tf.nn.relu(x)

    x = setool.mpool_op(input_tensor=x, k=3, s=2)  #@
    # Stage 2
    x = conv_block(x,
                   3, [64, 64, 256],
                   stage=2,
                   block='a',
                   strides=1,
                   pretrain=False,
                   data_dict=data_dict)
    x = identity_block(x,
                       3, [64, 64, 256],
                       stage=2,
                       block='b',
                       pretrain=False,
                       data_dict=data_dict)
    C2 = x = identity_block(x,
                            3, [64, 64, 256],
                            stage=2,
                            block='c',
                            pretrain=False,
                            data_dict=data_dict)
    # Stage 3
    x = conv_block(
        x,
        3,
        [128, 128, 512],
        stage=3,
        block='a',  #@
        pretrain=False,
        data_dict=data_dict)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=3,
                       block='b',
                       pretrain=False,
                       data_dict=data_dict)
    x = identity_block(x,
                       3, [128, 128, 512],
                       stage=3,
                       block='c',
                       pretrain=False,
                       data_dict=data_dict)
    C3 = x = identity_block(x,
                            3, [128, 128, 512],
                            stage=3,
                            block='d',
                            pretrain=False,
                            data_dict=data_dict)
    # Stage 4
    x = conv_block(
        x,
        3,
        [256, 256, 1024],
        stage=4,
        block='a',  #@
        pretrain=False,
        data_dict=data_dict)
    block_count = {"resnet50": 1, "resnet101": 22}[architecture]
    for i in range(block_count):
        x = identity_block(x,
                           3, [256, 256, 1024],
                           stage=4,
                           block=chr(98 + i),
                           pretrain=False,
                           data_dict=data_dict)
    C4 = x
    # Stage 5
    if stage5:
        x = conv_block(
            x,
            3,
            [512, 512, 2048],
            stage=5,
            block='a',  #@
            pretrain=False,
            data_dict=data_dict)
        x = identity_block(x,
                           3, [512, 512, 2048],
                           stage=5,
                           block='b',
                           pretrain=False,
                           data_dict=data_dict)
        C5 = x = identity_block(x,
                                3, [512, 512, 2048],
                                stage=5,
                                block='c',
                                pretrain=False,
                                data_dict=data_dict)
    else:
        C5 = None
    #if image.shape = 32, cg = [batchsize , 4,4,2048]

    return C2, C3, C4, C5