Esempio n. 1
0
    def test_can_cast_dtype(self):
        # Ensure that the data loader can only be used once
        def load_data():
            yield {"X": np.ones((1, 1), dtype=np.float32)}
        cache = DataLoaderCache(load_data())

        fp32_meta = TensorMetadata().add("X", dtype=np.float32, shape=(1, 1))
        cache.set_input_metadata(fp32_meta)
        feed_dict = cache[0]
        assert feed_dict["X"].dtype == np.float32

        fp64_meta = TensorMetadata().add("X", dtype=np.float64, shape=(1, 1))
        cache.set_input_metadata(fp64_meta)
        feed_dict = cache[0]
        assert feed_dict["X"].dtype == np.float64
Esempio n. 2
0
    def test_will_not_give_up_on_first_cache_miss(self):
        SHAPE = (32, 32)

        DATA = [OrderedDict()]
        DATA[0]["X"] = np.zeros(SHAPE, dtype=np.int64)
        DATA[0]["Y"] = np.zeros(SHAPE, dtype=np.int64)

        cache = DataLoaderCache(DATA)
        cache.set_input_metadata(TensorMetadata().add("X", np.int64, shape=SHAPE).add("Y", np.int64, SHAPE))

        # Populate the cache with bad X but good Y
        cache.cache[0] = OrderedDict()
        cache.cache[0]["X"] = np.ones((64, 64), dtype=np.int64)
        cache.cache[0]["Y"] = np.ones(SHAPE, dtype=np.int64)

        feed_dict = cache[0]
        # Cache cannot reuse X, so it'll reload - we'll get all 0s from the data loader
        assert np.all(feed_dict["X"] == 0)
        # Cache can reuse Y, even though it's after X, so we'll get ones from the cache
        assert np.all(feed_dict["Y"] == 1)