Esempio n. 1
0
    def test_sql(self, db_exists, pool):

        sql = """
            SELECT gameid FROM datasets
            WHERE gameid >=6 AND dataset_name='basictest'
            """
        data = dataset.TtyrecDataset(
            "basictest",
            seq_length=100,
            batch_size=2,
            threadpool=pool,
            shuffle=False,
            subselect_sql=sql,
        )

        data2 = dataset.TtyrecDataset(
            "basictest",
            seq_length=100,
            batch_size=2,
            threadpool=pool,
            gameids=[6, 7],
            shuffle=False,
        )
        data.populate_metadata()
        data2.populate_metadata()
        assert len(data.get_meta(6)) == 27
        assert data.get_meta(6) == data2.get_meta(6)

        for _, (mb1, mb2) in enumerate(zip(data, data2)):
            for k in mb1.keys():
                np.testing.assert_array_equal(mb1[k], mb2[k])
Esempio n. 2
0
    def test_start_and_end(self, db_exists):
        # gameids [1,2,3] are all the same tty_rec, rowid 4 is different
        data_1 = dataset.TtyrecDataset(
            "basictest",
            seq_length=100,
            batch_size=3,
            gameids=range(1, 4),
            shuffle=False,
        )
        data_2 = dataset.TtyrecDataset("basictest",
                                       seq_length=100,
                                       batch_size=1,
                                       gameids=[4, 4, 4],
                                       shuffle=False)

        mb1 = next(iter(data_1))
        mb2 = next(iter(data_2))

        for k in mb1:
            if k == "gameids":
                continue
            array1 = mb1[k]
            array2 = mb2[k]
            # Test first three rows are the same, and differ from from fourth
            np.testing.assert_array_equal(array1[0], array1[1])
            np.testing.assert_array_equal(array1[0], array1[2])
            np.testing.assert_raises(AssertionError,
                                     np.testing.assert_array_equal, array1[0],
                                     array2[0])
Esempio n. 3
0
    def test_get_ttyrec(self, db_exists, pool):
        data = dataset.TtyrecDataset(
            "basictest",
            seq_length=100,
            batch_size=1,
            gameids=[4, 5],
            threadpool=pool,
            shuffle=False,
        )

        mb = next(iter(data))
        reset = np.where(mb["done"][0] == 1)[0][0]
        assert reset == 31

        for chunk_size in [5, 50]:
            chunks = data.get_ttyrec(4, chunk_size=chunk_size)

            groups = {k: [] for k in mb.keys()}
            for c in chunks:
                for k in c:
                    groups[k].append(c[k])
            concat_chunks = {
                k: np.concatenate(t, axis=1)
                for k, t in groups.items()
            }
            for k in mb.keys():
                c = concat_chunks[k]
                m = mb[k][0]
                np.testing.assert_array_equal(c[0, :reset], m[:reset])
                np.testing.assert_equal(c[0, reset:], 0)
Esempio n. 4
0
 def test_dataset_metadata(self, db_exists, pool):
     # Test we can retrieve metadata from database and access it later
     # Everything after the gameid, path should be stored as a list in metadata
     sql1 = """
         SELECT games.gameid FROM games
         INNER JOIN datasets
         ON datasets.gameid = games.gameid
         WHERE datasets.dataset_name='basictest'
         AND games.name="ccc"
         """
     data1 = dataset.TtyrecDataset(
         "basictest",
         seq_length=100,
         batch_size=1,
         threadpool=pool,
         subselect_sql=sql1,
     )
     data1.populate_metadata()
     gameids = next(iter(data1))["gameids"]
     for _ in [7, 8, 9]:
         rowid = gameids[0][0]
         # These are SQLiteRows, which can be accessed
         # lists or dictionaries.
         assert data1.get_meta(rowid)[17] == "ascended"
         assert data1.get_meta(rowid)["death"] == "ascended"
         assert data1.get_meta(rowid)[2] == 999
         assert data1.get_meta(rowid)["points"] == 999
Esempio n. 5
0
    def test_char_frame(self, db_exists, pool):
        # gameids [1,2,3] are all the same tty_rec, rowid 4 is different
        gameids = [1, 2, 3, 4, 1, 2, 3, 4]
        seq_length = 20
        data = dataset.TtyrecDataset(
            "basictest",
            seq_length=seq_length,
            rows=25,
            batch_size=3,
            gameids=gameids,
            threadpool=pool,
            shuffle=False,
        )

        with open(getfilename(FINALFRAME), "r") as f:
            lines = [line.rstrip() for line in f.readlines()]

        with open(getfilename(FINALFRAMECOLORS), "r") as f:
            colorlines = [line.rstrip().split(",") for line in f.readlines()]

        with open(getfilename(COLSROWS)) as f:
            colsrows = [tuple(int(i) for i in line.split()) for line in f]

        with bz2.BZ2File(getfilename(TIMESTAMPS)) as f:
            saved_timestamps = [float(line) * 1e6 for line in f]

        i = 0
        for mb in data:
            chars, colors, curs, ts, done, _ = mb.values()
            batch_ids, seq_ids = np.where(done == 1)
            seq = seq_length if not seq_ids.any() else seq_ids[0]
            for j in range(seq):
                np.testing.assert_array_equal(colsrows[i][0], curs[:, j, 1])
                np.testing.assert_array_equal(colsrows[i][1], curs[:, j, 0])
                np.testing.assert_array_almost_equal(saved_timestamps[i],
                                                     ts[:, j],
                                                     decimal=0)

                i += 1
            if seq_ids.any():
                np.testing.assert_array_equal(batch_ids, np.arange(3))
                np.testing.assert_array_equal(seq_ids, seq_ids[0])

                final_frame = chars[0][seq_ids[0] - 1]
                char_frame = [
                    row.tobytes().decode("utf-8").rstrip()
                    for row in final_frame
                ]
                color_frame = colors[0][seq_ids[0] - 1]

                for actual, expected in zip(char_frame, lines):
                    assert actual == expected
                for actual, expected in zip(color_frame, colorlines):
                    for col_a, col_e in zip(actual, expected):
                        assert col_a == int(col_e)
                break
            if i > 2431:
                raise AssertionError  # this should have ended!
Esempio n. 6
0
    def test_dataset_gameids(self, db_exists, pool):
        gameids = np.random.choice(7, 7, replace=False) + 1
        data = dataset.TtyrecDataset("basictest",
                                     batch_size=7,
                                     threadpool=pool,
                                     gameids=gameids)
        mb = next(iter(data))
        files = mb["gameids"]
        assert len(mb) == 6

        np.testing.assert_array_equal(np.unique(files[:, 0]), np.arange(1, 8))
Esempio n. 7
0
    def test_multipart_game(self, db_exists, pool):
        # This test selects a multipart game

        # Use sql select all of user "ccc" ttyrecs (3 ttyrecs, gameid 7) as one game
        # eg: [(7, 0, /path/to/A), (7, 1, /path/to/B), (7, 2, /path/to/C)]
        sql1 = "SELECT gameid FROM games WHERE games.name='ccc'"
        data1 = dataset.TtyrecDataset(
            "basictest",
            seq_length=100,
            batch_size=1,
            threadpool=pool,
            subselect_sql=sql1,
        )

        # Hack data2 to treat of user "ccc" ttyrecs (7, 8, 9) as separate games
        # eg: [(0, 7, /path/to/A), (1, 7, /path/to/B), (2, 7, /path/to/C)]
        data2 = dataset.TtyrecDataset(
            "basictest",
            seq_length=100,
            batch_size=1,
            threadpool=pool,
            subselect_sql=sql1,
        )
        ccc_gameid = data2._gameids[0]
        data2._games[1000] = [data2._games[ccc_gameid][0]]
        data2._games[2000] = [data2._games[ccc_gameid][1]]
        data2._games[3000] = [data2._games[ccc_gameid][2]]
        data2._gameids = [1000, 2000, 3000]

        for _, (mb1, mb2) in enumerate(zip(data1, data2)):
            for k in mb1.keys():
                if k == "gameids":
                    assert np.all(np.isin(mb1[k], [0, 7]))
                elif k == "done":
                    np.testing.assert_equal(mb1[k], np.zeros_like(mb1[k]))
                else:
                    np.testing.assert_array_equal(mb1[k], mb2[k])
Esempio n. 8
0
 def test_no_looping(self, db_exists, pool, batch_size):
     # We expect as many done as (len(gameids) - batchsize)
     data = dataset.TtyrecDataset(
         "basictest",
         seq_length=100,
         batch_size=batch_size,
         threadpool=pool,
         gameids=range(1, 8),
     )
     done = 0
     for mb in data:
         done += np.sum(mb["done"] == 1)
         if done > 7 - batch_size:
             break
     assert done == 7 - batch_size
Esempio n. 9
0
    def test_shuffle(self, db_exists, batch_size):
        data = dataset.TtyrecDataset(
            "basictest",
            seq_length=1000,
            batch_size=batch_size,
            threadpool=None,
            gameids=(1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1),
            shuffle=False,
        )

        def get_data():
            """Get all the data from minibatches and concat into full data batches"""
            mbs = []
            gameids = []
            for mb in data:
                mbs.append([t.copy() for t in mb.values()])
                gameids.append(mb["gameids"].copy())
            return [
                np.concatenate(array_list, axis=1) for array_list in zip(*mbs)
            ], np.concatenate(gameids)

        b1, g1 = get_data()
        b2, g2 = get_data()
        for a1, a2 in zip(b1, b2):
            np.testing.assert_array_equal(a1, a2)
        np.testing.assert_array_equal(g1, g2)

        data.shuffle = True
        b3, g3 = get_data()
        b4, g4 = get_data()
        # Note: the last tensor in the batch is the gameid
        np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
                                 g3, g2)
        np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
                                 g3, g4)

        data.shuffle = False
        b5, g5 = get_data()
        for a1, a5 in zip(b1, b5):
            np.testing.assert_array_equal(a1, a5)
        np.testing.assert_array_equal(g1, g5)
Esempio n. 10
0
    def test_minibatches(self, db_exists, pool):
        data = dataset.TtyrecDataset(
            "basictest",
            seq_length=50,
            batch_size=4,
            threadpool=pool,
            gameids=range(1, 8),
            shuffle=False,
        )
        # starting gameids = [TTYREC, TTYREC, TTYREC, TTYREC2]

        mb = next(iter(data))
        for name, array in mb.items():
            if name in ("gameids", ):
                continue
            # Test first three rows are the same, and differ from from fourth
            np.testing.assert_array_equal(array[0], array[1])
            np.testing.assert_array_equal(array[0], array[2])
            np.testing.assert_raises(AssertionError,
                                     np.testing.assert_array_equal, array[0],
                                     array[3])

        # Check reseting occured
        reset = np.where(mb["done"][3] == 1)[0][0]
        assert reset == 31

        # Check the data at location is the same. Note reset occurs for batch 4
        seq = 10
        for name, array in mb.items():
            if name in ("done", "gameids"):
                continue
            np.testing.assert_array_equal(array[3][:seq],
                                          array[3][reset:reset + seq])

        # No leading 1s
        assert (mb["done"][:, 0] == 0).all()