def test_nle_v3_conversion(self): seq_length = 70 COLUMNS = 120 converter = Converter(ROWS, COLUMNS, TTYREC_V3) chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((seq_length, 2), dtype=np.int16) actions = np.zeros((seq_length), dtype=np.uint8) timestamps = np.zeros((seq_length, ), dtype=np.int64) scores = np.zeros((seq_length), dtype=np.int32) converter.load_ttyrec(getfilename(TTYREC_NLE_V3)) assert (converter.convert(chars, colors, cursors, timestamps, actions, scores) == 1) with open(getfilename(TTYREC_NLE_V3_FRAME_44)) as f: for row, line in enumerate(f): actual = chars[43][row].tobytes().decode("utf-8").rstrip() assert actual == line.rstrip() with open(getfilename(TTYREC_NLE_V3_ACTIONS_SCORES)) as f: lines = f.readlines() assert " ".join("%i" % a for a in actions) == lines[0].rstrip() assert " ".join("%i" % s for s in scores) == lines[1].rstrip()
def test_ttyrec_with_extra_data(self, seq_length=500): converter = Converter(ROWS, COLUMNS) chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((seq_length, 2), dtype=np.uint16) timestamps = np.zeros((seq_length, ), dtype=np.int64) actions = np.zeros((seq_length), dtype=np.uint8) converter.load_ttyrec(getfilename(TTYREC_2018)) remaining = converter.convert(chars, colors, cursors, timestamps, actions) assert remaining == 165
def test_data(self): converter = Converter(ROWS, COLUMNS, TTYREC_V1) assert converter.rows == ROWS assert converter.cols == COLUMNS chars = np.zeros((SEQ_LENGTH, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((SEQ_LENGTH, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((SEQ_LENGTH, 2), dtype=np.int16) timestamps = np.zeros((SEQ_LENGTH, ), dtype=np.int64) actions = np.zeros((SEQ_LENGTH), dtype=np.uint8) scores = np.zeros((SEQ_LENGTH), dtype=np.int32) converter.load_ttyrec(getfilename(TTYREC_2020)) with open(getfilename(COLSROWS)) as f: colsrows = [tuple(int(i) for i in line.split()) for line in f] with bz2.BZ2File(getfilename(TIMESTAMPS)) as f: saved_timestamps = [float(line) for line in f] num_frames = 0 while True: remaining = converter.convert(chars, colors, cursors, timestamps, actions, scores) for (row, col), ts in zip(cursors[:SEQ_LENGTH - remaining], timestamps[:SEQ_LENGTH - remaining]): assert (col, row) == colsrows[num_frames] assert pytest.approx(float(ts) / 1e6) == saved_timestamps[num_frames] # Cursor col == converter.cols when it is offscreen (ie cropped). assert 0 <= col < converter.cols + 1 assert 0 <= row < converter.rows num_frames += 1 if remaining > 0: break assert num_frames == len(colsrows) final_index = SEQ_LENGTH - remaining - 1 with open(getfilename(FINALFRAME)) as f: for row, line in enumerate(f): actual = chars[final_index][row].tobytes().decode( "utf-8").rstrip() assert actual == line.rstrip() with open(getfilename(FINALFRAMECOLORS)) as f: for row, line in enumerate(f): actual = ",".join(str(c) for c in colors[final_index][row]) assert actual == line.rstrip()
def test_ibm_graphics(self): seq_length = 10 converter = Converter(ROWS, COLUMNS) chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((seq_length, 2), dtype=np.uint16) actions = np.zeros((seq_length), dtype=np.uint8) timestamps = np.zeros((seq_length, ), dtype=np.int64) converter.load_ttyrec(getfilename(TTYREC_IBMGRAPHICS)) assert converter.convert(chars, colors, cursors, timestamps, actions) == 0 with open(getfilename(TTYREC_IBMGRAPHICS_FRAME_10)) as f: for row, line in enumerate(f): actual = chars[-1][row].tobytes().decode("utf-8").rstrip() assert actual == line.rstrip()
def test_shiftin_shiftout_graphics(self): seq_length = 10 COLUMNS = 120 converter = Converter(ROWS, COLUMNS, TTYREC_V1) chars = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((seq_length, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((seq_length, 2), dtype=np.int16) actions = np.zeros((seq_length), dtype=np.uint8) timestamps = np.zeros((seq_length, ), dtype=np.int64) scores = np.zeros((seq_length), dtype=np.int32) converter.load_ttyrec(getfilename(TTYREC_SHIFTIN)) assert (converter.convert(chars, colors, cursors, timestamps, actions, scores) == 0) with open(getfilename(TTYREC_SHIFTIN_FRAME_10)) as f: for row, line in enumerate(f): actual = chars[9][row].tobytes().decode("utf-8").rstrip() assert actual == line.rstrip()
def test_illegal_buffers(self): converter = Converter(ROWS, COLUMNS, TTYREC_V1) converter.load_ttyrec(getfilename(TTYREC_2020)) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((10, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((11, 2), dtype=np.int16) actions = np.zeros((10), dtype=np.uint8) timestamps = np.zeros((10, ), dtype=np.int64) scores = np.zeros((10), dtype=np.int32) with pytest.raises( ValueError, match=re.escape( "Array has wrong shape (expected [ 10 2 ], got [ 11 2 ])"), ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((10, ROWS, COLUMNS - 1), dtype=np.int8) cursors = np.zeros((10, 2), dtype=np.int16) with pytest.raises( ValueError, match=re.escape( "Array has wrong shape (expected [ 10 25 80 ], got [ 10 25 79 ])" ), ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS - 1, COLUMNS), dtype=np.uint8) colors = np.zeros((10, ROWS - 1, COLUMNS), dtype=np.int8) cursors = np.zeros((10, 2), dtype=np.int16) with pytest.raises( ValueError, match=re.escape( "Array has wrong shape (expected [ 10 25 80 ], got [ 10 24 80 ])" ), ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((11, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((11, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((11, 2), dtype=np.int16) actions = np.zeros((10), dtype=np.uint8) timestamps = np.zeros((10, ), dtype=np.int64) with pytest.raises( ValueError, match=re.escape( "Array has wrong shape (expected [ 11 ], got [ 10 ])"), ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) colors = np.zeros((10, ROWS, COLUMNS), dtype=np.int8) cursors = np.zeros((10, 3), dtype=np.int16) with pytest.raises( ValueError, match=re.escape( "Array has wrong shape (expected [ 10 2 ], got [ 10 3 ])"), ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS, 7), dtype=np.uint8) cursors = np.zeros((10, 2), dtype=np.int16) with pytest.raises( ValueError, match= r"Array has wrong number of dimensions \(expected 3, got 4\)", ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) cursors = np.zeros((10, 2, 1), dtype=np.int16) with pytest.raises( ValueError, match= r"Array has wrong number of dimensions \(expected 2, got 3\)", ): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.float32) cursors = np.zeros((10, 2), dtype=np.uint8) with pytest.raises(ValueError, match=r"Buffer dtype mismatch"): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) cursors = np.zeros((10, 2), dtype=np.uint8) with pytest.raises(ValueError, match=r"Buffer dtype mismatch"): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = "Hello" cursors = np.zeros((10, 2), dtype=np.int16) with pytest.raises(ValueError, match=r"Numpy array required"): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.uint(8) with pytest.raises(ValueError, match=r"Numpy array required"): converter.convert(chars, colors, cursors, timestamps, actions, scores) chars = np.zeros((10, ROWS, COLUMNS), dtype=np.uint8) timestamps = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] with pytest.raises(ValueError, match=r"Numpy array required"): converter.convert(chars, colors, cursors, timestamps, actions, scores)