def make_sampler(args) -> FrameSampler: if args.sampler == "full": return FullVideoSampler() elif args.sampler == "clip": return ClipSampler(args.sampler_clip_length) elif args.sampler == "tsn": return TemporalSegmentSampler( args.sampler_tsn_segment_count, args.sampler_tsn_segment_length ) else: raise ValueError("Expected --sampler to be one of 'full', 'clip', or 'tsn'")
def temporal_segment_sampler(): segment_count = st.integers(1, 100).example() snippet_length = st.integers(1, 1000).example() return TemporalSegmentSampler(segment_count, snippet_length)
def sample(self, video_length, segment_count, snippet_length, test=False): sampler = TemporalSegmentSampler(segment_count, snippet_length, test=test) frame_idx = frame_idx_to_list(sampler.sample(video_length)) return frame_idx
def test_segment_count_should_be_greater_than_0(self): with pytest.raises(ValueError): TemporalSegmentSampler(0, 1)
def test_str(self): assert ( str(TemporalSegmentSampler(1, 5, test=True)) == "TemporalSegmentSampler(segment_count=1, snippet_length=5, test=True)" )
def test_repr(self): assert ( repr(TemporalSegmentSampler(1, 5, test=False)) == "TemporalSegmentSampler(segment_count=1, snippet_length=5, test=False)" )
def test_raises_value_error_when_sampling_from_a_video_of_0_frames(self): sampler = TemporalSegmentSampler(1, 1) with pytest.raises(ValueError): sampler.sample(0)
def instantiate(self): return TemporalSegmentSampler( segment_count=self.frame_count, snippet_length=self.snippet_length, test=self.test_mode, )