def is_ready(bucket: Deque[Cut]): tot = TimeConstraint(max_duration=self.max_duration) for c in bucket: tot.add(c[0] if isinstance(c, tuple) else c) if tot.close_to_exceeding(): return True return False
def __init__( self, source_cuts: CutSet, target_cuts: CutSet, max_source_frames: int = None, max_source_samples: int = None, max_source_duration: Seconds = None, max_target_frames: int = None, max_target_samples: int = None, max_target_duration: int = None, max_cuts: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, ): """ CutPairsSampler's constructor. :param source_cuts: the first ``CutSet`` to sample data from. :param target_cuts: the second ``CutSet`` to sample data from. :param max_source_frames: The maximum total number of feature frames from ``source_cuts``. :param max_source_samples: The maximum total number of audio samples from ``source_cuts``. :param max_source_duration: The maximum total recording duration from ``source_cuts``. :param max_target_frames: The maximum total number of feature frames from ``target_cuts``. :param max_target_samples: The maximum total number of audio samples from ``target_cuts``. :param max_target_duration: The maximum total recording duration from ``target_cuts``. :param max_cuts: The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off. :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, the last batch is dropped if it's incomplete. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( shuffle=shuffle, world_size=world_size, rank=rank, seed=seed, ) self.source_cuts = DataSource(source_cuts) self.target_cuts = DataSource(target_cuts) # Constraints self.source_constraints = TimeConstraint( max_duration=max_source_duration, max_samples=max_source_samples, max_frames=max_source_frames, ) self.target_constraints = TimeConstraint( max_duration=max_target_duration, max_samples=max_target_samples, max_frames=max_target_frames, ) self.max_cuts = max_cuts self.drop_last = drop_last
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. .. caution:: The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere. .. caution:: The input ``state_dict`` is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don't want this behavior, pass a copy inside of this function (e.g., using ``import deepcopy``). .. note:: For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be handled in ``__iter__`` to make it avoid resetting the just-restored state (only once). """ self.drop_last = state_dict.pop("drop_last") source_constraints = TimeConstraint( **state_dict.pop("source_constraints")) if self.source_constraints != source_constraints: warnings.warn( "CutPairsSampler.load_state_dict(): Inconsistent source_constraint:\n" f"expected {self.source_constraints}\n" f"received {source_constraints}\n" f"We will overwrite the settings with the received state_dict." ) self.source_constraints = source_constraints target_constraints = TimeConstraint( **state_dict.pop("target_constraints")) if self.source_constraints != target_constraints: warnings.warn( "CutPairsSampler.load_state_dict(): Inconsistent target_constraint:\n" f"expected {self.target_constraints}\n" f"received {target_constraints}\n" f"We will overwrite the settings with the received state_dict." ) self.target_constraints = target_constraints max_cuts = state_dict.pop("max_cuts") if self.max_cuts != max_cuts: warnings.warn( "CutPairsSampler.load_state_dict(): Inconsistent max_cuts:\n" f"expected {self.max_cuts}\n" f"received {max_cuts}\n" f"We will ignore the received settings.") super().load_state_dict(state_dict) # Restore the data source's state if self.shuffle: self.source_cuts.shuffle(self.seed + self.epoch) self.target_cuts.shuffle(self.seed + self.epoch) self.source_cuts.fast_forward(self.diagnostics.total_cuts) self.target_cuts.fast_forward(self.diagnostics.total_cuts)
def __init__( self, cuts: CutSet, max_frames: int = None, max_samples: int = None, max_duration: Seconds = None, max_cuts: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, strict: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, ): """ SimpleCutSampler's constructor. :param cuts: the ``CutSet`` to sample data from. :param max_frames: The maximum total number of feature frames from ``cuts``. :param max_samples: The maximum total number of audio samples from ``cuts``. :param max_duration: The maximum total recording duration from ``cuts``. :param max_cuts: The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off. :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param strict: When ``True``, for the purposes of determining dynamic batch size, we take the longest cut sampled so far and multiply its duration/num_frames/num_samples by the number of cuts currently in mini-batch to check if it exceeded max_duration/etc. This can help make the GPU memory usage more predictable when there is a large variance in cuts duration. :param drop_last: When ``True``, the last batch is dropped if it's incomplete. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( shuffle=shuffle, world_size=world_size, rank=rank, seed=seed, ) self.data_source = DataSource(cuts) self.time_constraint = TimeConstraint( max_duration=max_duration, max_frames=max_frames, max_samples=max_samples, max_cuts=max_cuts, strict=strict, ) self.drop_last = drop_last
def __init__( self, datapipe: Iterable[Union[Cut, Tuple[Cut]]], max_frames: int = None, max_samples: int = None, max_duration: Seconds = None, max_cuts: Optional[int] = None, drop_last: bool = False, ) -> None: self.datapipe = datapipe self.reuse_cuts_buffer = deque() self.drop_last = drop_last self.max_cuts = max_cuts self.diagnostics = SamplingDiagnostics() self.time_constraint = TimeConstraint(max_duration=max_duration, max_frames=max_frames, max_samples=max_samples)
class CutPairsSampler(CutSampler): """ Samples pairs of cuts from a "source" and "target" CutSet. It expects that both CutSet's strictly consist of Cuts with corresponding IDs. It behaves like an iterable that yields lists of strings (cut IDs). When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration. """ def __init__( self, source_cuts: CutSet, target_cuts: CutSet, max_source_frames: int = None, max_source_samples: int = None, max_source_duration: Seconds = None, max_target_frames: int = None, max_target_samples: int = None, max_target_duration: int = None, max_cuts: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, strict: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, ): """ CutPairsSampler's constructor. :param source_cuts: the first ``CutSet`` to sample data from. :param target_cuts: the second ``CutSet`` to sample data from. :param max_source_frames: The maximum total number of feature frames from ``source_cuts``. :param max_source_samples: The maximum total number of audio samples from ``source_cuts``. :param max_source_duration: The maximum total recording duration from ``source_cuts``. :param max_target_frames: The maximum total number of feature frames from ``target_cuts``. :param max_target_samples: The maximum total number of audio samples from ``target_cuts``. :param max_target_duration: The maximum total recording duration from ``target_cuts``. :param max_cuts: The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off. :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, the last batch is dropped if it's incomplete. :param strict: When ``True``, for the purposes of determining dynamic batch size, we take the longest cut sampled so far and multiply its duration/num_frames/num_samples by the number of cuts currently in mini-batch to check if it exceeded max_duration/etc. This can help make the GPU memory usage more predictable when there is a large variance in cuts duration. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( shuffle=shuffle, world_size=world_size, rank=rank, seed=seed, ) self.source_cuts = DataSource(source_cuts) self.target_cuts = DataSource(target_cuts) # Constraints self.source_constraints = TimeConstraint( max_duration=max_source_duration, max_samples=max_source_samples, max_frames=max_source_frames, max_cuts=max_cuts, strict=strict, ) self.target_constraints = TimeConstraint( max_duration=max_target_duration, max_samples=max_target_samples, max_frames=max_target_frames, max_cuts=max_cuts, strict=strict, ) self.drop_last = drop_last @property def remaining_duration(self) -> Optional[float]: """ Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None). .. note: For :class:`.CutPairsSampler` we return the source cuts duration. """ return self.source_cuts.remaining_duration @property def remaining_cuts(self) -> Optional[int]: """ Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None). """ return self.source_cuts.remaining_cuts @property def num_cuts(self) -> Optional[int]: """ Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None). """ if self.source_cuts.is_lazy: return None return len(self.source_cuts) def state_dict(self) -> Dict[str, Any]: """ Return the current state of the sampler in a state_dict. Together with ``load_state_dict()``, this can be used to restore the training loop's state to the one stored in the state_dict. """ state_dict = super().state_dict() state_dict.update({ "drop_last": self.drop_last, "source_constraints": self.source_constraints.state_dict(), "target_constraints": self.target_constraints.state_dict(), }) return state_dict def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. .. caution:: The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere. .. caution:: The input ``state_dict`` is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don't want this behavior, pass a copy inside of this function (e.g., using ``import deepcopy``). .. note:: For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be handled in ``__iter__`` to make it avoid resetting the just-restored state (only once). """ self.drop_last = state_dict.pop("drop_last") source_constraints = TimeConstraint( **state_dict.pop("source_constraints")) if self.source_constraints != source_constraints: warnings.warn( "CutPairsSampler.load_state_dict(): Inconsistent source_constraint:\n" f"expected {self.source_constraints}\n" f"received {source_constraints}\n" f"We will overwrite the settings with the received state_dict." ) self.source_constraints = source_constraints target_constraints = TimeConstraint( **state_dict.pop("target_constraints")) if self.source_constraints != target_constraints: warnings.warn( "CutPairsSampler.load_state_dict(): Inconsistent target_constraint:\n" f"expected {self.target_constraints}\n" f"received {target_constraints}\n" f"We will overwrite the settings with the received state_dict." ) self.target_constraints = target_constraints super().load_state_dict(state_dict) # Restore the data source's state if self.shuffle: self.source_cuts.shuffle(self.seed + self.epoch) self.target_cuts.shuffle(self.seed + self.epoch) self.source_cuts.fast_forward( self.diagnostics.current_epoch_stats.total_cuts) self.target_cuts.fast_forward( self.diagnostics.current_epoch_stats.total_cuts) def __iter__(self) -> "CutPairsSampler": """ Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. """ # Restored state with load_state_dict()? Skip resetting only this once. if self._just_restored_state: return self # Reset the state to the beginning of the epoch. if self.shuffle: self.source_cuts.shuffle(self.seed + self.epoch) self.target_cuts.shuffle(self.seed + self.epoch) iter(self.source_cuts) iter(self.target_cuts) return self def _next_batch(self) -> Tuple[CutSet, CutSet]: # Keep iterating the underlying CutSets as long as we hit or exceed the constraints # provided by user (the max number of source_feats or max number of cuts). # Note: no actual data is loaded into memory yet because the manifests contain all the metadata # required to do this operation. self.source_constraints.reset() self.target_constraints.reset() source_cuts = [] target_cuts = [] while True: # Check that we have not reached the end of the dataset. try: # We didn't - grab the next cut next_source_cut = next(self.source_cuts) next_target_cut = next(self.target_cuts) assert next_source_cut.id == next_target_cut.id, ( "Sampled source and target cuts with differing IDs. " "Ensure that your source and target cuts have the same length, " "the same IDs, and the same order.") except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if source_cuts and ( not self.drop_last or self.source_constraints.close_to_exceeding() or self.target_constraints.close_to_exceeding()): # We have a partial batch and we can return it. assert len(source_cuts) == len( target_cuts ), "Unexpected state: some cuts in source / target are missing their counterparts..." self.diagnostics.keep(source_cuts) return CutSet.from_cuts(source_cuts), CutSet.from_cuts( target_cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. self.diagnostics.discard(source_cuts) raise StopIteration() # Check whether the cuts we're about to sample satisfy optional user-requested predicate. if not self._filter_fn(next_source_cut) or not self._filter_fn( next_target_cut): # No - try another one. self.diagnostics.discard_single(next_source_cut) continue self.source_constraints.add(next_source_cut) self.target_constraints.add(next_target_cut) # Did we exceed the max_source_frames and max_cuts constraints? if (not self.source_constraints.exceeded() and not self.target_constraints.exceeded()): # No - add the next cut to the batch, and keep trying. source_cuts.append(next_source_cut) target_cuts.append(next_target_cut) else: # Yes. Do we have at least one cut in the batch? if source_cuts: # Yes. Return it. self.source_cuts.take_back(next_source_cut) self.target_cuts.take_back(next_target_cut) break else: # No. We'll warn the user that the constrains might be too tight, # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates one of the max_... constraints" "we'll return it anyway. Consider increasing max_source_frames/max_cuts/etc." ) source_cuts.append(next_source_cut) target_cuts.append(next_target_cut) assert len(source_cuts) == len( target_cuts ), "Unexpected state: some cuts in source / target are missing their counterparts..." self.diagnostics.keep(source_cuts) return CutSet.from_cuts(source_cuts), CutSet.from_cuts(target_cuts)
class DurationBatcher: def __init__( self, datapipe: Iterable[Union[Cut, Tuple[Cut]]], max_frames: int = None, max_samples: int = None, max_duration: Seconds = None, max_cuts: Optional[int] = None, drop_last: bool = False, ) -> None: self.datapipe = datapipe self.reuse_cuts_buffer = deque() self.drop_last = drop_last self.max_cuts = max_cuts self.diagnostics = SamplingDiagnostics() self.time_constraint = TimeConstraint(max_duration=max_duration, max_frames=max_frames, max_samples=max_samples) def __iter__(self) -> Generator[Union[CutSet, Tuple[CutSet]], None, None]: self.cuts_iter = iter(self.datapipe) try: while True: yield self._collect_batch() except StopIteration: pass self.cuts_iter = None def _collect_batch(self) -> Union[CutSet, Tuple[CutSet]]: def detuplify( cuts: List[Union[Cut, Tuple[Cut]]]) -> Union[CutSet, Tuple[CutSet]]: """Helper to do the right thing whether we sampled single cuts or cut tuples.""" if isinstance(cuts[0], tuple): if len(cuts[0]) == 1: cuts = CutSet.from_cuts(cs[0] for cs in cuts) self.diagnostics.keep(cuts) return cuts else: tuple_of_cut_lists = list(zip(*cuts)) self.diagnostics.keep(cuts[0]) return tuple( [CutSet.from_cuts(cs) for cs in tuple_of_cut_lists]) else: self.diagnostics.keep(cuts) return CutSet.from_cuts(cuts) self.time_constraint.reset() cuts = [] while True: # Check that we have not reached the end of the dataset. try: if self.reuse_cuts_buffer: next_cut_or_tpl = self.reuse_cuts_buffer.popleft() else: # If this doesn't raise (typical case), it's not the end: keep processing. next_cut_or_tpl = next(self.cuts_iter) except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if cuts and (not self.drop_last or self.time_constraint.close_to_exceeding()): # We have a partial batch and we can return it. return detuplify(cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. try: self.diagnostics.discard(cuts) except AttributeError: # accounts for cuts being a tuple self.diagnostics.discard(cuts[0]) raise StopIteration() # Track the duration/frames/etc. constraints. self.time_constraint.add(next_cut_or_tpl[0] if isinstance( next_cut_or_tpl, tuple) else next_cut_or_tpl) next_num_cuts = len(cuts) + 1 # Did we exceed the max_frames and max_cuts constraints? if not self.time_constraint.exceeded() and ( self.max_cuts is None or next_num_cuts <= self.max_cuts): # No - add the next cut to the batch, and keep trying. cuts.append(next_cut_or_tpl) else: # Yes. Do we have at least one cut in the batch? if cuts: # Yes. Return the batch, but keep the currently drawn cut for later. self.reuse_cuts_buffer.append(next_cut_or_tpl) break else: # No. We'll warn the user that the constrains might be too tight, # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates " "the max_frames, max_cuts, or max_duration constraints - " "we'll return it anyway. " "Consider increasing max_frames/max_cuts/max_duration." ) cuts.append(next_cut_or_tpl) return detuplify(cuts)
def test_time_constraint_strictness(): normal = TimeConstraint(max_duration=100, strict=False) strict = TimeConstraint(max_duration=100, strict=True) # create cuts with large variance of durations cut_durs = [30.0, 30.0, 10.0, 10.0, 20.0] assert sum(cut_durs) == pytest.approx(100.0) cuts = [dummy_cut(idx, duration=cd) for idx, cd in enumerate(cut_durs)] # accumulate 80s of duration for cut in cuts[:-1]: normal.add(cut) strict.add(cut) assert normal.current == pytest.approx(80) assert strict.current == pytest.approx(80) # non-strict constraint is not close to exceeding (will accept next cut in a batch) # strict constraint is close to exceeding (will not accept next cut in a batch) assert not normal.close_to_exceeding() assert strict.close_to_exceeding() normal.add(cuts[-1]) strict.add(cuts[-1]) assert not normal.exceeded() assert strict.exceeded()
class SingleCutSampler(CutSampler): """ Samples cuts from a CutSet to satisfy the input constraints. It behaves like an iterable that yields lists of strings (cut IDs). When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration. Example usage:: >>> dataset = K2SpeechRecognitionDataset(cuts) >>> sampler = SingleCutSampler(cuts, shuffle=True) >>> loader = DataLoader(dataset, sampler=sampler, batch_size=None) >>> for epoch in range(start_epoch, n_epochs): ... sampler.set_epoch(epoch) ... train(loader) """ def __init__( self, cuts: CutSet, max_frames: int = None, max_samples: int = None, max_duration: Seconds = None, max_cuts: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, ): """ SingleCutSampler's constructor. :param cuts: the ``CutSet`` to sample data from. :param max_frames: The maximum total number of feature frames from ``cuts``. :param max_samples: The maximum total number of audio samples from ``cuts``. :param max_duration: The maximum total recording duration from ``cuts``. :param max_cuts: The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off. :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, the last batch is dropped if it's incomplete. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( shuffle=shuffle, world_size=world_size, rank=rank, seed=seed, ) self.data_source = DataSource(cuts) self.time_constraint = TimeConstraint(max_duration=max_duration, max_frames=max_frames, max_samples=max_samples) self.drop_last = drop_last self.max_cuts = max_cuts assert self.time_constraint.is_active() or not ( self.time_constraint.is_active() and self.max_cuts is not None) # Constraints assert is_none_or_gt(self.max_cuts, 0) @property def remaining_duration(self) -> Optional[float]: """ Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None). """ return self.data_source.remaining_duration @property def remaining_cuts(self) -> Optional[int]: """ Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None). """ return self.data_source.remaining_cuts @property def num_cuts(self) -> Optional[int]: """ Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None). """ if self.data_source.is_lazy: return None return len(self.data_source) def state_dict(self) -> Dict[str, Any]: """ Return the current state of the sampler in a state_dict. Together with ``load_state_dict()``, this can be used to restore the training loop's state to the one stored in the state_dict. """ state_dict = super().state_dict() state_dict.update({ "drop_last": self.drop_last, "time_constraint": self.time_constraint.state_dict(), "max_cuts": self.max_cuts, }) return state_dict def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. .. caution:: The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere. .. caution:: The input ``state_dict`` is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don't want this behavior, pass a copy inside of this function (e.g., using ``import deepcopy``). .. note:: For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be handled in ``__iter__`` to make it avoid resetting the just-restored state (only once). """ self.drop_last = state_dict.pop("drop_last") time_constraint = TimeConstraint(**state_dict.pop("time_constraint")) if self.time_constraint != time_constraint: warnings.warn( "SingleCutSampler.load_state_dict(): Inconsistent time_constraint:\n" f"expected {self.time_constraint}\n" f"received {time_constraint}\n" f"We will overwrite the settings with the received state_dict." ) self.time_constraint = time_constraint max_cuts = state_dict.pop("max_cuts") if self.max_cuts != max_cuts: warnings.warn( "SingleCutSampler.load_state_dict(): Inconsistent max_cuts:\n" f"expected {self.max_cuts}\n" f"received {max_cuts}\n" f"We will overwrite the settings with the received state_dict." ) self.max_cuts = max_cuts super().load_state_dict(state_dict) # Restore the data source's state if self.shuffle: self.data_source.shuffle(self.seed + self.epoch) self.data_source.fast_forward(self.diagnostics.total_cuts) def __iter__(self) -> "SingleCutSampler": """ Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. """ # Restored state with load_state_dict()? Skip resetting only this once. if self._just_restored_state: return self # Reset the state to the beginning of the epoch. if self.shuffle: self.data_source.shuffle(self.seed + self.epoch) iter(self.data_source) self.diagnostics.reset() return self def _next_batch(self) -> CutSet: # Keep iterating the underlying CutSet as long as we hit or exceed the constraints # provided by user (the max number of frames or max number of cuts). # Note: no actual data is loaded into memory yet because the manifests contain all the metadata # required to do this operation. self.time_constraint.reset() cuts = [] while True: # Check that we have not reached the end of the dataset. try: # If this doesn't raise (typical case), it's not the end: keep processing. next_cut = next(self.data_source) except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if cuts and (not self.drop_last or self.time_constraint.close_to_exceeding()): # We have a partial batch and we can return it. self.diagnostics.keep(cuts) return CutSet.from_cuts(cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. self.diagnostics.discard(cuts) raise StopIteration() # Check whether the cut we're about to sample satisfies optional user-requested predicate. if not self._filter_fn(next_cut): # No - try another one. self.diagnostics.discard_single(next_cut) continue # Track the duration/frames/etc. constraints. self.time_constraint.add(next_cut) next_num_cuts = len(cuts) + 1 # Did we exceed the max_frames and max_cuts constraints? if not self.time_constraint.exceeded() and ( self.max_cuts is None or next_num_cuts <= self.max_cuts): # No - add the next cut to the batch, and keep trying. cuts.append(next_cut) else: # Yes. Do we have at least one cut in the batch? if cuts: # Yes. Return the batch, but keep the currently drawn cut for later. self.data_source.take_back(next_cut) break else: # No. We'll warn the user that the constrains might be too tight, # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates " "the max_frames, max_cuts, or max_duration constraints - " "we'll return it anyway. " "Consider increasing max_frames/max_cuts/max_duration." ) cuts.append(next_cut) self.diagnostics.keep(cuts) return CutSet.from_cuts(cuts)