示例#1
0
    def filter_indices_by_size(self, indices, max_sizes):
        """Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        return data_utils.filter_paired_dataset_indices_by_size(
            self.src_sizes, self.tgt_sizes, indices, max_sizes,
        )
示例#2
0
    def filter_indices_by_size(self, indices, max_sizes):
        """Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        sizes = self.sizes
        tgt_sizes = sizes[:, 1] if len(
            sizes.shape) > 0 and sizes.shape[1] > 1 else None
        src_sizes = (sizes[:, 0]
                     if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes)

        return data_utils.filter_paired_dataset_indices_by_size(
            src_sizes, tgt_sizes, indices, max_sizes)