def test_UnsortedBatchSampler_repr(shape_files, drop_last): sampler = UnsortedBatchSampler(2, key_file=shape_files[0], drop_last=drop_last) print(sampler)
def test_UnsortedBatchSampler_len(shape_files, drop_last): sampler = UnsortedBatchSampler(2, key_file=shape_files[0], drop_last=drop_last) len(sampler)
def build_batch_sampler( type: str, batch_size: int, batch_bins: int, shape_files: Union[Tuple[str, ...], List[str]], sort_in_batch: str = "descending", sort_batch: str = "ascending", drop_last: bool = False, min_batch_size: int = 1, fold_lengths: Sequence[int] = (), padding: bool = True, utt2category_file: str = None, ) -> AbsSampler: """Helper function to instantiate BatchSampler. Args: type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length" batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode batch_bins: Used for "numel" model shape_files: Text files describing the length and dimension of each features. e.g. uttA 1330,80 sort_in_batch: sort_batch: drop_last: min_batch_size: Used for "numel" or "folded" mode fold_lengths: Used for "folded" mode padding: Whether sequences are input as a padded tensor or not. used for "numel" mode """ assert check_argument_types() if len(shape_files) == 0: raise ValueError("No shape file are given") if type == "unsorted": retval = UnsortedBatchSampler(batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last) elif type == "sorted": retval = SortedBatchSampler( batch_size=batch_size, shape_file=shape_files[0], sort_in_batch=sort_in_batch, sort_batch=sort_batch, drop_last=drop_last, ) elif type == "folded": if len(fold_lengths) != len(shape_files): raise ValueError(f"The number of fold_lengths must be equal to " f"the number of shape_files: " f"{len(fold_lengths)} != {len(shape_files)}") retval = FoldedBatchSampler( batch_size=batch_size, shape_files=shape_files, fold_lengths=fold_lengths, sort_in_batch=sort_in_batch, sort_batch=sort_batch, drop_last=drop_last, min_batch_size=min_batch_size, utt2category_file=utt2category_file, ) elif type == "numel": retval = NumElementsBatchSampler( batch_bins=batch_bins, shape_files=shape_files, sort_in_batch=sort_in_batch, sort_batch=sort_batch, drop_last=drop_last, padding=padding, min_batch_size=min_batch_size, ) elif type == "length": retval = LengthBatchSampler( batch_bins=batch_bins, shape_files=shape_files, sort_in_batch=sort_in_batch, sort_batch=sort_batch, drop_last=drop_last, padding=padding, min_batch_size=min_batch_size, ) else: raise ValueError(f"Not supported: {type}") assert check_return_type(retval) return retval