def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ max_tokens = max_tokens if max_tokens is not None else sys.maxsize max_sentences = max_sentences if max_sentences is not None else sys.maxsize bsz_mult = required_batch_size_multiple if isinstance(indices, types.GeneratorType): indices = np.fromiter(indices, dtype=np.int64, count=-1) return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)
def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, fixed_shapes=None, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be less than N or a multiple of N (default: 1). fixed_shapes (List[Tuple[int, int]], optional): if given, batches will only be created with the given shapes. *max_sentences* and *required_batch_size_multiple* will be ignored (default: None). """ try: from fairseq.data.data_utils_fast import ( batch_by_size_fast, batch_fixed_shapes_fast, ) except ImportError: raise ImportError( 'Please build Cython components with: `pip install --editable .` ' 'or `python setup.py build_ext --inplace`') max_tokens = max_tokens if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) if fixed_shapes is None: return batch_by_size_fast( indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult, ) else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort([ fixed_shapes[:, 1].argsort(), # length fixed_shapes[:, 0].argsort(), # bsz ]) fixed_shapes_sorted = fixed_shapes[sort_order] return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def batch_by_size_dep( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ try: from fairseq.data.data_utils_fast import batch_by_size_fast except ImportError: raise ImportError( 'Please build Cython components with: `pip install --editable .` ' 'or `python setup.py build_ext --inplace`' ) max_tokens = max_tokens if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple if isinstance(indices, types.GeneratorType): indices = np.fromiter(indices, dtype=np.int64, count=-1) return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)