def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, fixed_shapes=None, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be less than N or a multiple of N (default: 1). fixed_shapes (List[Tuple[int, int]], optional): if given, batches will only be created with the given shapes. *max_sentences* and *required_batch_size_multiple* will be ignored (default: None). """ try: from fairseq.data.data_utils_fast import ( batch_by_size_fast, batch_fixed_shapes_fast, ) except ImportError: raise ImportError( 'Please build Cython components with: `pip install --editable .` ' 'or `python setup.py build_ext --inplace`') max_tokens = max_tokens if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) if fixed_shapes is None: return batch_by_size_fast( indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult, ) else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort([ fixed_shapes[:, 1].argsort(), # length fixed_shapes[:, 0].argsort(), # bsz ]) fixed_shapes_sorted = fixed_shapes[sort_order] return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def batch_by_size_utils( self, indices, num_tokens_fn, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, fixed_shapes=None, ): try: from fairseq.data.data_utils_fast import ( batch_by_size_fast, batch_fixed_shapes_fast, ) except ImportError: raise ImportError( "Please build Cython components with: `pip install --editable .` " "or `python setup.py build_ext --inplace`") max_tokens = max_tokens if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) logger.info("enter batch by size utils") num_tokens_vec = [] for i in range(indices.size): num_tokens_vec.append(num_tokens_fn(i)) if fixed_shapes is None: return self.batch_by_size_utils_baseline_triplet( indices, num_tokens_vec, max_tokens, max_sentences, bsz_mult, ) else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort([ fixed_shapes[:, 1].argsort(), # length fixed_shapes[:, 0].argsort(), # bsz ]) fixed_shapes_sorted = fixed_shapes[sort_order] return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def batch_by_size( indices, num_tokens_fn, num_tokens_vec=None, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, fixed_shapes=None, ): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index num_tokens_vec (List[int], optional): precomputed vector of the number of tokens for each index in indices (to enable faster batch generation) max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be less than N or a multiple of N (default: 1). fixed_shapes (List[Tuple[int, int]], optional): if given, batches will only be created with the given shapes. *max_sentences* and *required_batch_size_multiple* will be ignored (default: None). """ try: from fairseq.data.data_utils_fast import ( batch_by_size_fn, batch_by_size_vec, batch_fixed_shapes_fast, ) except ImportError: raise ImportError("Please build Cython components with: " "`python setup.py build_ext --inplace`") except ValueError: raise ValueError( "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`." ) # added int() to avoid TypeError: an integer is required max_tokens = (int(max_tokens) if max_tokens is not None else -1) max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray): num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1) if fixed_shapes is None: if num_tokens_vec is None: return batch_by_size_fn( indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult, ) else: return batch_by_size_vec( indices, num_tokens_vec, max_tokens, max_sentences, bsz_mult, ) else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort([ fixed_shapes[:, 1].argsort(), # length fixed_shapes[:, 0].argsort(), # bsz ]) fixed_shapes_sorted = fixed_shapes[sort_order] return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)