Пример #1
0
    def _get_outputs(
            self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
        batch_data = next(self._dataset_iter)
        vis_outputs = []

        # Type ignore for issue with passing union to function taking generic
        # https://github.com/python/mypy/issues/1533
        for (
                inputs,
                additional_forward_args,
                label,
        ) in _batched_generator(  # type: ignore
                inputs=batch_data.inputs,
                additional_forward_args=batch_data.additional_args,
                target_ind=batch_data.labels,
                internal_batch_size=
                1,  # should be 1 until we have batch label support
        ):
            output = self._calculate_vis_output(inputs,
                                                additional_forward_args, label)
            if output is not None:
                cache = SampleCache(inputs, additional_forward_args, label)
                vis_outputs.append((output, cache))

        return vis_outputs
Пример #2
0
    def _get_outputs(self) -> List[VisualizationOutput]:
        batch_data = next(self.dataset)
        vis_outputs = []

        for inputs, additional_forward_args, label in _batched_generator(
            inputs=batch_data.inputs,
            additional_forward_args=batch_data.additional_args,
            target_ind=batch_data.labels,
            internal_batch_size=1,  # should be 1 until we have batch label support
        ):
            output = self._calculate_vis_output(inputs, additional_forward_args, label)
            if output is not None:
                cache = SampleCache(inputs, additional_forward_args, label)
                vis_outputs.append((output, cache))

        return vis_outputs
Пример #3
0
    def _get_outputs(
            self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
        # If we run out of new batches, then we need to
        # display data which was already shown before.
        # However, since the dataset given to us is a generator,
        # we can't reset it to return to the beginning.
        # Because of this, we store a small cache of stale
        # data, and iterate on it after the main generator
        # stops returning new batches.
        try:
            batch_data = next(self._dataset_iter)
            self._dataset_cache.append(batch_data)
            if len(self._dataset_cache) > self._config.num_examples:
                self._dataset_cache.pop(0)
        except StopIteration:
            self._dataset_iter = cycle(self._dataset_cache)
            batch_data = next(self._dataset_iter)

        vis_outputs = []

        # Type ignore for issue with passing union to function taking generic
        # https://github.com/python/mypy/issues/1533
        for (
                inputs,
                additional_forward_args,
                label,
        ) in _batched_generator(  # type: ignore
                inputs=batch_data.inputs,
                additional_forward_args=batch_data.additional_args,
                target_ind=batch_data.labels,
                internal_batch_size=
                1,  # should be 1 until we have batch label support
        ):
            output = self._calculate_vis_output(inputs,
                                                additional_forward_args, label)
            if output is not None:
                cache = SampleCache(inputs, additional_forward_args, label)
                vis_outputs.append((output, cache))

        return vis_outputs
Пример #4
0
    def test_batched_generator(self):
        def sample_operator(inputs, additional_forward_args, target_ind,
                            scale):
            return (
                scale * (sum(inputs)),
                scale * sum(additional_forward_args),
                target_ind,
            )

        array1 = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        array2 = [[6, 7, 8], [0, 1, 2], [3, 4, 5]]
        array3 = [[0, 1, 2], [0, 0, 0], [0, 0, 0]]
        inp1, inp2, inp3 = (
            torch.tensor(array1),
            torch.tensor(array2),
            torch.tensor(array3),
        )
        for index, (inp, add, targ) in enumerate(
                _batched_generator((inp1, inp2), (inp3, 5), 7, 1)):
            assertTensorAlmostEqual(self, inp[0], array1[index])
            assertTensorAlmostEqual(self, inp[1], array2[index])
            assertTensorAlmostEqual(self, add[0], array3[index])
            self.assertEqual(add[1], 5)
            self.assertEqual(targ, 7)