Ejemplo n.º 1
0
 def __init__(self):
     super().__init__()
     self.resnet50 = ResNet50(config.backbone_freeze_at, False)
     self.FPN = FPN(self.resnet50, 2, 6)
     self.RPN = RPN(config.rpn_channel)
     self.RCNN = RCNN()
     self.Cascade_0 = Cascade('cascade_0')
Ejemplo n.º 2
0
 def __init__(self):
     super().__init__()
     # ----------------------- build the backbone ------------------------ #
     self.resnet50 = ResNet50()
     # ------------ freeze the weights of resnet stage1 and stage 2 ------ #
     if config.backbone_freeze_at >= 1:
         for p in self.resnet50.conv1.parameters():
             p.requires_grad = False
     if config.backbone_freeze_at >= 2:
         for p in self.resnet50.layer1.parameters():
             p.requires_grad = False
     # -------------------------- build the FPN -------------------------- #
     self.backbone = FPN(self.resnet50)
     # -------------------------- build the RPN -------------------------- #
     self.RPN = RPN(config.rpn_channel)
     # ----------------------- build the RCNN head ----------------------- #
     self.Cascade_0 = Cascade('cascade_0')
     #self.Cascade_1 = Cascade('cascade_1')
     self.RCNN = RCNN()
     # -------------------------- input Tensor --------------------------- #
     self.inputs = {
         "image": mge.tensor(
             np.random.random([2, 3, 224, 224]).astype(np.float32), dtype="float32",
         ),
         "im_info": mge.tensor(
             np.random.random([2, 5]).astype(np.float32), dtype="float32",
         ),
         "gt_boxes": mge.tensor(
             np.random.random([2, 100, 5]).astype(np.float32), dtype="float32",
         ),
     }
Ejemplo n.º 3
0
 def __init__(self):
     super().__init__()
     self.resnet50 = ResNet50(config.backbone_freeze_at, False)
     self.FPN = FPN(self.resnet50, 3, 7)
     self.R_Head = RetinaNet_Head()
     self.R_Anchor = RetinaNet_Anchor()
     self.R_Criteria = RetinaNet_Criteria()
Ejemplo n.º 4
0
 def __init__(self):
     super().__init__()
     self.resnet50 = ResNet50(config.backbone_freeze_at, False)
     self.FPN = FPN(self.resnet50, 2, 6)
     self.RPN = RPN(config.rpn_channel)
     self.RCNN = RCNN()
     assert config.num_classes == 2, 'Only support two class(1fg/1bg).'
Ejemplo n.º 5
0
def create_triq_model(n_quality_levels,
                      input_shape=(None, None, 3),
                      backbone='resnet50',
                      transformer_params=(2, 32, 8, 64),
                      maximum_position_encoding=193,
                      vis=False):
    """
    Creates the hybrid TRIQ model
    :param n_quality_levels: number of quality levels, use 5 to predict quality distribution
    :param input_shape: input shape
    :param backbone: bakbone nets, supports ResNet50 and VGG16 now
    :param transformer_params: Transformer parameters
    :param maximum_position_encoding: the maximal number of positional embeddings
    :param vis: flag to visualize attention weight maps
    :return: TRIQ model
    """
    inputs = Input(shape=input_shape)
    if backbone == 'resnet50':
        backbone_model = ResNet50(inputs,
                                  return_feature_maps=False,
                                  return_last_map=True)
    elif backbone == 'vgg16':
        backbone_model = VGG16(inputs, return_last_map=True)
    else:
        raise NotImplementedError

    C5 = backbone_model.output

    dropout_rate = 0.1

    transformer = TriQImageQualityTransformer(
        num_layers=transformer_params[0],
        d_model=transformer_params[1],
        num_heads=transformer_params[2],
        mlp_dim=transformer_params[3],
        dropout=dropout_rate,
        n_quality_levels=n_quality_levels,
        maximum_position_encoding=maximum_position_encoding,
        vis=vis)
    outputs = transformer(C5)

    model = Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model
Ejemplo n.º 6
0
    def __init__(self):
        super().__init__()
        # ----------------------- build the backbone ------------------------ #
        self.resnet50 = ResNet50()
        # ------------ freeze the weights of resnet stage1 and stage 2 ------ #
        if config.backbone_freeze_at >= 1:
            for p in self.resnet50.conv1.parameters():
                # p.requires_grad = False
                p = p.detach()
        if config.backbone_freeze_at >= 2:
            for p in self.resnet50.layer1.parameters():
                # p.requires_grad = False
                p = p.detach()
        # -------------------------- build the FPN -------------------------- #
        self.backbone = FPN(self.resnet50)
        # -------------------------- build the RPN -------------------------- #
        # self.RPN = RPN(config.rpn_channel)
        self.head = RetinaNetHead()
        # -------------------------- buid the anchor generator -------------- #
        self.anchor_generator = RetinaNetAnchorV2()

        # -------------------------- buid the criteria ---------------------- #
        self.criteria = RetinaNetCriteriaV2()
        # -------------------------- input Tensor --------------------------- #
        self.inputs = {
            "image":
            mge.tensor(
                np.random.random([2, 3, 756, 1400]).astype(np.float32),
                dtype="float32",
            ),
            "im_info":
            mge.tensor(
                np.random.random([2, 6]).astype(np.float32),
                dtype="float32",
            ),
            "gt_boxes":
            mge.tensor(
                np.random.random([2, 500, 5]).astype(np.float32),
                dtype="float32",
            ),
        }