Пример #1
0
    def as_generator(self, shuffle=False, n_workers=0):
        """Return a generator that yields the entire dataset once

        This method is intended to act as a lightweight wrapper around the
        torch.utils.data.DataLoader class, which has built-in shuffling of the
        data without loading it all into memory. This method purposely removes
        the added batch dimension from DataLoader such that each element
        yielded is still a single sample, just as if it came from indexing into
        this class, e.g. AugmentedDataset[10].

        :param shuffle: if True, shuffle the data before returning it
        :type shuffle: bool
        :param n_workers: number of subprocesses to use for data loading
        :type n_workers: int
        :return: generator that yields the entire dataset once
        :rtype: generator
        """

        data_loader = DataLoader(
            dataset=self, shuffle=shuffle, num_workers=n_workers
        )
        for sample in cycle(data_loader):
            sample_batch_dim_removed = {}
            for key, val in sample.items():
                sample_batch_dim_removed[key] = val[0]
            yield sample_batch_dim_removed
Пример #2
0
    def train(self,
              network,
              train_dataset,
              n_steps_per_epoch,
              validation_dataset=None,
              n_validation_steps=None,
              metrics=None,
              callbacks=None):
        """Train the network as specified via the __init__ parameters

        :param network: network object to use for training
        :type network: networks.alexnet_pytorch.AlexNet
        :param train_dataset: dataset that iterates over the training data
         indefinitely
        :type train_data: torch.utils.data.DataLoader
        :param n_steps_per_epoch: number of batches to train on in one epoch
        :type n_steps_per_epoch: int
        :param validation_dataset: optional dataset that iterates over the
         validation data indefinitely
        :type validation_dataset: torch.utils.data.DataLoader
        :param n_validation_steps: number of batches to validate on after each
         epoch
        :type n_validation_steps: int
        :param metrics: metrics to be evaluated by the model during training
         and testing
        :type metrics: list[object]
        :param callbacks: callbacks to be used during training
        :type callbacks: list[object]
        """

        self.optimizer = self._init_optimizer(network)

        model = Model(network, n_outputs=network.n_outputs, gpu_id=self.gpu_id)
        model.compile(optimizer=self.optimizer,
                      loss=self.loss,
                      metrics=metrics,
                      loss_weights=self.loss_weights)

        if validation_dataset:
            validation_dataset = cycle(validation_dataset)

        model.fit_generator(generator=cycle(train_dataset),
                            n_steps_per_epoch=n_steps_per_epoch,
                            n_epochs=self.n_epochs,
                            validation_data=validation_dataset,
                            n_validation_steps=n_validation_steps,
                            callbacks=callbacks)
Пример #3
0
    def test_cycle__bad_iterable(self):
        """Test `cycle` when the iterable doesn't implement `__iter__`"""

        def iterable():
            for element in range(5):
                yield element

        with pytest.raises(AttributeError):
            cycle_iter = cycle(iterable)
            for _ in range(5):
                next(cycle_iter)
Пример #4
0
    def test_cycle__deterministic(self):
        """Test `cycle` when the iterable yields elements deterministically"""

        class DeterministicIter(object):

            def __iter__(self):
                for element in range(5):
                    yield element

        cycle_iter = cycle(DeterministicIter())
        for idx_element in range(10):
            expected_element = idx_element % 5
            element = next(cycle_iter)
            assert element == expected_element
Пример #5
0
    def test_cycle__nondeterministic(self):
        """Test `cycle` when the iterable yields elements non-deterministically

        This compares the output of the `cycle` function with that of the
        `itertools.cycle` function. The former should return different sets of
        elements when cycling over the iterable twice, whereas the latter
        should return the same sets of elements, since it caches the results of
        the returned elements during the first pass.
        """

        class NonDeterministicIter(object):

            def __init__(self, seeds):
                self.cycle = 0
                self.seeds = seeds

            def __iter__(self):

                np.random.seed(self.seeds[self.cycle])
                for element in np.random.randint(0, 1000, size=5):
                    yield element
                self.cycle += 1

        itertools_cycle_iter = itertools_cycle(
            NonDeterministicIter(seeds=[1226, 42])
        )
        non_itertools_cycle_iter = cycle(
            NonDeterministicIter(seeds=[1226, 42])
        )

        itertools_batch1 = [next(itertools_cycle_iter) for _ in range(5)]
        itertools_batch2 = [next(itertools_cycle_iter) for _ in range(5)]

        non_itertools_batch1 = [
            next(non_itertools_cycle_iter) for _ in range(5)
        ]
        non_itertools_batch2 = [
            next(non_itertools_cycle_iter) for _ in range(5)
        ]

        assert np.array_equal(itertools_batch1, itertools_batch2)
        assert np.array_equal(itertools_batch1, non_itertools_batch1)
        assert not np.array_equal(non_itertools_batch1, non_itertools_batch2)
Пример #6
0
        def mock_call(shuffle=False, n_workers=1):
            """Mock __call__ magic method"""

            for element in cycle(np.arange(4, dtype='float32')):
                yield {'element': element, 'label': 1}