Ejemplo n.º 1
0
 def build_collate_fn(
     cls, args: argparse.Namespace, train: bool
 ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
         List[str], Dict[str, torch.Tensor]], ]:
     assert check_argument_types()
     # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
     return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
Ejemplo n.º 2
0
    def build_collate_fn(
        cls, args: argparse.Namespace, train: bool
    ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
            List[str], Dict[str, torch.Tensor]], ]:
        assert check_argument_types()

        return CommonCollateFn(float_pad_value=0.0, int_pad_value=0)
Ejemplo n.º 3
0
def test_CommonCollateFn_repr(float_pad_value, int_pad_value, not_sequence):
    print(
        CommonCollateFn(
            float_pad_value=float_pad_value,
            int_pad_value=int_pad_value,
            not_sequence=not_sequence,
        ))
Ejemplo n.º 4
0
def test_ChunkIterFactory():
    dataset = Dataset()
    collatefn = CommonCollateFn()
    batches = [["a"], ["b"]]
    iter_factory = ChunkIterFactory(
        dataset=dataset,
        batches=batches,
        batch_size=2,
        chunk_length=3,
        collate_fn=collatefn,
    )

    for key, batch in iter_factory.build_iter(0):
        for k, v in batch.items():
            assert v.shape == (2, 3)
Ejemplo n.º 5
0
    def build_collate_fn(
        cls, args: argparse.Namespace, train: bool
    ) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
            List[str], Dict[str, torch.Tensor]], ]:
        """Build collate function.

        Args:
            cls: ASRTransducerTask object.
            args: Task arguments.
            train: Training mode.

        Return:
            : Callable collate function.

        """
        assert check_argument_types()

        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
Ejemplo n.º 6
0
def test_(float_pad_value, int_pad_value, not_sequence):
    _common_collate_fn = CommonCollateFn(
        float_pad_value=float_pad_value,
        int_pad_value=int_pad_value,
        not_sequence=not_sequence,
    )
    data = [
        ("id",
         dict(a=np.random.randn(3, 5), b=np.random.randn(4).astype(np.long))),
        ("id2",
         dict(a=np.random.randn(2, 5), b=np.random.randn(3).astype(np.long))),
    ]
    t = _common_collate_fn(data)

    desired = dict(
        a=np.stack([
            data[0][1]["a"],
            np.pad(
                data[1][1]["a"],
                [(0, 1), (0, 0)],
                mode="constant",
                constant_values=float_pad_value,
            ),
        ]),
        b=np.stack([
            data[0][1]["b"],
            np.pad(
                data[1][1]["b"],
                [(0, 1)],
                mode="constant",
                constant_values=int_pad_value,
            ),
        ]),
        a_lengths=np.array([3, 2], dtype=np.long),
        b_lengths=np.array([4, 3], dtype=np.long),
    )

    np.testing.assert_array_equal(t[1]["a"], desired["a"])
    np.testing.assert_array_equal(t[1]["b"], desired["b"])

    if "a" not in not_sequence:
        np.testing.assert_array_equal(t[1]["a_lengths"], desired["a_lengths"])
    if "b" not in not_sequence:
        np.testing.assert_array_equal(t[1]["b_lengths"], desired["b_lengths"])
Ejemplo n.º 7
0
 def build_collate_fn(cls, args, train):
     return CommonCollateFn()