Exemplo n.º 1
0
    def test_min_score(self):
        query_frame = Frame(
            frame_id=0,
            timestamp=0,
            ball_coordinates=Point(x=0, y=0),
            home_players_coordinates={},
            away_players_coordinates={}
        )

        dataset = TrackingDataset(
            dataset_id="test",
            frames=[
                # this frame will score less than 90
                replace(
                    query_frame,
                    frame_id=1,
                    ball_coordinates=Point(x=0, y=68)
                ),
                replace(
                    query_frame,
                    frame_id=3
                )
            ]
        )

        resultset = SearchEngine.search(
            dataset,
            matcher=MunkresMatcher(query_frame),
            min_score=90
        )

        assert resultset.results == [Result(frame_id=3, score=100)]
Exemplo n.º 2
0
    def test_single_match(self):
        query_frame = Frame(
            frame_id=0,
            timestamp=0,
            ball_coordinates=Point(x=0, y=0),
            home_players_coordinates={},
            away_players_coordinates={}
        )

        dataset = TrackingDataset(
            dataset_id="test",
            frames=[
                replace(
                    query_frame,
                    frame_id=1
                )
            ]
        )

        resultset = SearchEngine.search(
            dataset,
            matcher=MunkresMatcher(query_frame)
        )

        assert resultset.results == [Result(frame_id=1, score=100)]
Exemplo n.º 3
0
    def test_unknown_frame(self):
        search_service = self._init_search_service()
        search_service.repository.save(
            TrackingDataset(dataset_id="test", frames=[]))

        with pytest.raises(IndexError):
            search_service.search_by_frame("test", 1)
Exemplo n.º 4
0
    def test_same_frame(self):
        repository = MemoryRepository()
        repository.save(
            TrackingDataset(dataset_id="test",
                            frames=[
                                Frame(frame_id=1,
                                      timestamp=0,
                                      ball_coordinates=Point(x=0, y=0),
                                      home_players_coordinates={},
                                      away_players_coordinates={})
                            ]))

        search_service = self._init_search_service(repository=repository)

        resultset = search_service.search_by_frame("test", 1)
        assert len(resultset.results) == 0
Exemplo n.º 5
0
    def test_single_frame(self):
        search_service = self._init_search_service()
        search_service.repository.save(
            TrackingDataset(dataset_id="test",
                            frames=[
                                Frame(frame_id=1,
                                      timestamp=0,
                                      ball_coordinates=Point(x=0, y=0),
                                      home_players_coordinates={},
                                      away_players_coordinates={}),
                                Frame(frame_id=2,
                                      timestamp=0.1,
                                      ball_coordinates=Point(x=1, y=0),
                                      home_players_coordinates={},
                                      away_players_coordinates={})
                            ]))

        resultset = search_service.search_by_frame("test", 1)
        assert resultset.results == [Result(frame_id=2, score=99.5)]
Exemplo n.º 6
0
    def test_empty_resultset(self):
        query_frame = Frame(
            frame_id=0,
            timestamp=0,
            ball_coordinates=Point(x=0, y=0),
            home_players_coordinates={},
            away_players_coordinates={}
        )

        dataset = TrackingDataset(
            dataset_id="test",
            frames=[]
        )

        resultset = SearchEngine.search(
            dataset,
            matcher=MunkresMatcher(query_frame)
        )

        assert len(resultset.results) == 0
Exemplo n.º 7
0
    def parse(self,
              home_data: str,
              away_data: str,
              sample_rate=1 / 25,
              **kwargs) -> TrackingDataset:
        frames = []

        home_jersey_numbers = []
        away_jersey_numbers = []

        for line_idx, (home_line, away_line) in enumerate(
                zip(home_data.splitlines(keepends=False),
                    away_data.splitlines(keepends=False))):
            if line_idx == 0 or line_idx == 2:
                continue

            home_period, home_frame_id, home_time, *home_players, home_ball_x, home_ball_y = home_line.split(
                ",")
            away_period, away_frame_id, away_time, *away_players, away_ball_x, away_ball_y = away_line.split(
                ",")

            if line_idx == 1:
                home_jersey_numbers = [
                    int(number) for number in home_players[::2]
                ]
                away_jersey_numbers = [
                    int(number) for number in away_players[::2]
                ]
                continue

            if home_frame_id != away_frame_id:
                raise Exception(
                    f"Input file mismatch (frame_id): {home_frame_id} != {away_frame_id}"
                )

            if home_ball_x != away_ball_x or home_ball_y != away_ball_y:
                raise Exception(
                    f"Input file mismatch (ball): ({home_ball_x}, {home_ball_y}) != ({away_ball_x}, {away_ball_y})"
                )

            if (line_idx - 3) % (1 / sample_rate) != 0:
                continue

            if home_ball_x == 'NaN' or away_ball_y == 'NaN':
                continue

            frame = Frame(
                frame_id=int(home_frame_id),
                timestamp=float(home_time),
                home_players_coordinates={
                    home_jersey_numbers[int(i / 2)]:
                    self._create_point(home_players[i], home_players[i + 1])
                    for i in range(0, len(home_players), 2)
                    if home_players[i] != 'NaN'
                },
                away_players_coordinates={
                    away_jersey_numbers[int(i / 2)]:
                    self._create_point(away_players[i], away_players[i + 1])
                    for i in range(0, len(away_players), 2)
                    if away_players[i] != 'NaN'
                },
                ball_coordinates=self._create_point(home_ball_x, away_ball_y))
            frames.append(frame)

        return TrackingDataset(frames=frames, **kwargs)