def __init__( self, datapipe: Iterable[Union[Cut, Tuple[Cut]]], max_frames: int = None, max_samples: int = None, max_duration: Seconds = None, max_cuts: Optional[int] = None, drop_last: bool = False, ) -> None: self.datapipe = datapipe self.reuse_cuts_buffer = deque() self.drop_last = drop_last self.max_cuts = max_cuts self.diagnostics = SamplingDiagnostics() self.time_constraint = TimeConstraint(max_duration=max_duration, max_frames=max_frames, max_samples=max_samples)
def __init__( self, iterator: Iterable, predicate: Callable[[Cut], bool], diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.iterator = iterator self.predicate = predicate self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) assert callable( self.predicate ), f"LazyFilter: 'predicate' arg must be callable (got {predicate})."
def __init__( self, cuts: Iterable[Union[Cut, Tuple[Cut]]], duration_bins: List[Seconds], max_duration: float, max_cuts: Optional[int] = None, drop_last: bool = False, buffer_size: int = 10000, strict: bool = False, rng: random.Random = None, diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.cuts = cuts self.duration_bins = duration_bins self.max_duration = max_duration self.max_cuts = max_cuts self.drop_last = drop_last self.buffer_size = buffer_size self.strict = strict self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) if rng is None: rng = random.Random() self.rng = rng assert duration_bins == sorted(duration_bins), ( f"Argument list for 'duration_bins' is expected to be in " f"sorted order (got: {duration_bins}).") # A heuristic diagnostic first, for finding the right settings. mean_duration = np.mean(duration_bins) expected_buffer_duration = buffer_size * mean_duration expected_bucket_duration = expected_buffer_duration / ( len(duration_bins) + 1) if expected_bucket_duration < max_duration: warnings.warn( f"Your 'buffer_size' setting of {buffer_size} might be too low to satisfy " f"a 'max_duration' of {max_duration} (given our best guess).") # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`). self.buckets: List[Deque[Union[Cut, Tuple[Cut]]]] = [ deque() for _ in range(len(duration_bins) + 1) ]
class DurationBatcher: def __init__( self, datapipe: Iterable[Union[Cut, Tuple[Cut]]], max_frames: int = None, max_samples: int = None, max_duration: Seconds = None, max_cuts: Optional[int] = None, drop_last: bool = False, ) -> None: self.datapipe = datapipe self.reuse_cuts_buffer = deque() self.drop_last = drop_last self.max_cuts = max_cuts self.diagnostics = SamplingDiagnostics() self.time_constraint = TimeConstraint(max_duration=max_duration, max_frames=max_frames, max_samples=max_samples) def __iter__(self) -> Generator[Union[CutSet, Tuple[CutSet]], None, None]: self.cuts_iter = iter(self.datapipe) try: while True: yield self._collect_batch() except StopIteration: pass self.cuts_iter = None def _collect_batch(self) -> Union[CutSet, Tuple[CutSet]]: def detuplify( cuts: List[Union[Cut, Tuple[Cut]]]) -> Union[CutSet, Tuple[CutSet]]: """Helper to do the right thing whether we sampled single cuts or cut tuples.""" if isinstance(cuts[0], tuple): if len(cuts[0]) == 1: cuts = CutSet.from_cuts(cs[0] for cs in cuts) self.diagnostics.keep(cuts) return cuts else: tuple_of_cut_lists = list(zip(*cuts)) self.diagnostics.keep(cuts[0]) return tuple( [CutSet.from_cuts(cs) for cs in tuple_of_cut_lists]) else: self.diagnostics.keep(cuts) return CutSet.from_cuts(cuts) self.time_constraint.reset() cuts = [] while True: # Check that we have not reached the end of the dataset. try: if self.reuse_cuts_buffer: next_cut_or_tpl = self.reuse_cuts_buffer.popleft() else: # If this doesn't raise (typical case), it's not the end: keep processing. next_cut_or_tpl = next(self.cuts_iter) except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if cuts and (not self.drop_last or self.time_constraint.close_to_exceeding()): # We have a partial batch and we can return it. return detuplify(cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. try: self.diagnostics.discard(cuts) except AttributeError: # accounts for cuts being a tuple self.diagnostics.discard(cuts[0]) raise StopIteration() # Track the duration/frames/etc. constraints. self.time_constraint.add(next_cut_or_tpl[0] if isinstance( next_cut_or_tpl, tuple) else next_cut_or_tpl) next_num_cuts = len(cuts) + 1 # Did we exceed the max_frames and max_cuts constraints? if not self.time_constraint.exceeded() and ( self.max_cuts is None or next_num_cuts <= self.max_cuts): # No - add the next cut to the batch, and keep trying. cuts.append(next_cut_or_tpl) else: # Yes. Do we have at least one cut in the batch? if cuts: # Yes. Return the batch, but keep the currently drawn cut for later. self.reuse_cuts_buffer.append(next_cut_or_tpl) break else: # No. We'll warn the user that the constrains might be too tight, # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates " "the max_frames, max_cuts, or max_duration constraints - " "we'll return it anyway. " "Consider increasing max_frames/max_cuts/max_duration." ) cuts.append(next_cut_or_tpl) return detuplify(cuts)