Ejemplo n.º 1
0
    def __init__(self,
                 input_shape,
                 class_num,
                 anchors,
                 batch_size=5,
                 is_predict=False,
                 train_taegets=None):
        self.__input_shape = input_shape

        if train_taegets is None:
            train_taegets = []
        train_backbone = TrainTarget.BACKBONE in train_taegets
        train_rpn = TrainTarget.RPN in train_taegets
        train_head = TrainTarget.HEAD in train_taegets

        inputs = []
        outputs = []

        inputs_images = Input(shape=self.__input_shape)
        inputs += [inputs_images]

        resnet = ResNet(inputs_images.get_shape(),
                        input_layers=inputs_images,
                        trainable=train_backbone).get_residual_network()

        rpn = RegionproposalNet(resnet.get_shape(),
                                anchors,
                                input_layers=inputs_images,
                                image_shape=self.__input_shape,
                                prev_layers=resnet,
                                batch_size=batch_size,
                                is_predict=is_predict,
                                trainable=train_rpn).get_network()
        rpn_cls_probs, rpn_regions, rpn_prop_regs = rpn

        if train_rpn and not is_predict:
            inputs_rp_cls = Input(shape=[None, 1], dtype='int32')
            inputs_rp_reg = Input(shape=[None, 4], dtype='float32')
            inputs += [inputs_rp_cls, inputs_rp_reg]

            rp_cls_losses = RPClassLoss()([inputs_rp_cls, rpn_cls_probs])
            rp_reg_losses = RPRegionLoss()(
                [inputs_rp_cls, inputs_rp_reg, rpn_regions])
            outputs += [rp_cls_losses, rp_reg_losses]

        if train_head and not is_predict:
            inputs_cls = Input(shape=[None, 1], dtype='int32')
            inputs_reg = Input(shape=[None, 4], dtype='float32')
            inputs += [inputs_cls, inputs_reg]

            dtr = DetectionTargetRegion(
                positive_threshold=0.5,
                positive_ratio=0.33,
                image_shape=self.__input_shape,
                batch_size=batch_size,
                exclusion_threshold=0.1,
                count_per_batch=64)([inputs_cls, inputs_reg, rpn_prop_regs])
            dtr_cls_labels, dtr_offsets_labels, dtr_regions = dtr
            clsses, offsets = self.__head_net(resnet,
                                              dtr_regions,
                                              class_num,
                                              batch_size=batch_size)

            cls_losses = ClassLoss()([dtr_cls_labels, clsses])
            reg_losses = RegionLoss()(
                [dtr_cls_labels, dtr_offsets_labels, offsets])
            outputs += [cls_losses, reg_losses]

        if is_predict:
            clsses, offsets = self.__head_net(resnet,
                                              rpn_prop_regs,
                                              class_num,
                                              batch_size=batch_size)
            outputs = [rpn_prop_regs, clsses, offsets]

        self.__network = outputs
        self.__model = Model(inputs=inputs, outputs=outputs)

        for output in outputs:
            self.__model.add_loss(tf.reduce_mean(output))