コード例 #1
0
def main():
    # compose data for xgboost
    train, train_target, val, val_target = load_dataset() # np arrays [X 2048] [X 15] (large X)
    # train = np.array([[1,2,3,4,5], [2,3,4,5,6], [4,5,6,7,8], [5,6,7,8,9]])
    # train_target = np.array([3,4,5,6])
    print(train.shape, train_target.shape, val[0].shape, val_target[0].shape)
    regressor = xgb.XGBRegressor(tree_method='gpu_hist', predictor='gpu_predictor')
    # 
    regressor.fit(train, train_target)
    score = regressor.score(train, train_target)
    print('Score:', score)
    # print(regressor.predict(np.array([[6,7,8,9,10]])))
    video_count = len(val)
    corr_mean = 0.0
    for i in range(video_count):
        preds = regressor.predict(val[i])
        preds = torch.from_numpy(preds).cuda()
        preds = torch.cat([preds, preds[-1:]])
        preds = preds.unsqueeze(1) # T -> T 1
        preds = interpolate_output(preds, 1 , 6)
        pred_len = val_target[i].shape[0]
        preds = preds[:pred_len]
        val_ti = torch.from_numpy(val_target[i]).cuda()
        val_ti = val_ti.unsqueeze(1) # T -> T 1
        corr, _ = correlation(preds, val_ti)
        corr_mean += corr.item()
    print('Correlation:', corr_mean/video_count)
コード例 #2
0
def test_result(filt_emp):
    global best_conf
    # do a correlation evalutation
    filt_emp = np.concatenate((filt_emp, filt_emp[-1:]))
    output = interpolate_output(torch.from_numpy(filt_emp), 1, 6)
    cor, _ = correlation(output[:vid_labels.shape[0]],
                         torch.from_numpy(vid_labels))
    print('Correlation:', cor.item())
    return cor.item()
コード例 #3
0
def main_test():
    print('Running test...')
    torch.multiprocessing.set_sharing_strategy('file_system')
    model = Baseline()
    if args.use_swa:
        model = torch.optim.swa_utils.AveragedModel(model)
    model = torch.nn.DataParallel(model).cuda()
    # ckpt structure {epoch, state_dict, optimizer, best_corr}
    if args.resume and os.path.isfile(args.resume):
        print('Load checkpoint:', args.resume)
        ckpt = torch.load(args.resume)
        args.start_epoch = ckpt['epoch']
        best_corr = ckpt['best_corr']
        model.load_state_dict(ckpt['state_dict'])
        print('Loaded ckpt at epoch:', args.start_epoch)
    else:
        print('No model given. Abort!')
        exit(1)

    test_loader = torch.utils.data.DataLoader(
        dataset=EEV_Dataset(
            csv_path=None,
            vidmap_path=args.test_vidmap,
            image_feat_path=args.image_features,
            audio_feat_path=args.audio_features,
            mode='test',
            test_freq=args.test_freq
        ),
        batch_size=None, shuffle=False,
        num_workers=args.workers, pin_memory=False
    )

    model.eval()
    batch_time = AverageMeter()

    t_start = time.time()

    outputs = []
    with torch.no_grad():
        for i, (img_feat, au_feat, frame_count, vid) in enumerate(test_loader):
            img_feat = torch.stack(img_feat).cuda()
            au_feat = torch.stack(au_feat).cuda()
            assert len(au_feat.size()) == 3, 'bad auf %s' % (vid)
            output = model(img_feat, au_feat) # [Clip S 15]
            # rearrange and remove extra padding in the end
            output = rearrange(output, 'Clip S C -> (Clip S) C')
            output = torch.cat([output, output[-1:]]) # repeat the last frame to avoid missing 
            if args.train_freq < args.test_freq:
                # print('interpolating:', output.size()[0], frame_count)
                output = interpolate_output(output, args.train_freq, 6)
            # print('Interpolated:', output.size()[0], frame_count)
            # truncate extra frames
            assert output.size(0) >= frame_count, '{}/{}'.format(output.size(0), frame_count)
            output = output[:frame_count]
            outputs.append((vid, frame_count, output.cpu().detach().numpy()))

            # update statistics
            batch_time.update(time.time() - t_start)
            t_start = time.time()

            if i % args.print_freq == 0:
                output = ('Test: [{0}/{1}]\t'
                          'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
                    i, len(test_loader), batch_time=batch_time))
                print(output)
    
    time_stamps = [0, 166666, 333333, 500000, 666666, 833333]
    time_step = 1000000 # time starts at 0
    header = 'Video ID,Timestamp (milliseconds),amusement,anger,awe,concentration,confusion,contempt,contentment,disappointment,doubt,elation,interest,pain,sadness,surprise,triumph\n'
   
    final_res = {}
    for vid, frame_count, out in outputs:# videos
        video_time = frame_count // 6 + 1
        # print('video', vid, video_time)
        entry_count = 0
        for t in range(video_time): # seconds
            for i in range(6): # frames
                timestamp = time_step * t + time_stamps[i]
                fcc = t * 6 + i
                if fcc >= frame_count:
                    continue
                # print('Frame count', frame_count)
                frame_output = out[fcc]
                frame_output = [str(x) for x in frame_output]
                temp = '{vid},{timestamp},'.format(vid=vid,timestamp=timestamp) + ','.join(frame_output) + '\n'
                # file.write(temp)
                if vid in final_res:
                    final_res[vid].append(temp)
                else:
                    final_res[vid] = [temp]
                entry_count += 1
        assert entry_count == frame_count
    # fixed for now
    missing = [('WKXrnB7alT8', 2919), ('o0ooW14pIa4', 3733), ('GufMoL_MuNE',2038), ('Uee0Tv1rTz8', 1316), ('ScvvOWtb04Q', 152), ('R9kJlLungmo', 3609),('QMW3GuohzzE', 822), ('fjJYTW2n6rk', 4108), ('rbTIMt0VcLw', 1084),('L9cdaj74kLo', 3678), ('l-ka23gU4NA', 1759)]
    for vid, length in missing:
        video_time = length // 6 + 1
        # print('video', vid, video_time)
        for t in range(video_time): # seconds
            for i in range(6): # frames
                timestamp = time_step * t + time_stamps[i]
                fcc = t * 6 + i
                if fcc >= length:
                    continue
                frame_output = ',0'*15
                temp = '{vid},{timestamp}'.format(vid=vid, timestamp=timestamp) + frame_output + '\n'
                # file.write(temp)
                if vid in final_res:
                    final_res[vid].append(temp)
                else:
                    final_res[vid] = [temp]
    print('Write test outputs...')
    with open('test_output.csv', 'w') as file:
        file.write(header)
        temp_vidmap = [x.strip().split(' ') for x in open(args.test_vidmap)]
        temp_vidmap = [x[0] for x in temp_vidmap]
        for vid in tqdm(temp_vidmap):
            for entry in final_res[vid]:
                file.write(entry)
コード例 #4
0
def validate(val_loader, model, accuracy, epoch, log=None, tb_writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    correlations = AverageMeter()

    model.eval()
    t_start = time.time()

    # load 1 video at a time for now (loader is not batched when batch_size == None)
    # but since we split one vid into multiple small clips, the input shape is still batch-like
    with torch.no_grad():
        for i, (img_feat, au_feat, labels, frame_count) in enumerate(val_loader):
            data_time.update(time.time() - t_start)

            # print(type(img_feat), len(img_feat), img_feat[0].size())
            img_feat = torch.stack(img_feat).cuda()
            au_feat = torch.stack(au_feat).cuda()
            labels = torch.stack(labels).cuda()

            output = model(img_feat, au_feat) # [Clip S 15]
            # rearrange and remove extra padding in the end
            output = rearrange(output, 'Clip S C -> (Clip S) C')
            output = torch.cat([output, output[-1:]]) # repeat the last frame to avoid missing 
            if args.train_freq < args.val_freq:
                output = interpolate_output(output, args.train_freq, args.val_freq)
            output = output[:frame_count]
            labels = rearrange(labels, 'Clip S C -> (Clip S) C')[:frame_count]
            
            loss = loss_function(output, labels, validate=True)

            mean_cor, cor = accuracy(output, labels) # mean and per-class correlation
            # update statistics
            losses.update(loss.item())
            assert not math.isnan(mean_cor.item()), 'at epoch %d' % (epoch)
            correlations.update(mean_cor.item())

            batch_time.update(time.time() - t_start)
            t_start = time.time()

            if i % args.print_freq == 0:
                output = ('Val: [{0}/{1}]\t'
                          'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Corr: {corr.val:.4f} ({corr.avg:.4f}, {corr.count})'.format(
                    i, len(val_loader), batch_time=batch_time, loss=losses,
                    corr=correlations))
                print(output)
                if log is not None:
                    log.write(output + '\n')
                    log.flush()

    output = ('Validate Results: Corr:{corr.avg:.4f} Loss {loss.avg:.5f}'
              .format(corr=correlations, loss=losses))
    print(output)
    if log is not None:
        log.write(output + '\n')
        log.flush()

    if tb_writer is not None:
        tb_writer.add_scalar('loss/validate', losses.avg, epoch)
        tb_writer.add_scalar('acc/validate_corr', correlations.avg, epoch)

    return correlations.avg, losses.avg
コード例 #5
0
def validate(val_loader, model, accuracy, epoch, log, tb_writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    correlations = AverageMeter()
    pc_corr = []
    model.eval()
    t_start = time.time()

    # load 1 video at a time for now (loader is not batched when batch_size == None)
    # but since we split one vid into multiple small clips, the input shape is still batch-like
    with torch.no_grad():
        for i, (img_feat, au_feat, labels,
                frame_count) in enumerate(val_loader):
            data_time.update(time.time() - t_start)

            # print(type(img_feat), len(img_feat), img_feat[0].size())
            img_feat = torch.stack(img_feat).cuda()
            au_feat = torch.stack(au_feat).cuda()
            labels = torch.stack(labels).cuda()
            # normalize
            # img_feat = F.normalize(img_feat, 2, dim=2)
            # au_feat = F.normalize(au_feat, 2, dim=2)
            if args.model == 'EmoBase':
                output, _, _ = model(img_feat, au_feat)
            else:
                if args.repeat_sample:
                    img_feat = rearrange(
                        img_feat, 'Clip R S C -> (Clip R) S C')  # Clip R S C
                output = model_output(model, img_feat, au_feat)

            # rearrange and remove extra padding in the end
            if args.repeat_sample:
                output = rearrange(output, '(Clip R) S C -> (Clip S R) C', R=6)
            else:
                output = rearrange(output, 'Clip S C -> (Clip S) C')

            output = torch.cat([output, output[-1:]
                                ])  # repeat the last frame to avoid missing
            if not args.repeat_sample and args.train_freq < args.val_freq:
                output = interpolate_output(output, args.train_freq,
                                            args.val_freq)
            output = output[:frame_count]
            labels = rearrange(labels, 'Clip S C -> (Clip S) C')[:frame_count]

            loss, c_loss, t_loss = loss_function(output,
                                                 labels,
                                                 args,
                                                 validate=True)

            mean_cor, cor = accuracy(output,
                                     labels)  # mean and per-class correlation
            if args.cls_mask != None:
                # mask = [1 if x in args.cls_mask else 0 for x in range(0, 15)]
                mean_cor = torch.mean(cor[args.cls_mask])
            pc_corr.append(cor)
            # update statistics
            losses.update(loss.item())
            assert not math.isnan(mean_cor.item()), 'at epoch %d' % (epoch)
            correlations.update(mean_cor.item())

            batch_time.update(time.time() - t_start)
            t_start = time.time()

            if i % args.print_freq == 0:
                output = ('Val: [{0}/{1}]\t'
                          'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Corr: {corr.val:.4f} ({corr.avg:.4f})'.format(
                              i,
                              len(val_loader),
                              batch_time=batch_time,
                              loss=losses,
                              corr=correlations))
                print(output)
                if log is not None:
                    log.write(output + '\n')
                    log.flush()

    output = (
        'Validate Results: Corr:{corr.avg:.4f} Loss {loss.avg:.5f}'.format(
            corr=correlations, loss=losses))
    print(output)
    pc_corr = torch.stack(pc_corr, dim=0)
    pc_corr = torch.mean(pc_corr, dim=0).cpu().numpy()
    print('Per-Class Corr:', pc_corr)
    if log is not None:
        log.write(output + '\n')
        log.flush()

    if tb_writer is not None:
        tb_writer.add_scalar('loss/validate', losses.avg, epoch)
        tb_writer.add_scalar('acc/validate_corr', correlations.avg, epoch)

    return correlations.avg