class TestAttributeMonitorHierarchical(unittest.TestCase):
    def setUp(self):
        self.n = 60
        self.names = ["obj.y"]
        self.rng = np.random.default_rng(0)
        self.y = self.rng.normal(size=self.n)

        self.monitor = AttributeMonitor(self.names)
        self.monitor.setup(self.n)

    def test_record_works_correctly(self):
        for i in range(self.n):
            tracked = SimpleNamespace(obj=SimpleNamespace(y=self.y[i]))
            self.monitor.record(tracked)

        self.assertTrue(hasattr(self.monitor.history_, "obj"))
        self.assertTrue(hasattr(self.monitor.history_.obj, "y"))
        np.testing.assert_allclose(self.monitor.history_.obj.y, self.y)

    def test_record_batch_works_correctly(self):
        chunk_size = 13
        for i in range(0, self.n, chunk_size):
            tracked = SimpleNamespace(obj=SimpleNamespace(
                y=self.y[i:i + chunk_size]))
            self.monitor.record_batch(tracked)

        self.assertTrue(hasattr(self.monitor.history_, "obj"))
        self.assertTrue(hasattr(self.monitor.history_.obj, "y"))
        np.testing.assert_allclose(self.monitor.history_.obj.y, self.y)
    def test_raises_value_error_if_not_all_variables_are_the_same_length(self):
        monitor_alt = AttributeMonitor(["x", "y"])
        monitor_alt.setup(3)

        obj = SimpleNamespace(x=np.zeros(3), y=[1])
        with self.assertRaises(ValueError):
            monitor_alt.record_batch(obj)
    def test_monitor_nothing(self):
        monitor = AttributeMonitor([])
        monitor.setup(5)

        obj = SimpleNamespace()
        monitor.record_batch(obj)

        self.assertEqual(len(monitor.history_.__dict__), 0)
    def test_one_by_one_same_as_batch_when_step_is_one(self):
        monitor_alt = AttributeMonitor(self.names)
        monitor_alt.setup(self.n)

        rng = np.random.default_rng(0)
        v = rng.normal(size=self.n)

        for i in range(self.n):
            obj = SimpleNamespace(x=v[i])
            self.monitor.record(obj)

        i = 0
        while i < self.n:
            k = min(rng.integers(1, 10), self.n - i)
            crt_v = v[i:i + k]
            obj_batch = SimpleNamespace(x=crt_v)
            monitor_alt.record_batch(obj_batch)

            i += k

        np.testing.assert_allclose(self.monitor.history_.x,
                                   monitor_alt.history_.x)
class TestAttributeMonitorRecordBatch(unittest.TestCase):
    def setUp(self):
        self.n = 60
        self.names = ["x"]
        self.monitor = AttributeMonitor(self.names)
        self.monitor.setup(self.n)

    def test_raises_value_error_if_not_all_variables_are_the_same_length(self):
        monitor_alt = AttributeMonitor(["x", "y"])
        monitor_alt.setup(3)

        obj = SimpleNamespace(x=np.zeros(3), y=[1])
        with self.assertRaises(ValueError):
            monitor_alt.record_batch(obj)

    def test_one_by_one_same_as_batch_when_step_is_one(self):
        monitor_alt = AttributeMonitor(self.names)
        monitor_alt.setup(self.n)

        rng = np.random.default_rng(0)
        v = rng.normal(size=self.n)

        for i in range(self.n):
            obj = SimpleNamespace(x=v[i])
            self.monitor.record(obj)

        i = 0
        while i < self.n:
            k = min(rng.integers(1, 10), self.n - i)
            crt_v = v[i:i + k]
            obj_batch = SimpleNamespace(x=crt_v)
            monitor_alt.record_batch(obj_batch)

            i += k

        np.testing.assert_allclose(self.monitor.history_.x,
                                   monitor_alt.history_.x)

    def test_one_by_one_same_as_batch_when_step_is_not_one(self):
        step = 7
        self.monitor.step = step
        monitor_alt = AttributeMonitor(self.names, step=step)
        monitor_alt.setup(self.n)

        rng = np.random.default_rng(1)
        v = rng.normal(size=self.n)

        for i in range(self.n):
            obj = SimpleNamespace(x=v[i])
            self.monitor.record(obj)

        i = 0
        while i < self.n:
            k = min(rng.integers(1, 10), self.n - i)
            crt_v = v[i:i + k]
            obj_batch = SimpleNamespace(x=crt_v)
            monitor_alt.record_batch(obj_batch)

            i += k

        np.testing.assert_allclose(self.monitor.history_.x,
                                   monitor_alt.history_.x)

    def test_store_copies_objects(self):
        lst = [1, 2, 3]
        obj = SimpleNamespace(x=[lst])
        self.monitor.record_batch(obj)

        lst0 = copy.copy(lst)
        lst[2] = -1

        np.testing.assert_allclose(self.monitor.history_.x[0], lst0)

    def test_monitor_nothing(self):
        monitor = AttributeMonitor([])
        monitor.setup(5)

        obj = SimpleNamespace()
        monitor.record_batch(obj)

        self.assertEqual(len(monitor.history_.__dict__), 0)