Ejemplo n.º 1
0
    def __init__(
        self,
        dataset,
        batch_size: int,
        batches: Union[AbsSampler, Sequence[Sequence[Any]]],
        chunk_length: Union[int, str],
        chunk_shift_ratio: float = 0.5,
        num_cache_chunks: int = 1024,
        num_samples_per_epoch: int = None,
        seed: int = 0,
        shuffle: bool = False,
        num_workers: int = 0,
        collate_fn=None,
        pin_memory: bool = False,
    ):
        assert check_argument_types()
        assert all(len(x) == 1 for x in batches), "batch-size must be 1"

        self.per_sample_iter_factory = SequenceIterFactory(
            dataset=dataset,
            batches=batches,
            num_iters_per_epoch=num_samples_per_epoch,
            seed=seed,
            shuffle=shuffle,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
        )

        self.num_cache_chunks = max(num_cache_chunks, batch_size)
        if isinstance(chunk_length, str):
            if len(chunk_length) == 0:
                raise ValueError("e.g. 5,8 or 3-5: but got empty string")

            self.chunk_lengths = []
            for x in chunk_length.split(","):
                try:
                    sps = list(map(int, x.split("-")))
                except ValueError:
                    raise ValueError(
                        f"e.g. 5,8 or 3-5: but got {chunk_length}")

                if len(sps) > 2:
                    raise ValueError(
                        f"e.g. 5,8 or 3-5: but got {chunk_length}")
                elif len(sps) == 2:
                    # Append all numbers between the range into the candidates
                    self.chunk_lengths += list(range(sps[0], sps[1] + 1))
                else:
                    self.chunk_lengths += [sps[0]]
        else:
            # Single candidates: Fixed chunk length
            self.chunk_lengths = [chunk_length]

        self.chunk_shift_ratio = chunk_shift_ratio
        self.batch_size = batch_size
        self.seed = seed
        self.shuffle = shuffle
Ejemplo n.º 2
0
def test_SequenceIterFactory_without_num_iters_per_epoch_deterministic(
        collate):
    dataset = Dataset()
    batches = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
    iter_factory = SequenceIterFactory(dataset=dataset,
                                       batches=batches,
                                       shuffle=True,
                                       collate_fn=collate)
    for i in range(1, 10):
        for v, v2 in zip(iter_factory.build_iter(i),
                         iter_factory.build_iter(i)):
            assert (v == v2).all()
Ejemplo n.º 3
0
def test_SequenceIterFactory(collate):
    dataset = Dataset()
    batches = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
    iter_factory = SequenceIterFactory(dataset=dataset,
                                       batches=batches,
                                       num_iters_per_epoch=3,
                                       collate_fn=collate)

    seq = [[list(map(int, it)) for it in iter_factory.build_iter(i)]
           for i in range(1, 5)]
    assert seq == [
        [[0, 1], [2, 3], [4, 5]],
        [[6, 7], [8, 9], [0, 1]],
        [[2, 3], [4, 5], [6, 7]],
        [[8, 9], [0, 1], [2, 3]],
    ]