예제 #1
0
backbone = torch.nn.Sequential(net.conv1,net.bn1,net.relu,net.maxpool,
                               net.layer1,net.layer2,net.layer3,net.layer4)

backbone.out_channels = 512

anchor_generator = AnchorGenerator(sizes=((32,64,128,256,512)),
                                   aspect_ratios=((0.5,1.0,2.0)))
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)

model = FasterRCNN(backbone,
                   num_classes=91,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
model = model.cuda()



criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9,weight_decay=5e-4)
solver = Solver(model,"./models/rcnn_4.pth",trainLoader,valLoader,criterion,optimizer,
                logfile="./logs/rcnn_resnet18.log",
                print_freq=20,save_name="rcnn")


solver.train(4)



# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)

# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   min_size=800,
                   max_size=1200,
                   box_roi_pool=roi_pooler,
                   box_detections_per_img=200)

model.cuda()
model.load_state_dict(torch.load('./checkpoint/efficient_model_L_7.pth'))
model.eval()

start = time.time()
print(img.size())
results = model([img.cuda()])
open_cv_image = np.array(imge)
open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_RGB2BGR)
for box in results[0]['boxes']:
    box = box[:4].tolist()
    cv2.rectangle(open_cv_image, (int(box[0]), int(
        box[1]), int(box[2]) - int(box[0]), int(box[3]) - int(box[1])),
                  (255, 225, 0), 2)
cv2.imshow("sd", open_cv_image)
cv2.imwrite("demo.jpg", open_cv_image)