def training(epoch, device, model, dataset_folder, sample_duration, spatial_transform, temporal_transform, boxes_file, splt_txt_path, cls2idx, batch_size, n_threads, lr, mode = 1):

    data = Video_Dataset_small_clip(dataset_folder, frames_dur=sample_duration, spatial_transform=spatial_transform,
                                    temporal_transform=temporal_transform, bboxes_file= boxes_file,
                                    split_txt_path=splt_txt_path, mode='train', classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size*16,
                                              shuffle=True, num_workers=32, pin_memory=True)
    # data_loader = torch.utils.data.DataLoader(data, batch_size=2,
    #                                           shuffle=True, num_workers=0, pin_memory=True)

    model.train()
    loss_temp = 0
    
    ## 2 rois : 1450
    for step, data  in enumerate(data_loader):

        # if step == 2:
        #     break

        clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data
        clips_ = clips.to(device)
        gt_tubes_r_ = gt_tubes_r.to(device)
        gt_rois_ = gt_rois.to(device)
        n_actions_ = n_actions.to(device)
        im_info_ = im_info.to(device)
        start_fr = torch.zeros(clips_.size(0)).to(device)
        
        inputs = Variable(clips_)

        tubes, _, \
        rpn_loss_cls,  rpn_loss_bbox, \
        rpn_loss_cls_16,\
        rpn_loss_bbox_16,  rois_label, \
        sgl_rois_bbox_pred, sgl_rois_bbox_loss,  = model(inputs, \
                                                         im_info_,
                                                         gt_tubes_r_, gt_rois_,
                                                         start_fr)
        if mode == 3:
            loss = sgl_rois_bbox_loss.mean()
        elif mode == 4:
            loss = rpn_loss_cls.mean() +  rpn_loss_bbox.mean() 
        elif mode == 5:
            loss = rpn_loss_cls.mean() +  rpn_loss_bbox.mean() + sgl_rois_bbox_loss.mean()


        loss_temp += loss.item()

        # backw\ard
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Train Epoch: {} \tLoss: {:.6f}\t lr : {:.6f}'.format(
        epoch+1,loss_temp/(step+1), lr))

    return model, loss_temp
def validation(epoch, device, model, data_loader, n_threads):

    iou_thresh = 0.5  # Intersection Over Union thresh
    iou_thresh_4 = 0.4  # Intersection Over Union thresh
    iou_thresh_3 = 0.3  # Intersection Over Union thresh
    data = Video_Dataset_small_clip(dataset_folder,
                                    frames_dur=sample_duration,
                                    spatial_transform=spatial_transform,
                                    temporal_transform=temporal_transform,
                                    bboxes_file=boxes_file,
                                    split_txt_path=splt_txt_path,
                                    mode='test',
                                    classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=True)
    model.eval()

    sgl_true_pos = 0
    sgl_false_neg = 0

    sgl_true_pos_4 = 0
    sgl_false_neg_4 = 0

    sgl_true_pos_3 = 0
    sgl_false_neg_3 = 0

    ## 2 rois : 1450
    tubes_sum = 0
    for step, data in enumerate(data_loader):

        # if step == 2:
        #     break
        # print('step :',step)

        clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data
        clips_ = clips.to(device)
        gt_tubes_r_ = gt_tubes_r.to(device)
        gt_rois_ = gt_rois.to(device)
        n_actions_ = n_actions.to(device)
        im_info_ = im_info.to(device)
        start_fr = torch.zeros(clips_.size(0)).to(device)

        tubes, _, _, _, _, _, _, \
        sgl_rois_bbox_pred, _  = model(clips,
                                       im_info,
                                       None, None,
                                       None)
        batch_size = len(tubes)

        tubes = tubes.view(-1, sample_duration * 4 + 2)
        tubes[:,1:-1] = tube_transform_inv(tubes[:,1:-1],\
                                           sgl_rois_bbox_pred.view(-1,sample_duration*4),(1.0,1.0,1.0,1.0))
        tubes = tubes.view(batch_size, -1, sample_duration * 4 + 2)

        for i in range(tubes.size(0)):  # how many frames we have

            tubes_t = tubes[i, :, 1:-1].contiguous()
            gt_rois_t = gt_rois_[i, :, :, :4].contiguous().view(
                -1, sample_duration * 4)
            rois_overlaps = tube_overlaps(tubes_t, gt_rois_t)

            gt_max_overlaps_sgl, _ = torch.max(rois_overlaps, 0)
            n_elems = gt_tubes_r[i, :, -1].ne(0).sum().item()

            # 0.5
            gt_max_overlaps_sgl_ = torch.where(
                gt_max_overlaps_sgl > iou_thresh, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))
            # print('gt_max_overlaps_sgl_.shape :',gt_max_overlaps_sgl_.shape)
            # print('gt_max_overlaps_sgl_.shape :',gt_max_overlaps_sgl_)
            sgl_detected = gt_max_overlaps_sgl_.ne(0).sum()
            sgl_true_pos += sgl_detected
            sgl_false_neg += n_elems - sgl_detected
            # print('sgl_detected :',sgl_detected)
            # print('sgl_detected :',sgl_true_pos)
            # print('sgl_detected :',sgl_false_neg)

            # 0.4
            gt_max_overlaps_sgl_ = torch.where(
                gt_max_overlaps_sgl > iou_thresh_4, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))
            sgl_detected = gt_max_overlaps_sgl_.ne(0).sum()
            sgl_true_pos_4 += sgl_detected
            sgl_false_neg_4 += n_elems - sgl_detected
            # print('sgl_detected :',sgl_detected)
            # print('sgl_detected :',sgl_true_pos)
            # print('sgl_detected :',sgl_false_neg)

            # 0.3
            gt_max_overlaps_sgl_ = torch.where(
                gt_max_overlaps_sgl > iou_thresh_3, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))
            sgl_detected = gt_max_overlaps_sgl_.ne(0).sum()
            sgl_true_pos_3 += sgl_detected
            sgl_false_neg_3 += n_elems - sgl_detected

    # print('sgl_true_pos :',sgl_true_pos)

    recall = float(sgl_true_pos) / (float(sgl_true_pos) + float(sgl_false_neg))
    recall_4 = float(sgl_true_pos_4) / (float(sgl_true_pos_4) +
                                        float(sgl_false_neg_4))
    recall_3 = float(sgl_true_pos_3) / (float(sgl_true_pos_3) +
                                        float(sgl_false_neg_3))

    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch + 1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos, sgl_false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos_4, sgl_false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos_3, sgl_false_neg_3, recall_3))

    print(' -----------------------')
    n_classes = len(actions)

    #######################################################
    #          Part 1-1 - train nTPN - without reg         #
    #######################################################

    print(' -----------------------------------------------------')
    print('|          Part 1-1 - train TPN - without reg         |')
    print(' -----------------------------------------------------')

    ## Define Dataloaders
    train_data = Video_Dataset_small_clip(
        video_path=dataset_frames,
        frames_dur=sample_duration,
        spatial_transform=spatial_transform,
        temporal_transform=temporal_transform,
        bboxes_file=boxes_file,
        split_txt_path=split_txt_path,
        mode='train',
        classes_idx=cls2idx)
    train_data_loader = torch.utils.data.DataLoader(train_data,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    num_workers=2,
                                                    pin_memory=True)

    # Init action_net
    act_model = ACT_net(actions, sample_duration, device=device)

    act_model.create_architecture(model_path=model_path)
Beispiel #4
0
def validation(epoch, device, model, dataset_folder, sample_duration,
               spatial_transform, temporal_transform, boxes_file,
               splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5  # Intersection Over Union thresh
    # iou_thresh = 0.1 # Intersection Over Union thresh
    data = Video_Dataset_small_clip(dataset_folder,
                                    frames_dur=sample_duration,
                                    spatial_transform=spatial_transform,
                                    temporal_transform=temporal_transform,
                                    bboxes_file=boxes_file,
                                    split_txt_path=splt_txt_path,
                                    mode='test',
                                    classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=16,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=True)
    # data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size*4,
    #                                           shuffle=True, num_workers=0, pin_memory=True)

    model.eval()

    sgl_true_pos = 0
    sgl_false_neg = 0

    ## 2 rois : 1450
    tubes_sum = 0
    for step, data in enumerate(data_loader):

        # if step == 10:
        #     break
        print('step :', step)

        clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data
        clips_ = clips.to(device)
        gt_tubes_r_ = gt_tubes_r.to(device)
        gt_rois_ = gt_rois.to(device)
        n_actions_ = n_actions.to(device)
        im_info_ = im_info.to(device)
        start_fr = torch.zeros(clips_.size(0)).to(device)
        # for i in range(2):
        #     print('gt_rois :',gt_rois[i,:n_actions[i]])
        tubes, _, _, _, _, _, _, \
        sgl_rois_bbox_pred, _  = model(clips,
                                       im_info,
                                       None, None,
                                       None)
        tubes_ = tubes.contiguous()
        n_tubes = len(tubes)

        tubes = tubes.view(-1, sample_duration * 4 + 2)

        tubes[:,1:-1] = tube_transform_inv(tubes[:,1:-1],\
                                           sgl_rois_bbox_pred.view(-1,sample_duration*4),(1.0,1.0,1.0,1.0))
        tubes = tubes.view(n_tubes, -1, sample_duration * 4 + 2)
        tubes[:, :, 1:-1] = clip_boxes(tubes[:, :, 1:-1], im_info,
                                       tubes.size(0))

        # print('tubes[0]:',tubes.shape)
        # exit(-1)
        # print('tubes.cpu().numpy() :',tubes.cpu().numpy())
        # exit(-1)
        # print('gt_rois_[:,0] :',gt_rois_[:,0])

        for i in range(tubes.size(0)):  # how many frames we have

            tubes_t = tubes[i, :, 1:-1].contiguous()
            gt_rois_t = gt_rois_[i, :, :, :4].contiguous().view(
                -1, sample_duration * 4)

            rois_overlaps = tube_overlaps(tubes_t, gt_rois_t)
            # rois_overlaps = Tube_Overlaps()(tubes_t,gt_rois_t)

            gt_max_overlaps_sgl, max_indices = torch.max(rois_overlaps, 0)

            non_empty_indices = gt_rois_t.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()
            # print('non_empty_indices :',non_empty_indices)
            # if gt_tubes_r[i,0,5] - gt_tubes_r[i,0,2 ] < 12 and gt_tubes_r[i,0,5] - gt_tubes_r[i,0,2 ] > 0:
            #     print('tubes_t.cpu().numpy() :',tubes_t[:5].detach().cpu().numpy())
            #     print('sgl_rois_bbox_pred.cpu().numpy() :',sgl_rois_bbox_pred[i,:5].detach().cpu().numpy())
            #     print('tubes_.detach.cpu().numpy() :',tubes_[i,:5].detach().cpu().numpy())
            #     print('gt_rubes_r[i] :',gt_tubes_r[i])
            #     exit(-1)

            if gt_max_overlaps_sgl[0] > 0.5 and gt_rois_t[0, -4:].sum() == 0:
                print('max_indices :', max_indices, max_indices.shape,
                      gt_max_overlaps_sgl)
                print('tubes_t[max_indices[0]] :', tubes_t[max_indices[0]])
                print('gt_rois_t[0] :', gt_rois_t[0])

            gt_max_overlaps_sgl = torch.where(
                gt_max_overlaps_sgl > iou_thresh, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))

            sgl_detected = gt_max_overlaps_sgl[non_empty_indices].ne(0).sum()

            sgl_true_pos += sgl_detected
            sgl_false_neg += n_elems - sgl_detected

        # if step == 0:
        #     break
        #     # exit(-1)

    recall = float(sgl_true_pos) / (float(sgl_true_pos) + float(sgl_false_neg))

    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch + 1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos, sgl_false_neg, recall))

    print(' -----------------------')
def validation(epoch, device, model, dataset_folder, sample_duration,
               spatial_transform, temporal_transform, boxes_file,
               splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5  # Intersection Over Union thresh
    data = Video_Dataset_small_clip(dataset_folder,
                                    frames_dur=sample_duration,
                                    spatial_transform=spatial_transform,
                                    temporal_transform=temporal_transform,
                                    bboxes_file=boxes_file,
                                    split_txt_path=splt_txt_path,
                                    mode='test',
                                    classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=2,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=True)
    model.eval()

    true_pos = 0
    false_neg = 0

    true_pos_xy = 0
    false_neg_xy = 0

    true_pos_t = 0
    false_neg_t = 0

    sgl_true_pos = 0
    sgl_false_neg = 0

    ## 2 rois : 1450
    tubes_sum = 0
    for step, data in enumerate(data_loader):

        # if step == 2:
        #     break
        print('step :', step)

        clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data
        clips_ = clips.to(device)
        gt_tubes_r_ = gt_tubes_r.to(device)
        gt_rois_ = gt_rois.to(device)
        n_actions_ = n_actions.to(device)
        im_info_ = im_info.to(device)
        start_fr = torch.zeros(clips_.size(0)).to(device)
        # for i in range(2):
        #     print('gt_rois :',gt_rois[i,:n_actions[i]])
        tubes, bbox_pred, _, _, _, _, _, _, _, _, sgl_rois_bbox_pred, _ = model(
            clips, im_info, None, None, None)
        n_tubes = len(tubes)
        # init tensor for final frames

        for i in range(tubes.size(0)):  # how many frames we have
            # calculate single frame overlaps
            tubes_t = tubes[i]
            gt_tub = gt_tubes_r[i]

            non_empty = gt_tub.sum(1).nonzero()
            if non_empty.nelement() == 0:
                continue
            non_empty = non_empty.view(-1)
            gt_tub = gt_tub[non_empty]

            overlaps, overlaps_xy, overlaps_t = bbox_overlaps_batch_3d(
                tubes_t,
                gt_tub.unsqueeze(0).type_as(
                    tubes_t))  # check one video each time

            ## for the whole tube
            gt_max_overlaps, _ = torch.max(overlaps, 1)
            gt_max_overlaps = torch.where(
                gt_max_overlaps > iou_thresh, gt_max_overlaps,
                torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))

            detected = gt_max_overlaps.ne(0).sum()
            n_elements = gt_max_overlaps.nelement()
            true_pos += detected
            false_neg += n_elements - detected

            # ## for xy - area
            # gt_max_overlaps_xy, _ = torch.max(overlaps_xy, 1)
            # gt_max_overlaps_xy = torch.where(gt_max_overlaps_xy > iou_thresh, gt_max_overlaps_xy, torch.zeros_like(gt_max_overlaps_xy).type_as(gt_max_overlaps_xy))

            # detected_xy =  gt_max_overlaps_xy.ne(0).sum()
            # n_elements_xy = gt_max_overlaps_xy.nelement()
            # true_pos_xy += detected_xy
            # false_neg_xy += n_elements_xy - detected_xy

            # ## for t - area
            # gt_max_overlaps_t, _ = torch.max(overlaps_t, 1)
            # gt_max_overlaps_t = torch.where(gt_max_overlaps_t > iou_thresh, gt_max_overlaps_t, torch.zeros_like(gt_max_overlaps_t).type_as(gt_max_overlaps_t))
            # detected_t =  gt_max_overlaps_t.ne(0).sum()
            # n_elements_t = gt_max_overlaps_t.nelement()
            # true_pos_t += detected_t
            # false_neg_t += n_elements_t - detected_t

            tubes_sum += 1

    recall = float(true_pos) / (float(true_pos) + float(false_neg))
    # recall_xy  = float(true_pos_xy)   / (float(true_pos_xy)   + float(false_neg_xy))
    # recall_t   = float(true_pos_t)    / (float(true_pos_t)    + float(false_neg_t))
    # sgl_recall = float(sgl_true_pos)  / (float(sgl_true_pos)  + float(sgl_false_neg))

    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch + 1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print(
        '| In {: >6} steps    :  |\n| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(step, true_pos, false_neg, recall))
    # print('|                       |')
    # print('| In xy area            |')
    # print('|                       |')
    # print('| In {: >6} steps    :  |\n| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
    #     step, true_pos_xy, false_neg_xy, recall_xy))
    # print('|                       |')
    # print('| In time area          |')
    # print('|                       |')
    # print('| In {: >6} steps    :  |\n| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
    #     step, true_pos_t, false_neg_t, recall_t))
    # print('|                       |')
    # print('| Single frame          |')
    # print('|                       |')
    # print('| In {: >6} steps    :  |'.format(step))
    # print('|                       |')
    # print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
    # sgl_true_pos, sgl_false_neg, sgl_recall))

    print(' -----------------------')