Exemplo n.º 1
0
def detect(net,
           im_path,
           device,
           thresh=0.5,
           visualize=False,
           timers=None,
           pyramid=False,
           visualization_folder=None):
    """
    Main module to detect faces
    :param net: The trained network
    :param im_path: The path to the image
    :param device: GPU or CPU device to be used
    :param thresh: Detection with a less score than thresh are ignored
    :param visualize: Whether to visualize the detections
    :param timers: Timers for calculating detect time (if None new timers would be created)
    :param pyramid: Whether to use pyramid during inference
    :param visualization_folder: If set the visualizations would be saved in this folder (if visualize=True)
    :return: cls_dets (bounding boxes concatenated with scores) and the timers
    """

    if not timers:
        timers = {'detect': Timer(), 'misc': Timer()}

    im = cv2.imread(im_path)
    imfname = os.path.basename(im_path)
    sys.stdout.flush()
    timers['detect'].tic()

    if not pyramid:
        im_scale = _compute_scaling_factor(im.shape, cfg.TEST.SCALES[0],
                                           cfg.TEST.MAX_SIZE)
        im_blob = _get_image_blob(im, [im_scale])[0]
        ssh_rois = forward(net, im_blob, im_scale, device, thresh)

    else:
        assert False, 'not implement'

    timers['detect'].toc()
    timers['misc'].tic()

    nms_keep = nms(ssh_rois, cfg.TEST.RPN_NMS_THRESH)
    cls_dets = ssh_rois[nms_keep, :]

    if visualize:
        plt_name = os.path.splitext(imfname)[0] + '_detections_{}'.format(
            "SSH pytorch")
        visusalize_detections(im,
                              cls_dets,
                              plt_name=plt_name,
                              visualization_folder=visualization_folder)
    timers['misc'].toc()
    return cls_dets, timers
Exemplo n.º 2
0
    def preprocessImage(self, im):
        """
        Perform basic processing on input image such as 
        scaling the image to appropriate sizes

        Arguments:
            im {[numpy.ndarray]} -- The image/frame to be processed by SSH detector

        Returns:
            im_info [torch.Tensor] -- array containing shapes of image_blobs
            im_data [torch.Tensor] -- the actual image data extracted from the image_blob
            im_scale [int] -- The scale factor
        """
        im_scale = _compute_scaling_factor(
            im.shape, cfg.TEST.SCALES[0], cfg.TEST.MAX_SIZE)
        im_blob = _get_image_blob(im, [im_scale])[0]
        im_info = np.array(
            [[im_blob['data'].shape[2], im_blob['data'].shape[3], im_scale]])
        im_data = im_blob['data']

        im_info = torch.from_numpy(im_info).to(self.device)
        im_data = torch.from_numpy(im_data).to(self.device)
        return im_info, im_data, im_scale
Exemplo n.º 3
0
    if (os.path.isfile(saved_model_path)):
        check_point = load_check_point(saved_model_path)
        net.load_state_dict(check_point['model_state_dict'])
        for param_tensor in net.state_dict():
            print(param_tensor, "\t", net.state_dict()[param_tensor].size())

    net.to(device)
    net.eval()

    with torch.no_grad():

        im = cv2.imread(filepath)
        im_scale = _compute_scaling_factor(im.shape, cfg.TEST.SCALES[0],
                                           cfg.TEST.MAX_SIZE)
        im_blob = _get_image_blob(im, [im_scale])[0]

        im_info = np.array(
            [[im_blob['data'].shape[2], im_blob['data'].shape[3], im_scale]])
        im_data = im_blob['data']

        im_info = torch.from_numpy(im_info).to(device)
        im_data = torch.from_numpy(im_data).to(device)

        batch_size = im_data.size()[0]
        ssh_rois = net(im_data, im_info)

        inds = (ssh_rois[:, :, 4] > thresh)
        # inds=inds.unsqueeze(2).expand(batch_size,inds.size()[1],5)
        #
        # ssh_roi_keep = ssh_rois[inds].view(batch_size,-1,5)
Exemplo n.º 4
0
def train(net, optimizer, imdb, roidb, arg):
    max_iters = arg.max_iters
    iter = 1
    display_interval = cfg.TRAIN.DISPLAY
    train_data = RoIDataLayer(roidb, imdb.num_classes)

    loss_sum = 0
    m3_ssh_cls_loss_sum = 0
    m3_bbox_loss_sum = 0
    m2_ssh_cls_loss_sum = 0
    m2_bbox_loss_sum = 0
    m1_ssh_cls_loss_sum = 0
    m1_bbox_loss_sum = 0
    timer = {"forward": Timer(), "data": Timer()}

    im = cv2.imread("/home/dwang/SynologyDrive/pyt_example/data/datasets/wider/WIDER_train/images/28--Sports_Fan/28_Sports_Fan_Sports_Fan_28_39.jpg")



    im_scale = _compute_scaling_factor(im.shape, cfg.TRAIN.SCALES[0], cfg.TRAIN.MAX_SIZE)
    bbox = [122, 2, 752, 688]
    bbox = np.array([[bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3],1]], np.float)
    bbox *=im_scale
    bbox [:,4]=1
    im_blob = _get_image_blob(im, [im_scale])[0]

    im_info = np.array([[im_blob['data'].shape[2], im_blob['data'].shape[3], im_scale]])
    im_data = im_blob['data']

    im_data = torch.from_numpy(im_data).to(device)
    # add a batch dimension
    im_info = torch.from_numpy(im_info).to(device)
    gt_boxes = torch.from_numpy(bbox).to(device).unsqueeze(0).float()


    # img = np.squeeze(blobs['data'])
    #
    # img=img.transpose(1,2,0)
    # for i in range(len(blobs['gt_boxes'])):
    #     pt1 = tuple(blobs['gt_boxes'][i, 0:2])
    #     pt2 = tuple(blobs['gt_boxes'][i, 2:4])
    #     cv2.rectangle(img, pt1, pt2, (255, 255, 255))
    # cv2.imwrite("train.jpg", img)

    optimizer.zero_grad()



    m3_ssh_cls_loss, m2_ssh_cls_loss, m1_ssh_cls_loss, \
    m3_bbox_loss, m2_bbox_loss, m1_bbox_loss = net(im_data, im_info, gt_boxes)

    loss = (m3_ssh_cls_loss + m2_ssh_cls_loss + m1_ssh_cls_loss + \
            m3_bbox_loss + m2_bbox_loss + m1_bbox_loss)

    m3_ssh_cls_loss_sum += m3_ssh_cls_loss.item()
    m3_bbox_loss_sum += m3_bbox_loss.item()
    m2_ssh_cls_loss_sum += m2_ssh_cls_loss.item()
    m2_bbox_loss_sum += m2_bbox_loss.item()
    m1_ssh_cls_loss_sum += m1_ssh_cls_loss.item()
    m1_bbox_loss_sum += m1_bbox_loss.item()

    loss.backward()

    torch.nn.utils.clip_grad_norm_(net.parameters(),0.5)
    optimizer.step()

    loss_sum += loss.item()

    timer["forward"].toc()

    # if m3_bbox_loss.item() == 0 :
    #     img = np.squeeze(blobs['data'])
    #
    #     img=img.transpose(1,2,0)
    #     for i in range(len(blobs['gt_boxes'])):
    #         pt1 = tuple(blobs['gt_boxes'][i, 0:2])
    #         pt2 = tuple(blobs['gt_boxes'][i, 2:4])
    #         cv2.rectangle(img, pt1, pt2, (255, 255, 255))
    #     cv2.imwrite("zero/loss_0_{}.jpg".format(iter), img)
    #     f = open("zero/loss_0_{}.txt".format(iter), "a")
    #     f.write(blobs['file_path'])


    if (iter % display_interval == 0):
        loss_average = loss_sum / display_interval
        m3_ssh_cls_loss_average = m3_ssh_cls_loss_sum / display_interval
        m3_bbox_loss_average = m3_bbox_loss_sum / display_interval
        m2_ssh_cls_loss_average = m2_ssh_cls_loss_sum / display_interval
        m2_bbox_loss_average = m2_bbox_loss_sum / display_interval
        m1_ssh_cls_loss_average = m1_ssh_cls_loss_sum / display_interval
        m1_bbox_loss_average = m1_bbox_loss_sum / display_interval

        loss_sum = 0
        m3_ssh_cls_loss_sum = 0
        m3_bbox_loss_sum = 0
        m2_ssh_cls_loss_sum = 0
        m2_bbox_loss_sum = 0
        m1_ssh_cls_loss_sum = 0
        m1_bbox_loss_sum = 0

        print("------------------------iteration {}-----------{} left---------".format(iter, max_iters - iter))
        print("Average per iter: {:.4f} second.   ETA: {:.4f} hours".format(timer["forward"].average_time,
                                                                            (max_iters - iter) * (
                                                                                timer["forward"].average_time) / (
                                                                                        60 * 60)))
        print("Average data load time: {:.4f}".format(timer["data"].average_time))
        print('loss:{}\nm3 cls:{}\nm3 box:{}\nm2 cls:{}\nm2 box:{}'
              '\nm1 cls:{}\nm1 box:{} '.format(loss_average, m3_ssh_cls_loss_average, m3_bbox_loss_average,
                                               m2_ssh_cls_loss_average, m2_bbox_loss_average,
                                               m1_ssh_cls_loss_average, m1_bbox_loss_average))
        timer["forward"].reset()
        timer["data"].reset()

    if iter % cfg.TRAIN.CHECKPOINT == 0:
        save_check_point(arg.model_save_path, iter, loss, net, optimizer)
        print("check point saved")