def forward(self, n_devs, dataset_folder, vid_names, clips, vid_id, boxes,
                mode, cls2idx, num_actions, num_frames, h_, w_):
        '''
        TODO describe procedure
        '''

        # print('boxes.shape :',boxes.shape)

        ## define a dataloader for the whole video
        # print('----------Inside----------')
        # print('num_frames :',num_frames)
        # print('clips.shape :',clips.shape)

        clips = clips.squeeze(0)
        clips = clips[:num_frames]

        print('num_frames :', num_frames)
        print('clips.shape :', clips.shape)

        if self.training:
            boxes = boxes.squeeze(0).permute(1, 0, 2).cpu()
            boxes = boxes[:num_frames, :num_actions]

        batch_size = 2  #
        # batch_size = 16 #

        num_images = 1
        rois_per_image = int(cfg.TRAIN.BATCH_SIZE /
                             num_images) if self.training else 150

        data = single_video(dataset_folder,
                            h_,
                            w_,
                            vid_names,
                            vid_id,
                            frames_dur=self.sample_duration,
                            sample_size=self.sample_size,
                            classes_idx=cls2idx,
                            n_frames=num_frames)

        data_loader = torch.utils.data.DataLoader(
            data,
            batch_size=batch_size,
            pin_memory=False,  # num_workers=num_workers, pin_memory=True,
            # shuffle=False, num_workers=8)
            shuffle=False)

        n_clips = data.__len__()

        features = torch.zeros(n_clips, rois_per_image, self.p_feat_size,
                               self.sample_duration)
        p_tubes = torch.zeros(n_clips, rois_per_image, self.sample_duration *
                              4)  # all the proposed tube-rois
        actioness_score = torch.zeros(n_clips, rois_per_image)
        overlaps_scores = torch.zeros(n_clips, rois_per_image, rois_per_image)

        f_tubes = []

        if self.training:

            f_gt_tubes = torch.zeros(n_clips, num_actions,
                                     self.sample_duration * 4)  # gt_tubes
            tubes_labels = torch.zeros(n_clips, rois_per_image)  # tubes rois
            loops = int(np.ceil(n_clips / batch_size))
            labels = torch.zeros(num_actions)

            for i in range(num_actions):
                idx = boxes[:, i, 4].nonzero().view(-1)
                labels[i] = boxes[i, idx[0], 4]

        for step, dt in enumerate(data_loader):

            # if step == 1:
            #     break
            print('\tstep :', step)

            frame_indices, im_info, start_fr = dt
            clips_ = clips[frame_indices].cuda()

            if self.training:
                boxes_ = boxes[frame_indices].cuda()
                box_ = boxes_.permute(0, 2, 1,
                                      3).float().contiguous()[:, :, :, :-1]
            else:
                box_ = None

            im_info = im_info.cuda()
            start_fr = start_fr.cuda()

            with torch.no_grad():
                tubes, pooled_feat, \
                rpn_loss_cls,  rpn_loss_bbox, \
                _,_, rois_label, \
                sgl_rois_bbox_pred, sgl_rois_bbox_loss = self.act_net(clips_.permute(0,2,1,3,4),
                                                            im_info,
                                                            None,
                                                            box_,
                                                            start_fr)

            pooled_feat = pooled_feat.view(-1, rois_per_image,
                                           self.p_feat_size,
                                           self.sample_duration)

            indexes_ = (torch.arange(0, tubes.size(0)) *
                        int(self.sample_duration / 2) +
                        start_fr[0].cpu()).unsqueeze(1)
            indexes_ = indexes_.expand(tubes.size(0),
                                       tubes.size(1)).type_as(tubes)

            idx_s = step * batch_size
            idx_e = step * batch_size + batch_size

            features[idx_s:idx_e] = pooled_feat
            p_tubes[idx_s:idx_e, ] = tubes[:, :, 1:-1]
            actioness_score[idx_s:idx_e] = tubes[:, :, -1]

            if self.training:

                box = boxes_.permute(0, 2, 1, 3).contiguous()[:, :, :, :-2]
                box = box.contiguous().view(box.size(0), box.size(1), -1)

                f_gt_tubes[idx_s:idx_e] = box
                tubes_labels[idx_s:idx_e] = rois_label.squeeze(-1).type_as(
                    tubes_labels)

        ########################################################
        #          Calculate overlaps and connections          #
        ########################################################

        overlaps_scores = torch.zeros(n_clips, rois_per_image,
                                      rois_per_image).type_as(overlaps_scores)

        for i in range(n_clips - 1):
            overlaps_scores[i] = tube_overlaps(
                p_tubes[i, :, int(self.sample_duration * 4 / 2):],
                p_tubes[i + 1, :, :int(self.sample_duration * 4 / 2)])

        if n_clips > 1:
            final_scores, final_poss = self.calc(
                overlaps_scores.cuda(), actioness_score.cuda(),
                torch.Tensor([n_clips]), torch.Tensor([rois_per_image]))
        else:
            offset = torch.arange(rois_per_image).float()
            final_poss = torch.stack([torch.zeros((rois_per_image)), offset],
                                     dim=1).unsqueeze(1).long()

        ## Now connect the tubes
        final_tubes = torch.zeros(final_poss.size(0), num_frames, 4)
        f_tubes = []
        for i in range(final_poss.size(0)):
            tub = []
            for j in range(final_poss.size(1)):

                curr_ = final_poss[i, j]
                start_fr = curr_[0] * int(self.sample_duration / 2)
                end_fr = min((curr_[0] * int(self.sample_duration / 2) +
                              self.sample_duration).type_as(num_frames),
                             num_frames).type_as(start_fr)

                if curr_[0] == -1:
                    break

                curr_frames = p_tubes[curr_[0], curr_[1]]
                tub.append((curr_[0].item(), curr_[1].item()))
                ## TODO change with avg
                final_tubes[i, start_fr:end_fr] = torch.max(
                    curr_frames.view(-1, 4).contiguous()[:(end_fr -
                                                           start_fr).long()],
                    final_tubes[i, start_fr:end_fr].type_as(curr_frames))
            f_tubes.append(tub)

        ###################################################
        #          Choose gth Tubes for RCNN\TCN          #
        ###################################################
        if self.training:

            # # get gt tubes and feats
            ##  calculate overlaps
            boxes_ = boxes.permute(1, 0, 2).contiguous()
            boxes_ = boxes_[:, :, :4].contiguous().view(num_actions, -1)

            overlaps = tube_overlaps(final_tubes.view(-1, num_frames * 4),
                                     boxes_.type_as(final_tubes))
            max_overlaps, _ = torch.max(overlaps, 1)
            max_overlaps = max_overlaps.clamp_(min=0)
            ## TODO change numbers
            bg_tubes_indices = max_overlaps.lt(0.3).nonzero()
            bg_tubes_indices_picked = (torch.rand(5) *
                                       bg_tubes_indices.size(0)).long()
            bg_tubes_list = [
                f_tubes[i] for i in bg_tubes_indices[bg_tubes_indices_picked]
            ]
            bg_labels = torch.zeros(len(bg_tubes_list))

            gt_tubes_list = [[] for i in range(num_actions)]

            for i in range(n_clips):

                overlaps = tube_overlaps(p_tubes[i], f_gt_tubes[i])
                max_overlaps, argmax_overlaps = torch.max(overlaps, 0)

                for j in range(num_actions):
                    if max_overlaps[j] == 1.0:
                        gt_tubes_list[j].append((i, j))

            ## concate fb, bg tubes
            f_tubes = gt_tubes_list + bg_tubes_list
            target_lbl = torch.cat([labels, bg_labels], dim=0)

        ##############################################

        if len(f_tubes) == 0:
            print('------------------')
            print('    empty tube    ')
            return torch.Tensor([]).cuda(), torch.Tensor([]).cuda(), None
        max_seq = reduce(lambda x, y: y if len(y) > len(x) else x, f_tubes)
        max_length = len(max_seq)

        ## calculate input rois
        ## f_feats.shape : [#f_tubes, max_length, 512]
        final_video_tubes = torch.zeros(len(f_tubes), 6).cuda()
        prob_out = torch.zeros(len(f_tubes), self.n_classes).cuda()

        for i in range(len(f_tubes)):

            seq = f_tubes[i]
            tmp_tube = torch.Tensor(len(seq), 6)
            feats = torch.Tensor(len(seq), self.p_feat_size)

            for j in range(len(seq)):

                feats[j] = features[seq[j][0], seq[j][1]].mean(1)
                tmp_tube[j] = p_tubes[seq[j]][1:7]

            prob_out[i] = self.act_rnn(feats.cuda())
            if prob_out[i, 0] != prob_out[i, 0]:
                print('tmp_tube :',tmp_tube, ' prob_out :', prob_out ,' feats :',feats.cpu().numpy(), ' numpy(), feats.shape  :,', feats.shape ,' target_lbl :',target_lbl, \
                      ' \ntmp_tube :',tmp_tube, )
                exit(-1)

        # ##########################################
        # #           Time for Linear Loss         #
        # ##########################################

        cls_loss = torch.Tensor([0]).cuda()

        final_tubes = final_tubes.type_as(final_poss)
        # # classification probability
        if self.training:
            cls_loss = F.cross_entropy(prob_out.cpu(),
                                       target_lbl.long()).cuda()

        if self.training:
            return final_tubes, prob_out, cls_loss,
        else:
            return final_tubes, prob_out, None
def validation(epoch, device, model, dataset_folder, sample_duration, spatial_transform, temporal_transform, boxes_file, splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5 # Intersection Over Union thresh
    iou_thresh_4 = 0.4 # Intersection Over Union thresh
    iou_thresh_3 = 0.3 # Intersection Over Union thresh

    vid_name_loader = Video_Dataset_whole_video(dataset_frames, split_txt_path, boxes_file, vid2idx, mode='test')
    data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=n_devs, num_workers=8*n_devs, pin_memory=True,
                                              shuffle=True)    # reset learning rate
    data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=1, num_workers=8*n_devs, pin_memory=True,
                                              shuffle=True)    # reset learning rate

    model.eval()

    true_pos = 0
    false_neg = 0

    true_pos_4 = 0
    false_neg_4 = 0

    true_pos_3 = 0
    false_neg_3 = 0

    correct_preds = torch.zeros(1).long().to(device)
    n_preds = torch.zeros(1).long().to(device)
    preds = torch.zeros(1).long().to(device)

    ## 2 rois : 1450
    tubes_sum = 0
    for step, data  in enumerate(data_loader):

        if step == 10:
            break
        print('step =>',step)

        vid_id, clips, boxes, n_frames, n_actions, h, w =data
        vid_id = vid_id.int()
        clips = clips.to(device)
        boxes = boxes.to(device)
        n_frames = n_frames.to(device)
        n_actions = n_actions.int().to(device)

        im_info = torch.cat([h,w,torch.ones(clips.size(0)).long()*clips.size(2)]).to(device)
        mode = 'test'

        tubes,  \
        prob_out, _ =  model(n_devs, dataset_frames, \
                                vid_names, clips, vid_id,  \
                                None, \
                                mode, cls2idx, None, n_frames, h, w)

        print('tubes.shape :',tubes.shape)
        print('prob_out.shape :',prob_out.shape)

        n_tubes = len(tubes)

        # get predictions
        _, cls_int = torch.max(prob_out,1)

        for i in range(clips.size(0)):
            box = boxes[i,:n_actions, :n_frames,:4].contiguous()
            box = box.view(-1,n_frames*4)

            overlaps = tube_overlaps(tubes.view(-1,n_frames*4).float(), box.float())

            # max_overlaps, argmax_overlaps = torch.max(overlaps, 1)
            # max_overlaps_ =  torch.where(max_overlaps > iou_thresh, max_overlaps, torch.zeros_like(max_overlaps).type_as(max_overlaps))
            # non_zero = max_overlaps_.nonzero()
            # offset = torch.arange(0,overlaps.size(0)) * n_actions[i].item()
            # print('argmax_overlaps.shape :',argmax_overlaps.shape)
            # print('argmax_overlaps.shape :',argmax_overlaps)
            # print('offset :',offset)
            # print('offset :',offset.shape)
            # offset = offset + argmax_overlaps.type_as(offset)
            # # print('non_zero :',non_zero)
            # overlaps = overlaps.view(-1).contiguous()[offset]
            # print('overlaps :',overlaps)
            # print('overlaps :',overlaps.shape)
            # exit(-1)
            gt_max_overlaps, argmax_gt_overlaps = torch.max(overlaps, 0)

            non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()            

            print('gt_max_overlaps :',gt_max_overlaps)
            if gt_max_overlaps[0] > 0.5 :
                print('argmax_gt_overlaps :',argmax_gt_overlaps, argmax_gt_overlaps.shape, gt_max_overlaps )
                print('tubes_t[max_indices[0]] :',tubes[argmax_gt_overlaps[0]])
                print('gt_rois_t[0] :',box[0])

            # 0.5 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum()
            true_pos += detected
            false_neg += n_elems - detected

            # 0.4 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_4, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum()
            true_pos_4 += detected
            false_neg_4 += n_elems - detected

            # 0.3 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_3, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum()
            true_pos_3 += detected
            false_neg_3 += n_elems - detected
            

        #     ### TODO add classification step
        # for k in cls_int.cpu().tolist():
        #     if k == target.data:
        #         print('Found one')
        #         correct_preds += 1
        #     n_preds += 1

    recall    = true_pos.float()    / (true_pos.float()    + false_neg.float())
    recall_4  = true_pos_4.float()  / (true_pos_4.float()  + false_neg_4.float())
    recall_3  = true_pos_3.float()  / (true_pos_3.float()  + false_neg_3.float())

    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch+1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos, false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_4, false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_3, false_neg_3, recall_3))


    print(' -----------------------')
Пример #3
0
def validation(epoch, device, model, dataset_folder, sample_duration, spatial_transform, temporal_transform, boxes_file, splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5 # Intersection Over Union thresh
    iou_thresh_4 = 0.4 # Intersection Over Union thresh
    iou_thresh_3 = 0.3 # Intersection Over Union thresh

    vid_name_loader = video_names(dataset_folder, split_txt_path, boxes_file, vid2idx, mode='test', classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=n_devs, num_workers=8*n_devs, pin_memory=True,
                                              shuffle=True)    # reset learning rate
    # data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=1, num_workers=8*n_devs, pin_memory=True,
    #                                           shuffle=True)    # reset learning rate
    model.eval()

    true_pos = 0
    false_neg = 0

    true_pos_4 = 0
    false_neg_4 = 0

    true_pos_3 = 0
    false_neg_3 = 0


    # for precision
    preds = 0
    tp = 0
    tp_4 = 0
    tp_3 = 0
    fp = 0
    fp_4 = 0
    fp_3 = 0
    fn = 0
    fn_4 = 0
    fn_3 = 0
    
    ## 2 rois : 1450
    tubes_sum = 0
    for step, data  in enumerate(data_loader):

        # if step == 1:
        #     break
        print('step =>',step)

        vid_id, clips, boxes, n_frames, n_actions, h, w, target =data
        vid_id = vid_id.int()
        clips = clips.to(device)
        boxes = boxes.to(device)
        n_frames = n_frames.to(device)
        n_actions = n_actions.int().to(device)
        target = target.to(device)
        im_info = torch.cat([h,w,torch.ones(clips.size(0)).long()*clips.size(2)]).to(device)
        mode = 'test'
        with torch.no_grad():
            tubes,  \
            prob_out, tubes_size =  model(n_devs, dataset_folder, \
                                    vid_names, clips, vid_id,  \
                                    None, \
                                    mode, cls2idx, None, n_frames, h, w)

        

        # print('tubes.shape :',tubes.dim())
        # print('prob_out.shape :',prob_out.shape)
        # print('clips.size(0) :',clips.size(0))
        # print('clips.size(0) :',clips.shape)
        # exit(-1)
        if tubes.dim() == 1:
        
            for i in range(clips.size(0)):

                box = boxes[i,:n_actions, :n_frames,:4].contiguous()
                box = box.view(-1,n_frames*4)

                non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
                n_elems = non_empty_indices.nelement()            
                false_neg += n_elems
                false_neg_4 += n_elems
                false_neg_3 += n_elems 
            continue

        prob_out = F.softmax(prob_out,2)
        _, predictions = torch.max(prob_out,dim=2)

        for i in range(clips.size(0)):


            tubes_ = tubes[i,:tubes_size[i].int().item(),:n_frames[i].int().item()]
            preds_ = predictions[i,:tubes_size[i].int().item()]

            box = boxes[i,:n_actions[i], :n_frames[i],:4].contiguous()
            box = box.view(-1,n_frames[i]*4).contiguous().type_as(tubes)
            overlaps = tube_overlaps(tubes_.view(-1,n_frames[i]*4).float(), box.view(-1,n_frames[i]*4).float())
            gt_max_overlaps, argmax_gt_overlaps = torch.max(overlaps, 0)
            max_overlaps, argmax_overlaps = torch.max(overlaps, 1)
            
            # offset = torch.arange(0,overlaps.size(0)) * n_actions[i].item()
            # offset = offset + argmax_overlaps.type_as(offset)
            # overlaps = overlaps.view(-1).contiguous()[offset]
            # non_zero = non_zero.view(-1)

            non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()
            preds += n_elems
            tubes_sum += tubes.size(0)

            # 0.5 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos += detected
            false_neg += n_elems - detected

            max_overlaps_ =  torch.where(max_overlaps > iou_thresh, max_overlaps, torch.zeros_like(max_overlaps).type_as(max_overlaps)) 
            non_zero = max_overlaps_.nonzero().view(-1)
            bg_idx = max_overlaps_.eq(0).nonzero().view(-1)
            fn += preds_[bg_idx].ne(0).sum().item() # add to false negative all non-background tubes with no gt tube overlaping
            predictions_ = preds_[non_zero] # overlaping predictions

            fn += (predictions_== target[argmax_overlaps[non_zero]]).eq(0).sum()

            predictions_ = predictions_[(predictions_== target[argmax_overlaps[non_zero]]).ne(0).nonzero().view(-1)]
            unique_labels =torch.unique(predictions_) # unique labels
            for i in unique_labels:
                fp += predictions_.eq(i).sum().item() -1
                tp += 1

            
            # 0.4 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_4, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_4 += detected
            false_neg_4 += n_elems - detected

            max_overlaps_ =  torch.where(max_overlaps > iou_thresh_4, max_overlaps, torch.zeros_like(max_overlaps).type_as(max_overlaps)) 
            non_zero = max_overlaps_.nonzero().view(-1)
            bg_idx = max_overlaps_.eq(0).nonzero().view(-1)
            fn_4 += preds_[bg_idx].ne(0).sum().item() # add to false negative all non-background tubes with no gt tube overlaping
            predictions_ = preds_[non_zero] # overlaping predictions

            fn_4 += (predictions_== target[argmax_overlaps[non_zero]]).eq(0).sum()

            predictions_ = predictions_[(predictions_== target[argmax_overlaps[non_zero]]).ne(0).nonzero().view(-1)]
            unique_labels =torch.unique(predictions_) # unique labels
            for i in unique_labels:
                fp_4 += predictions_.eq(i).sum().item() -1
                tp_4 += 1

            # 0.3 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_3, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))

            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_3 += detected
            false_neg_3 += n_elems - detected

            max_overlaps_ =  torch.where(max_overlaps > iou_thresh_3, max_overlaps, torch.zeros_like(max_overlaps).type_as(max_overlaps)) 
            non_zero = max_overlaps_.nonzero().view(-1)
            bg_idx = max_overlaps_.eq(0).nonzero().view(-1)
            fn_3 += preds_[bg_idx].ne(0).sum().item() # add to false negative all non-background tubes with no gt tube overlaping
            predictions_ = preds_[non_zero] # overlaping predictions

            fn_3 += (predictions_== target[argmax_overlaps[non_zero]]).eq(0).sum()

            predictions_ = predictions_[(predictions_== target[argmax_overlaps[non_zero]]).ne(0).nonzero().view(-1)]
            unique_labels =torch.unique(predictions_) # unique labels
            for i in unique_labels:
                fp_3 += predictions_.eq(i).sum().item() -1
                tp_3 += 1

            # print('detected_ind :',detected_ind, ' detected :',detected, ' argmax_gt_overlaps[detected_ind] :',argmax_gt_overlaps[detected_ind],' tubes[argmax_gt_overlaps[detected_ind]] :',\
            #       tubes[argmax_gt_overlaps[detected_ind]], ' prob_out[argmax_gt_overlaps[detected_ind]] :',predictions[argmax_gt_overlaps[detected_ind]], target[detected_ind])
            # print(predictions[argmax_gt_overlaps[detected_ind]] == target[detected_ind])

        #     ### TODO add classification step
        # for k in cls_int.cpu().tolist():
        #     if k == target.data:
        #         print('Found one')
        #         correct_preds += 1
        #     n_preds += 1

    recall    = float(true_pos)     /  (true_pos    + false_neg)
    recall_4  = float(true_pos_4)  / (true_pos_4  + false_neg_4)
    recall_3  = float(true_pos_3)  / (true_pos_3  + false_neg_3)

    precision   = float(tp)   / (tp   + fp  ) if tp > 0 or fp > 0 else 0
    precision_4 = float(tp_4) / (tp_4 + fp_4) if tp_4 > 0 or fp_4 > 0 else 0
    precision_3 = float(tp_3) / (tp_3 + fp_3) if tp_3 > 0 or fp_3 > 0 else 0
 
    print(' -----------------------')
    print('| Validation Epoch: {: >3} |\n'.format(epoch+1))
    print('|                       |')
    print('| we have {: >6} tubes  |'.format(tubes_sum))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Precision             |')
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |'.format(
        tp, fp, fn, precision))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |'.format(
        tp_4, fp_4, fn_4, precision_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |'.format(
        tp_3, fp_3, fn_3, precision_3))
    print('|                       |')
    print('| Recall                |')
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos, false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_4, false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_3, false_neg_3, recall_3))


    print(' -----------------------')
    filep = open('validation.txt','a')
    filep.write(' -----------------------\n')
    filep.write('| Validation Epoch: {: >3} | '.format(epoch+1))
    filep.write('|                       |\n')
    filep.write('| we have {: >6} tubes  |\n'.format(tubes_sum))
    filep.write('|                       |\n')
    filep.write('| Proposed Action Tubes |\n')
    filep.write('|                       |\n')
    filep.write('| Single frame          |\n')
    filep.write('|                       |\n')
    filep.write('| In {: >6} steps    :  |\n'.format(step))
    filep.write('|                       |\n')
    filep.write('| Precision             |\n')
    filep.write('|                       |\n')
    filep.write('| Threshold : 0.5       |\n')
    filep.write('|                       |\n')
    filep.write('| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |\n'.format(
        int(tp), int(fp), int(fn), float(precision)))
    filep.write('|                       |\n')
    filep.write('| Threshold : 0.4       |\n')
    filep.write('|                       |\n')
    filep.write('| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |\n'.format(
        tp_4, fp_4, fn_4, float(precision_4)))
    filep.write('|                       |\n')
    filep.write('| Threshold : 0.3       |\n')
    filep.write('|                       |\n')
    filep.write('| True_pos   --> {: >6} |\n| False_pos  --> {: >6} |\n| False_neg  --> {: >6} | \n| Precision  --> {: >6.4f} |\n'.format(
        tp_3, fp_3, fn_3, precision_3))
    filep.write('|                       |\n')
    filep.write('| Recall                |\n')
    filep.write('|                       |\n')
    filep.write('| Threshold : 0.5       |\n')
    filep.write('|                       |\n')
    filep.write('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'.format(
        true_pos, false_neg, recall))
    filep.write('|                       |\n')
    filep.write('| Threshold : 0.4       |\n')
    filep.write('|                       |\n')
    filep.write('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'.format(
        true_pos_4, false_neg_4, recall_4))
    filep.write('|                       |\n')
    filep.write('| Threshold : 0.3       |\n')
    filep.write('|                       |\n')
    filep.write('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'.format(
        true_pos_3, false_neg_3, recall_3))


    filep.write(' -----------------------\n')
    filep.close()
    def forward(self, n_devs, dataset_folder, vid_names, clips, vid_id, boxes,
                mode, cls2idx, num_actions, num_frames, h_, w_):
        '''
        TODO describe procedure
        '''

        # print('boxes.shape :',boxes.shape)

        ## define a dataloader for the whole video
        # print('----------Inside----------')
        # print('num_frames :',num_frames)
        # print('clips.shape :',clips.shape)

        clips = clips.squeeze(0)
        ret_n_frames = clips.size(0)
        clips = clips[:num_frames]

        # print('num_frames :',num_frames)
        # print('clips.shape :',clips.shape)
        # exit(-1)
        if self.training:
            boxes = boxes.squeeze(0).permute(1, 0, 2).cpu()
            boxes = boxes[:num_frames, :num_actions].clamp_(min=0)

        batch_size = 4  #
        # batch_size = 2 #
        # batch_size = 16 #

        num_images = 1
        rois_per_image = int(conf.TRAIN.BATCH_SIZE /
                             num_images) if self.training else 150

        data = single_video(dataset_folder,
                            h_,
                            w_,
                            vid_names,
                            vid_id,
                            frames_dur=self.sample_duration,
                            sample_size=self.sample_size,
                            classes_idx=cls2idx,
                            n_frames=num_frames)

        data_loader = torch.utils.data.DataLoader(
            data,
            batch_size=batch_size,
            pin_memory=False,  # num_workers=num_workers, pin_memory=True,
            # shuffle=False, num_workers=8)
            shuffle=False)

        n_clips = data.__len__()

        features = torch.zeros(n_clips, rois_per_image, self.p_feat_size,
                               self.sample_duration).type_as(clips)
        p_tubes = torch.zeros(n_clips, rois_per_image, self.sample_duration *
                              4).type_as(clips)  # all the proposed tube-rois
        actioness_score = torch.zeros(n_clips, rois_per_image).type_as(clips)
        overlaps_scores = torch.zeros(n_clips, rois_per_image,
                                      rois_per_image).type_as(clips)

        f_tubes = []

        # #
        # overlaps_scores = torch.zeros(n_clips, rois_per_image, rois_per_image).type_as(overlaps_scores)

        if self.training:

            f_gt_tubes = torch.zeros(n_clips, num_actions,
                                     self.sample_duration * 4)  # gt_tubes
            tubes_labels = torch.zeros(n_clips, rois_per_image)  # tubes rois
            loops = int(np.ceil(n_clips / batch_size))
            labels = torch.zeros(num_actions)

            for i in range(num_actions):
                idx = boxes[:, i, 4].nonzero().view(-1)
                labels[i] = boxes[idx[0], i, 4]

        ## Init connect thresh
        self.calc.thresh = self.connection_thresh

        for step, dt in enumerate(data_loader):

            frame_indices, im_info, start_fr = dt
            clips_ = clips[frame_indices].cuda()

            if self.training:
                boxes_ = boxes[frame_indices].cuda()
                box_ = boxes_.permute(0, 2, 1,
                                      3).float().contiguous()[:, :, :, :-1]
            else:
                box_ = None

            im_info = im_info.cuda()
            start_fr = start_fr.cuda()

            with torch.no_grad():
                tubes, pooled_feat, \
                rpn_loss_cls,  rpn_loss_bbox, \
                _,_, rois_label, \
                sgl_rois_bbox_pred, sgl_rois_bbox_loss = self.act_net(clips_.permute(0,2,1,3,4),
                                                            im_info,
                                                            None,
                                                            box_,
                                                            start_fr)
            pooled_feat = pooled_feat.mean(-1).mean(-1)
            pooled_feat = pooled_feat.view(-1, rois_per_image,
                                           self.p_feat_size,
                                           self.sample_duration)

            # regression
            n_tubes = len(tubes)
            if not self.training:
                tubes = tubes.view(-1, self.sample_duration * 4 + 2)
                tubes[:,1:-1] = tube_transform_inv(tubes[:,1:-1],\
                                               sgl_rois_bbox_pred.view(-1,self.sample_duration*4),(1.0,1.0,1.0,1.0))
                tubes = tubes.view(n_tubes, rois_per_image,
                                   self.sample_duration * 4 + 2)
                tubes[:, :, 1:-1] = clip_boxes(tubes[:, :, 1:-1], im_info,
                                               tubes.size(0))

            indexes_ = (torch.arange(0, tubes.size(0)) *
                        int(self.sample_duration / 2) +
                        start_fr[0].cpu()).unsqueeze(1)
            indexes_ = indexes_.expand(tubes.size(0),
                                       tubes.size(1)).type_as(tubes)

            idx_s = step * batch_size
            idx_e = min(step * batch_size + batch_size, n_clips)

            features[idx_s:idx_e] = pooled_feat
            p_tubes[idx_s:idx_e, ] = tubes[:, :, 1:-1]
            actioness_score[idx_s:idx_e] = tubes[:, :, -1]

            if self.training:

                box = boxes_.permute(0, 2, 1, 3).contiguous()[:, :, :, :-2]
                box = box.contiguous().view(box.size(0), box.size(1), -1)

                f_gt_tubes[idx_s:idx_e] = box

            # connection algo
            for i in range(idx_s, idx_e):
                if i == 0:

                    # Init tensors for connecting
                    offset = torch.arange(0, rois_per_image).int().cuda()
                    ones_t = torch.ones(rois_per_image).int().cuda()
                    zeros_t = torch.zeros(rois_per_image, n_clips,
                                          2).int().cuda() - 1

                    pos = torch.zeros(rois_per_image, n_clips,
                                      2).int().cuda() - 1  # initial pos
                    pos[:, 0, 0] = 0
                    pos[:, 0, 1] = offset.contiguous(
                    )  # contains the current tubes to be connected
                    pos_indices = torch.zeros(rois_per_image).int().cuda(
                    )  # contains the pos of the last element of the previous tensor
                    actioness_scr = actioness_score[0].float().cuda(
                    )  # actioness sum of active tubes
                    overlaps_scr = torch.zeros(rois_per_image).float().cuda(
                    )  # overlaps  sum of active tubes
                    final_scores = torch.Tensor().float().cuda(
                    )  # final scores
                    final_poss = torch.Tensor().int().cuda()  # final tubes

                    continue

                overlaps_ = tube_overlaps(
                    p_tubes[i - 1, :,
                            int(self.sample_duration * 4 / 2):],
                    p_tubes[i, :, :int(self.sample_duration * 4 /
                                       2)]).type_as(p_tubes)

                pos, pos_indices, \
                f_scores, actioness_scr, \
                overlaps_scr = self.calc(torch.Tensor([n_clips]),torch.Tensor([rois_per_image]),torch.Tensor([pos.size(0)]),
                                         pos, pos_indices, actioness_scr, overlaps_scr,
                                         overlaps_, actioness_score[i], torch.Tensor([i]))

                if pos.size(0) > self.update_thresh:

                    final_scores, final_poss, pos , pos_indices, \
                    actioness_scr, overlaps_scr,  f_scores = self.calc.update_scores(final_scores,final_poss, f_scores, pos, pos_indices, actioness_scr, overlaps_scr)

                if f_scores.dim() == 0:
                    f_scores = f_scores.unsqueeze(0)
                    pos = pos.unsqueeze(0)
                    pos_indices = pos_indices.unsqueeze(0)
                    actioness_scr = actioness_scr.unsqueeze(0)
                    overlaps_scr = overlaps_scr.unsqueeze(0)
                if final_scores.dim() == 0:
                    final_scores = final_scores.unsqueeze(0)
                    final_poss = final_poss.unsqueeze(0)

                try:
                    final_scores = torch.cat((final_scores, f_scores))
                except:
                    print('final_scores :', final_scores)
                    print('final_scores.shape :', final_scores.shape)
                    print('final_scores.dim() :', final_scores.dim())
                    print('f_scores :', f_scores)
                    print('f_scores.shape :', f_scores.shape)
                    print('f_scores.dim() :', f_scores.dim())
                    exit(-1)
                try:
                    final_poss = torch.cat((final_poss, pos))
                except:
                    print('final_poss :', final_poss)
                    print('final_poss.shape :', final_poss.shape)
                    print('final_poss.dim() :', final_poss.dim())
                    print('pos :', pos)
                    print('pos.shape :', pos.shape)
                    print('pos.dim() :', pos.dim())
                    exit(-1)

                # add new tubes
                pos = torch.cat((pos, zeros_t))
                pos[-rois_per_image:, 0, 0] = ones_t * i
                pos[-rois_per_image:, 0, 1] = offset

                pos_indices = torch.cat(
                    (pos_indices, torch.zeros(
                        (rois_per_image)).type_as(pos_indices)))
                actioness_scr = torch.cat((actioness_scr, actioness_score[i]))
                overlaps_scr = torch.cat(
                    (overlaps_scr, torch.zeros(
                        (rois_per_image)).type_as(overlaps_scr)))

        ## add only last layers
        ## TODO check again
        indices = actioness_score[-1].ge(self.calc.thresh).nonzero().view(-1)
        if indices.nelement() > 0:
            zeros_t[:, 0, 0] = idx_e - 1
            zeros_t[:, 0, 1] = offset
            final_poss = torch.cat([final_poss, zeros_t[indices]])

        if pos.size(0) > self.update_thresh:
            print('Updating thresh...', final_scores.shape, final_poss.shape,
                  pos.shape, f_scores.shape, pos_indices.shape)
            final_scores, final_poss, pos , pos_indices, \
                actioness_scr, overlaps_scr,  f_scores = self.calc.update_scores(final_scores,final_poss, f_scores, pos, pos_indices, actioness_scr, overlaps_scr)
            print('Updating thresh...', final_scores.shape, final_poss.shape,
                  pos.shape, f_scores.shape, pos_indices.shape)

        final_tubes = torch.zeros(final_poss.size(0), num_frames, 4)

        f_tubes = []

        for i in range(final_poss.size(0)):
            tub = []
            for j in range(final_poss.size(1)):

                curr_ = final_poss[i, j]
                start_fr = curr_[0] * int(self.sample_duration / 2)
                end_fr = min((curr_[0] * int(self.sample_duration / 2) +
                              self.sample_duration).type_as(num_frames),
                             num_frames).type_as(start_fr)

                if curr_[0] == -1:
                    break

                curr_frames = p_tubes[curr_[0], curr_[1]]
                tub.append((curr_[0].item(), curr_[1].item()))
                ## TODO change with avg
                final_tubes[i, start_fr:end_fr] = torch.max(
                    curr_frames.view(-1, 4).contiguous()[:(end_fr -
                                                           start_fr).long()],
                    final_tubes[i, start_fr:end_fr].type_as(curr_frames))
            f_tubes.append(tub)

        ###################################################
        #          Choose gth Tubes for RCNN\TCN          #
        ###################################################
        if self.training:

            # # get gt tubes and feats
            ##  calculate overlaps

            boxes_ = boxes.permute(1, 0, 2).contiguous()
            boxes_ = boxes_[:, :, :4].contiguous().view(num_actions, -1)

            if final_tubes.nelement() == 0:

                print('problem final_tubes ...')
                print('boxes :', boxes.cpu().numpy())
                print('boxes_ :', boxes_)
                print('boxes_.shape :', boxes_.shape)
                print('final_tubes :', final_tubes)
                print('self.calc.thresh:', self.calc.thresh)
                print('final_scores :', final_scores.shape)
                print('final_pos.shape :', final_poss.shape)

            if final_tubes.nelement() > 0:
                overlaps = tube_overlaps(final_tubes.view(-1, num_frames * 4),
                                         boxes_.type_as(final_tubes))
                max_overlaps, _ = torch.max(overlaps, 1)
                max_overlaps = max_overlaps.clamp_(min=0)

                ## TODO change numbers
                bg_tubes_indices = max_overlaps.lt(0.3).nonzero()
                if bg_tubes_indices.nelement() > 0:
                    bg_tubes_indices_picked = (
                        torch.rand(9) * bg_tubes_indices.size(0)).long()
                    bg_tubes_list = [
                        f_tubes[i]
                        for i in bg_tubes_indices[bg_tubes_indices_picked]
                    ]
                    bg_labels = torch.zeros(len(bg_tubes_list))
                else:
                    bg_tubes_list = []
                    bg_labels = torch.Tensor([])
            else:
                bg_tubes_list = []
                bg_labels = torch.Tensor([])

            gt_tubes_list = [[] for i in range(num_actions)]

            # print('n_clips :',n_clips)

            for i in range(n_clips):
                # print('i :',i)
                # print('p_tubes.shape :',p_tubes.shape)
                # print('f_gt_tubes.shape :',f_gt_tubes.shape)
                # print('p_tubes.shape :',p_tubes[i])
                # print('f_gt_tubes.shape :',f_gt_tubes[i])

                overlaps = tube_overlaps(p_tubes[i],
                                         f_gt_tubes[i].type_as(p_tubes))
                # print('overlaps :',overlaps)
                max_overlaps, argmax_overlaps = torch.max(overlaps, 0)

                for j in range(num_actions):
                    if max_overlaps[j] == 1.0:
                        gt_tubes_list[j].append((i, j))
            gt_tubes_list = [i for i in gt_tubes_list if i != []]
            if len(gt_tubes_list) != num_actions:
                print('len(gt_tubes_list :', len(gt_tubes_list))
                print('num_actions :', num_actions)
                print('boxes.cpu().numpy() :', boxes.cpu().numpy())

            # print('gt_tubes_list :',gt_tubes_list)
            ## concate fb, bg tubes
            if gt_tubes_list == [[]]:
                print('overlaps :', overlaps)
                print('max_overlaps :', max_overlaps)
                print('p_tubes :', p_tubes)
                print('f_gt_tubes :', f_gt_tubes)
                exit(-1)
            if bg_tubes_list != []:
                f_tubes = gt_tubes_list + bg_tubes_list
                target_lbl = torch.cat([labels, bg_labels], dim=0)
            else:
                f_tubes = gt_tubes_list
                target_lbl = labels

        # print('num_frames :',num_frames)
        # print('gt_tubes_list :',gt_tubes_list, ' labels :',labels)
        # print('f_tubes :',f_tubes, ' target_lbl :',target_lbl)
        ##############################################

        if len(f_tubes) == 0:
            print('------------------')
            print('    empty tube    ')
            print(' vid_id :', vid_id)
            print('self.calc.thresh :', self.calc.thresh)
            return torch.Tensor([]).cuda(), torch.Tensor([]).cuda(), None
        max_seq = reduce(lambda x, y: y if len(y) > len(x) else x, f_tubes)
        max_length = len(max_seq)

        ## calculate input rois
        ## f_feats.shape : [#f_tubes, max_length, 512]
        # final_video_tubes = torch.zeros(len(f_tubes),6).cuda()
        prob_out = torch.zeros(len(f_tubes), self.n_classes).cuda()
        # final_feats = []
        f_feats = torch.zeros(len(f_tubes), n_clips, 64,
                              self.sample_duration).type_as(features) - 1
        f_feats_len = torch.zeros(len(f_tubes)).type_as(features) - 1

        for i in range(len(f_tubes)):

            seq = f_tubes[i]

            # tmp_tube = torch.Tensor(len(seq),6)

            # feats = torch.Tensor(len(seq),self.p_feat_size)
            feats = torch.Tensor(len(seq), self.p_feat_size,
                                 self.sample_duration)

            for j in range(len(seq)):

                # feats[j] = features[seq[j][0],seq[j][1]].mean(1)
                feats[j] = features[seq[j][0], seq[j][1]]
                # tmp_tube[j] = p_tubes[seq[j]][1:7]

            f_feats_len[i] = len(seq)
            f_feats[i, :len(seq)] = feats
            prob_out[i] = self.act_rnn(
                feats.mean(0).view(1, -1).contiguous().cuda())

            # # feats = torch.mean(feats, dim=0)
            # if mode == 'extract':
            #     final_feats.append(feats)

            # try:
            #     prob_out[i] = self.act_rnn(feats.view(-1).cuda())
            # except Exception as e:
            #     print('feats.shape :',feats.shape)
            #     print('seq :',seq)
            #     for i in range(len(f_tubes)):
            #         print('seq[i] :',f_tubes[i])

            #     print('e :',e)
            #     exit(-1)
            # if prob_out[i,0] != prob_out[i,0]:
            #     print(' prob_out :', prob_out ,' feats :',feats.cpu().numpy(), ' numpy(), feats.shape  :,', feats.shape ,' target_lbl :',target_lbl, \
            #           ' \ntmp_tube :',tmp_tube, )
            #     exit(-1)

        if mode == 'extract':
            # now we use mean so we can have a tensor containing all features
            # final_tubes = final_tubes.cuda()
            target_lbl = target_lbl.cuda()
            # max_length = torch.Tensor([max_length]).cuda()
            return f_feats, target_lbl, f_feats_len
        # ##########################################
        # #           Time for Linear Loss         #
        # ##########################################

        cls_loss = torch.Tensor([0]).cuda()

        final_tubes = final_tubes.type_as(final_poss)
        # # classification probability
        if self.training:
            cls_loss = F.cross_entropy(prob_out.cpu(),
                                       target_lbl.long()).cuda()

        if self.training:
            return None, prob_out, cls_loss,
        else:

            # init padding tubes because of multi-GPU system
            if final_tubes.size(0) > conf.UPDATE_THRESH:
                _, indices = torch.sort(final_scores)
                final_tubes = final_tubes[
                    indices[:conf.UPDATE_THRESH]].contiguous()
                prob_out = prob_out[indices[:conf.UPDATE_THRESH]].contiguous()

            max_prob_out, _ = torch.max(prob_out, 1)

            f_tubes = torch.cat([
                final_tubes.view(-1, num_frames * 4),
                max_prob_out.view(-1, 1).type_as(final_tubes)
            ],
                                dim=1)

            keep = torch.Tensor(py_cpu_nms_tubes(f_tubes.float(), 0.5)).long()
            final_tubes = final_tubes[keep]
            prob_out = prob_out[keep]

            ret_tubes = torch.zeros(1, conf.UPDATE_THRESH, ret_n_frames,
                                    4).type_as(final_tubes).float() - 1
            ret_prob_out = torch.zeros(
                1, conf.UPDATE_THRESH,
                self.n_classes).type_as(final_tubes).float() - 1
            ret_tubes[0, :final_tubes.size(0), :num_frames] = final_tubes
            ret_prob_out[0, :final_tubes.size(0)] = prob_out
            return ret_tubes, ret_prob_out, torch.Tensor([final_tubes.size(0)
                                                          ]).cuda()
def validation(epoch, device, model, dataset_folder, sample_duration, spatial_transform, temporal_transform, boxes_file, splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5 # Intersection Over Union thresh
    iou_thresh_4 = 0.4 # Intersection Over Union thresh
    iou_thresh_3 = 0.3 # Intersection Over Union thresh

    confidence_thresh = 0.2
    vid_name_loader = video_names(dataset_folder, split_txt_path, boxes_file, vid2idx, mode='test', classes_idx=cls2idx)
    # data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=n_devs, num_workers=8*n_devs, pin_memory=True,
    #                                           shuffle=True)    # reset learning rate
    data_loader = torch.utils.data.DataLoader(vid_name_loader, batch_size=1, num_workers=8*n_devs, pin_memory=True,
                                              shuffle=True)    # reset learning rate

    model.eval()

    true_pos = 0
    false_neg = 0

    true_pos_4 = 0
    false_neg_4 = 0

    true_pos_3 = 0
    false_neg_3 = 0

    correct_preds = torch.zeros(1).long().to(device)
    n_preds = torch.zeros(1).long().to(device)
    preds = torch.zeros(1).long().to(device)

    ## 2 rois : 1450
    tubes_sum = 0

    groundtruth_dic = {}
    detection_dic = {}

    for step, data  in enumerate(data_loader):

        # if step == 3:
        #     break
        print('step =>',step)

        vid_id, clips, boxes, n_frames, n_actions, h, w, target =data
        vid_id = vid_id.int()
        clips = clips.to(device)
        boxes = boxes.to(device)
        n_frames = n_frames.to(device)
        n_actions = n_actions.int().to(device)
        target = target.to(device)


        im_info = torch.cat([h,w,torch.ones(clips.size(0)).long()*clips.size(2)]).to(device)
        mode = 'test'

        tubes,  \
        prob_out, n_tubes =  model(n_devs, dataset_folder, \
                                vid_names, clips, vid_id,  \
                                None, \
                                mode, cls2idx, None, n_frames, h, w)
        # get predictions
        for i in range(batch_size):
            _, cls_int = torch.max(prob_out[i],1)

            f_prob = torch.zeros(n_tubes[i].long()).type_as(prob_out)
            for j in range(n_tubes[i].long()):
                f_prob[j] = prob_out[i,j,cls_int[j]]
            cls_int = cls_int[:n_tubes[i].long()]

            keep_ = (f_prob[j].ge(confidence_thresh)) & cls_int.ne(0)

            keep_indices = keep_.nonzero().view(-1)
            f_tubes = torch.cat([cls_int.view(-1,1).type_as(tubes),f_prob.view(-1,1).type_as(tubes), \
                                 tubes[i,:n_tubes[i].long(),:n_frames[i]].contiguous().view(-1,n_frames[i]*4)], dim=1)
            f_tubes = f_tubes[keep_indices].contiguous()
            f_boxes = torch.cat([target.type_as(boxes),boxes[i,:,:n_frames[i],:4].contiguous().view(n_frames[i]*4)]).unsqueeze(0)
            v_name = vid_names[vid_id[i]].split('/')[1]

            

            detection_dic[v_name] = f_tubes.float()
            groundtruth_dic[v_name] = f_boxes.type_as(f_tubes)
            # with open(os.path.join('outputs','detection',v_name+'.json'), 'w') as f:
            #     json.dump(f_tubes.cpu().tolist(), f)
            # with open(os.path.join('outputs','groundtruth',v_name+'.json'), 'w') as f:
            #     json.dump(f_boxes.cpu().tolist(), f)


        if tubes.dim() == 1:
        
            for i in range(clips.size(0)):

                box = boxes[i,:n_actions, :n_frames,:4].contiguous()
                box = box.view(-1,n_frames*4)

                non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
                n_elems = non_empty_indices.nelement()            
                false_neg += n_elems
                false_neg_4 += n_elems
                false_neg_3 += n_elems 
            continue


        for i in range(clips.size(0)):
            print('boxes.shape:',boxes.shape)

            box = boxes[i,:n_actions[i].long(), :n_frames[i].long(),:4].contiguous()
            print('box.shape :',box.shape)
            box = box.view(-1,n_frames[i]*4).contiguous().type_as(tubes)
            print('box.shape :',box.shape)
            print('n_frames :',n_frames)
            print('n_tubes :',n_tubes)
            overlaps = tube_overlaps(tubes[i,:n_tubes[i].long(),:n_frames[i].long()].view(-1,n_frames*4).float(), box.view(-1,n_frames[i]*4).float())
            gt_max_overlaps, argmax_gt_overlaps = torch.max(overlaps, 0)

            non_empty_indices =  box.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()

            print('gt_max_overlaps :',gt_max_overlaps)
            # if gt_max_overlaps[0] > 0.5 :
            #     print('argmax_gt_overlaps :',argmax_gt_overlaps, argmax_gt_overlaps.shape, gt_max_overlaps )
            #     print('tubes_t[max_indices[0]] :',tubes[argmax_gt_overlaps[0]])
            #     print('gt_rois_t[0] :',box[0])

            # 0.5 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos += detected
            false_neg += n_elems - detected

            # 0.4 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_4, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            # print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_4 += detected
            false_neg_4 += n_elems - detected

            # 0.3 thresh
            gt_max_overlaps_ = torch.where(gt_max_overlaps > iou_thresh_3, gt_max_overlaps, torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
            # print('gt_max_overlaps_ :',gt_max_overlaps_)
            detected =  gt_max_overlaps_[non_empty_indices].ne(0).sum().item()
            true_pos_3 += detected
            false_neg_3 += n_elems - detected
            

        #     ### TODO add classification step
        # for k in cls_int.cpu().tolist():
        #     if k == target.data:
        #         print('Found one')
        #         correct_preds += 1
        #     n_preds += 1

    recall    = float(true_pos)     /  (true_pos    + false_neg)
    recall_4  = float(true_pos_4)  / (true_pos_4  + false_neg_4)
    recall_3  = float(true_pos_3)  / (true_pos_3  + false_neg_3)


    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch+1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos, false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_4, false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print('| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'.format(
        true_pos_3, false_neg_3, recall_3))


    print(' -----------------------')
        
    print(' -----------------')
    print('|                  |')
    print('| mAP Thresh : 0.5 |')
    print('|                  |')
    print(' ------------------')
    calculate_mAP(detection_dic, groundtruth_dic, iou_thresh)

    print(' -------------------')
    print('|                   |')
    print('| mAP Thresh : 0.4  |')
    print('|                   |')
    print(' -------------------')
    calculate_mAP(detection_dic, groundtruth_dic, iou_thresh_4)

    print(' ------------------')
    print('|                  |')
    print('| mAP Thresh : 0.3 |')
    print('|                  |')
    print(' ------------------')
    calculate_mAP(detection_dic, groundtruth_dic, iou_thresh_3)