示例#1
0
    def test_nle_v3_conversion(self):
        seq_length = 70
        COLUMNS = 120
        converter = Converter(ROWS, COLUMNS, TTYREC_V3)

        chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((seq_length, 2), dtype=np.int16)
        actions = np.zeros((seq_length), dtype=np.uint8)
        timestamps = np.zeros((seq_length, ), dtype=np.int64)
        scores = np.zeros((seq_length), dtype=np.int32)

        converter.load_ttyrec(getfilename(TTYREC_NLE_V3))
        assert (converter.convert(chars, colors, cursors, timestamps, actions,
                                  scores) == 1)

        with open(getfilename(TTYREC_NLE_V3_FRAME_44)) as f:
            for row, line in enumerate(f):
                actual = chars[43][row].tobytes().decode("utf-8").rstrip()
                assert actual == line.rstrip()

        with open(getfilename(TTYREC_NLE_V3_ACTIONS_SCORES)) as f:
            lines = f.readlines()
            assert " ".join("%i" % a for a in actions) == lines[0].rstrip()
            assert " ".join("%i" % s for s in scores) == lines[1].rstrip()
示例#2
0
    def test_ttyrec_with_extra_data(self, seq_length=500):
        converter = Converter(ROWS, COLUMNS)

        chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((seq_length, 2), dtype=np.uint16)
        timestamps = np.zeros((seq_length, ), dtype=np.int64)
        actions = np.zeros((seq_length), dtype=np.uint8)

        converter.load_ttyrec(getfilename(TTYREC_2018))
        remaining = converter.convert(chars, colors, cursors, timestamps,
                                      actions)
        assert remaining == 165
示例#3
0
    def test_data(self):
        converter = Converter(ROWS, COLUMNS, TTYREC_V1)
        assert converter.rows == ROWS
        assert converter.cols == COLUMNS

        chars = np.zeros((SEQ_LENGTH, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((SEQ_LENGTH, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((SEQ_LENGTH, 2), dtype=np.int16)
        timestamps = np.zeros((SEQ_LENGTH, ), dtype=np.int64)
        actions = np.zeros((SEQ_LENGTH), dtype=np.uint8)
        scores = np.zeros((SEQ_LENGTH), dtype=np.int32)

        converter.load_ttyrec(getfilename(TTYREC_2020))
        with open(getfilename(COLSROWS)) as f:
            colsrows = [tuple(int(i) for i in line.split()) for line in f]

        with bz2.BZ2File(getfilename(TIMESTAMPS)) as f:
            saved_timestamps = [float(line) for line in f]

        num_frames = 0
        while True:
            remaining = converter.convert(chars, colors, cursors, timestamps,
                                          actions, scores)
            for (row, col), ts in zip(cursors[:SEQ_LENGTH - remaining],
                                      timestamps[:SEQ_LENGTH - remaining]):
                assert (col, row) == colsrows[num_frames]
                assert pytest.approx(float(ts) /
                                     1e6) == saved_timestamps[num_frames]
                # Cursor col == converter.cols when it is offscreen (ie cropped).
                assert 0 <= col < converter.cols + 1
                assert 0 <= row < converter.rows
                num_frames += 1
            if remaining > 0:
                break

        assert num_frames == len(colsrows)
        final_index = SEQ_LENGTH - remaining - 1
        with open(getfilename(FINALFRAME)) as f:
            for row, line in enumerate(f):

                actual = chars[final_index][row].tobytes().decode(
                    "utf-8").rstrip()
                assert actual == line.rstrip()
        with open(getfilename(FINALFRAMECOLORS)) as f:
            for row, line in enumerate(f):
                actual = ",".join(str(c) for c in colors[final_index][row])
                assert actual == line.rstrip()
示例#4
0
    def test_ibm_graphics(self):
        seq_length = 10
        converter = Converter(ROWS, COLUMNS)

        chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((seq_length, 2), dtype=np.uint16)
        actions = np.zeros((seq_length), dtype=np.uint8)
        timestamps = np.zeros((seq_length, ), dtype=np.int64)

        converter.load_ttyrec(getfilename(TTYREC_IBMGRAPHICS))
        assert converter.convert(chars, colors, cursors, timestamps,
                                 actions) == 0

        with open(getfilename(TTYREC_IBMGRAPHICS_FRAME_10)) as f:
            for row, line in enumerate(f):
                actual = chars[-1][row].tobytes().decode("utf-8").rstrip()
                assert actual == line.rstrip()
示例#5
0
    def test_shiftin_shiftout_graphics(self):
        seq_length = 10
        COLUMNS = 120
        converter = Converter(ROWS, COLUMNS, TTYREC_V1)

        chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((seq_length, 2), dtype=np.int16)
        actions = np.zeros((seq_length), dtype=np.uint8)
        timestamps = np.zeros((seq_length, ), dtype=np.int64)
        scores = np.zeros((seq_length), dtype=np.int32)

        converter.load_ttyrec(getfilename(TTYREC_SHIFTIN))
        assert (converter.convert(chars, colors, cursors, timestamps, actions,
                                  scores) == 0)

        with open(getfilename(TTYREC_SHIFTIN_FRAME_10)) as f:
            for row, line in enumerate(f):
                actual = chars[9][row].tobytes().decode("utf-8").rstrip()
                assert actual == line.rstrip()
示例#6
0
    def test_illegal_buffers(self):
        converter = Converter(ROWS, COLUMNS, TTYREC_V1)
        converter.load_ttyrec(getfilename(TTYREC_2020))

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((10, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((11, 2), dtype=np.int16)
        actions = np.zeros((10), dtype=np.uint8)
        timestamps = np.zeros((10, ), dtype=np.int64)
        scores = np.zeros((10), dtype=np.int32)

        with pytest.raises(
                ValueError,
                match=re.escape(
                    "Array has wrong shape (expected [ 10 2 ], got [ 11 2 ])"),
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((10, ROWS, COLUMNS - 1), dtype=np.int8)
        cursors = np.zeros((10, 2), dtype=np.int16)
        with pytest.raises(
                ValueError,
                match=re.escape(
                    "Array has wrong shape (expected [ 10 25 80 ], got [ 10 25 79 ])"
                ),
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS - 1, COLUMNS), dtype=np.uint8)
        colors = np.zeros((10, ROWS - 1, COLUMNS), dtype=np.int8)
        cursors = np.zeros((10, 2), dtype=np.int16)
        with pytest.raises(
                ValueError,
                match=re.escape(
                    "Array has wrong shape (expected [ 10 25 80 ], got [ 10 24 80 ])"
                ),
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((11, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((11, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((11, 2), dtype=np.int16)
        actions = np.zeros((10), dtype=np.uint8)
        timestamps = np.zeros((10, ), dtype=np.int64)
        with pytest.raises(
                ValueError,
                match=re.escape(
                    "Array has wrong shape (expected [ 11 ], got [ 10 ])"),
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8)
        colors = np.zeros((10, ROWS, COLUMNS), dtype=np.int8)
        cursors = np.zeros((10, 3), dtype=np.int16)
        with pytest.raises(
                ValueError,
                match=re.escape(
                    "Array has wrong shape (expected [ 10 2 ], got [ 10 3 ])"),
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS, 7), dtype=np.uint8)
        cursors = np.zeros((10, 2), dtype=np.int16)
        with pytest.raises(
                ValueError,
                match=
                r"Array has wrong number of dimensions \(expected 3, got 4\)",
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8)
        cursors = np.zeros((10, 2, 1), dtype=np.int16)
        with pytest.raises(
                ValueError,
                match=
                r"Array has wrong number of dimensions \(expected 2, got 3\)",
        ):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.float32)
        cursors = np.zeros((10, 2), dtype=np.uint8)
        with pytest.raises(ValueError, match=r"Buffer dtype mismatch"):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8)
        cursors = np.zeros((10, 2), dtype=np.uint8)
        with pytest.raises(ValueError, match=r"Buffer dtype mismatch"):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = "Hello"
        cursors = np.zeros((10, 2), dtype=np.int16)
        with pytest.raises(ValueError, match=r"Numpy array required"):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.uint(8)
        with pytest.raises(ValueError, match=r"Numpy array required"):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)

        chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8)
        timestamps = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        with pytest.raises(ValueError, match=r"Numpy array required"):
            converter.convert(chars, colors, cursors, timestamps, actions,
                              scores)