def rpn_net(config, stage='train'):
    batch_size = config.IMAGES_PER_GPU
    input_image = Input(shape=config.IMAGE_INPUT_SHAPE)
    input_class_ids = Input(shape=(config.MAX_GT_INSTANCES, 1 + 1))
    input_boxes = Input(shape=(config.MAX_GT_INSTANCES, 4 + 1))
    input_image_meta = Input(shape=(12, ))

    # 特征及预测结果
    features = resnet50(input_image)
    boxes_regress, class_logits = rpn(features, config.RPN_ANCHOR_NUM)

    # 生成anchor
    anchors, anchors_tag = Anchor(config.RPN_ANCHOR_HEIGHTS,
                                  config.RPN_ANCHOR_WIDTHS,
                                  config.RPN_ANCHOR_BASE_SIZE,
                                  config.RPN_ANCHOR_RATIOS,
                                  config.RPN_ANCHOR_SCALES,
                                  config.BACKBONE_STRIDE,
                                  name='gen_anchors')(features)
    # 裁剪到窗口内
    # anchors = UniqueClipBoxes(config.IMAGE_INPUT_SHAPE, name='clip_anchors')(anchors)
    # windows = Lambda(lambda x: x[:, 7:11])(input_image_meta)
    # anchors = ClipBoxes()([anchors, windows])

    if stage == 'train':
        # 生成分类和回归目标
        rpn_targets = RpnTarget(batch_size,
                                config.RPN_TRAIN_ANCHORS_PER_IMAGE,
                                name='rpn_target')([
                                    input_boxes, input_class_ids, anchors,
                                    anchors_tag
                                ])  # [deltas,cls_ids,indices,..]
        deltas, cls_ids, anchor_indices = rpn_targets[:3]
        # 定义损失layer
        cls_loss = Lambda(lambda x: rpn_cls_loss(*x), name='rpn_class_loss')(
            [class_logits, cls_ids, anchor_indices])
        regress_loss = Lambda(lambda x: rpn_regress_loss(*x),
                              name='rpn_bbox_loss')(
                                  [boxes_regress, deltas, anchor_indices])

        return Model(inputs=[
            input_image, input_image_meta, input_class_ids, input_boxes
        ],
                     outputs=[cls_loss, regress_loss])
    else:  # 测试阶段
        # 应用分类和回归
        detect_boxes, class_scores, _ = RpnToProposal(
            batch_size,
            output_box_num=config.POST_NMS_ROIS_INFERENCE,
            iou_threshold=config.RPN_NMS_THRESHOLD_INFERENCE,
            name='rpn2proposals')(
                [boxes_regress, class_logits, anchors, anchors_tag])
        return Model(inputs=[input_image, input_image_meta],
                     outputs=[detect_boxes, class_scores])
def frcnn(config, stage='train'):
    batch_size = config.IMAGES_PER_GPU
    # 输入
    input_image = Input(shape=config.IMAGE_INPUT_SHAPE, name='input_image')
    input_image_meta = Input(shape=(12, ), name='input_image_meta')
    gt_class_ids = Input(shape=(config.MAX_GT_INSTANCES, 1 + 1),
                         name='input_gt_class_ids')
    gt_boxes = Input(shape=(config.MAX_GT_INSTANCES, 4 + 1),
                     name='input_gt_boxes')

    # 特征及预测结果
    features = config.base_fn(input_image)
    boxes_regress, class_logits = rpn(features, config.RPN_ANCHOR_NUM)

    # 生成anchor
    anchors, anchors_tag = Anchor(config.RPN_ANCHOR_HEIGHTS,
                                  config.RPN_ANCHOR_WIDTHS,
                                  config.RPN_ANCHOR_BASE_SIZE,
                                  config.RPN_ANCHOR_RATIOS,
                                  config.RPN_ANCHOR_SCALES,
                                  config.BACKBONE_STRIDE,
                                  name='gen_anchors')(features)
    # 裁剪到输入形状内
    # anchors = UniqueClipBoxes(config.IMAGE_INPUT_SHAPE, name='clip_anchors')(anchors)
    windows = Lambda(lambda x: x[:, 7:11])(input_image_meta)
    # anchors = ClipBoxes()([anchors, windows])

    # 应用分类和回归生成proposal
    output_box_num = config.POST_NMS_ROIS_TRAINING if stage == 'train' else config.POST_NMS_ROIS_INFERENCE
    iou_threshold = config.RPN_NMS_THRESHOLD_TRAINING if stage == 'train' else config.RPN_NMS_THRESHOLD_INFERENCE
    proposal_boxes, _, _ = RpnToProposal(batch_size,
                                         output_box_num=output_box_num,
                                         iou_threshold=iou_threshold,
                                         name='rpn2proposals')([
                                             boxes_regress, class_logits,
                                             anchors, anchors_tag,
                                             input_image_meta
                                         ])
    # proposal裁剪到图像窗口内
    # proposal_boxes_coordinate, proposal_boxes_tag = Lambda(lambda x: [x[..., :4], x[..., 4:]])(proposal_boxes)
    # proposal_boxes_coordinate = ClipBoxes()([proposal_boxes_coordinate, windows])
    # proposal_boxes_coordinate = UniqueClipBoxes(config.IMAGE_INPUT_SHAPE,
    #                                             name='clip_proposals')(proposal_boxes_coordinate)
    # 最后再合并tag返回
    # proposal_boxes = Lambda(lambda x: tf.concat(x, axis=-1))([proposal_boxes_coordinate, proposal_boxes_tag])

    if stage == 'train':
        # 生成分类和回归目标
        rpn_targets = RpnTarget(batch_size,
                                config.RPN_TRAIN_ANCHORS_PER_IMAGE,
                                name='rpn_target')([
                                    gt_boxes, gt_class_ids, anchors,
                                    anchors_tag
                                ])  # [deltas,cls_ids,indices,..]
        rpn_deltas, rpn_cls_ids, anchor_indices = rpn_targets[:3]
        # 定义rpn损失layer
        cls_loss_rpn = Lambda(lambda x: rpn_cls_loss(*x),
                              name='rpn_class_loss')(
                                  [class_logits, rpn_cls_ids, anchor_indices])
        regress_loss_rpn = Lambda(lambda x: rpn_regress_loss(*x),
                                  name='rpn_bbox_loss')([
                                      boxes_regress, rpn_deltas, anchor_indices
                                  ])

        # 检测网络的分类和回归目标
        detect_targets = DetectTarget(
            batch_size,
            config.TRAIN_ROIS_PER_IMAGE,
            config.ROI_POSITIVE_RATIO,
            name='rcnn_target')([gt_boxes, gt_class_ids, proposal_boxes])
        roi_deltas, roi_class_ids, train_rois = detect_targets[:3]

        # 检测网络
        rcnn_deltas, rcnn_class_logits = rcnn(
            features,
            train_rois,
            config.NUM_CLASSES,
            config.IMAGE_MAX_DIM,
            config.head_fn,
            pool_size=config.POOL_SIZE,
            fc_layers_size=config.RCNN_FC_LAYERS_SIZE)

        # 检测网络损失函数
        regress_loss_rcnn = Lambda(lambda x: detect_regress_loss(*x),
                                   name='rcnn_bbox_loss')([
                                       rcnn_deltas, roi_deltas, roi_class_ids
                                   ])
        cls_loss_rcnn = Lambda(lambda x: detect_cls_loss(*x),
                               name='rcnn_class_loss')(
                                   [rcnn_class_logits, roi_class_ids])
        # 自定义度量命名
        gt_num, positive_num, negative_num, rpn_miss_gt_num, rpn_gt_min_max_iou = rpn_targets[
            3:]
        rcnn_miss_gt_num, rcnn_miss_gt_num_as, gt_min_max_iou, pos_roi_num, neg_roi_num, roi_num = detect_targets[
            3:]
        gt_num = Lambda(lambda x: tf.identity(x),
                        name='identity_gt_num')(gt_num)
        positive_num = Lambda(lambda x: tf.identity(x),
                              name='identity_positive_num')(positive_num)
        negative_num = Lambda(lambda x: tf.identity(x),
                              name='identity_negative_num')(negative_num)
        rpn_miss_gt_num = Lambda(
            lambda x: tf.identity(x),
            name='identity_rpn_miss_gt_num')(rpn_miss_gt_num)
        rpn_gt_min_max_iou = Lambda(
            lambda x: tf.identity(x),
            name='identity_rpn_gt_min_max_iou')(rpn_gt_min_max_iou)
        rcnn_miss_gt_num = Lambda(
            lambda x: tf.identity(x),
            name='identity_rcnn_miss_gt_num')(rcnn_miss_gt_num)
        rcnn_miss_gt_num_as = Lambda(
            lambda x: tf.identity(x),
            name='identity_rcnn_miss_gt_num_as')(rcnn_miss_gt_num_as)
        gt_min_max_iou = Lambda(lambda x: tf.identity(x),
                                name='identity_gt_min_max_iou')(gt_min_max_iou)
        pos_roi_num = Lambda(lambda x: tf.identity(x),
                             name='identity_pos_roi_num')(pos_roi_num)
        neg_roi_num = Lambda(lambda x: tf.identity(x),
                             name='identity_neg_roi_num')(neg_roi_num)
        roi_num = Lambda(lambda x: tf.identity(x),
                         name='identity_roi_num')(roi_num)

        # 构建模型
        model = Model(
            inputs=[input_image, input_image_meta, gt_class_ids, gt_boxes],
            outputs=[
                cls_loss_rpn, regress_loss_rpn, regress_loss_rcnn,
                cls_loss_rcnn
            ] + [
                gt_num, positive_num, negative_num, rpn_miss_gt_num,
                rpn_gt_min_max_iou, roi_num, pos_roi_num, neg_roi_num,
                rcnn_miss_gt_num, rcnn_miss_gt_num_as, gt_min_max_iou
            ])  # 在并行model中所有自定义度量必须在output中
        # 多gpu训练
        if config.GPU_COUNT > 1:
            model = ParallelModel(model, config.GPU_COUNT)
        return model
    else:  # 测试阶段
        # 检测网络
        rcnn_deltas, rcnn_class_logits = rcnn(
            features,
            proposal_boxes,
            config.NUM_CLASSES,
            config.IMAGE_MAX_DIM,
            config.head_fn,
            pool_size=config.POOL_SIZE,
            fc_layers_size=config.RCNN_FC_LAYERS_SIZE)
        # 处理类别相关
        rcnn_deltas = layers.Lambda(lambda x: deal_delta(*x),
                                    name='deal_delta')(
                                        [rcnn_deltas, rcnn_class_logits])
        # 应用分类和回归生成最终检测框
        detect_boxes, class_scores, detect_class_ids, detect_class_logits = ProposalToDetectBox(
            score_threshold=config.DETECTION_MIN_CONFIDENCE,
            output_box_num=config.DETECTION_MAX_INSTANCES,
            iou_threshold=config.DETECTION_NMS_THRESHOLD,
            name='proposals2detectboxes')(
                [rcnn_deltas, rcnn_class_logits, proposal_boxes])
        # 裁剪到窗口内部
        detect_boxes_coordinate, detect_boxes_tag = Lambda(
            lambda x: [x[..., :4], x[..., 4:]])(detect_boxes)
        detect_boxes_coordinate = ClipBoxes()(
            [detect_boxes_coordinate, windows])
        # 最后再合并tag返回
        detect_boxes = Lambda(lambda x: tf.concat(x, axis=-1))(
            [detect_boxes_coordinate, detect_boxes_tag])
        image_meta = Lambda(lambda x: tf.identity(x))(input_image_meta)  # 原样返回
        return Model(inputs=[input_image, input_image_meta],
                     outputs=[
                         detect_boxes, class_scores, detect_class_ids,
                         detect_class_logits, image_meta
                     ])