예제 #1
0
 def __init__(
     self,
     datapipe: Iterable[Union[Cut, Tuple[Cut]]],
     max_frames: int = None,
     max_samples: int = None,
     max_duration: Seconds = None,
     max_cuts: Optional[int] = None,
     drop_last: bool = False,
 ) -> None:
     self.datapipe = datapipe
     self.reuse_cuts_buffer = deque()
     self.drop_last = drop_last
     self.max_cuts = max_cuts
     self.diagnostics = SamplingDiagnostics()
     self.time_constraint = TimeConstraint(max_duration=max_duration,
                                           max_frames=max_frames,
                                           max_samples=max_samples)
예제 #2
0
    def __init__(
        self,
        iterator: Iterable,
        predicate: Callable[[Cut], bool],
        diagnostics: Optional[SamplingDiagnostics] = None,
    ) -> None:
        self.iterator = iterator
        self.predicate = predicate
        self.diagnostics = ifnone(diagnostics, SamplingDiagnostics())

        assert callable(
            self.predicate
        ), f"LazyFilter: 'predicate' arg must be callable (got {predicate})."
예제 #3
0
    def __init__(
        self,
        cuts: Iterable[Union[Cut, Tuple[Cut]]],
        duration_bins: List[Seconds],
        max_duration: float,
        max_cuts: Optional[int] = None,
        drop_last: bool = False,
        buffer_size: int = 10000,
        strict: bool = False,
        rng: random.Random = None,
        diagnostics: Optional[SamplingDiagnostics] = None,
    ) -> None:
        self.cuts = cuts
        self.duration_bins = duration_bins
        self.max_duration = max_duration
        self.max_cuts = max_cuts
        self.drop_last = drop_last
        self.buffer_size = buffer_size
        self.strict = strict
        self.diagnostics = ifnone(diagnostics, SamplingDiagnostics())
        if rng is None:
            rng = random.Random()
        self.rng = rng

        assert duration_bins == sorted(duration_bins), (
            f"Argument list for 'duration_bins' is expected to be in "
            f"sorted order (got: {duration_bins}).")

        # A heuristic diagnostic first, for finding the right settings.
        mean_duration = np.mean(duration_bins)
        expected_buffer_duration = buffer_size * mean_duration
        expected_bucket_duration = expected_buffer_duration / (
            len(duration_bins) + 1)
        if expected_bucket_duration < max_duration:
            warnings.warn(
                f"Your 'buffer_size' setting of {buffer_size} might be too low to satisfy "
                f"a 'max_duration' of {max_duration} (given our best guess).")

        # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`).
        self.buckets: List[Deque[Union[Cut, Tuple[Cut]]]] = [
            deque() for _ in range(len(duration_bins) + 1)
        ]
예제 #4
0
class DurationBatcher:
    def __init__(
        self,
        datapipe: Iterable[Union[Cut, Tuple[Cut]]],
        max_frames: int = None,
        max_samples: int = None,
        max_duration: Seconds = None,
        max_cuts: Optional[int] = None,
        drop_last: bool = False,
    ) -> None:
        self.datapipe = datapipe
        self.reuse_cuts_buffer = deque()
        self.drop_last = drop_last
        self.max_cuts = max_cuts
        self.diagnostics = SamplingDiagnostics()
        self.time_constraint = TimeConstraint(max_duration=max_duration,
                                              max_frames=max_frames,
                                              max_samples=max_samples)

    def __iter__(self) -> Generator[Union[CutSet, Tuple[CutSet]], None, None]:
        self.cuts_iter = iter(self.datapipe)
        try:
            while True:
                yield self._collect_batch()
        except StopIteration:
            pass
        self.cuts_iter = None

    def _collect_batch(self) -> Union[CutSet, Tuple[CutSet]]:
        def detuplify(
            cuts: List[Union[Cut,
                             Tuple[Cut]]]) -> Union[CutSet, Tuple[CutSet]]:
            """Helper to do the right thing whether we sampled single cuts or cut tuples."""
            if isinstance(cuts[0], tuple):
                if len(cuts[0]) == 1:
                    cuts = CutSet.from_cuts(cs[0] for cs in cuts)
                    self.diagnostics.keep(cuts)
                    return cuts
                else:
                    tuple_of_cut_lists = list(zip(*cuts))
                    self.diagnostics.keep(cuts[0])
                    return tuple(
                        [CutSet.from_cuts(cs) for cs in tuple_of_cut_lists])
            else:
                self.diagnostics.keep(cuts)
                return CutSet.from_cuts(cuts)

        self.time_constraint.reset()
        cuts = []
        while True:
            # Check that we have not reached the end of the dataset.
            try:
                if self.reuse_cuts_buffer:
                    next_cut_or_tpl = self.reuse_cuts_buffer.popleft()
                else:
                    # If this doesn't raise (typical case), it's not the end: keep processing.
                    next_cut_or_tpl = next(self.cuts_iter)
            except StopIteration:
                # No more cuts to sample from: if we have a partial batch,
                # we may output it, unless the user requested to drop it.
                # We also check if the batch is "almost there" to override drop_last.
                if cuts and (not self.drop_last
                             or self.time_constraint.close_to_exceeding()):
                    # We have a partial batch and we can return it.
                    return detuplify(cuts)
                else:
                    # There is nothing more to return or it's discarded:
                    # signal the iteration code to stop.
                    try:
                        self.diagnostics.discard(cuts)
                    except AttributeError:  # accounts for cuts being a tuple
                        self.diagnostics.discard(cuts[0])
                    raise StopIteration()

            # Track the duration/frames/etc. constraints.
            self.time_constraint.add(next_cut_or_tpl[0] if isinstance(
                next_cut_or_tpl, tuple) else next_cut_or_tpl)
            next_num_cuts = len(cuts) + 1

            # Did we exceed the max_frames and max_cuts constraints?
            if not self.time_constraint.exceeded() and (
                    self.max_cuts is None or next_num_cuts <= self.max_cuts):
                # No - add the next cut to the batch, and keep trying.
                cuts.append(next_cut_or_tpl)
            else:
                # Yes. Do we have at least one cut in the batch?
                if cuts:
                    # Yes. Return the batch, but keep the currently drawn cut for later.
                    self.reuse_cuts_buffer.append(next_cut_or_tpl)
                    break
                else:
                    # No. We'll warn the user that the constrains might be too tight,
                    # and return the cut anyway.
                    warnings.warn(
                        "The first cut drawn in batch collection violates "
                        "the max_frames, max_cuts, or max_duration constraints - "
                        "we'll return it anyway. "
                        "Consider increasing max_frames/max_cuts/max_duration."
                    )
                    cuts.append(next_cut_or_tpl)

        return detuplify(cuts)