Example #1
0
    def test_reader_with_limit(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)

        """ 1. feed full dataset """
        src_init = core.Net('src_init')
        src_values = Struct(('label', np.array(range(100))))
        src_blobs = NewRecord(src_init, src_values)
        src_ds = Dataset(src_blobs)
        FeedRecord(src_blobs, src_values, ws)
        ws.run(src_init)

        """ 2. Read with limit smaller than size of dataset """
        dst_init = core.Net('dst_init')
        dst_ds = Dataset(src_values.clone_schema())
        dst_ds.init_empty(dst_init)
        ws.run(dst_init)

        with TaskGroup() as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
            pipe(reader, dst_ds.writer(), num_threads=8)
        session.run(tg)
        self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10))

        """ 3. Read with limit larger than size of dataset """
        ws.run(dst_init)
        with TaskGroup() as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
            pipe(reader, dst_ds.writer(), num_threads=8)
        session.run(tg)
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100))
        self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
Example #2
0
    def test_composite_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        num_srcs = 3
        names = ["src_{}".format(i) for i in range(num_srcs)]
        size = 100
        offsets = [i * size for i in range(num_srcs)]
        src_dses = [
            make_source_dataset(ws, offset=offset, size=size, name=name)
            for (name, offset) in zip(names, offsets)
        ]

        data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
        # Sanity check we didn't overwrite anything
        for d, offset in zip(data, offsets):
            npt.assert_array_equal(d, range(offset, offset + size))

        # Make an identically-sized empty destnation dataset
        dst_ds_schema = schema.Struct(
            *[(name, src_ds.content().clone_schema())
              for name, src_ds in zip(names, src_dses)])
        dst_ds = make_destination_dataset(ws, dst_ds_schema)

        with TaskGroup() as tg:
            reader = CompositeReader(names,
                                     [src_ds.reader() for src_ds in src_dses])
            pipe(reader, dst_ds.writer(), num_runtime_threads=3)
        session.run(tg)

        for i in range(num_srcs):
            written_data = sorted(
                ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
            npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
Example #3
0
    def test_composite_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        num_srcs = 3
        names = ["src_{}".format(i) for i in range(num_srcs)]
        size = 100
        offsets = [i * size for i in range(num_srcs)]
        src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
                    for (name, offset) in zip(names, offsets)]

        data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
        # Sanity check we didn't overwrite anything
        for d, offset in zip(data, offsets):
            npt.assert_array_equal(d, range(offset, offset + size))

        # Make an identically-sized empty destnation dataset
        dst_ds_schema = schema.Struct(
            *[
                (name, src_ds.content().clone_schema())
                for name, src_ds in zip(names, src_dses)
            ]
        )
        dst_ds = make_destination_dataset(ws, dst_ds_schema)

        with TaskGroup() as tg:
            reader = CompositeReader(names,
                                     [src_ds.reader() for src_ds in src_dses])
            pipe(reader, dst_ds.writer(), num_runtime_threads=3)
        session.run(tg)

        for i in range(num_srcs):
            written_data = sorted(
                ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
            npt.assert_array_equal(data[i], written_data, "i: {}".format(i))
Example #4
0
    def test_reader_with_limit(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        """ 1. feed full dataset """
        src_init = core.Net('src_init')
        src_values = Struct(('label', np.array(range(100))))
        src_blobs = NewRecord(src_init, src_values)
        src_ds = Dataset(src_blobs)
        FeedRecord(src_blobs, src_values, ws)
        ws.run(src_init)
        """ 2. Read with limit smaller than size of dataset """
        dst_init = core.Net('dst_init')
        dst_ds = Dataset(src_values.clone_schema())
        dst_ds.init_empty(dst_init)
        ws.run(dst_init)

        with TaskGroup() as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
            pipe(reader, dst_ds.writer(), num_threads=8)
        session.run(tg)
        self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10))
        """ 3. Read with limit larger than size of dataset """
        ws.run(dst_init)
        with TaskGroup() as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
            pipe(reader, dst_ds.writer(), num_threads=8)
        session.run(tg)
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
            range(100))
        self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
Example #5
0
    def test_composite_reader_builder(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        num_srcs = 3
        names = ["src_{}".format(i) for i in range(num_srcs)]
        size = 100
        offsets = [i * size for i in range(num_srcs)]
        src_ds_builders = [
            TestReaderBuilder(offset=offset, size=size, name=name)
            for (name, offset) in zip(names, offsets)
        ]

        # Make an identically-sized empty destnation dataset
        dst_ds_schema = schema.Struct(
            *[(name, src_ds_builder.schema())
              for name, src_ds_builder in zip(names, src_ds_builders)])
        dst_ds = make_destination_dataset(ws, dst_ds_schema)

        with TaskGroup() as tg:
            reader_builder = CompositeReaderBuilder(names, src_ds_builders)
            reader_builder.setup(ws=ws)
            pipe(reader_builder.new_reader(),
                 dst_ds.writer(),
                 num_runtime_threads=3)
        session.run(tg)

        for name, offset in zip(names, offsets):
            written_data = sorted(
                ws.fetch_blob(str(dst_ds.content()[name].label())))
            npt.assert_array_equal(range(offset, offset + size), written_data,
                                   "name: {}".format(name))
Example #6
0
    def test_composite_reader_builder(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        num_srcs = 3
        names = ["src_{}".format(i) for i in range(num_srcs)]
        size = 100
        offsets = [i * size for i in range(num_srcs)]
        src_ds_builders = [
            TestReaderBuilder(offset=offset, size=size, name=name)
            for (name, offset) in zip(names, offsets)
        ]

        # Make an identically-sized empty destnation dataset
        dst_ds_schema = schema.Struct(
            *[
                (name, src_ds_builder.schema())
                for name, src_ds_builder in zip(names, src_ds_builders)
            ]
        )
        dst_ds = make_destination_dataset(ws, dst_ds_schema)

        with TaskGroup() as tg:
            reader_builder = CompositeReaderBuilder(
                names, src_ds_builders)
            reader_builder.setup(ws=ws)
            pipe(reader_builder.new_reader(), dst_ds.writer(),
                 num_runtime_threads=3)
        session.run(tg)

        for name, offset in zip(names, offsets):
            written_data = sorted(
                ws.fetch_blob(str(dst_ds.content()[name].label())))
            npt.assert_array_equal(range(offset, offset + size), written_data,
                                   "name: {}".format(name))
Example #7
0
    def test_dequeue_many(self):
        init_net = core.Net('init')
        N = 17
        NUM_DEQUEUE_RECORDS = 3
        src_values = Struct(
            ('uid', np.array(range(N))),
            ('value', 0.1 * np.array(range(N))))
        expected_dst = Struct(
            ('uid', 2 * np.array(range(N))),
            ('value', np.array(N * [0.0])))

        with core.NameScope('init'):
            src_blobs = NewRecord(init_net, src_values)
            dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
            counter = init_net.Const(0)
            ONE = init_net.Const(1)

        def proc1(rec):
            with core.NameScope('proc1'):
                out = NewRecord(ops, rec)
            ops.Add([rec.uid(), rec.uid()], [out.uid()])
            out.value.set(blob=rec.value(), unsafe=True)
            return out

        def proc2(rec):
            with core.NameScope('proc2'):
                out = NewRecord(ops, rec)
            out.uid.set(blob=rec.uid(), unsafe=True)
            ops.Sub([rec.value(), rec.value()], [out.value()])
            ops.Add([counter, ONE], [counter])
            return out

        src_ds = Dataset(src_blobs)
        dst_ds = Dataset(dst_blobs)

        with TaskGroup() as tg:
            out1 = pipe(
                src_ds.reader(),
                output=Queue(
                    capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS),
                processor=proc1)
            out2 = pipe(out1, processor=proc2)
            pipe(out2, dst_ds.writer())

        ws = workspace.C.Workspace()
        FeedRecord(src_blobs, src_values, ws)
        session = LocalSession(ws)
        session.run(init_net)
        session.run(tg)
        output = FetchRecord(dst_blobs, ws=ws)
        num_dequeues = ws.blobs[str(counter)].fetch()

        self.assertEquals(
            num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS)))

        for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
            np.testing.assert_array_equal(a, b)
    def test_dequeue_many(self):
        init_net = core.Net('init')
        N = 17
        NUM_DEQUEUE_RECORDS = 3
        src_values = Struct(
            ('uid', np.array(range(N))),
            ('value', 0.1 * np.array(range(N))))
        expected_dst = Struct(
            ('uid', 2 * np.array(range(N))),
            ('value', np.array(N * [0.0])))

        with core.NameScope('init'):
            src_blobs = NewRecord(init_net, src_values)
            dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())
            counter = init_net.Const(0)
            ONE = init_net.Const(1)

        def proc1(rec):
            with core.NameScope('proc1'):
                out = NewRecord(ops, rec)
            ops.Add([rec.uid(), rec.uid()], [out.uid()])
            out.value.set(blob=rec.value(), unsafe=True)
            return out

        def proc2(rec):
            with core.NameScope('proc2'):
                out = NewRecord(ops, rec)
            out.uid.set(blob=rec.uid(), unsafe=True)
            ops.Sub([rec.value(), rec.value()], [out.value()])
            ops.Add([counter, ONE], [counter])
            return out

        src_ds = Dataset(src_blobs)
        dst_ds = Dataset(dst_blobs)

        with TaskGroup() as tg:
            out1 = pipe(
                src_ds.reader(),
                output=Queue(
                    capacity=11, num_dequeue_records=NUM_DEQUEUE_RECORDS),
                processor=proc1)
            out2 = pipe(out1, processor=proc2)
            pipe(out2, dst_ds.writer())

        ws = workspace.C.Workspace()
        FeedRecord(src_blobs, src_values, ws)
        session = LocalSession(ws)
        session.run(init_net)
        session.run(tg)
        output = FetchRecord(dst_blobs, ws=ws)
        num_dequeues = ws.blobs[str(counter)].fetch()

        self.assertEquals(
            num_dequeues, int(math.ceil(float(N) / NUM_DEQUEUE_RECORDS)))

        for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
            np.testing.assert_array_equal(a, b)
Example #9
0
    def test_runtime_threads(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        src_ds = init_dataset(ws)
        totals = [None] * 3

        def proc(rec):
            # executed once
            with ops.task_init():
                counter1 = ops.CreateCounter([], ['global_counter'])
                counter2 = ops.CreateCounter([], ['global_counter2'])
                counter3 = ops.CreateCounter([], ['global_counter3'])
            # executed once per thread
            with ops.task_instance_init():
                task_counter = ops.CreateCounter([], ['task_counter'])
            # executed on each iteration
            ops.CountUp(counter1)
            ops.CountUp(task_counter)
            # executed once per thread
            with ops.task_instance_exit():
                with ops.loop(ops.RetrieveCount(task_counter)):
                    ops.CountUp(counter2)
                ops.CountUp(counter3)
            # executed once
            with ops.task_exit():
                totals[0] = final_output(ops.RetrieveCount(counter1))
                totals[1] = final_output(ops.RetrieveCount(counter2))
                totals[2] = final_output(ops.RetrieveCount(counter3))
            return rec

        # Read full data set from original reader
        with TaskGroup() as tg:
            pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
        session.run(tg)
        self.assertEqual(totals[0].fetch(), 100)
        self.assertEqual(totals[1].fetch(), 100)
        self.assertEqual(totals[2].fetch(), 8)

        # Read with a count-limited reader
        with TaskGroup() as tg:
            q1 = pipe(src_ds.reader(), num_runtime_threads=2)
            q2 = pipe(
                ReaderWithLimit(q1.reader(), num_iter=25),
                num_runtime_threads=3)
            pipe(q2, processor=proc, num_runtime_threads=6)
        session.run(tg)
        self.assertEqual(totals[0].fetch(), 25)
        self.assertEqual(totals[1].fetch(), 25)
        self.assertEqual(totals[2].fetch(), 6)
Example #10
0
    def test_runtime_threads(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        src_ds = init_dataset(ws)
        totals = [None] * 3

        def proc(rec):
            # executed once
            with ops.task_init():
                counter1 = ops.CreateCounter([], ['global_counter'])
                counter2 = ops.CreateCounter([], ['global_counter2'])
                counter3 = ops.CreateCounter([], ['global_counter3'])
            # executed once per thread
            with ops.task_instance_init():
                task_counter = ops.CreateCounter([], ['task_counter'])
            # executed on each iteration
            ops.CountUp(counter1)
            ops.CountUp(task_counter)
            # executed once per thread
            with ops.task_instance_exit():
                with ops.loop(ops.RetrieveCount(task_counter)):
                    ops.CountUp(counter2)
                ops.CountUp(counter3)
            # executed once
            with ops.task_exit():
                totals[0] = final_output(ops.RetrieveCount(counter1))
                totals[1] = final_output(ops.RetrieveCount(counter2))
                totals[2] = final_output(ops.RetrieveCount(counter3))
            return rec

        """ 1. Feed full dataset """
        with TaskGroup() as tg:
            pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
        session.run(tg)
        self.assertEquals(totals[0].fetch(), 100)
        self.assertEquals(totals[1].fetch(), 100)
        self.assertEquals(totals[2].fetch(), 8)

        """ 2. Add a few steps in between """
        with TaskGroup() as tg:
            q1 = pipe(src_ds.reader(), num_runtime_threads=2)
            q2 = pipe(
                ReaderWithLimit(q1.reader(), num_iter=25),
                num_runtime_threads=3)
            pipe(q2, processor=proc, num_runtime_threads=6)
        session.run(tg)
        self.assertEquals(totals[0].fetch(), 25)
        self.assertEquals(totals[1].fetch(), 25)
        self.assertEquals(totals[2].fetch(), 6)
Example #11
0
    def test_reader_with_limit(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)

        """ 1. feed full dataset """
        src_ds = init_dataset(ws)

        """ 2. Read with limit smaller than size of dataset """
        dst_init = core.Net('dst_init')
        with core.NameScope('dst'):
            dst_ds = Dataset(src_ds.content().clone_schema())
            dst_ds.init_empty(dst_init)
        ws.run(dst_init)

        # WorkspaceType.GLOBAL is required because we are fetching
        # reader.data_finished() after the TaskGroup finishes.
        with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
            pipe(reader, dst_ds.writer(), num_threads=8)
        session.run(tg)

        self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
            list(range(10))
        )

        """ 3. Read with limit larger than size of dataset """
        ws.run(dst_init)
        with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
            pipe(reader, dst_ds.writer(), num_runtime_threads=8)
        session.run(tg)
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
            list(range(100))
        )
        self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())

        """ 4. Read without counter """
        ws.run(dst_init)
        with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
            reader = ReaderWithLimit(src_ds.reader(), num_iter=None)
            pipe(reader, dst_ds.writer(), num_threads=8)
        session.run(tg)
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
            list(range(100))
        )
        self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())

        """ 5. Read using the same reader without resetting workspace """
        session.run(tg)
        self.assertEquals(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
            sorted(list(range(100)) * 2)
        )
Example #12
0
    def test_local_session(self):
        init_net = core.Net('init')
        src_values = Struct(
            ('uid', np.array([1, 2, 6])),
            ('value', np.array([1.4, 1.6, 1.7])))
        expected_dst = Struct(
            ('uid', np.array([2, 4, 12])),
            ('value', np.array([0.0, 0.0, 0.0])))

        with core.NameScope('init'):
            src_blobs = NewRecord(init_net, src_values)
            dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())

        def proc1(rec):
            net = core.Net('proc1')
            with core.NameScope('proc1'):
                out = NewRecord(net, rec)
            net.Add([rec.uid(), rec.uid()], [out.uid()])
            out.value.set(blob=rec.value(), unsafe=True)
            return [net], out

        def proc2(rec):
            net = core.Net('proc2')
            with core.NameScope('proc2'):
                out = NewRecord(net, rec)
            out.uid.set(blob=rec.uid(), unsafe=True)
            net.Sub([rec.value(), rec.value()], [out.value()])
            return [net], out

        src_ds = Dataset(src_blobs)
        dst_ds = Dataset(dst_blobs)

        with TaskGroup() as tg:
            out1 = pipe(src_ds.reader(), processor=proc1)
            out2 = pipe(out1, processor=proc2)
            pipe(out2, dst_ds.writer())

        ws = workspace.C.Workspace()
        FeedRecord(src_blobs, src_values, ws)
        session = LocalSession(ws)
        session.run(init_net)
        session.run(tg)
        output = FetchRecord(dst_blobs, ws=ws)

        for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
            np.testing.assert_array_equal(a, b)
Example #13
0
    def test_local_session(self):
        init_net = core.Net('init')
        src_values = Struct(
            ('uid', np.array([1, 2, 6])),
            ('value', np.array([1.4, 1.6, 1.7])))
        expected_dst = Struct(
            ('uid', np.array([2, 4, 12])),
            ('value', np.array([0.0, 0.0, 0.0])))

        with core.NameScope('init'):
            src_blobs = NewRecord(init_net, src_values)
            dst_blobs = InitEmptyRecord(init_net, src_values.clone_schema())

        def proc1(rec):
            net = core.Net('proc1')
            with core.NameScope('proc1'):
                out = NewRecord(net, rec)
            net.Add([rec.uid(), rec.uid()], [out.uid()])
            out.value.set(blob=rec.value(), unsafe=True)
            return [net], out

        def proc2(rec):
            net = core.Net('proc2')
            with core.NameScope('proc2'):
                out = NewRecord(net, rec)
            out.uid.set(blob=rec.uid(), unsafe=True)
            net.Sub([rec.value(), rec.value()], [out.value()])
            return [net], out

        src_ds = Dataset(src_blobs)
        dst_ds = Dataset(dst_blobs)

        with TaskGroup() as tg:
            out1 = pipe(src_ds.reader(), processor=proc1)
            out2 = pipe(out1, processor=proc2)
            pipe(out2, dst_ds.writer())

        ws = workspace.C.Workspace()
        FeedRecord(src_blobs, src_values, ws)
        session = LocalSession(ws)
        session.run(init_net)
        session.run(tg)
        output = FetchRecord(dst_blobs, ws=ws)

        for a, b in zip(output.field_blobs(), expected_dst.field_blobs()):
            np.testing.assert_array_equal(a, b)
Example #14
0
    def test_db_file_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        db_path = self._make_temp_path()

        # Build a cache DB file.
        cached_reader = CachedReader(
            self._build_source_reader(ws, 100),
            db_path=db_path,
            db_type='LevelDB',
        )
        build_cache_step = cached_reader.build_cache_step()
        session.run(build_cache_step)

        # Read data from cache DB file.
        db_file_reader = DBFileReader(
            db_path=db_path,
            db_type='LevelDB',
        )
        data = self._read_all_data(ws, db_file_reader, session)
        self.assertEqual(sorted(data), list(range(100)))

        self._delete_path(db_path)
Example #15
0
    def test_db_file_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        db_path = self._make_temp_path()

        # Build a cache DB file.
        cached_reader = CachedReader(
            self._build_source_reader(ws, 100),
            db_path=db_path,
            db_type='LevelDB',
        )
        build_cache_step = cached_reader.build_cache_step()
        session.run(build_cache_step)

        # Read data from cache DB file.
        workspace.ResetWorkspace()
        db_file_reader = DBFileReader(
            db_path=db_path,
            db_type='LevelDB',
        )
        data = self._read_all_data(ws, db_file_reader, session)
        self.assertEqual(sorted(data), list(range(100)))

        self._delete_path(db_path)
Example #16
0
    def test_cached_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)

        def build_source_reader(size):
            src_ds = init_dataset(ws, size)
            return src_ds.reader()

        with tempfile.NamedTemporaryFile(delete=False) as f:
            path = f.name
            f.close()
            os.remove(path)

            """ 1. Read data for the first time. """
            cached_reader1 = CachedReader(build_source_reader(100))
            init_step = cached_reader1.build_cache(path)
            session.run(init_step)

            data = read_all_data(ws, cached_reader1, session)
            self.assertEqual(sorted(data), list(range(100)))

            """ 2. Read data from cache. """
            workspace.ResetWorkspace()
            cached_reader2 = CachedReader(build_source_reader(200))
            init_step = cached_reader2.build_cache(path)
            session.run(init_step)

            data = read_all_data(ws, cached_reader2, session)
            self.assertEqual(sorted(data), list(range(100)))

            shutil.rmtree(path)

            """ 3. We removed cache so we expect to receive data from original
            reader. """
            workspace.ResetWorkspace()
            cached_reader3 = CachedReader(build_source_reader(300))
            init_step = cached_reader3.build_cache(path)
            session.run(init_step)

            data = read_all_data(ws, cached_reader3, session)
            self.assertEqual(sorted(data), list(range(300)))

            shutil.rmtree(path)
Example #17
0
    def test_cached_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)

        def build_source_reader(size):
            src_ds = make_source_dataset(ws, size)
            return src_ds.reader()

        # Make a temp file path as cache_path
        with tempfile.NamedTemporaryFile(delete=False) as f:
            cache_path = f.name
            f.close()
            os.remove(cache_path)

        # Read data for the first time.
        cached_reader1 = CachedReader(build_source_reader(100))
        init_step = cached_reader1.build_cache(cache_path)
        session.run(init_step)

        data = read_all_data(ws, cached_reader1, session)
        self.assertEqual(sorted(data), list(range(100)))

        # Read data from cache.
        workspace.ResetWorkspace()
        cached_reader2 = CachedReader(build_source_reader(200))
        init_step = cached_reader2.build_cache(cache_path)
        session.run(init_step)

        data = read_all_data(ws, cached_reader2, session)
        self.assertEqual(sorted(data), list(range(100)))

        shutil.rmtree(cache_path)

        # We removed cache so we expect to receive data from original reader
        workspace.ResetWorkspace()
        cached_reader3 = CachedReader(build_source_reader(300))
        init_step = cached_reader3.build_cache(cache_path)
        session.run(init_step)

        data = read_all_data(ws, cached_reader3, session)
        self.assertEqual(sorted(data), list(range(300)))

        shutil.rmtree(cache_path)
Example #18
0
    def test_cached_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        db_path = self._make_temp_path()

        # Read data for the first time.
        cached_reader1 = CachedReader(
            self._build_source_reader(ws, 100),
            db_path,
        )
        build_cache_step = cached_reader1.build_cache_step()
        session.run(build_cache_step)

        data = self._read_all_data(ws, cached_reader1, session)
        self.assertEqual(sorted(data), list(range(100)))

        # Read data from cache.
        workspace.ResetWorkspace()
        cached_reader2 = CachedReader(
            self._build_source_reader(ws, 200),
            db_path,
        )
        build_cache_step = cached_reader2.build_cache_step()
        session.run(build_cache_step)

        data = self._read_all_data(ws, cached_reader2, session)
        self.assertEqual(sorted(data), list(range(100)))

        self._delete_path(db_path)

        # We removed cache so we expect to receive data from original reader.
        workspace.ResetWorkspace()
        cached_reader3 = CachedReader(
            self._build_source_reader(ws, 300),
            db_path,
        )
        build_cache_step = cached_reader3.build_cache_step()
        session.run(build_cache_step)

        data = self._read_all_data(ws, cached_reader3, session)
        self.assertEqual(sorted(data), list(range(300)))

        self._delete_path(db_path)
Example #19
0
    def test_cached_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        db_path = self._make_temp_path()

        # Read data for the first time.
        cached_reader1 = CachedReader(
            self._build_source_reader(ws, 100), db_path,
        )
        build_cache_step = cached_reader1.build_cache_step()
        session.run(build_cache_step)

        data = self._read_all_data(ws, cached_reader1, session)
        self.assertEqual(sorted(data), list(range(100)))

        # Read data from cache.
        workspace.ResetWorkspace()
        cached_reader2 = CachedReader(
            self._build_source_reader(ws, 200), db_path,
        )
        build_cache_step = cached_reader2.build_cache_step()
        session.run(build_cache_step)

        data = self._read_all_data(ws, cached_reader2, session)
        self.assertEqual(sorted(data), list(range(100)))

        self._delete_path(db_path)

        # We removed cache so we expect to receive data from original reader.
        workspace.ResetWorkspace()
        cached_reader3 = CachedReader(
            self._build_source_reader(ws, 300), db_path,
        )
        build_cache_step = cached_reader3.build_cache_step()
        session.run(build_cache_step)

        data = self._read_all_data(ws, cached_reader3, session)
        self.assertEqual(sorted(data), list(range(300)))

        self._delete_path(db_path)