Exemple #1
0
def main():
    num_classes = 91
    device = torch.device('cuda')

    backbone = build_backbone()

    transformer = Transformer(
        d_model=256,
        dropout=0.1,
        nhead=8,
        dim_feedforward=2048,
        num_encoder_layers=6,
        num_decoder_layers=6,
        normalize_before=False,
        return_intermediate_dec=True,
    )

    model = DETR(
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=100,
        aux_loss=True,
    )
    checkpoint = torch.load('./detr-r50-e632da11.pth')
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    model.eval()

    gen_wts(model, "detr")
Exemple #2
0
def create_model(weights):
    backbone = build_backbone()
    transformer = Transformer(d_model=256, return_intermediate_dec=True)

    model = DETR(backbone, transformer, num_classes=91, num_queries=100)

    checkpoint = torch.load(weights, map_location='cpu')['model']
    model.load_state_dict(checkpoint)

    return model
Exemple #3
0
    aux_loss = True

    backbone = build_backbone(lr_backbone, masks, backbone, dilation, hidden_dim, position_embedding)
    transformer = build_transformer(hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm)
    model = DETR(
        backbone,
        transformer,
        num_classes=num_classes,
        num_queries=num_queries,
        aux_loss=aux_loss,
    )
    transform = make_coco_transforms()
    postprocessors = PostProcess()

    checkpoint = torch.load('/home/palm/PycharmProjects/detr/snapshots/1/checkpoint00295.pth')
    model.load_state_dict(checkpoint['model'])

    train_ints, valid_ints, labels, max_box_per_image = create_csv_training_instances(
        '/home/palm/PycharmProjects/algea/dataset/train_annotations',
        '/home/palm/PycharmProjects/algea/dataset/test_annotations',
        '/home/palm/PycharmProjects/algea/dataset/classes',
    )
    # os.listdir()
    all_detections = []
    all_annotations = []
    model.cuda()
    for instance in valid_ints:
        t = time.time()
        all_annotation = all_annotation_from_instance(instance)
        target_image_ori = Image.open(instance["filename"])
        target_image = transform(target_image_ori)