Пример #1
0
 def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
     for key in columns:
         if key in self.data:
             arr = self.data[key]
             if is_compressed(arr):
                 self.data[key] = unpack(arr)
             elif len(arr) > 0 and is_compressed(arr[0]):
                 self.data[key] = np.array(
                     [unpack(o) for o in self.data[key]])
Пример #2
0
 def _decompress_in_place(path, value):
     if path[0] not in columns:
         return
     curr = self
     for p in path[:-1]:
         curr = curr[p]
     # Bulk compressed.
     if is_compressed(value):
         curr[path[-1]] = unpack(value)
     # Non bulk compressed.
     elif len(value) > 0 and is_compressed(value[0]):
         curr[path[-1]] = np.array([unpack(o) for o in value])
Пример #3
0
    def decompress_if_needed(
        self, columns: Set[str] = frozenset(["obs",
                                             "new_obs"])) -> "SampleBatch":
        """Decompresses data buffers (per column if not compressed) in place.

        Args:
            columns (Set[str]): The columns to decompress. Default: Only
                decompress the obs and new_obs columns.

        Returns:
            SampleBatch: This very SampleBatch.
        """
        for key in columns:
            if key in self.keys():
                arr = self[key]
                if is_compressed(arr):
                    self[key] = unpack(arr)
                elif len(arr) > 0 and is_compressed(arr[0]):
                    self[key] = np.array([unpack(o) for o in self[key]])
        return self
Пример #4
0
    def test_compression(self):
        """Tests, whether compression and decompression work properly."""
        s1 = SampleBatch({
            "a": np.array([1, 2, 3, 2, 3, 4]),
            "b": {
                "c": np.array([4, 5, 6, 5, 6, 7])
            },
        })
        # Test, whether compressing happens in-place.
        s1.compress(columns={"a", "b"}, bulk=True)
        self.assertTrue(is_compressed(s1["a"]))
        self.assertTrue(is_compressed(s1["b"]["c"]))
        self.assertTrue(isinstance(s1["b"], dict))

        # Test, whether de-compressing happens in-place.
        s1.decompress_if_needed(columns={"a", "b"})
        check(s1["a"], [1, 2, 3, 2, 3, 4])
        check(s1["b"]["c"], [4, 5, 6, 5, 6, 7])
        it = s1.rows()
        next(it)
        check(next(it), {"a": 2, "b": {"c": 5}})