def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): for key in columns: if key in self.data: arr = self.data[key] if is_compressed(arr): self.data[key] = unpack(arr) elif len(arr) > 0 and is_compressed(arr[0]): self.data[key] = np.array( [unpack(o) for o in self.data[key]])
def _decompress_in_place(path, value): if path[0] not in columns: return curr = self for p in path[:-1]: curr = curr[p] # Bulk compressed. if is_compressed(value): curr[path[-1]] = unpack(value) # Non bulk compressed. elif len(value) > 0 and is_compressed(value[0]): curr[path[-1]] = np.array([unpack(o) for o in value])
def decompress_if_needed( self, columns: Set[str] = frozenset(["obs", "new_obs"])) -> "SampleBatch": """Decompresses data buffers (per column if not compressed) in place. Args: columns (Set[str]): The columns to decompress. Default: Only decompress the obs and new_obs columns. Returns: SampleBatch: This very SampleBatch. """ for key in columns: if key in self.keys(): arr = self[key] if is_compressed(arr): self[key] = unpack(arr) elif len(arr) > 0 and is_compressed(arr[0]): self[key] = np.array([unpack(o) for o in self[key]]) return self
def test_compression(self): """Tests, whether compression and decompression work properly.""" s1 = SampleBatch({ "a": np.array([1, 2, 3, 2, 3, 4]), "b": { "c": np.array([4, 5, 6, 5, 6, 7]) }, }) # Test, whether compressing happens in-place. s1.compress(columns={"a", "b"}, bulk=True) self.assertTrue(is_compressed(s1["a"])) self.assertTrue(is_compressed(s1["b"]["c"])) self.assertTrue(isinstance(s1["b"], dict)) # Test, whether de-compressing happens in-place. s1.decompress_if_needed(columns={"a", "b"}) check(s1["a"], [1, 2, 3, 2, 3, 4]) check(s1["b"]["c"], [4, 5, 6, 5, 6, 7]) it = s1.rows() next(it) check(next(it), {"a": 2, "b": {"c": 5}})