Пример #1
0
def validation(epoch, device, model, dataset_folder, sample_duration,
               spatial_transform, temporal_transform, boxes_file,
               splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5  # Intersection Over Union thresh
    iou_thresh_4 = 0.4  # Intersection Over Union thresh
    iou_thresh_3 = 0.3  # Intersection Over Union thresh

    data = Video(dataset_folder,
                 frames_dur=sample_duration,
                 spatial_transform=spatial_transform,
                 temporal_transform=temporal_transform,
                 json_file=boxes_file,
                 split_txt_path=splt_txt_path,
                 mode='test',
                 classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=True)
    model.eval()

    sgl_true_pos = 0
    sgl_false_neg = 0

    sgl_true_pos_4 = 0
    sgl_false_neg_4 = 0

    sgl_true_pos_3 = 0
    sgl_false_neg_3 = 0

    ## 2 rois : 1450
    tubes_sum = 0
    for step, data in enumerate(data_loader):

        # if step == 1:
        #     break
        # print('step :',step)

        clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data

        clips_ = clips.to(device)
        gt_tubes_r_ = gt_tubes_r.to(device)
        gt_rois_ = gt_rois.float().to(device)
        n_actions_ = n_actions.to(device)
        im_info_ = im_info.to(device)
        start_fr = torch.zeros(clips_.size(0)).to(device)

        tubes, _, _, _, _, _, _, \
        sgl_rois_bbox_pred, _  = model(clips,
                                       im_info,
                                       None, None,
                                       None)

        n_tubes = len(tubes)
        tubes = tubes.view(-1, sample_duration * 4 + 2)
        tubes[:,1:-1] = tube_transform_inv(tubes[:,1:-1],\
                                           sgl_rois_bbox_pred.view(-1,sample_duration*4),(1.0,1.0,1.0,1.0))
        tubes = tubes.view(n_tubes, -1, sample_duration * 4 + 2)
        tubes[:, :, 1:-1] = clip_boxes(tubes[:, :, 1:-1], im_info,
                                       tubes.size(0))

        for i in range(tubes.size(0)):  # how many frames we have

            tubes_t = tubes[i, :, 1:-1].contiguous()
            gt_rois_t = gt_rois_[i, :, :, :4].contiguous().view(
                -1, sample_duration * 4)

            rois_overlaps = tube_overlaps(tubes_t, gt_rois_t)
            gt_max_overlaps_sgl, _ = torch.max(rois_overlaps, 0)
            non_empty_indices = gt_rois_t.ne(0).any(dim=1).nonzero().view(-1)
            n_elems = non_empty_indices.nelement()

            # 0.5
            gt_max_overlaps_sgl_ = torch.where(
                gt_max_overlaps_sgl > iou_thresh, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))

            sgl_detected = gt_max_overlaps_sgl_[non_empty_indices].ne(0).sum()
            sgl_true_pos += sgl_detected
            sgl_false_neg += n_elems - sgl_detected

            # 0.4
            gt_max_overlaps_sgl_ = torch.where(
                gt_max_overlaps_sgl > iou_thresh_4, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))
            sgl_detected = gt_max_overlaps_sgl_.ne(0).sum()
            sgl_true_pos_4 += sgl_detected
            sgl_false_neg_4 += n_elems - sgl_detected

            # 0.3
            gt_max_overlaps_sgl_ = torch.where(
                gt_max_overlaps_sgl > iou_thresh_3, gt_max_overlaps_sgl,
                torch.zeros_like(gt_max_overlaps_sgl).type_as(
                    gt_max_overlaps_sgl))
            sgl_detected = gt_max_overlaps_sgl_.ne(0).sum()
            sgl_true_pos_3 += sgl_detected
            sgl_false_neg_3 += n_elems - sgl_detected

    recall = float(sgl_true_pos) / (float(sgl_true_pos) + float(sgl_false_neg))
    recall_4 = float(sgl_true_pos_4) / (float(sgl_true_pos_4) +
                                        float(sgl_false_neg_4))
    recall_3 = float(sgl_true_pos_3) / (float(sgl_true_pos_3) +
                                        float(sgl_false_neg_3))

    f = open('../images_etc/recall_jhmdb.txt', 'a')
    f.write('| Validation Epoch: {: >3} |\n'.format(epoch + 1))
    f.write('| Threshold : 0.5       |\n')
    f.write(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'
        .format(sgl_true_pos, sgl_false_neg, recall))
    f.write('| Threshold : 0.4       |\n')
    f.write(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'
        .format(sgl_true_pos_4, sgl_false_neg_4, recall_4))
    f.write('| Threshold : 0.3       |\n')
    f.write(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |\n'
        .format(sgl_true_pos_3, sgl_false_neg_3, recall_3))
    f.close()

    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch + 1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print('| Single frame          |')
    print('|                       |')
    print('| In {: >6} steps    :  |'.format(step))
    print('|                       |')
    print('| Threshold : 0.5       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos, sgl_false_neg, recall))
    print('|                       |')
    print('| Threshold : 0.4       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos_4, sgl_false_neg_4, recall_4))
    print('|                       |')
    print('| Threshold : 0.3       |')
    print('|                       |')
    print(
        '| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(sgl_true_pos_3, sgl_false_neg_3, recall_3))

    print(' -----------------------')
Пример #2
0
def training(epoch,
             device,
             model,
             dataset_folder,
             sample_duration,
             spatial_transform,
             temporal_transform,
             boxes_file,
             splt_txt_path,
             cls2idx,
             batch_size,
             n_threads,
             lr,
             mode=1):

    print('sample_duration :', sample_duration)
    data = Video(dataset_folder,
                 frames_dur=sample_duration,
                 spatial_transform=spatial_transform,
                 temporal_transform=temporal_transform,
                 json_file=boxes_file,
                 split_txt_path=splt_txt_path,
                 mode='train',
                 classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=batch_size * 16,
                                              shuffle=True,
                                              num_workers=32,
                                              pin_memory=True)
    # data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size*4,
    #                                           shuffle=True, num_workers=32, pin_memory=True)

    # data_loader = torch.utils.data.DataLoader(data, batch_size=2,
    #                                           shuffle=True, num_workers=0, pin_memory=True)

    model.train()
    loss_temp = 0

    ## 2 rois : 1450
    for step, data in enumerate(data_loader):

        # if step == 2:
        #     exit(-1)
        #     break

        clips, h, w, gt_tubes_r, gt_rois, n_actions, n_frames, im_info = data
        clips_ = clips.to(device)
        gt_tubes_r_ = gt_tubes_r.to(device)
        gt_rois_ = gt_rois.float().to(device)
        n_actions_ = n_actions.to(device)
        im_info_ = im_info.to(device)
        start_fr = torch.zeros(clips_.size(0)).to(device)

        inputs = Variable(clips_)

        tubes, _, \
        rpn_loss_cls,  rpn_loss_bbox, \
        rpn_loss_cls_16,\
        rpn_loss_bbox_16,  rois_label, \
        sgl_rois_bbox_pred, sgl_rois_bbox_loss,  = model(inputs, \
                                                         im_info_,
                                                         gt_tubes_r_, gt_rois_,
                                                         start_fr)
        if mode == 3:
            loss = sgl_rois_bbox_loss.mean()
        elif mode == 4:
            loss = rpn_loss_cls.mean() + rpn_loss_bbox.mean()
        elif mode == 5:
            loss = rpn_loss_cls.mean() + rpn_loss_bbox.mean(
            ) + sgl_rois_bbox_loss.mean()

        loss_temp += loss.item()

        # # backw\ard
        # if step % 4 == 0:
        #     optimizer.zero_grad()

        # loss.backward()

        # if step % 4 == 0:
        #     optimizer.step()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Train Epoch: {} \tLoss: {:.6f}\t lr : {:.6f}'.format(
        epoch + 1, loss_temp / (step + 1), lr))

    return model, loss_temp
Пример #3
0
def training(
    epoch,
    device,
    model,
    dataset_folder,
    sample_duration,
    spatial_transform,
    temporal_transform,
    boxes_file,
    splt_txt_path,
    cls2idx,
    batch_size,
    n_threads,
    lr,
):

    data = Video(dataset_folder,
                 frames_dur=sample_duration,
                 spatial_transform=spatial_transform,
                 temporal_transform=temporal_transform,
                 json_file=boxes_file,
                 split_txt_path=splt_txt_path,
                 mode='train',
                 classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=n_threads,
                                              pin_memory=True)
    n_classes = len(classes)
    resnet_shortcut = 'A'

    model.train()
    loss_temp = 0

    ## 2 rois : 1450
    for step, data in enumerate(data_loader):

        # if step == 2:
        #     break

        clips, (h, w), gt_tubes_r, gt_rois, n_actions, n_frames = data
        clips = clips.to(device)
        gt_tubes_r = gt_tubes_r.to(device)
        gt_rois = gt_rois.to(device)
        # print('gt_tubes_r :',gt_tubes_r)
        # print('gt_tubes :',gt_tubes)
        # h = h.to(device)
        # w = w.to(device)
        # gt_tubes = gt_tubes.to(device)
        n_actions = n_actions.to(device)
        im_info = torch.Tensor([[sample_size, sample_size, n_frames]] *
                               gt_tubes_r.size(1)).to(device)
        # print('gt_rois.shape :',gt_rois.shape )
        inputs = Variable(clips)
        rois,  bbox_pred, cls_prob, \
        rpn_loss_cls, rpn_loss_bbox, \
        act_loss_cls, act_loss_bbox  = model(inputs,
                                             im_info,
                                             gt_tubes_r, gt_rois,
                                             n_actions)
        # print('rois :',rois)
        # print('rpn_loss_bbox :',rpn_loss_bbox)
        # print('rpn_loss_cls :',rpn_loss_cls)
        loss = rpn_loss_cls.mean() + rpn_loss_bbox.mean() + act_loss_bbox.mean(
        ) + act_loss_cls.mean()
        # loss = rpn_loss_cls.mean() + rpn_loss_bbox.mean() + act_loss_bbox.mean()
        loss_temp += loss.item()

        # backw\ard
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Train Epoch: {} \tLoss: {:.6f}\t lr : {:.6f}'.format(
        epoch + 1, loss_temp / step, lr))

    return model, loss_temp
Пример #4
0
def validation(epoch, device, model, dataset_folder, sample_duration,
               spatial_transform, temporal_transform, boxes_file,
               splt_txt_path, cls2idx, batch_size, n_threads):

    iou_thresh = 0.5  # Intersection Over Union thresh
    data = Video(dataset_folder,
                 frames_dur=sample_duration,
                 spatial_transform=spatial_transform,
                 temporal_transform=temporal_transform,
                 json_file=boxes_file,
                 split_txt_path=splt_txt_path,
                 mode='val',
                 classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=n_threads,
                                              pin_memory=True)
    model.eval()
    true_pos = torch.zeros(1).long().to(device)
    false_neg = torch.zeros(1).long().to(device)
    ## 2 rois : 1450
    for step, data in enumerate(data_loader):

        if step == 2:
            break

        clips, (h, w), gt_tubes_r, gt_rois, n_actions, n_frames = data
        clips = clips.to(device)
        gt_tubes_r = gt_tubes_r.to(device)
        n_actions = n_actions.to(device)
        im_info = torch.Tensor([[sample_size, sample_size, n_frames]] *
                               gt_tubes_r.size(1)).to(device)
        inputs = Variable(clips)
        tubes, bbox_pred, cls_prob = model(inputs, im_info, gt_tubes_r,
                                           gt_rois, n_actions)

        overlaps = bbox_overlaps_batch_3d(
            tubes.squeeze(0), gt_tubes_r)  # check one video each time
        gt_max_overlaps, _ = torch.max(overlaps, 1)
        gt_max_overlaps = torch.where(
            gt_max_overlaps > iou_thresh, gt_max_overlaps,
            torch.zeros_like(gt_max_overlaps).type_as(gt_max_overlaps))
        detected = gt_max_overlaps.ne(0).sum()
        n_elements = gt_max_overlaps.nelement()
        true_pos += detected
        false_neg += n_elements - detected

    recall = true_pos.float() / (true_pos.float() + false_neg.float())
    print('recall :', recall)
    print(' -----------------------')
    print('| Validation Epoch: {: >3} | '.format(epoch + 1))
    print('|                       |')
    print('| Proposed Action Tubes |')
    print('|                       |')
    print(
        '| In {: >6} steps    :  |\n| True_pos   --> {: >6} |\n| False_neg  --> {: >6} | \n| Recall     --> {: >6.4f} |'
        .format(step,
                true_pos.cpu().tolist()[0],
                false_neg.cpu().tolist()[0],
                recall.cpu().tolist()[0]))
    print(' -----------------------')
        'swing_baseball', 'walk'
    ]

    cls2idx = {classes[i]: i for i in range(0, len(classes))}

    spatial_transform = Compose([
        Scale(sample_size),  # [Resize(sample_size),
        ToTensor(),
        Normalize(mean, [1, 1, 1])
    ])
    temporal_transform = LoopPadding(sample_duration)

    data = Video(dataset_folder,
                 frames_dur=sample_duration,
                 spatial_transform=spatial_transform,
                 temporal_transform=temporal_transform,
                 json_file=boxes_file,
                 split_txt_path=splt_txt_path,
                 mode='train',
                 classes_idx=cls2idx)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=n_threads,
                                              pin_memory=True)

    n_classes = len(classes)
    resnet_shortcut = 'A'

    ## ResNet 34 init
    model = resnet34(num_classes=400,
                     shortcut_type=resnet_shortcut,