Exemplo n.º 1
0
Arquivo: hub.py Projeto: zivzone/d2go
def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False):
    hidden_dim = 256
    backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation)
    pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
    backbone_with_pos_enc = Joiner(backbone, pos_enc)
    backbone_with_pos_enc.num_channels = backbone.num_channels
    transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True)
    detr = DETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100)
    if mask:
        return DETRsegm(detr)
    return detr
Exemplo n.º 2
0
def build_backbone(lr_backbone, masks, backbone, dilation, hidden_dim,
                   position_embedding):
    position_embedding = build_position_encoding(hidden_dim,
                                                 position_embedding)
    train_backbone = lr_backbone > 0
    return_interm_layers = masks
    if 'resnet' in backbone:
        backbone = Backbone(backbone, train_backbone, return_interm_layers,
                            dilation)
    elif 'mobilenet' in backbone:
        backbone = MNetBackbone(train_backbone, return_interm_layers)
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model
Exemplo n.º 3
0
 def test_backbone_script(self):
     backbone = Backbone("resnet50", True, False, False)
     torch.jit.script(backbone)  # noqa