Esempio n. 1
0
def demo_plt(img_id=0):
    net = build_msc('test', 11)  # initialize SSD
    print(net)
    net.load_weights(
        '/media/sunwl/Datum/Projects/GraduationProject/Multi_Scale_CNN_512/weights/v2_vhr.pth'
    )
    testset = VHRDetection(VHRroot, ['test2'], None, AnnotationTransform_VHR)
    # image = testset.pull_image(img_id)
    image = cv2.imread('demos/038.jpg')
    rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # View the sampled input image before transform
    plt.figure(figsize=(10, 10))
    plt.imshow(rgb_image)

    x = cv2.resize(rgb_image, (512, 512)).astype(np.float32)
    x -= (104.0, 117.0, 123.0)
    x = x.astype(np.float32)
    x = x[:, :, ::-1].copy()
    x = torch.from_numpy(x).permute(2, 0, 1)

    xx = Variable(x.unsqueeze(0))  # wrap tensor in Variable
    if torch.cuda.is_available():
        xx = xx.cuda()
    y = net(xx)
    #
    plt.figure(figsize=(10, 10))
    colors = plt.cm.hsv(np.linspace(0, 1, 11)).tolist()
    plt.imshow(rgb_image.astype(np.uint8))  # plot the image for matplotlib
    currentAxis = plt.gca()

    detections = y.data

    # scale each detection back up to the image
    scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
    for i in range(detections.size(1)):
        j = 0
        while detections[0, i, j, 0] >= 0.5:
            score = detections[0, i, j, 0]
            label_name = labels[i - 1]
            display_txt = '%s: %.2f' % (label_name, score)
            pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
            color = colors[i]
            coords = (pt[0], pt[1]), pt[2] - pt[0] + 1, pt[3] - pt[1] + 1
            currentAxis.add_patch(
                plt.Rectangle(*coords,
                              fill=False,
                              edgecolor=color,
                              linewidth=2))
            currentAxis.text(pt[0],
                             pt[1],
                             display_txt,
                             bbox={
                                 'facecolor': color,
                                 'alpha': 0.5
                             })
            j += 1
    plt.savefig(
        '/media/sunwl/Datum/Projects/GraduationProject/Multi_Scale_CNN_512/outputs/{:03}.png'
        .format(img_id))
    plt.show()
Esempio n. 2
0
def demo_cv2(img_id=0):
    net = build_msc('test', 11)  # initialize SSD
    print(net)
    # net.load_weights('/media/sunwl/Datum/Projects/GraduationProject/Multi_Scale_CNN_512/weights/v2_vhr.pth')
    net.load_weights(
        '/media/sunwl/Datum/Projects/GraduationProject/Multi_Scale_CNN_512/weights/msc512_vhr_80000.pth'
    )
    testset = VHRDetection(VHRroot, ['test2'], None, AnnotationTransform_VHR)
    image = testset.pull_image(img_id)
    # image = cv2.imread('demos/089.jpg')
    rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    x = cv2.resize(rgb_image, (512, 512)).astype(np.float32)
    x -= (104.0, 117.0, 123.0)
    x = x.astype(np.float32)
    x = x[:, :, ::-1].copy()
    x = torch.from_numpy(x).permute(2, 0, 1)

    xx = Variable(x.unsqueeze(0))  # wrap tensor in Variable
    if torch.cuda.is_available():
        xx = xx.cuda()
    y = net(xx)
    colors = plt.cm.hsv(np.linspace(0, 1, 11)).tolist()
    detections = y.data

    # scale each detection back up to the image
    scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
    bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
    im2show = np.copy(bgr_image)
    for i in range(detections.size(1)):
        j = 0
        while detections[0, i, j, 0] >= 0.6:
            score = detections[0, i, j, 0]
            label_name = labels[i - 1]
            display_txt = '%s: %.2f' % (label_name, score)
            pt = (detections[0, i, j, 1:] * scale).cpu().numpy()
            color = colors[i]
            color = [int(c * 255) for c in color[:3]]
            coords = pt[0], pt[1], pt[2], pt[3]
            cv2.rectangle(im2show,
                          coords[0:2],
                          coords[2:4],
                          color,
                          thickness=2)
            cv2.putText(im2show,
                        display_txt, (int(coords[0]), int(coords[1]) - 3),
                        cv2.FONT_HERSHEY_PLAIN,
                        1.0,
                        color,
                        thickness=1)
            j += 1
    cv2.imshow('original', bgr_image)
    cv2.imshow('demo', im2show)
    cv2.imwrite(
        os.path.join(
            '/media/sunwl/Datum/Projects/GraduationProject/Multi_Scale_CNN_512',
            "outputs", "{:03d}.jpg".format(img_id)), im2show)
    cv2.waitKey(0)
Esempio n. 3
0
num_classes = len(VHR_CLASSES) + 1
batch_size = args.batch_size
accum_batch_size = 16
iter_size = accum_batch_size / batch_size
max_iter = 120000
weight_decay = args.weight_decay
stepvalues = (80000, 100000, 120000)
gamma = args.gamma
momentum = args.momentum

if args.visdom:
    import visdom

    viz = visdom.Visdom()

msc_net = build_msc('train', num_classes)
net = msc_net

if args.cuda:
    net = torch.nn.DataParallel(msc_net)
    cudnn.benchmark = True

if args.cuda:
    net = net.cuda()


def xavier(param):
    init.xavier_uniform(param)


def weights_init(model):
Esempio n. 4
0
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    evaluate_detections(all_boxes, output_dir, dataset)


def evaluate_detections(box_list, output_dir, dataset):
    write_vhr_results_file(box_list, dataset)
    do_python_eval(output_dir)


if __name__ == '__main__':
    # load net
    num_classes = len(VHR_CLASSES) + 1  # +1 background
    net = build_msc('test', num_classes)  # initialize SSD
    net.load_state_dict(torch.load(args.trained_model))
    net.eval()
    print('Finished loading model!')
    print(net)
    # load data
    dataset = VHRDetection(args.vhr_root, ['test2'],
                           BaseTransform(512, dataset_mean),
                           AnnotationTransform_VHR())

    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    # evaluation
    test_net(args.save_folder,
             net,