コード例 #1
0
ファイル: yolo_api.py プロジェクト: zihan987/AlphAction
    def write_results(self,
                      prediction,
                      confidence,
                      num_classes,
                      nms=True,
                      nms_conf=0.4):
        args = self.detector_opt
        #prediction: (batchsize, num of objects, (xc,yc,w,h,box confidence, 80 class scores))
        conf_mask = (prediction[:, :, 4] >
                     confidence).float().float().unsqueeze(2)
        prediction = prediction * conf_mask

        try:
            ind_nz = torch.nonzero(prediction[:, :, 4],
                                   as_tuple=False).transpose(0,
                                                             1).contiguous()
        except:
            return 0

        #the 3rd channel of prediction: (xc,yc,w,h)->(x1,y1,x2,y2)
        box_a = prediction.new(prediction.shape)
        box_a[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2)
        box_a[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2)
        box_a[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2)
        box_a[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2)
        prediction[:, :, :4] = box_a[:, :, :4]

        batch_size = prediction.size(0)

        output = prediction.new(1, prediction.size(2) + 1)
        write = False
        num = 0
        for ind in range(batch_size):
            #select the image from the batch
            image_pred = prediction[ind]

            #Get the class having maximum score, and the index of that class
            #Get rid of num_classes softmax scores
            #Add the class index and the class score of class having maximum score
            max_conf, max_conf_score = torch.max(
                image_pred[:, 5:5 + num_classes], 1)
            max_conf = max_conf.float().unsqueeze(1)
            max_conf_score = max_conf_score.float().unsqueeze(1)
            seq = (image_pred[:, :5], max_conf, max_conf_score)
            #image_pred:(n,(x1,y1,x2,y2,c,s,idx of cls))
            image_pred = torch.cat(seq, 1)

            #Get rid of the zero entries
            non_zero_ind = (torch.nonzero(image_pred[:, 4], as_tuple=False))

            image_pred_ = image_pred[non_zero_ind.squeeze(), :].view(-1, 7)

            #Get the various classes detected in the image
            try:
                img_classes = unique(image_pred_[:, -1])
            except:
                continue

            #WE will do NMS classwise
            #print(img_classes)
            for cls in img_classes:
                if cls == 0:
                    continue
                #get the detections with one particular class
                cls_mask = image_pred_ * (image_pred_[:, -1]
                                          == cls).float().unsqueeze(1)
                class_mask_ind = torch.nonzero(cls_mask[:, -2],
                                               as_tuple=False).squeeze()

                image_pred_class = image_pred_[class_mask_ind].view(-1, 7)

                #sort the detections such that the entry with the maximum objectness
                #confidence is at the top
                conf_sort_index = torch.sort(image_pred_class[:, 4],
                                             descending=True)[1]
                image_pred_class = image_pred_class[conf_sort_index]
                idx = image_pred_class.size(0)

                #if nms has to be done
                if nms:
                    if platform.system() != 'Windows':
                        #We use faster rcnn implementation of nms (soft nms is optional)
                        nms_op = getattr(nms_wrapper, 'nms')
                        #nms_op input:(n,(x1,y1,x2,y2,c))
                        #nms_op output: input[inds,:], inds
                        _, inds = nms_op(image_pred_class[:, :5], nms_conf)

                        image_pred_class = image_pred_class[inds]
                    else:
                        # Perform non-maximum suppression
                        max_detections = []
                        while image_pred_class.size(0):
                            # Get detection with highest confidence and save as max detection
                            max_detections.append(
                                image_pred_class[0].unsqueeze(0))
                            # Stop if we're at the last detection
                            if len(image_pred_class) == 1:
                                break
                            # Get the IOUs for all boxes with lower confidence
                            ious = bbox_iou(max_detections[-1],
                                            image_pred_class[1:], args)
                            # Remove detections with IoU >= NMS threshold
                            image_pred_class = image_pred_class[1:][
                                ious < nms_conf]

                        image_pred_class = torch.cat(max_detections).data

                #Concatenate the batch_id of the image to the detection
                #this helps us identify which image does the detection correspond to
                #We use a linear straucture to hold ALL the detections from the batch
                #the batch_dim is flattened
                #batch is identified by extra batch column

                batch_ind = image_pred_class.new(image_pred_class.size(0),
                                                 1).fill_(ind)
                seq = batch_ind, image_pred_class
                if not write:
                    output = torch.cat(seq, 1)
                    write = True
                else:
                    out = torch.cat(seq, 1)
                    output = torch.cat((output, out))
                num += 1

        if not num:
            return 0
        #output:(n,(batch_ind,x1,y1,x2,y2,c,s,idx of cls))
        return output
コード例 #2
0
def write_results_half(prediction,
                       confidence,
                       num_classes,
                       nms=True,
                       nms_conf=0.4):
    conf_mask = (prediction[:, :, 4] > confidence).half().unsqueeze(2)
    prediction = prediction * conf_mask

    try:
        ind_nz = torch.nonzero(prediction[:, :, 4]).transpose(0,
                                                              1).contiguous()
    except:
        return 0

    box_a = prediction.new(prediction.shape)
    box_a[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2)
    box_a[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2)
    box_a[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2)
    box_a[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2)
    prediction[:, :, :4] = box_a[:, :, :4]

    batch_size = prediction.size(0)

    output = prediction.new(1, prediction.size(2) + 1)
    write = False

    for ind in range(batch_size):
        #select the image from the batch
        image_pred = prediction[ind]

        #Get the class having maximum score, and the index of that class
        #Get rid of num_classes softmax scores
        #Add the class index and the class score of class having maximum score
        max_conf, max_conf_score = torch.max(image_pred[:, 5:5 + num_classes],
                                             1)
        max_conf = max_conf.half().unsqueeze(1)
        max_conf_score = max_conf_score.half().unsqueeze(1)
        seq = (image_pred[:, :5], max_conf, max_conf_score)
        image_pred = torch.cat(seq, 1)

        #Get rid of the zero entries
        non_zero_ind = (torch.nonzero(image_pred[:, 4]))
        try:
            image_pred_ = image_pred[non_zero_ind.squeeze(), :]
        except:
            continue

        #Get the various classes detected in the image
        img_classes = unique(image_pred_[:, -1].long()).half()

        #WE will do NMS classwise
        for cls in img_classes:
            #get the detections with one particular class
            cls_mask = image_pred_ * (image_pred_[:, -1]
                                      == cls).half().unsqueeze(1)
            class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze()

            image_pred_class = image_pred_[class_mask_ind]

            #sort the detections such that the entry with the maximum objectness
            #confidence is at the top
            conf_sort_index = torch.sort(image_pred_class[:, 4],
                                         descending=True)[1]
            image_pred_class = image_pred_class[conf_sort_index]
            idx = image_pred_class.size(0)

            #if nms has to be done
            if nms:
                #For each detection
                for i in range(idx):
                    #Get the IOUs of all boxes that come after the one we are looking at
                    #in the loop
                    try:
                        ious = bbox_iou(image_pred_class[i].unsqueeze(0),
                                        image_pred_class[i + 1:])
                    except ValueError:
                        break

                    except IndexError:
                        break

                    #Zero out all the detections that have IoU > treshhold
                    iou_mask = (ious < nms_conf).half().unsqueeze(1)
                    image_pred_class[i + 1:] *= iou_mask

                    #Remove the non-zero entries
                    non_zero_ind = torch.nonzero(
                        image_pred_class[:, 4]).squeeze()
                    image_pred_class = image_pred_class[non_zero_ind]

            #Concatenate the batch_id of the image to the detection
            #this helps us identify which image does the detection correspond to
            #We use a linear straucture to hold ALL the detections from the batch
            #the batch_dim is flattened
            #batch is identified by extra batch column
            batch_ind = image_pred_class.new(image_pred_class.size(0),
                                             1).fill_(ind)
            seq = batch_ind, image_pred_class

            if not write:
                output = torch.cat(seq, 1)
                write = True
            else:
                out = torch.cat(seq, 1)
                output = torch.cat((output, out))

    return output