Пример #1
0
def main():
    args = init_helper.get_arguments()

    init_helper.init_logger(args.model_dir, args.log_file)
    init_helper.set_random_seed(args.seed)

    logger.info(vars(args))
    model = get_model(args.model, **vars(args))
    model = model.eval().to(args.device)

    for split_path in args.splits:
        split_path = Path(split_path)
        splits = data_helper.load_yaml(split_path)

        stats = data_helper.AverageMeter('fscore', 'diversity')

        for split_idx, split in enumerate(splits):
            ckpt_path = data_helper.get_ckpt_path(args.model_dir, split_path,
                                                  split_idx)
            state_dict = torch.load(str(ckpt_path),
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(state_dict)

            val_set = data_helper.VideoDataset(split['test_keys'])
            val_loader = data_helper.DataLoader(val_set, shuffle=False)

            fscore, diversity = evaluate(model, val_loader, args.nms_thresh,
                                         args.device)
            stats.update(fscore=fscore, diversity=diversity)

            logger.info(f'{split_path.stem} split {split_idx}: diversity: '
                        f'{diversity:.4f}, F-score: {fscore:.4f}')

        logger.info(f'{split_path.stem}: diversity: {stats.diversity:.4f}, '
                    f'F-score: {stats.fscore:.4f}')
Пример #2
0
def evaluate(model, val_loader, nms_thresh, device):
    model.eval()
    stats = data_helper.AverageMeter('fscore', 'diversity')

    with torch.no_grad():
        for test_key, seq, _, cps, n_frames, nfps, picks, user_summary in val_loader:
            seq_len = len(seq)
            seq_torch = torch.from_numpy(seq).unsqueeze(0).to(device)

            pred_cls, pred_bboxes = model.predict(seq_torch)

            pred_bboxes = np.clip(pred_bboxes, 0,
                                  seq_len).round().astype(np.int32)

            pred_cls, pred_bboxes = bbox_helper.nms(pred_cls, pred_bboxes,
                                                    nms_thresh)
            pred_summ = vsumm_helper.bbox2summary(seq_len, pred_cls,
                                                  pred_bboxes, cps, n_frames,
                                                  nfps, picks)

            eval_metric = 'avg' if 'tvsum' in test_key else 'max'
            fscore = vsumm_helper.get_summ_f1score(pred_summ, user_summary,
                                                   eval_metric)

            pred_summ = vsumm_helper.downsample_summ(pred_summ)
            diversity = vsumm_helper.get_summ_diversity(pred_summ, seq)
            stats.update(fscore=fscore, diversity=diversity)

    return stats.fscore, stats.diversity
Пример #3
0
def main():
    args = parse_args()
    logger.add(
        f"binary_classification_{args.model_name}_{args.splits.split('/')[-1]}.log"
    )

    # build model
    model = ModelBuilder(args.model_name, num_classes=2).get_model()
    model.load_state_dict(torch.load(args.model_weights))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()
    summarizator = SoccerSummarizator(
        model=model,
        device=device,
        batch_size=args.batch_size,
        binary_closing=True,
        classification_type="goals_from_celebration")
    split_path = Path(args.splits)
    splits = data_helper.load_yaml(split_path)
    stats = data_helper.AverageMeter('fscore', 'diversity')
    for split_idx, split in enumerate(splits):
        val_set = data_helper.VideoDataset(split['test_keys'])
        val_loader = data_helper.DataLoader(val_set, shuffle=False)
        fscore, diversity = evaluate(summarizator, val_loader)
        logger.info(
            f"F-score:{round(fscore, 2)} Diversity:{round(diversity, 2)} on {split_idx} split"
        )
        stats.update(fscore=fscore, diversity=diversity)
    logger.info(f'{split_path.stem}: diversity: {stats.diversity:.4f}, '
                f'F-score: {stats.fscore:.4f}')
def main():
    args = init_helper.get_arguments()

    init_helper.init_logger(args.model_dir, args.log_file)
    init_helper.set_random_seed(args.seed)

    logger.info(vars(args))

    model_dir = Path(args.model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    data_helper.get_ckpt_dir(model_dir).mkdir(parents=True, exist_ok=True)
    trainer = get_trainer(args.model)

    data_helper.dump_yaml(vars(args), model_dir / 'args.yml')

    for split_path in args.splits:
        split_path = Path(split_path)
        splits = data_helper.load_yaml(split_path)

        results = {}
        stats = data_helper.AverageMeter('fscore')

        for split_idx, split in enumerate(splits):
            logger.info(f'Start training on {split_path.stem}: split {split_idx}')
            ckpt_path = data_helper.get_ckpt_path(model_dir, split_path, split_idx)
            fscore = trainer(args, split, ckpt_path)
            stats.update(fscore=fscore)
            results[f'split{split_idx}'] = float(fscore)

        results['mean'] = float(stats.fscore)
        data_helper.dump_yaml(results, model_dir / f'{split_path.stem}.yml')

        logger.info(f'Training done on {split_path.stem}. F-score: {stats.fscore:.4f}')
Пример #5
0
def test_average_meter():
    num = 100
    xs = [random.randint(0, 100) for _ in range(num)]
    ys = [random.randint(0, 100) for _ in range(num)]

    avg_meter = data_helper.AverageMeter('x', 'y')
    assert avg_meter.x == 0.0
    assert avg_meter.y == 0.0

    for x, y in zip(xs, ys):
        avg_meter.update(x=x, y=y)

    assert avg_meter.x == sum(xs) / num
    assert avg_meter.y == sum(ys) / num
Пример #6
0
def evaluate(summarizator, val_loader):
    stats = data_helper.AverageMeter('fscore', 'diversity')
    with torch.no_grad():
        for test_key, seq, _, cps, n_frames, nfps, picks, user_summary in val_loader:
            video_name = os.path.basename(test_key)
            video_path = os.path.join(ORIGINAL_VIDEOS_PATH,
                                      f"{video_name}.mp4")
            print(video_path)

            video = skvideo.io.vreader(video_path)
            videometadata = skvideo.io.ffprobe(video_path)

            summarizator(video, videometadata=videometadata)
            pred_summ = summarizator.summary
            eval_metric = 'max'
            fscore = vsumm_helper.get_summ_f1score(pred_summ, user_summary,
                                                   eval_metric)
            stats.update(fscore=fscore)
            # save video
            # writer = VideoWriter(video_path, f"summaries/{video_name}_goals_summmary.mp4")
            # writer(summarizator.summary)

    return stats.fscore, stats.diversity
Пример #7
0
def train(args, split, save_path):
    model = DSNet(base_model=args.base_model,
                  num_feature=args.num_feature,
                  num_hidden=args.num_hidden,
                  anchor_scales=args.anchor_scales,
                  num_head=args.num_head)
    model = model.to(args.device)

    model.apply(xavier_init)

    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(parameters,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    max_val_fscore = -1

    train_set = data_helper.VideoDataset(split['train_keys'])
    train_loader = data_helper.DataLoader(train_set, shuffle=True)

    val_set = data_helper.VideoDataset(split['test_keys'])
    val_loader = data_helper.DataLoader(val_set, shuffle=False)

    for epoch in range(args.max_epoch):
        model.train()
        stats = data_helper.AverageMeter('loss', 'cls_loss', 'loc_loss')

        for _, seq, gtscore, cps, n_frames, nfps, picks, _ in train_loader:
            keyshot_summ = vsumm_helper.get_keyshot_summ(
                gtscore, cps, n_frames, nfps, picks)
            target = vsumm_helper.downsample_summ(keyshot_summ)

            if not target.any():
                continue

            target_bboxes = bbox_helper.seq2bbox(target)
            target_bboxes = bbox_helper.lr2cw(target_bboxes)
            anchors = anchor_helper.get_anchors(target.size,
                                                args.anchor_scales)
            # Get class and location label for positive samples
            cls_label, loc_label = anchor_helper.get_pos_label(
                anchors, target_bboxes, args.pos_iou_thresh)

            # Get negative samples
            num_pos = cls_label.sum()
            cls_label_neg, _ = anchor_helper.get_pos_label(
                anchors, target_bboxes, args.neg_iou_thresh)
            cls_label_neg = anchor_helper.get_neg_label(
                cls_label_neg, int(args.neg_sample_ratio * num_pos))

            # Get incomplete samples
            cls_label_incomplete, _ = anchor_helper.get_pos_label(
                anchors, target_bboxes, args.incomplete_iou_thresh)
            cls_label_incomplete[cls_label_neg != 1] = 1
            cls_label_incomplete = anchor_helper.get_neg_label(
                cls_label_incomplete,
                int(args.incomplete_sample_ratio * num_pos))

            cls_label[cls_label_neg == -1] = -1
            cls_label[cls_label_incomplete == -1] = -1

            cls_label = torch.tensor(cls_label,
                                     dtype=torch.float32).to(args.device)
            loc_label = torch.tensor(loc_label,
                                     dtype=torch.float32).to(args.device)

            seq = torch.tensor(seq, dtype=torch.float32).unsqueeze(0).to(
                args.device)

            pred_cls, pred_loc = model(seq)

            loc_loss = calc_loc_loss(pred_loc, loc_label, cls_label)
            cls_loss = calc_cls_loss(pred_cls, cls_label)

            loss = cls_loss + args.lambda_reg * loc_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            stats.update(loss=loss.item(),
                         cls_loss=cls_loss.item(),
                         loc_loss=loc_loss.item())

        val_fscore, _ = evaluate(model, val_loader, args.nms_thresh,
                                 args.device)

        if max_val_fscore < val_fscore:
            max_val_fscore = val_fscore
            torch.save(model.state_dict(), str(save_path))

        logger.info(
            f'Epoch: {epoch}/{args.max_epoch} '
            f'Loss: {stats.cls_loss:.4f}/{stats.loc_loss:.4f}/{stats.loss:.4f} '
            f'F-score cur/max: {val_fscore:.4f}/{max_val_fscore:.4f}')

    return max_val_fscore
Пример #8
0
def train(args, split, save_path):
    model = DSNetAF(base_model=args.base_model,
                    num_feature=args.num_feature,
                    num_hidden=args.num_hidden,
                    num_head=args.num_head)
    model = model.to(args.device)

    model.train()

    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(parameters,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    max_val_fscore = -1

    train_set = data_helper.VideoDataset(split['train_keys'])
    train_loader = data_helper.DataLoader(train_set, shuffle=True)

    val_set = data_helper.VideoDataset(split['test_keys'])
    val_loader = data_helper.DataLoader(val_set, shuffle=False)
    if args.saved_ckpt:
        state_dict = torch.load(str(args.saved_ckpt),
                                map_location=torch.device(args.device))
        model.load_state_dict(state_dict)
        print(f'Model loaded from{str(args.saved_ckpt)}')
    for epoch in range(args.max_epoch):
        model.train()
        stats = data_helper.AverageMeter('loss', 'cls_loss', 'loc_loss',
                                         'ctr_loss')

        for _, seq, gtscore, change_points, n_frames, nfps, picks, _ in train_loader:
            keyshot_summ = vsumm_helper.get_keyshot_summ(
                gtscore, change_points, n_frames, nfps, picks)
            target = vsumm_helper.downsample_summ(keyshot_summ)

            if not target.any():
                continue

            seq = torch.tensor(seq, dtype=torch.float32).unsqueeze(0).to(
                args.device)

            cls_label = target
            loc_label = anchor_free_helper.get_loc_label(target)
            ctr_label = anchor_free_helper.get_ctr_label(target, loc_label)

            pred_cls, pred_loc, pred_ctr = model(seq)

            cls_label = torch.tensor(cls_label,
                                     dtype=torch.float32).to(args.device)
            loc_label = torch.tensor(loc_label,
                                     dtype=torch.float32).to(args.device)
            ctr_label = torch.tensor(ctr_label,
                                     dtype=torch.float32).to(args.device)

            cls_loss = calc_cls_loss(pred_cls, cls_label, args.cls_loss)
            loc_loss = calc_loc_loss(pred_loc, loc_label, cls_label,
                                     args.reg_loss)
            ctr_loss = calc_ctr_loss(pred_ctr, ctr_label, cls_label)

            loss = cls_loss + args.lambda_reg * loc_loss + args.lambda_ctr * ctr_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            stats.update(loss=loss.item(),
                         cls_loss=cls_loss.item(),
                         loc_loss=loc_loss.item(),
                         ctr_loss=ctr_loss.item())

        val_fscore, _ = evaluate(model, val_loader, args.nms_thresh,
                                 args.device)

        if max_val_fscore < val_fscore:
            max_val_fscore = val_fscore
            torch.save(model.state_dict(), str(save_path))

        logger.info(
            f'Epoch: {epoch}/{args.max_epoch} '
            f'Loss: {stats.cls_loss:.4f}/{stats.loc_loss:.4f}/{stats.ctr_loss:.4f}/{stats.loss:.4f} '
            f'F-score cur/max: {val_fscore:.4f}/{max_val_fscore:.4f}')

    return max_val_fscore