def test_memory_tracking_performance_impact(self):
     torch.manual_seed(0)
     model = models.resnet18()
     with with_timing("no_tracking"):
         model(torch.randn(size=(1, 3, 224, 224)))
     with with_timing("with_tracking"):
         tracker = LayerwiseMemoryTracker()
         tracker.monitor(model)
         model(torch.randn(size=(1, 3, 224, 224)))
    def test_find_best_reset_points_performance(self):
        """
        Test that the algorithm is O(N**2) complexity for N activations
        """
        import numpy as np

        activations_1000 = list(
            np.random.randint(low=0, high=1_000_000, size=1_000))
        activations_2000 = list(
            np.random.randint(low=0, high=1_000_000, size=2_000))
        nb_checkpoints = 10
        with with_timing(name="best_reset_points_1000") as timer_1000:
            find_best_reset_points(activations_1000,
                                   nb_checkpoints=nb_checkpoints)
        with with_timing(name="best_reset_points_2000") as timer_2000:
            find_best_reset_points(activations_2000,
                                   nb_checkpoints=nb_checkpoints)
        self.assertGreaterEqual(timer_2000.elapsed_time_ms,
                                timer_1000.elapsed_time_ms)
        self.assertLessEqual(timer_2000.elapsed_time_ms,
                             timer_1000.elapsed_time_ms * 6)