Пример #1
0
def training(epoch, device, model, dataset_folder, sample_duration, spatial_transform, temporal_transform, boxes_file, splt_txt_path, cls2idx, batch_size, n_threads, lr, mode = 1):

    vid_name_loader = video_names(dataset_folder, split_txt_path, boxes_file, vid2idx, mode='train', classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=batch_size,
                                              shuffle=True, num_workers=32, pin_memory=True)
    # data_loader = torch.utils.data.DataLoader(data, batch_size=2,
    #                                           shuffle=True, num_workers=0, pin_memory=True)

    model.train()
    loss_temp = 0
    
    ## 2 rois : 1450
    for step, data  in enumerate(data_loader):

        # if step == 2:
        #     break
        # clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data
        vid_id, clips, boxes, n_frames, n_actions, h, w, target =data
        
        vid_id = vid_id.to(device)
        clips_ = clips.to(device)
        boxes  = boxes.to(device)
        n_frames = n_frames.to(device)
        n_actions = n_actions.int().to(device)
        im_info = torch.stack([h,w],dim=1).to(device)

        # print('vid_id :',vid_id)
        # print('clips_ :',clips_.shape)
        # print('boxes.shape :',boxes.shape)
        # print('n_frames :',n_frames)
        # print('n_actions :',n_actions)
        # print('im_info :',im_info)

        inputs = Variable(clips_)

        tubes,  \
        prob_out, cls_loss =  model(n_devs, dataset_folder, \
                                vid_names, clips, vid_id,  \
                                boxes, \
                                mode, cls2idx, n_actions, n_frames, h, w)

        # print('prob_out.shape :',prob_out.shape)
        
        loss = cls_loss.mean()
        loss_temp += loss.item()

        # backw\ard
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        exit(-1)
    print('Train Epoch: {} \tLoss: {:.6f}\t lr : {:.6f}'.format(
        epoch+1,loss_temp/(step+1), lr))

    return model, loss_temp
Пример #2
0
def validation(epoch, device, model, dataset_folder, sample_duration,
               spatial_transform, temporal_transform, boxes_file,
               splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5  # Intersection Over Union thresh
    iou_thresh_4 = 0.4  # Intersection Over Union thresh
    iou_thresh_3 = 0.3  # Intersection Over Union thresh

    vid_name_loader = video_names(dataset_folder,
                                  split_txt_path,
                                  boxes_file,
                                  vid2idx,
                                  mode='test',
                                  classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(
        vid_name_loader,
        batch_size=n_devs,
        num_workers=2 * n_devs,
        pin_memory=True,
        shuffle=True)  # reset learning rate
    # data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=1, num_workers=8*n_devs, pin_memory=True,
    #                                           shuffle=True)    # reset learning rate
    model.eval()

    true_pos = 0
    false_neg = 0

    true_pos_4 = 0
    false_neg_4 = 0

    true_pos_3 = 0
    false_neg_3 = 0

    # for precision
    preds = 0
    tp = 0
    tp_4 = 0
    tp_3 = 0
    fp = 0
    fp_4 = 0
    fp_3 = 0
    fn = 0
    fn_4 = 0
    fn_3 = 0

    ## 2 rois : 1450
    tubes_sum = 0
    for step, data in enumerate(data_loader):

        # if step == 1:
        #     break
        print('step =>', step)

        vid_id, clips, boxes, n_frames, n_actions, h, w, target = data

        vid_id = vid_id.int()
        clips = clips.to(device)
        boxes = boxes.to(device)
        n_frames = n_frames.to(device)
        n_actions = n_actions.int().to(device)
        target = target.to(device)
        im_info = torch.cat(
            [h, w,
             torch.ones(clips.size(0)).long() * clips.size(2)]).to(device)
        mode = 'test'
        with torch.no_grad():
            tubes,  \
            prob_out, tubes_size =  model(n_devs, dataset_folder, \
                                    vid_names, clips, vid_id,  \
                                    None, \
                                    mode, cls2idx, None, n_frames, h, w)

        # print('tubes.shape :',tubes.dim())
        # print('prob_out.shape :',prob_out.shape)
        # print('clips.size(0) :',clips.size(0))
        # print('clips.size(0) :',clips.shape)
        # exit(-1)
        if tubes.dim() == 1:

            for i in range(clips.size(0)):

                box = boxes[i, :n_actions, :n_frames, :4].contiguous()
                box = box.view(-1, n_frames * 4)

                non_empty_indices = box.ne(0).any(dim=1).nonzero().view(-1)
                n_elems = non_empty_indices.nelement()
                false_neg += n_elems
                false_neg_4 += n_elems
                false_neg_3 += n_elems
            continue

        prob_out = F.softmax(prob_out, 2)
        _, predictions = torch.max(prob_out, dim=2)

        for i in range(clips.size(0)):

            targ = target[i].expand(n_actions[i])

            tubes_ = tubes[
                i, :tubes_size[i].int().item(), :n_frames[i].int().item()]
            preds_ = predictions[i, :tubes_size[i].int().item()]

            box = boxes[i, :n_actions[i], :n_frames[i], :4].contiguous()
            box = box.view(-1, n_frames[i] * 4).contiguous().type_as(tubes)
            overlaps = tube_overlaps(
                tubes_.view(-1, n_frames[i] * 4).float(),
                box.view(-1, n_frames[i] * 4).float())
            gt_max_overlaps, argmax_gt_overlaps = torch.max(overlaps, 0)
            max_overlaps, argmax_overlaps = torch.max(overlaps, 1)
            # print('max_overlaps :',max_overlaps.shape)
            # print('gt_max_overlaps.shape :',gt_max_overlaps.shape)
            # print('targ :',targ)
            non_empty_indices = box.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()
            preds += n_elems
            tubes_sum += tubes.size(0)

            # 0.5 thresh
            gt_max_overlaps_ = torch.where(
                gt_max_overlaps > iou_thresh, gt_max_overlaps,
                torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            detected = gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos += detected
            false_neg += n_elems - detected

            max_overlaps_ = torch.where(
                max_overlaps > iou_thresh, max_overlaps,
                torch.zeros_like(max_overlaps).type_as(max_overlaps))
            non_zero = max_overlaps_.nonzero().view(-1)
            bg_idx = max_overlaps_.eq(0).nonzero().view(-1)
            fn += preds_[bg_idx].ne(0).sum().item(
            )  # add to false negative all non-background tubes with no gt tube overlaping
            predictions_ = preds_[non_zero]  # overlaping predictions

            fn += (predictions_ == targ[argmax_overlaps[non_zero]]).eq(0).sum()

            predictions_ = predictions_[(predictions_ == targ[
                argmax_overlaps[non_zero]]).ne(0).nonzero().view(-1)]
            unique_labels = torch.unique(predictions_)  # unique labels
            for i in unique_labels:
                fp += predictions_.eq(i).sum().item() - 1
                tp += 1

            # 0.4 thresh
            gt_max_overlaps_ = torch.where(
                gt_max_overlaps > iou_thresh_4, gt_max_overlaps,
                torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            detected = gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_4 += detected
            false_neg_4 += n_elems - detected

            max_overlaps_ = torch.where(
                max_overlaps > iou_thresh_4, max_overlaps,
                torch.zeros_like(max_overlaps).type_as(max_overlaps))
            non_zero = max_overlaps_.nonzero().view(-1)
            bg_idx = max_overlaps_.eq(0).nonzero().view(-1)
            fn_4 += preds_[bg_idx].ne(0).sum().item(
            )  # add to false negative all non-background tubes with no gt tube overlaping
            predictions_ = preds_[non_zero]  # overlaping predictions

            fn_4 += (
                predictions_ == targ[argmax_overlaps[non_zero]]).eq(0).sum()

            predictions_ = predictions_[(predictions_ == targ[
                argmax_overlaps[non_zero]]).ne(0).nonzero().view(-1)]
            unique_labels = torch.unique(predictions_)  # unique labels
            for i in unique_labels:
                fp_4 += predictions_.eq(i).sum().item() - 1
                tp_4 += 1

            # 0.3 thresh
            gt_max_overlaps_ = torch.where(
                gt_max_overlaps > iou_thresh_3, gt_max_overlaps,
                torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))

            detected = gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_3 += detected
            false_neg_3 += n_elems - detected

            max_overlaps_ = torch.where(
                max_overlaps > iou_thresh_3, max_overlaps,
                torch.zeros_like(max_overlaps).type_as(max_overlaps))
            non_zero = max_overlaps_.nonzero().view(-1)
            bg_idx = max_overlaps_.eq(0).nonzero().view(-1)
            fn_3 += preds_[bg_idx].ne(0).sum().item(
            )  # add to false negative all non-background tubes with no gt tube overlaping
            predictions_ = preds_[non_zero]  # overlaping predictions

            fn_3 += (
                predictions_ == targ[argmax_overlaps[non_zero]]).eq(0).sum()

            predictions_ = predictions_[(predictions_ == targ[
                argmax_overlaps[non_zero]]).ne(0).nonzero().view(-1)]
            unique_labels = torch.unique(predictions_)  # unique labels
            for i in unique_labels:
                fp_3 += predictions_.eq(i).sum().item() - 1
                tp_3 += 1

    recall = float(true_pos) / (
        true_pos + false_neg) if true_pos > 0 or false_neg > 0 else 0
    recall_4 = float(true_pos_4) / (
        true_pos_4 + false_neg_4) if true_pos_4 > 0 or false_neg_4 > 0 else 0
    recall_3 = float(true_pos_3) / (
        true_pos_3 + false_neg_3) if true_pos_3 > 0 or false_neg_3 > 0 else 0

    precision = float(tp) / (tp + fp) if tp > 0 or fp > 0 else 0
    precision_4 = float(tp_4) / (tp_4 + fp_4) if tp_4 > 0 or fp_4 > 0 else 0
    precision_3 = float(tp_3) / (tp_3 + fp_3) if tp_3 > 0 or fp_3 > 0 else 0

    print(' -----------------------\n')
    print('| Validation Epoch: {: >3} |\n'.format(epoch + 1))
    print('|                       |')
    print('| we have {: >6} tubes  |'.format(tubes_sum))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Precision             |')
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |'
        .format(tp, fp, fn, precision))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |'
        .format(tp_4, fp_4, fn_4, precision_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |'
        .format(tp_3, fp_3, fn_3, precision_3))
    print('|                       |')
    print('| Recall                |')
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(true_pos, false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(true_pos_4, false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(true_pos_3, false_neg_3, recall_3))

    print(' -----------------------')

    file_p = open('../images_etc/validation_jhmdb_linear.txt', 'a')
    file_p.write(' -----------------------\n')
    file_p.write('| Validation Epoch: {: >3} | '.format(epoch + 1))
    file_p.write('|                       |\n')
    file_p.write('| we have {: >6} tubes  |\n'.format(tubes_sum))
    file_p.write('|                       |\n')
    file_p.write('| Proposed Action Tubes |\n')
    file_p.write('|                       |\n')
    file_p.write('| Single frame          |\n')
    file_p.write('|                       |\n')
    file_p.write('| In {: >6} steps    :  |\n'.format(step))
    file_p.write('|                       |\n')
    file_p.write('| Precision             |\n')
    file_p.write('|                       |\n')
    file_p.write('| Threshold : 0.5       |\n')
    file_p.write('|                       |\n')
    file_p.write(
        '| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |\n'
        .format(tp, fp, fn, precision))
    file_p.write('|                       |\n')
    file_p.write('| Threshold : 0.4       |\n')
    file_p.write('|                       |\n')
    file_p.write(
        '| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |\n'
        .format(tp_4, fp_4, fn_4, precision_4))
    file_p.write('|                       |\n')
    file_p.write('| Threshold : 0.3       |\n')
    file_p.write('|                       |\n')
    file_p.write(
        '| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |\n'
        .format(tp_3, fp_3, fn_3, precision_3))
    file_p.write('|                       |\n')
    file_p.write('| Recall                |\n')
    file_p.write('|                       |\n')
    file_p.write('| Threshold : 0.5       |\n')
    file_p.write('|                       |\n')
    file_p.write(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'
        .format(true_pos, false_neg, recall))
    file_p.write('|                       |\n')
    file_p.write('| Threshold : 0.4       |\n')
    file_p.write('|                       |\n')
    file_p.write(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'
        .format(true_pos_4, false_neg_4, recall_4))
    file_p.write('|                       |\n')
    file_p.write('| Threshold : 0.3       |\n')
    file_p.write('|                       |\n')
    file_p.write(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'
        .format(true_pos_3, false_neg_3, recall_3))

    file_p.write(' -----------------------')
def validation(epoch, device, model, dataset_folder, sample_duration, spatial_transform, temporal_transform, boxes_file, splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5 # Intersection Over Union thresh
    iou_thresh_4 = 0.4 # Intersection Over Union thresh
    iou_thresh_3 = 0.3 # Intersection Over Union thresh

    confidence_thresh = 0.2
    vid_name_loader = video_names(dataset_folder, split_txt_path, boxes_file, vid2idx, mode='test', classes_idx=cls2idx)
    # data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=n_devs, num_workers=8*n_devs, pin_memory=True,
    #                                           shuffle=True)    # reset learning rate
    data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=1, num_workers=8*n_devs, pin_memory=True,
                                              shuffle=True)    # reset learning rate

    model.eval()

    true_pos = 0
    false_neg = 0

    true_pos_4 = 0
    false_neg_4 = 0

    true_pos_3 = 0
    false_neg_3 = 0

    correct_preds = torch.zeros(1).long().to(device)
    n_preds = torch.zeros(1).long().to(device)
    preds = torch.zeros(1).long().to(device)

    ## 2 rois : 1450
    tubes_sum = 0

    groundtruth_dic = {}
    detection_dic = {}

    for step, data  in enumerate(data_loader):

        # if step == 3:
        #     break
        print('step =>',step)

        vid_id, clips, boxes, n_frames, n_actions, h, w, target =data
        vid_id = vid_id.int()
        clips = clips.to(device)
        boxes = boxes.to(device)
        n_frames = n_frames.to(device)
        n_actions = n_actions.int().to(device)
        target = target.to(device)


        im_info = torch.cat([h,w,torch.ones(clips.size(0)).long()*clips.size(2)]).to(device)
        mode = 'test'

        tubes,  \
        prob_out, n_tubes =  model(n_devs, dataset_folder, \
                                vid_names, clips, vid_id,  \
                                None, \
                                mode, cls2idx, None, n_frames, h, w)
        # get predictions
        for i in range(batch_size):
            _, cls_int = torch.max(prob_out[i],1)

            f_prob = torch.zeros(n_tubes[i].long()).type_as(prob_out)
            for j in range(n_tubes[i].long()):
                f_prob[j] = prob_out[i,j,cls_int[j]]
            cls_int = cls_int[:n_tubes[i].long()]

            keep_ = (f_prob[j].ge(confidence_thresh)) & cls_int.ne(0)

            keep_indices = keep_.nonzero().view(-1)
            f_tubes = torch.cat([cls_int.view(-1,1).type_as(tubes),f_prob.view(-1,1).type_as(tubes), \
                                 tubes[i,:n_tubes[i].long(),:n_frames[i]].contiguous().view(-1,n_frames[i]*4)], dim=1)
            f_tubes = f_tubes[keep_indices].contiguous()
            f_boxes = torch.cat([target.type_as(boxes),boxes[i,:,:n_frames[i],:4].contiguous().view(n_frames[i]*4)]).unsqueeze(0)
            v_name = vid_names[vid_id[i]].split('/')[1]

            

            detection_dic[v_name] = f_tubes.float()
            groundtruth_dic[v_name] = f_boxes.type_as(f_tubes)
            # with open(os.path.join('outputs','detection',v_name+'.json'), 'w') as f:
            #     json.dump(f_tubes.cpu().tolist(), f)
            # with open(os.path.join('outputs','groundtruth',v_name+'.json'), 'w') as f:
            #     json.dump(f_boxes.cpu().tolist(), f)


        if tubes.dim() == 1:
        
            for i in range(clips.size(0)):

                box = boxes[i,:n_actions, :n_frames,:4].contiguous()
                box = box.view(-1,n_frames*4)

                non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
                n_elems = non_empty_indices.nelement()            
                false_neg += n_elems
                false_neg_4 += n_elems
                false_neg_3 += n_elems 
            continue


        for i in range(clips.size(0)):
            print('boxes.shape:',boxes.shape)

            box = boxes[i,:n_actions[i].long(), :n_frames[i].long(),:4].contiguous()
            print('box.shape :',box.shape)
            box = box.view(-1,n_frames[i]*4).contiguous().type_as(tubes)
            print('box.shape :',box.shape)
            print('n_frames :',n_frames)
            print('n_tubes :',n_tubes)
            overlaps = tube_overlaps(tubes[i,:n_tubes[i].long(),:n_frames[i].long()].view(-1,n_frames*4).float(), box.view(-1,n_frames[i]*4).float())
            gt_max_overlaps, argmax_gt_overlaps = torch.max(overlaps, 0)

            non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()

            print('gt_max_overlaps :',gt_max_overlaps)
            # if gt_max_overlaps[0] > 0.5 :
            #     print('argmax_gt_overlaps :',argmax_gt_overlaps, argmax_gt_overlaps.shape, gt_max_overlaps )
            #     print('tubes_t[max_indices[0]] :',tubes[argmax_gt_overlaps[0]])
            #     print('gt_rois_t[0] :',box[0])

            # 0.5 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos += detected
            false_neg += n_elems - detected

            # 0.4 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_4, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            # print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_4 += detected
            false_neg_4 += n_elems - detected

            # 0.3 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_3, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            # print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_3 += detected
            false_neg_3 += n_elems - detected
            

        #     ### TODO add classification step
        # for k in cls_int.cpu().tolist():
        #     if k == target.data:
        #         print('Found one')
        #         correct_preds += 1
        #     n_preds += 1

    recall    = float(true_pos)     /  (true_pos    + false_neg)
    recall_4  = float(true_pos_4)  / (true_pos_4  + false_neg_4)
    recall_3  = float(true_pos_3)  / (true_pos_3  + false_neg_3)


    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch+1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos, false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_4, false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_3, false_neg_3, recall_3))


    print(' -----------------------')
        
    print(' -----------------')
    print('|                  |')
    print('| mAP Thresh : 0.5 |')
    print('|                  |')
    print(' ------------------')
    calculate_mAP(detection_dic, groundtruth_dic, iou_thresh)

    print(' -------------------')
    print('|                   |')
    print('| mAP Thresh : 0.4  |')
    print('|                   |')
    print(' -------------------')
    calculate_mAP(detection_dic, groundtruth_dic, iou_thresh_4)

    print(' ------------------')
    print('|                  |')
    print('| mAP Thresh : 0.3 |')
    print('|                  |')
    print(' ------------------')
    calculate_mAP(detection_dic, groundtruth_dic, iou_thresh_3)