Beispiel #1
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()

        self.cgroups_enabled = CGroupConfigurator.get_instance().enabled()
        CGroupConfigurator.get_instance().disable()
    def test_it_should_clear_queue_before_stopping(self):
        events = [TelemetryEvent(eventId=ustr(uuid.uuid4())), TelemetryEvent(eventId=ustr(uuid.uuid4()))]
        wait_time = timedelta(seconds=10)

        with patch("time.sleep", lambda *_: mock_sleep(0.01)):
            with patch("azurelinuxagent.ga.send_telemetry_events.SendTelemetryEventsHandler._MIN_BATCH_WAIT_TIME", wait_time):
                with self._create_send_telemetry_events_handler(batching_queue_limit=5) as telemetry_handler:
                    for test_event in events:
                        telemetry_handler.enqueue_event(test_event)

                    self.assertEqual(0, len(telemetry_handler.event_calls), "No events should have been logged")
                    TestSendTelemetryEventsHandler._stop_handler(telemetry_handler, timeout=0.01)
                    # After the service is asked to stop, we should send all data in the queue
                    self._assert_test_data_in_event_body(telemetry_handler, events)
Beispiel #3
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4"))

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()
class CGroupConfiguratorSystemdTestCase(AgentTestCase):
    @classmethod
    def tearDownClass(cls):
        CGroupConfigurator._instance = None
        AgentTestCase.tearDownClass()

    @staticmethod
    def _get_new_cgroup_configurator_instance(initialize=True):
        CGroupConfigurator._instance = None
        configurator = CGroupConfigurator.get_instance()
        if initialize:
            with mock_cgroup_commands():
                configurator.initialize()
        return configurator

    def test_initialize_should_start_tracking_the_agent_cgroups(self):
        CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        )

        tracked = CGroupsTelemetry._tracked

        self.assertTrue(
            any(cg for cg in tracked
                if cg.name == 'walinuxagent.service' and 'cpu' in cg.path),
            "The Agent's CPU is not being tracked")
        self.assertTrue(
            any(cg for cg in tracked
                if cg.name == 'walinuxagent.service' and 'memory' in cg.path),
            "The Agent's memory is not being tracked")

    def test_enable_and_disable_should_change_the_enabled_state_of_cgroups(
            self):
        configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        )

        self.assertTrue(configurator.enabled(),
                        "CGroupConfigurator should be enabled by default")

        configurator.disable()
        self.assertFalse(configurator.enabled(),
                         "disable() should disable the CGroupConfigurator")

        configurator.enable()
        self.assertTrue(configurator.enabled(),
                        "enable() should enable the CGroupConfigurator")

    def test_enable_should_raise_CGroupsException_when_cgroups_are_not_supported(
            self):
        with patch(
                "azurelinuxagent.common.cgroupapi.CGroupsApi.cgroups_supported",
                return_value=False):
            configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
                initialize=False)
            configurator.initialize()

            with self.assertRaises(CGroupsException) as context_manager:
                configurator.enable()
            self.assertIn(
                "Attempted to enable cgroups, but they are not supported on the current platform",
                str(context_manager.exception))

    def test_disable_should_reset_tracked_cgroups(self):
        # Start tracking a couple of dummy cgroups
        CGroupsTelemetry.track_cgroup(
            CGroup("dummy", "/sys/fs/cgroup/memory/system.slice/dummy.service",
                   "cpu"))
        CGroupsTelemetry.track_cgroup(
            CGroup("dummy", "/sys/fs/cgroup/memory/system.slice/dummy.service",
                   "memory"))

        CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        ).disable()

        self.assertEquals(len(CGroupsTelemetry._tracked), 0)

    def test_cgroup_operations_should_not_invoke_the_cgroup_api_when_cgroups_are_not_enabled(
            self):
        configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        )
        configurator.disable()

        # List of operations to test, and the functions to mock used in order to do verifications
        operations = [
            [
                lambda: configurator.create_extension_cgroups_root(),
                "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi.create_extension_cgroups_root"
            ],
            [
                lambda: configurator.create_extension_cgroups("A.B.C-1.0.0"),
                "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi.create_extension_cgroups"
            ],
            [
                lambda: configurator.remove_extension_cgroups("A.B.C-1.0.0"),
                "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi.remove_extension_cgroups"
            ]
        ]

        for op in operations:
            with patch(op[1]) as mock_cgroup_api_operation:
                op[0]()

            self.assertEqual(mock_cgroup_api_operation.call_count, 0)

    def test_cgroup_operations_should_log_a_warning_when_the_cgroup_api_raises_an_exception(
            self):
        configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        )

        # cleanup_legacy_cgroups disables cgroups on error, so make disable() a no-op
        with patch.object(configurator, "disable"):
            # List of operations to test, and the functions to mock in order to raise exceptions
            operations = [
                [
                    lambda: configurator.create_extension_cgroups_root(),
                    "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi.create_extension_cgroups_root"
                ],
                [
                    lambda: configurator.create_extension_cgroups("A.B.C-1.0.0"
                                                                  ),
                    "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi.create_extension_cgroups"
                ],
                [
                    lambda: configurator.remove_extension_cgroups("A.B.C-1.0.0"
                                                                  ),
                    "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi.remove_extension_cgroups"
                ]
            ]

            def raise_exception(*_):
                raise Exception("A TEST EXCEPTION")

            for op in operations:
                with patch(
                        "azurelinuxagent.common.cgroupconfigurator.logger.warn"
                ) as mock_logger_warn:
                    with patch(op[1], raise_exception):
                        op[0]()

                    self.assertEquals(mock_logger_warn.call_count, 1)

                    args, kwargs = mock_logger_warn.call_args
                    message = args[0]
                    self.assertIn("A TEST EXCEPTION", message)

    def test_get_processes_in_agent_cgroup_should_return_the_processes_within_the_agent_cgroup(
            self):
        with mock_cgroup_commands():
            configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
            )

            processes = configurator.get_processes_in_agent_cgroup()

            self.assertTrue(
                len(processes) >= 2,
                "The cgroup should contain at least 2 procceses (daemon and extension handler): [{0}]"
                .format(processes))

            daemon_present = any("waagent -daemon" in command
                                 for (pid, command) in processes)
            self.assertTrue(
                daemon_present,
                "Could not find the daemon in the cgroup: [{0}]".format(
                    processes))

            extension_handler_present = any(
                re.search("(WALinuxAgent-.+\.egg|waagent) -run-exthandlers",
                          command) for (pid, command) in processes)
            self.assertTrue(
                extension_handler_present,
                "Could not find the extension handler in the cgroup: [{0}]".
                format(processes))

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_not_use_systemd_when_cgroups_are_not_enabled(
            self, _):
        configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        )
        configurator.disable()

        with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                   wraps=subprocess.Popen) as patcher:
            configurator.start_extension_command(
                extension_name="Microsoft.Compute.TestExtension-1.2.3",
                command="date",
                timeout=300,
                shell=False,
                cwd=self.tmp_dir,
                env={},
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)

            command_calls = [
                args[0] for args, _ in patcher.call_args_list
                if len(args) > 0 and "date" in args[0]
            ]
            self.assertEqual(
                len(command_calls), 1,
                "The test command should have been called exactly once [{0}]".
                format(command_calls))
            self.assertNotIn(
                "systemd-run", command_calls[0],
                "The command should not have been invoked using systemd")
            self.assertEqual(command_calls[0], "date",
                             "The command line should not have been modified")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_use_systemd_run_when_cgroups_are_enabled(
            self, _):
        with mock_cgroup_commands():
            with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                       wraps=subprocess.Popen) as popen_patch:
                CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
                ).start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="the-test-extension-command",
                    timeout=300,
                    shell=False,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE)

                command_calls = [
                    args[0] for (args, _) in popen_patch.call_args_list
                    if "the-test-extension-command" in args[0]
                ]

                self.assertEqual(
                    len(command_calls), 1,
                    "The test command should have been called exactly once [{0}]"
                    .format(command_calls))
                self.assertIn(
                    "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                    command_calls[0],
                    "The extension should have been invoked using systemd")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_start_tracking_the_extension_cgroups(
            self, _):
        # CPU usage is initialized when we begin tracking a CPU cgroup; since this test does not retrieve the
        # CPU usage, there is no need for initialization
        with mock_cgroup_commands():
            CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
            ).start_extension_command(
                extension_name="Microsoft.Compute.TestExtension-1.2.3",
                command="test command",
                timeout=300,
                shell=False,
                cwd=self.tmp_dir,
                env={},
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)

        tracked = CGroupsTelemetry._tracked

        self.assertTrue(
            any(cg for cg in tracked
                if cg.name == 'Microsoft.Compute.TestExtension-1.2.3'
                and 'cpu' in cg.path),
            "The extension's CPU is not being tracked")
        self.assertTrue(
            any(cg for cg in tracked
                if cg.name == 'Microsoft.Compute.TestExtension-1.2.3'
                and 'memory' in cg.path),
            "The extension's memory is not being tracked")

    def test_start_extension_command_should_raise_an_exception_when_the_command_cannot_be_started(
            self):
        configurator = CGroupConfiguratorSystemdTestCase._get_new_cgroup_configurator_instance(
        )

        original_popen = subprocess.Popen

        def mock_popen(command_arg, *args, **kwargs):
            if "test command" in command_arg:
                raise Exception("A TEST EXCEPTION")
            return original_popen(command_arg, *args, **kwargs)

        with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                   side_effect=mock_popen) as patcher:
            with self.assertRaises(Exception) as context_manager:
                configurator.start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="test command",
                    timeout=300,
                    shell=False,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE)

                self.assertIn("A TEST EXCEPTION",
                              str(context_manager.exception))
Beispiel #5
0
class SystemdCgroupsApiTestCase(AgentTestCase):
    def test_get_extensions_slice_root_name_should_return_the_root_slice_for_extensions(
            self):
        root_slice_name = SystemdCgroupsApi()._get_extensions_slice_root_name()
        self.assertEqual(root_slice_name,
                         "system-walinuxagent.extensions.slice")

    def test_get_extension_slice_name_should_return_the_slice_for_the_given_extension(
            self):
        extension_name = "Microsoft.Azure.DummyExtension-1.0"
        extension_slice_name = SystemdCgroupsApi()._get_extension_slice_name(
            extension_name)
        self.assertEqual(
            extension_slice_name,
            "system-walinuxagent.extensions-Microsoft.Azure.DummyExtension_1.0.slice"
        )

    @attr('requires_sudo')
    def test_create_extension_cgroups_root_should_create_extensions_root_slice(
            self):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        SystemdCgroupsApi().create_extension_cgroups_root()

        unit_name = SystemdCgroupsApi()._get_extensions_slice_root_name()
        _, status = shellutil.run_get_output(
            "systemctl status {0}".format(unit_name))
        self.assertIn("Loaded: loaded", status)
        self.assertIn("Active: active", status)

        shellutil.run_get_output("systemctl stop {0}".format(unit_name))
        shellutil.run_get_output("systemctl disable {0}".format(unit_name))
        os.remove("/etc/systemd/system/{0}".format(unit_name))
        shellutil.run_get_output("systemctl daemon-reload")

    @attr('requires_sudo')
    def test_create_extension_cgroups_should_create_extension_slice(self):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        extension_name = "Microsoft.Azure.DummyExtension-1.0"
        cgroups = SystemdCgroupsApi().create_extension_cgroups(extension_name)
        cpu_cgroup, memory_cgroup = cgroups[0], cgroups[1]
        self.assertEqual(
            cpu_cgroup.path,
            "/sys/fs/cgroup/cpu/system.slice/Microsoft.Azure.DummyExtension_1.0"
        )
        self.assertEqual(
            memory_cgroup.path,
            "/sys/fs/cgroup/memory/system.slice/Microsoft.Azure.DummyExtension_1.0"
        )

        unit_name = SystemdCgroupsApi()._get_extension_slice_name(
            extension_name)
        self.assertEqual(
            "system-walinuxagent.extensions-Microsoft.Azure.DummyExtension_1.0.slice",
            unit_name)

        _, status = shellutil.run_get_output(
            "systemctl status {0}".format(unit_name))
        self.assertIn("Loaded: loaded", status)
        self.assertIn("Active: active", status)

        shellutil.run_get_output("systemctl stop {0}".format(unit_name))
        shellutil.run_get_output("systemctl disable {0}".format(unit_name))
        os.remove("/etc/systemd/system/{0}".format(unit_name))
        shellutil.run_get_output("systemctl daemon-reload")

    def assert_cgroups_created(self, extension_cgroups):
        self.assertEqual(
            len(extension_cgroups), 2,
            'start_extension_command did not return the expected number of cgroups'
        )

        cpu_found = memory_found = False

        for cgroup in extension_cgroups:
            match = re.match(
                r'^/sys/fs/cgroup/(cpu|memory)/system.slice/Microsoft.Compute.TestExtension_1\.2\.3\_([a-f0-9-]+)\.scope$',
                cgroup.path)

            self.assertTrue(
                match is not None,
                "Unexpected path for cgroup: {0}".format(cgroup.path))

            if match.group(1) == 'cpu':
                cpu_found = True
            if match.group(1) == 'memory':
                memory_found = True

        self.assertTrue(cpu_found,
                        'start_extension_command did not return a cpu cgroup')
        self.assertTrue(
            memory_found,
            'start_extension_command did not return a memory cgroup')

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_create_extension_scopes(self, _):
        original_popen = subprocess.Popen

        def mock_popen(*args, **kwargs):
            return original_popen("date", **kwargs)

        # we mock subprocess.Popen to execute a dummy command (date), so no actual cgroups are created; their paths
        # should be computed properly, though
        with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                   mock_popen):
            extension_cgroups, process_output = SystemdCgroupsApi(
            ).start_extension_command(
                extension_name="Microsoft.Compute.TestExtension-1.2.3",
                command="date",
                shell=False,
                timeout=300,
                cwd=self.tmp_dir,
                env={},
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)

            self.assert_cgroups_created(extension_cgroups)

    @attr('requires_sudo')
    @patch('time.sleep', side_effect=lambda _: mock_sleep(0.2))
    def test_start_extension_command_should_use_systemd_and_not_the_fallback_option_if_successful(
            self, _):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen", wraps=subprocess.Popen) \
                        as patch_mock_popen:
                    extension_cgroups, process_output = SystemdCgroupsApi(
                    ).start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="date",
                        timeout=300,
                        shell=True,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=stdout,
                        stderr=stderr)

                    # We should have invoked the extension command only once and succeeded
                    self.assertEquals(1, patch_mock_popen.call_count)

                    args = patch_mock_popen.call_args[0][0]
                    self.assertIn("systemd-run --unit", args)

                    self.assert_cgroups_created(extension_cgroups)

    @patch('time.sleep', side_effect=lambda _: mock_sleep(0.2))
    def test_start_extension_command_should_use_fallback_option_if_systemd_fails(
            self, _):
        original_popen = subprocess.Popen

        def mock_popen(*args, **kwargs):
            # Inject a syntax error to the call
            systemd_command = args[0].replace('systemd-run',
                                              'systemd-run syntax_error')
            new_args = (systemd_command, )
            return original_popen(new_args, **kwargs)

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.add_event"
                           ) as mock_add_event:
                    with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen", side_effect=mock_popen) \
                            as patch_mock_popen:
                        # We expect this call to fail because of the syntax error
                        extension_cgroups, process_output = SystemdCgroupsApi(
                        ).start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command="date",
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                        args, kwargs = mock_add_event.call_args
                        self.assertIn(
                            "Failed to run systemd-run for unit Microsoft.Compute.TestExtension_1.2.3",
                            kwargs['message'])
                        self.assertIn(
                            "Failed to find executable syntax_error: No such file or directory",
                            kwargs['message'])
                        self.assertEquals(False, kwargs['is_success'])
                        self.assertEquals('InvokeCommandUsingSystemd',
                                          kwargs['op'])

                        # We expect two calls to Popen, first for the systemd-run call, second for the fallback option
                        self.assertEquals(2, patch_mock_popen.call_count)

                        first_call_args = patch_mock_popen.mock_calls[0][1][0]
                        second_call_args = patch_mock_popen.mock_calls[1][1][0]
                        self.assertIn("systemd-run --unit", first_call_args)
                        self.assertNotIn("systemd-run --unit",
                                         second_call_args)

                        # No cgroups should have been created
                        self.assertEquals(extension_cgroups, [])

    @patch('time.sleep', side_effect=lambda _: mock_sleep(0.001))
    def test_start_extension_command_should_use_fallback_option_if_systemd_times_out(
            self, _):
        # Systemd has its own internal timeout which is shorter than what we define for extension operation timeout.
        # When systemd times out, it will write a message to stderr and exit with exit code 1.
        # In that case, we will internally recognize the failure due to the non-zero exit code, not as a timeout.
        original_popen = subprocess.Popen
        systemd_timeout_command = "echo 'Failed to start transient scope unit: Connection timed out' >&2 && exit 1"

        def mock_popen(*args, **kwargs):
            # If trying to invoke systemd, mock what would happen if systemd timed out internally:
            # write failure to stderr and exit with exit code 1.
            new_args = args
            if "systemd-run" in args[0]:
                new_args = (systemd_timeout_command, )

            return original_popen(new_args, **kwargs)

        expected_output = "[stdout]\n{0}\n\n\n[stderr]\n"

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen", side_effect=mock_popen) \
                        as patch_mock_popen:
                    extension_cgroups, process_output = SystemdCgroupsApi(
                    ).start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="echo 'success'",
                        timeout=300,
                        shell=True,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=stdout,
                        stderr=stderr)

                    # We expect two calls to Popen, first for the systemd-run call, second for the fallback option
                    self.assertEquals(2, patch_mock_popen.call_count)

                    first_call_args = patch_mock_popen.mock_calls[0][1][0]
                    second_call_args = patch_mock_popen.mock_calls[1][1][0]
                    self.assertIn("systemd-run --unit", first_call_args)
                    self.assertNotIn("systemd-run --unit", second_call_args)

                    self.assertEquals(extension_cgroups, [])
                    self.assertEquals(expected_output.format("success"),
                                      process_output)

    @attr('requires_sudo')
    @patch("azurelinuxagent.common.cgroupapi.add_event")
    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_not_use_fallback_option_if_extension_fails(
            self, *args):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen", wraps=subprocess.Popen) \
                        as patch_mock_popen:
                    with self.assertRaises(ExtensionError) as context_manager:
                        SystemdCgroupsApi().start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command="ls folder_does_not_exist",
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                        # We should have invoked the extension command only once, in the systemd-run case
                        self.assertEquals(1, patch_mock_popen.call_count)
                        args = patch_mock_popen.call_args[0][0]
                        self.assertIn("systemd-run --unit", args)

                        self.assertEquals(
                            context_manager.exception.code,
                            ExtensionErrorCodes.PluginUnknownFailure)
                        self.assertIn("Non-zero exit code",
                                      ustr(context_manager.exception))

    @attr('requires_sudo')
    @patch("azurelinuxagent.common.cgroupapi.add_event")
    def test_start_extension_command_should_not_use_fallback_option_if_extension_times_out(
            self, *args):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch(
                        "azurelinuxagent.common.utils.extensionprocessutil.wait_for_process_completion_or_timeout",
                        return_value=[True, None]):
                    with patch(
                            "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi._is_systemd_failure",
                            return_value=False):
                        with self.assertRaises(
                                ExtensionError) as context_manager:
                            SystemdCgroupsApi().start_extension_command(
                                extension_name=
                                "Microsoft.Compute.TestExtension-1.2.3",
                                command="date",
                                timeout=300,
                                shell=True,
                                cwd=self.tmp_dir,
                                env={},
                                stdout=stdout,
                                stderr=stderr)

                        self.assertEquals(
                            context_manager.exception.code,
                            ExtensionErrorCodes.PluginHandlerScriptTimedout)
                        self.assertIn("Timeout",
                                      ustr(context_manager.exception))

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_capture_only_the_last_subprocess_output(
            self, _):
        original_popen = subprocess.Popen

        def mock_popen(*args, **kwargs):
            # Inject a syntax error to the call
            systemd_command = args[0].replace('systemd-run',
                                              'systemd-run syntax_error')
            new_args = (systemd_command, )
            return original_popen(new_args, **kwargs)

        expected_output = "[stdout]\n{0}\n\n\n[stderr]\n"

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.add_event"):
                    with patch(
                            "azurelinuxagent.common.cgroupapi.subprocess.Popen",
                            side_effect=mock_popen):
                        # We expect this call to fail because of the syntax error
                        extension_cgroups, process_output = SystemdCgroupsApi(
                        ).start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command="echo 'very specific test message'",
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                        self.assertEquals(
                            expected_output.format(
                                "very specific test message"), process_output)
                        self.assertEquals(extension_cgroups, [])

    @patch("azurelinuxagent.common.utils.fileutil.read_file")
    def test_create_agent_cgroups_should_create_cgroups_on_all_controllers(
            self, patch_read_file):
        mock_proc_self_cgroup = '''12:blkio:/system.slice/walinuxagent.service
11:memory:/system.slice/walinuxagent.service
10:perf_event:/
9:hugetlb:/
8:freezer:/
7:net_cls,net_prio:/
6:devices:/system.slice/walinuxagent.service
5:cpuset:/
4:cpu,cpuacct:/system.slice/walinuxagent.service
3:pids:/system.slice/walinuxagent.service
2:rdma:/
1:name=systemd:/system.slice/walinuxagent.service
0::/system.slice/walinuxagent.service
'''
        patch_read_file.return_value = mock_proc_self_cgroup
        agent_cgroups = SystemdCgroupsApi().create_agent_cgroups()

        def assert_cgroup_created(controller):
            expected_cgroup_path = os.path.join(CGROUPS_FILE_SYSTEM_ROOT,
                                                controller, "system.slice",
                                                VM_AGENT_CGROUP_NAME)

            self.assertTrue(
                any(cgroups.path == expected_cgroup_path
                    for cgroups in agent_cgroups))
            self.assertTrue(
                any(cgroups.name == VM_AGENT_CGROUP_NAME
                    for cgroups in agent_cgroups))

        assert_cgroup_created("cpu")
        assert_cgroup_created("memory")
Beispiel #6
0
class FileSystemCgroupsApiTestCase(_MockedFileSystemTestCase):
    def test_cleanup_legacy_cgroups_should_move_daemon_pid_to_new_cgroup_and_remove_legacy_cgroups(
            self):
        # Set up a mock /var/run/waagent.pid file
        daemon_pid = "42"
        daemon_pid_file = os.path.join(self.tmp_dir, "waagent.pid")
        fileutil.write_file(daemon_pid_file, daemon_pid + "\n")

        # Set up old controller cgroups and add the daemon PID to them
        legacy_cpu_cgroup = CGroupsTools.create_legacy_agent_cgroup(
            self.cgroups_file_system_root, "cpu", daemon_pid)
        legacy_memory_cgroup = CGroupsTools.create_legacy_agent_cgroup(
            self.cgroups_file_system_root, "memory", daemon_pid)

        # Set up new controller cgroups and add extension handler's PID to them
        new_cpu_cgroup = CGroupsTools.create_agent_cgroup(
            self.cgroups_file_system_root, "cpu", "999")
        new_memory_cgroup = CGroupsTools.create_agent_cgroup(
            self.cgroups_file_system_root, "memory", "999")

        with patch("azurelinuxagent.common.cgroupapi.add_event"
                   ) as mock_add_event:
            with patch(
                    "azurelinuxagent.common.cgroupapi.get_agent_pid_file_path",
                    return_value=daemon_pid_file):
                FileSystemCgroupsApi().cleanup_legacy_cgroups()

        # The method should have added the daemon PID to the new controllers and deleted the old ones
        new_cpu_contents = fileutil.read_file(
            os.path.join(new_cpu_cgroup, "cgroup.procs"))
        new_memory_contents = fileutil.read_file(
            os.path.join(new_memory_cgroup, "cgroup.procs"))

        self.assertTrue(daemon_pid in new_cpu_contents)
        self.assertTrue(daemon_pid in new_memory_contents)

        self.assertFalse(os.path.exists(legacy_cpu_cgroup))
        self.assertFalse(os.path.exists(legacy_memory_cgroup))

        # Assert the event parameters that were sent out
        self.assertEquals(len(mock_add_event.call_args_list), 2)
        self.assertTrue(
            all(kwargs['op'] == 'CGroupsCleanUp'
                for _, kwargs in mock_add_event.call_args_list))
        self.assertTrue(
            all(kwargs['is_success']
                for _, kwargs in mock_add_event.call_args_list))
        self.assertTrue(
            any(
                re.match(
                    r"Moved daemon's PID from legacy cgroup to /.*/cgroup/cpu/walinuxagent.service",
                    kwargs['message'])
                for _, kwargs in mock_add_event.call_args_list))
        self.assertTrue(
            any(
                re.match(
                    r"Moved daemon's PID from legacy cgroup to /.*/cgroup/memory/walinuxagent.service",
                    kwargs['message'])
                for _, kwargs in mock_add_event.call_args_list))

    def test_create_agent_cgroups_should_create_cgroups_on_all_controllers(
            self):
        agent_cgroups = FileSystemCgroupsApi().create_agent_cgroups()

        def assert_cgroup_created(controller):
            cgroup_path = os.path.join(self.cgroups_file_system_root,
                                       controller, VM_AGENT_CGROUP_NAME)
            self.assertTrue(
                any(cgroups.path == cgroup_path for cgroups in agent_cgroups))
            self.assertTrue(
                any(cgroups.name == VM_AGENT_CGROUP_NAME
                    for cgroups in agent_cgroups))
            self.assertTrue(os.path.exists(cgroup_path))
            cgroup_task = int(
                fileutil.read_file(os.path.join(cgroup_path, "cgroup.procs")))
            current_process = os.getpid()
            self.assertEqual(cgroup_task, current_process)

        assert_cgroup_created("cpu")
        assert_cgroup_created("memory")

    def test_create_extension_cgroups_root_should_create_root_directory_for_extensions(
            self):
        FileSystemCgroupsApi().create_extension_cgroups_root()

        cpu_cgroup = os.path.join(self.cgroups_file_system_root, "cpu",
                                  "walinuxagent.extensions")
        self.assertTrue(os.path.exists(cpu_cgroup))

        memory_cgroup = os.path.join(self.cgroups_file_system_root, "memory",
                                     "walinuxagent.extensions")
        self.assertTrue(os.path.exists(memory_cgroup))

    def test_create_extension_cgroups_should_create_cgroups_on_all_controllers(
            self):
        api = FileSystemCgroupsApi()
        api.create_extension_cgroups_root()
        extension_cgroups = api.create_extension_cgroups(
            "Microsoft.Compute.TestExtension-1.2.3")

        def assert_cgroup_created(controller):
            cgroup_path = os.path.join(
                self.cgroups_file_system_root, controller,
                "walinuxagent.extensions",
                "Microsoft.Compute.TestExtension_1.2.3")

            self.assertTrue(
                any(cgroups.path == cgroup_path
                    for cgroups in extension_cgroups))
            self.assertTrue(os.path.exists(cgroup_path))

        assert_cgroup_created("cpu")
        assert_cgroup_created("memory")

    def test_remove_extension_cgroups_should_remove_all_cgroups(self):
        api = FileSystemCgroupsApi()
        api.create_extension_cgroups_root()
        extension_cgroups = api.create_extension_cgroups(
            "Microsoft.Compute.TestExtension-1.2.3")

        api.remove_extension_cgroups("Microsoft.Compute.TestExtension-1.2.3")

        for cgroup in extension_cgroups:
            self.assertFalse(os.path.exists(cgroup.path))

    def test_remove_extension_cgroups_should_log_a_warning_when_the_cgroup_contains_active_tasks(
            self):
        api = FileSystemCgroupsApi()
        api.create_extension_cgroups_root()
        api.create_extension_cgroups("Microsoft.Compute.TestExtension-1.2.3")

        with patch("azurelinuxagent.common.cgroupapi.logger.warn"
                   ) as mock_logger_warn:
            with patch("azurelinuxagent.common.cgroupapi.os.rmdir",
                       side_effect=OSError(16, "Device or resource busy")):
                api.remove_extension_cgroups(
                    "Microsoft.Compute.TestExtension-1.2.3")

            args, kwargs = mock_logger_warn.call_args
            message = args[0]
            self.assertIn("still has active tasks", message)

    def test_get_extension_cgroups_should_return_all_cgroups(self):
        api = FileSystemCgroupsApi()
        api.create_extension_cgroups_root()
        created = api.create_extension_cgroups(
            "Microsoft.Compute.TestExtension-1.2.3")

        retrieved = api.get_extension_cgroups(
            "Microsoft.Compute.TestExtension-1.2.3")

        self.assertEqual(len(retrieved), len(created))

        for cgroup in created:
            self.assertTrue(
                any(retrieved_cgroup.path == cgroup.path
                    for retrieved_cgroup in retrieved))

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_add_the_child_process_to_the_extension_cgroup(
            self, _):
        api = FileSystemCgroupsApi()
        api.create_extension_cgroups_root()

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                extension_cgroups, process_output = api.start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="echo $$",
                    timeout=300,
                    shell=True,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=stdout,
                    stderr=stderr)

        # The expected format of the process output is [stdout]\n{PID}\n\n\n[stderr]\n"
        pattern = re.compile(r"\[stdout\]\n(\d+)\n\n\n\[stderr\]\n")
        m = pattern.match(process_output)

        try:
            pid_from_output = int(m.group(1))
        except Exception as e:
            self.fail(
                "No PID could be extracted from the process output! Error: {0}"
                .format(ustr(e)))

        for cgroup in extension_cgroups:
            cgroups_procs_path = os.path.join(cgroup.path, "cgroup.procs")
            with open(cgroups_procs_path, "r") as f:
                contents = f.read()
            pid_from_cgroup = int(contents)

            self.assertEquals(
                pid_from_output, pid_from_cgroup,
                "The PID from the process output ({0}) does not match the PID found in the"
                "process cgroup {1} ({2})".format(pid_from_output,
                                                  cgroups_procs_path,
                                                  pid_from_cgroup))
class SystemdCgroupsApiTestCase(AgentTestCase):
    def test_get_systemd_version_should_return_a_version_number(self):
        with mock_cgroup_commands():
            version_info = SystemdCgroupsApi.get_systemd_version()
            found = re.search("systemd \d+", version_info) is not None
            self.assertTrue(
                found, "Could not determine the systemd version: {0}".format(
                    version_info))

    def test_get_cpu_and_memory_mount_points_should_return_the_cgroup_mount_points(
            self):
        with mock_cgroup_commands():
            cpu, memory = SystemdCgroupsApi().get_cgroup_mount_points()
            self.assertEquals(
                cpu, '/sys/fs/cgroup/cpu,cpuacct',
                "The mount point for the CPU controller is incorrect")
            self.assertEquals(
                memory, '/sys/fs/cgroup/memory',
                "The mount point for the memory controller is incorrect")

    def test_get_cpu_and_memory_cgroup_relative_paths_for_process_should_return_the_cgroup_relative_paths(
            self):
        with mock_cgroup_commands():
            cpu, memory = SystemdCgroupsApi.get_process_cgroup_relative_paths(
                'self')
            self.assertEquals(
                cpu, "system.slice/walinuxagent.service",
                "The relative path for the CPU cgroup is incorrect")
            self.assertEquals(
                memory, "system.slice/walinuxagent.service",
                "The relative memory for the CPU cgroup is incorrect")

    def test_get_cgroup2_controllers_should_return_the_v2_cgroup_controllers(
            self):
        with mock_cgroup_commands():
            mount_point, controllers = SystemdCgroupsApi.get_cgroup2_controllers(
            )

            self.assertEquals(mount_point, "/sys/fs/cgroup/unified",
                              "Invalid mount point for V2 cgroups")
            self.assertIn(
                "cpu", controllers,
                "The CPU controller is not in the list of V2 controllers")
            self.assertIn(
                "memory", controllers,
                "The memory controller is not in the list of V2 controllers")

    def test_get_unit_property_should_return_the_value_of_the_given_property(
            self):
        with mock_cgroup_commands():
            cpu_accounting = SystemdCgroupsApi.get_unit_property(
                "walinuxagent.service", "CPUAccounting")

            self.assertEquals(
                cpu_accounting, "no",
                "Property {0} of {1} is incorrect".format(
                    "CPUAccounting", "walinuxagent.service"))

    def test_get_extensions_slice_root_name_should_return_the_root_slice_for_extensions(
            self):
        root_slice_name = SystemdCgroupsApi()._get_extensions_slice_root_name()
        self.assertEqual(root_slice_name,
                         "system-walinuxagent.extensions.slice")

    def test_get_extension_slice_name_should_return_the_slice_for_the_given_extension(
            self):
        extension_name = "Microsoft.Azure.DummyExtension-1.0"
        extension_slice_name = SystemdCgroupsApi()._get_extension_slice_name(
            extension_name)
        self.assertEqual(
            extension_slice_name,
            "system-walinuxagent.extensions-Microsoft.Azure.DummyExtension_1.0.slice"
        )

    @attr('requires_sudo')
    def test_create_extension_cgroups_root_should_create_extensions_root_slice(
            self):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        SystemdCgroupsApi().create_extension_cgroups_root()

        unit_name = SystemdCgroupsApi()._get_extensions_slice_root_name()
        _, status = shellutil.run_get_output(
            "systemctl status {0}".format(unit_name))
        self.assertIn("Loaded: loaded", status)
        self.assertIn("Active: active", status)

        shellutil.run_get_output("systemctl stop {0}".format(unit_name))
        shellutil.run_get_output("systemctl disable {0}".format(unit_name))
        os.remove("/etc/systemd/system/{0}".format(unit_name))
        shellutil.run_get_output("systemctl daemon-reload")

    def test_get_processes_in_cgroup_should_return_the_processes_within_the_cgroup(
            self):
        with mock_cgroup_commands():
            processes = SystemdCgroupsApi.get_processes_in_cgroup(
                "/sys/fs/cgroup/cpu/system.slice/walinuxagent.service")

            self.assertTrue(
                len(processes) >= 2,
                "The cgroup should contain at least 2 procceses (daemon and extension handler): [{0}]"
                .format(processes))

            daemon_present = any("waagent -daemon" in command
                                 for (pid, command) in processes)
            self.assertTrue(
                daemon_present,
                "Could not find the daemon in the cgroup: [{0}]".format(
                    processes))

            extension_handler_present = any(
                re.search("(WALinuxAgent-.+\.egg|waagent) -run-exthandlers",
                          command) for (pid, command) in processes)
            self.assertTrue(
                extension_handler_present,
                "Could not find the extension handler in the cgroup: [{0}]".
                format(processes))

    @attr('requires_sudo')
    def test_create_extension_cgroups_should_create_extension_slice(self):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        extension_name = "Microsoft.Azure.DummyExtension-1.0"
        cgroups = SystemdCgroupsApi().create_extension_cgroups(extension_name)
        cpu_cgroup, memory_cgroup = cgroups[0], cgroups[1]
        self.assertEqual(
            cpu_cgroup.path,
            "/sys/fs/cgroup/cpu/system.slice/Microsoft.Azure.DummyExtension_1.0"
        )
        self.assertEqual(
            memory_cgroup.path,
            "/sys/fs/cgroup/memory/system.slice/Microsoft.Azure.DummyExtension_1.0"
        )

        unit_name = SystemdCgroupsApi()._get_extension_slice_name(
            extension_name)
        self.assertEqual(
            "system-walinuxagent.extensions-Microsoft.Azure.DummyExtension_1.0.slice",
            unit_name)

        _, status = shellutil.run_get_output(
            "systemctl status {0}".format(unit_name))
        self.assertIn("Loaded: loaded", status)
        self.assertIn("Active: active", status)

        shellutil.run_get_output("systemctl stop {0}".format(unit_name))
        shellutil.run_get_output("systemctl disable {0}".format(unit_name))
        os.remove("/etc/systemd/system/{0}".format(unit_name))
        shellutil.run_get_output("systemctl daemon-reload")

    def assert_cgroups_created(self, extension_cgroups):
        self.assertEqual(
            len(extension_cgroups), 2,
            'start_extension_command did not return the expected number of cgroups'
        )

        cpu_found = memory_found = False

        for cgroup in extension_cgroups:
            match = re.match(
                r'^/sys/fs/cgroup/(cpu|memory)/system.slice/Microsoft.Compute.TestExtension_1\.2\.3\_([a-f0-9-]+)\.scope$',
                cgroup.path)

            self.assertTrue(
                match is not None,
                "Unexpected path for cgroup: {0}".format(cgroup.path))

            if match.group(1) == 'cpu':
                cpu_found = True
            if match.group(1) == 'memory':
                memory_found = True

        self.assertTrue(cpu_found,
                        'start_extension_command did not return a cpu cgroup')
        self.assertTrue(
            memory_found,
            'start_extension_command did not return a memory cgroup')

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_return_the_command_output(self, _):
        original_popen = subprocess.Popen

        def mock_popen(command, *args, **kwargs):
            if command.startswith(
                    'systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3'
            ):
                command = "echo TEST_OUTPUT"
            return original_popen(command, *args, **kwargs)

        with mock_cgroup_commands() as mock_commands:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as output_file:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           side_effect=mock_popen) as popen_patch:
                    command_output = SystemdCgroupsApi(
                    ).start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="A_TEST_COMMAND",
                        shell=True,
                        timeout=300,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=output_file,
                        stderr=output_file)

                    self.assertIn("[stdout]\nTEST_OUTPUT\n", command_output,
                                  "The test output was not captured")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_execute_the_command_in_a_cgroup(
            self, _):
        with mock_cgroup_commands():
            SystemdCgroupsApi().start_extension_command(
                extension_name="Microsoft.Compute.TestExtension-1.2.3",
                command="test command",
                shell=False,
                timeout=300,
                cwd=self.tmp_dir,
                env={},
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)

            tracked = CGroupsTelemetry._tracked

            self.assertTrue(
                any(cg for cg in tracked
                    if cg.name == 'Microsoft.Compute.TestExtension-1.2.3'
                    and 'cpu' in cg.path),
                "The extension's CPU is not being tracked")
            self.assertTrue(
                any(cg for cg in tracked
                    if cg.name == 'Microsoft.Compute.TestExtension-1.2.3'
                    and 'memory' in cg.path),
                "The extension's memory is not being tracked")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_use_systemd_to_execute_the_command(
            self, _):
        with mock_cgroup_commands():
            with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                       wraps=subprocess.Popen) as popen_patch:
                SystemdCgroupsApi().start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="the-test-extension-command",
                    timeout=300,
                    shell=True,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE)

                extension_calls = [
                    args[0] for (args, _) in popen_patch.call_args_list
                    if "the-test-extension-command" in args[0]
                ]

                self.assertEquals(
                    1, len(extension_calls),
                    "The extension should have been invoked exactly once")
                self.assertIn(
                    "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                    extension_calls[0],
                    "The extension should have been invoked using systemd")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_invoke_the_command_directly_if_systemd_fails(
            self, _):
        original_popen = subprocess.Popen

        def mock_popen(command, *args, **kwargs):
            if command.startswith('systemd-run'):
                # Inject a syntax error to the call
                command = command.replace('systemd-run',
                                          'systemd-run syntax_error')
            return original_popen(command, *args, **kwargs)

        with tempfile.TemporaryFile(dir=self.tmp_dir,
                                    mode="w+b") as output_file:
            with patch("azurelinuxagent.common.cgroupapi.add_event"
                       ) as mock_add_event:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           side_effect=mock_popen) as popen_patch:
                    CGroupsTelemetry.reset()

                    command = "echo TEST_OUTPUT"

                    command_output = SystemdCgroupsApi(
                    ).start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command=command,
                        timeout=300,
                        shell=True,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=output_file,
                        stderr=output_file)

                    args, kwargs = mock_add_event.call_args
                    self.assertIn(
                        "Failed to run systemd-run for unit Microsoft.Compute.TestExtension_1.2.3",
                        kwargs['message'])
                    self.assertIn(
                        "Failed to find executable syntax_error: No such file or directory",
                        kwargs['message'])
                    self.assertEquals(False, kwargs['is_success'])
                    self.assertEquals('InvokeCommandUsingSystemd',
                                      kwargs['op'])

                    extension_calls = [
                        args[0] for (args, _) in popen_patch.call_args_list
                        if command in args[0]
                    ]

                    self.assertEquals(
                        2, len(extension_calls),
                        "The extension should have been invoked exactly twice")
                    self.assertIn(
                        "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                        extension_calls[0],
                        "The first call to the extension should have used systemd"
                    )
                    self.assertEquals(
                        command, extension_calls[1],
                        "The second call to the extension should not have used systemd"
                    )

                    self.assertEquals(len(CGroupsTelemetry._tracked), 0,
                                      "No cgroups should have been created")

                    self.assertIn("TEST_OUTPUT\n", command_output,
                                  "The test output was not captured")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_invoke_the_command_directly_if_systemd_times_out(
            self, _):
        # Systemd has its own internal timeout which is shorter than what we define for extension operation timeout.
        # When systemd times out, it will write a message to stderr and exit with exit code 1.
        # In that case, we will internally recognize the failure due to the non-zero exit code, not as a timeout.
        original_popen = subprocess.Popen
        systemd_timeout_command = "echo 'Failed to start transient scope unit: Connection timed out' >&2 && exit 1"

        def mock_popen(*args, **kwargs):
            # If trying to invoke systemd, mock what would happen if systemd timed out internally:
            # write failure to stderr and exit with exit code 1.
            new_args = args
            if "systemd-run" in args[0]:
                new_args = (systemd_timeout_command, )

            return original_popen(new_args, **kwargs)

        expected_output = "[stdout]\n{0}\n\n\n[stderr]\n"

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           side_effect=mock_popen) as popen_patch:
                    CGroupsTelemetry.reset()

                    SystemdCgroupsApi().start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="echo 'success'",
                        timeout=300,
                        shell=True,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=stdout,
                        stderr=stderr)

                    extension_calls = [
                        args[0] for (args, _) in popen_patch.call_args_list
                        if "echo 'success'" in args[0]
                    ]

                    self.assertEquals(
                        2, len(extension_calls),
                        "The extension should have been invoked exactly twice")
                    self.assertIn(
                        "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                        extension_calls[0],
                        "The first call to the extension should have used systemd"
                    )
                    self.assertEquals(
                        "echo 'success'", extension_calls[1],
                        "The second call to the extension should not have used systemd"
                    )

                    self.assertEquals(len(CGroupsTelemetry._tracked), 0,
                                      "No cgroups should have been created")

    @attr('requires_sudo')
    @patch("azurelinuxagent.common.cgroupapi.add_event")
    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_not_use_fallback_option_if_extension_fails(
            self, *args):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen", wraps=subprocess.Popen) \
                        as patch_mock_popen:
                    with self.assertRaises(ExtensionError) as context_manager:
                        SystemdCgroupsApi().start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command="ls folder_does_not_exist",
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                        # We should have invoked the extension command only once, in the systemd-run case
                        self.assertEquals(1, patch_mock_popen.call_count)
                        args = patch_mock_popen.call_args[0][0]
                        self.assertIn("systemd-run --unit", args)

                        self.assertEquals(
                            context_manager.exception.code,
                            ExtensionErrorCodes.PluginUnknownFailure)
                        self.assertIn("Non-zero exit code",
                                      ustr(context_manager.exception))

    @attr('requires_sudo')
    @patch("azurelinuxagent.common.cgroupapi.add_event")
    def test_start_extension_command_should_not_use_fallback_option_if_extension_times_out(
            self, *args):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch(
                        "azurelinuxagent.common.utils.extensionprocessutil.wait_for_process_completion_or_timeout",
                        return_value=[True, None]):
                    with patch(
                            "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi._is_systemd_failure",
                            return_value=False):
                        with self.assertRaises(
                                ExtensionError) as context_manager:
                            SystemdCgroupsApi().start_extension_command(
                                extension_name=
                                "Microsoft.Compute.TestExtension-1.2.3",
                                command="date",
                                timeout=300,
                                shell=True,
                                cwd=self.tmp_dir,
                                env={},
                                stdout=stdout,
                                stderr=stderr)

                        self.assertEquals(
                            context_manager.exception.code,
                            ExtensionErrorCodes.PluginHandlerScriptTimedout)
                        self.assertIn("Timeout",
                                      ustr(context_manager.exception))

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_capture_only_the_last_subprocess_output(
            self, _):
        original_popen = subprocess.Popen

        def mock_popen(*args, **kwargs):
            # Inject a syntax error to the call
            systemd_command = args[0].replace('systemd-run',
                                              'systemd-run syntax_error')
            new_args = (systemd_command, )
            return original_popen(new_args, **kwargs)

        expected_output = "[stdout]\n{0}\n\n\n[stderr]\n"

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.add_event"):
                    with patch(
                            "azurelinuxagent.common.cgroupapi.subprocess.Popen",
                            side_effect=mock_popen):
                        # We expect this call to fail because of the syntax error
                        process_output = SystemdCgroupsApi(
                        ).start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command="echo 'very specific test message'",
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                        self.assertEquals(
                            expected_output.format(
                                "very specific test message"), process_output)
Beispiel #8
0
class DownloadExtensionTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch(
            "azurelinuxagent.ga.exthandlers.CGroupConfigurator")
        cls.mock_cgroups.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
        ext_handler.properties = ext_handler_properties

        protocol = WireProtocol(
            "http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

        self.pkg = ExtHandlerPackage()
        self.pkg.uris = [
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri()
        ]
        self.pkg.uris[
            0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler,
                                                       protocol=protocol)
        self.ext_handler_instance.pkg = self.pkg

        self.extension_dir = os.path.join(
            self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            return_value=self.extension_dir)
        self.mock_get_base_dir.start()

        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            return_value=self.tmp_dir)
        self.mock_get_log_dir.start()

        self.agent_dir = self.tmp_dir
        self.mock_get_lib_dir = patch(
            "azurelinuxagent.ga.exthandlers.conf.get_lib_dir",
            return_value=self.agent_dir)
        self.mock_get_lib_dir.start()

    def tearDown(self):
        self.mock_get_lib_dir.stop()
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    _extension_command = "RunCommandLinux.sh"

    @staticmethod
    def _create_zip_file(filename):
        file = None
        try:
            file = zipfile.ZipFile(filename, "w")
            info = zipfile.ZipInfo(
                DownloadExtensionTestCase._extension_command)
            info.date_time = time.localtime(time.time())[:6]
            info.compress_type = zipfile.ZIP_DEFLATED
            file.writestr(
                info,
                "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n")
        finally:
            if file is not None:
                file.close()

    @staticmethod
    def _create_invalid_zip_file(filename):
        with open(filename, "w") as file:
            file.write("An invalid ZIP file\n")

    def _get_extension_package_file(self):
        return os.path.join(
            self.agent_dir,
            self.ext_handler_instance.get_extension_package_zipfile_name())

    def _get_extension_command_file(self):
        return os.path.join(self.extension_dir,
                            DownloadExtensionTestCase._extension_command)

    def _assert_download_and_expand_succeeded(self):
        self.assertTrue(
            os.path.exists(self._get_extension_package_file()),
            "The extension package was not downloaded to the expected location"
        )
        self.assertTrue(
            os.path.exists(self._get_extension_command_file()),
            "The extension package was not expanded to the expected location")

    def test_it_should_download_and_expand_extension_package(self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            with patch(
                    "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.report_event"
            ) as mock_report_event:
                self.ext_handler_instance.download()

        # first download attempt should succeed
        mock_download_ext_handler_pkg.assert_called_once()
        mock_report_event.assert_called_once()

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_existing_extension_package_when_already_downloaded(
            self):
        DownloadExtensionTestCase._create_zip_file(
            self._get_extension_package_file())

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg"
        ) as mock_download_ext_handler_pkg:
            with patch(
                    "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.report_event"
            ) as mock_report_event:
                self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_not_called()
        mock_report_event.assert_not_called()

        self.assertTrue(
            os.path.exists(self._get_extension_command_file()),
            "The extension package was not expanded to the expected location")

    def test_it_should_ignore_existing_extension_package_when_it_is_invalid(
            self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        DownloadExtensionTestCase._create_invalid_zip_file(
            self._get_extension_package_file())

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_called_once()

        self._assert_download_and_expand_succeeded()

    def test_it_should_maintain_extension_handler_state_when_good_zip_exists(
            self):
        DownloadExtensionTestCase._create_zip_file(
            self._get_extension_package_file())
        self.ext_handler_instance.set_handler_state(
            ExtHandlerState.NotInstalled)
        self.ext_handler_instance.download()
        self._assert_download_and_expand_succeeded()
        self.assertTrue(
            os.path.exists(
                os.path.join(self.ext_handler_instance.get_conf_dir(),
                             "HandlerState")),
            "Ensure that the HandlerState file exists on disk")
        self.assertEqual(
            self.ext_handler_instance.get_handler_state(),
            ExtHandlerState.NotInstalled,
            "Ensure that the state is maintained for extension HandlerState")

    def test_it_should_maintain_extension_handler_state_when_bad_zip_exists_and_recovers_with_good_zip(
            self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        DownloadExtensionTestCase._create_invalid_zip_file(
            self._get_extension_package_file())
        self.ext_handler_instance.set_handler_state(
            ExtHandlerState.NotInstalled)

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_called_once()
        self._assert_download_and_expand_succeeded()
        self.assertEqual(
            self.ext_handler_instance.get_handler_state(),
            ExtHandlerState.NotInstalled,
            "Ensure that the state is maintained for extension HandlerState")

    @patch('time.sleep', side_effect=lambda _: mock_sleep(0.001))
    def test_it_should_maintain_extension_handler_state_when_it_downloads_bad_zips(
            self, _):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_invalid_zip_file(destination)
            return True

        self.ext_handler_instance.set_handler_state(
            ExtHandlerState.NotInstalled)

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg):
            with self.assertRaises(ExtensionDownloadError):
                self.ext_handler_instance.download()

        self.assertFalse(
            os.path.exists(self._get_extension_package_file()),
            "The bad zip extension package should not be downloaded to the expected location"
        )
        self.assertFalse(
            os.path.exists(self._get_extension_command_file()),
            "The extension package should not expanded be to the expected location due to bad zip"
        )
        self.assertEqual(
            self.ext_handler_instance.get_handler_state(),
            ExtHandlerState.NotInstalled,
            "Ensure that the state is maintained for extension HandlerState")

    def test_it_should_use_alternate_uris_when_download_fails(self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                return False
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_download_raises_an_exception(
            self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                raise Exception("Download failed")
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(
            self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                DownloadExtensionTestCase._create_invalid_zip_file(destination)
            else:
                DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_raise_an_exception_when_all_downloads_fail(self):
        def download_ext_handler_pkg(_uri, _destination):
            DownloadExtensionTestCase._create_invalid_zip_file(
                self._get_extension_package_file())
            return True

        with patch("time.sleep", lambda *_: None):
            with patch(
                    "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                    side_effect=download_ext_handler_pkg
            ) as mock_download_ext_handler_pkg:
                with self.assertRaises(
                        ExtensionDownloadError) as context_manager:
                    self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris))

        self.assertRegex(str(context_manager.exception),
                         "Failed to download extension")
        self.assertEquals(context_manager.exception.code,
                          ExtensionErrorCodes.PluginManifestDownloadError)

        self.assertFalse(os.path.exists(self.extension_dir),
                         "The extension directory was not removed")
        self.assertFalse(os.path.exists(self._get_extension_package_file()),
                         "The extension package was not removed")
class SystemdCgroupsApiTestCase(AgentTestCase):
    def test_get_systemd_version_should_return_a_version_number(self):
        with mock_cgroup_environment(self.tmp_dir):
            version_info = systemd.get_version()
            found = re.search(r"systemd \d+", version_info) is not None
            self.assertTrue(
                found, "Could not determine the systemd version: {0}".format(
                    version_info))

    def test_get_cpu_and_memory_mount_points_should_return_the_cgroup_mount_points(
            self):
        with mock_cgroup_environment(self.tmp_dir):
            cpu, memory = SystemdCgroupsApi().get_cgroup_mount_points()
            self.assertEqual(
                cpu, '/sys/fs/cgroup/cpu,cpuacct',
                "The mount point for the CPU controller is incorrect")
            self.assertEqual(
                memory, '/sys/fs/cgroup/memory',
                "The mount point for the memory controller is incorrect")

    def test_get_cpu_and_memory_cgroup_relative_paths_for_process_should_return_the_cgroup_relative_paths(
            self):
        with mock_cgroup_environment(self.tmp_dir):
            cpu, memory = SystemdCgroupsApi.get_process_cgroup_relative_paths(
                'self')
            self.assertEqual(
                cpu, "system.slice/walinuxagent.service",
                "The relative path for the CPU cgroup is incorrect")
            self.assertEqual(
                memory, "system.slice/walinuxagent.service",
                "The relative memory for the CPU cgroup is incorrect")

    def test_get_cgroup2_controllers_should_return_the_v2_cgroup_controllers(
            self):
        with mock_cgroup_environment(self.tmp_dir):
            mount_point, controllers = SystemdCgroupsApi.get_cgroup2_controllers(
            )

            self.assertEqual(mount_point, "/sys/fs/cgroup/unified",
                             "Invalid mount point for V2 cgroups")
            self.assertIn(
                "cpu", controllers,
                "The CPU controller is not in the list of V2 controllers")
            self.assertIn(
                "memory", controllers,
                "The memory controller is not in the list of V2 controllers")

    def test_get_unit_property_should_return_the_value_of_the_given_property(
            self):
        with mock_cgroup_environment(self.tmp_dir):
            cpu_accounting = systemd.get_unit_property("walinuxagent.service",
                                                       "CPUAccounting")

            self.assertEqual(
                cpu_accounting, "no",
                "Property {0} of {1} is incorrect".format(
                    "CPUAccounting", "walinuxagent.service"))

    def assert_cgroups_created(self, extension_cgroups):
        self.assertEqual(
            len(extension_cgroups), 2,
            'start_extension_command did not return the expected number of cgroups'
        )

        cpu_found = memory_found = False

        for cgroup in extension_cgroups:
            match = re.match(
                r'^/sys/fs/cgroup/(cpu|memory)/system.slice/Microsoft.Compute.TestExtension_1\.2\.3\_([a-f0-9-]+)\.scope$',
                cgroup.path)

            self.assertTrue(
                match is not None,
                "Unexpected path for cgroup: {0}".format(cgroup.path))

            if match.group(1) == 'cpu':
                cpu_found = True
            if match.group(1) == 'memory':
                memory_found = True

        self.assertTrue(cpu_found,
                        'start_extension_command did not return a cpu cgroup')
        self.assertTrue(
            memory_found,
            'start_extension_command did not return a memory cgroup')

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_return_the_command_output(self, _):
        original_popen = subprocess.Popen

        def mock_popen(command, *args, **kwargs):
            if command.startswith(
                    'systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3'
            ):
                command = "echo TEST_OUTPUT"
            return original_popen(command, *args, **kwargs)

        with mock_cgroup_environment(self.tmp_dir):
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as output_file:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           side_effect=mock_popen) as popen_patch:  # pylint: disable=unused-variable
                    command_output = SystemdCgroupsApi(
                    ).start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="A_TEST_COMMAND",
                        shell=True,
                        timeout=300,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=output_file,
                        stderr=output_file)

                    self.assertIn("[stdout]\nTEST_OUTPUT\n", command_output,
                                  "The test output was not captured")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_execute_the_command_in_a_cgroup(
            self, _):
        with mock_cgroup_environment(self.tmp_dir):
            SystemdCgroupsApi().start_extension_command(
                extension_name="Microsoft.Compute.TestExtension-1.2.3",
                command="test command",
                shell=False,
                timeout=300,
                cwd=self.tmp_dir,
                env={},
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)

            tracked = CGroupsTelemetry._tracked

            self.assertTrue(
                any(cg for cg in tracked
                    if cg.name == 'Microsoft.Compute.TestExtension-1.2.3'
                    and 'cpu' in cg.path),
                "The extension's CPU is not being tracked")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_use_systemd_to_execute_the_command(
            self, _):
        with mock_cgroup_environment(self.tmp_dir):
            with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                       wraps=subprocess.Popen) as popen_patch:
                SystemdCgroupsApi().start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="the-test-extension-command",
                    timeout=300,
                    shell=True,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE)

                extension_calls = [
                    args[0] for (args, _) in popen_patch.call_args_list
                    if "the-test-extension-command" in args[0]
                ]

                self.assertEqual(
                    1, len(extension_calls),
                    "The extension should have been invoked exactly once")
                self.assertIn(
                    "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                    extension_calls[0],
                    "The extension should have been invoked using systemd")
class CGroupConfiguratorSystemdTestCase(AgentTestCase):
    @classmethod
    def tearDownClass(cls):
        CGroupConfigurator._instance = None
        AgentTestCase.tearDownClass()

    @contextlib.contextmanager
    def _get_cgroup_configurator(self,
                                 initialize=True,
                                 enable=True,
                                 mock_commands=None):
        CGroupConfigurator._instance = None
        configurator = CGroupConfigurator.get_instance()
        CGroupsTelemetry.reset()
        with mock_cgroup_environment(self.tmp_dir) as mock_environment:
            if mock_commands is not None:
                for command in mock_commands:
                    mock_environment.add_command(command)
            configurator.mocks = mock_environment
            if initialize:
                if not enable:
                    with patch.object(configurator, "enable"):
                        configurator.initialize()
                else:
                    configurator.initialize()
            yield configurator

    def test_initialize_should_enable_cgroups(self):
        with self._get_cgroup_configurator() as configurator:
            self.assertTrue(configurator.enabled(), "cgroups were not enabled")

    def test_initialize_should_start_tracking_the_agent_cgroups(self):
        with self._get_cgroup_configurator() as configurator:
            tracked = CGroupsTelemetry._tracked

            self.assertTrue(configurator.enabled(),
                            "Cgroups should be enabled")
            self.assertTrue(
                any(cg for cg in tracked
                    if cg.name == AGENT_NAME_TELEMETRY and 'cpu' in cg.path),
                "The Agent's CPU is not being tracked. Tracked: {0}".format(
                    tracked))

    def test_initialize_should_start_tracking_other_controllers_when_one_is_not_present(
            self):
        command_mocks = [
            MockCommand(
                r"^mount -t cgroup$",
                '''cgroup on /sys/fs/cgroup/systemd type cgroup (rw,nosuid,nodev,noexec,relatime,xattr,name=systemd)
cgroup on /sys/fs/cgroup/rdma type cgroup (rw,nosuid,nodev,noexec,relatime,rdma)
cgroup on /sys/fs/cgroup/cpuset type cgroup (rw,nosuid,nodev,noexec,relatime,cpuset)
cgroup on /sys/fs/cgroup/net_cls,net_prio type cgroup (rw,nosuid,nodev,noexec,relatime,net_cls,net_prio)
cgroup on /sys/fs/cgroup/perf_event type cgroup (rw,nosuid,nodev,noexec,relatime,perf_event)
cgroup on /sys/fs/cgroup/hugetlb type cgroup (rw,nosuid,nodev,noexec,relatime,hugetlb)
cgroup on /sys/fs/cgroup/freezer type cgroup (rw,nosuid,nodev,noexec,relatime,freezer)
cgroup on /sys/fs/cgroup/pids type cgroup (rw,nosuid,nodev,noexec,relatime,pids)
cgroup on /sys/fs/cgroup/devices type cgroup (rw,nosuid,nodev,noexec,relatime,devices)
cgroup on /sys/fs/cgroup/cpu,cpuacct type cgroup (rw,nosuid,nodev,noexec,relatime,cpu,cpuacct)
cgroup on /sys/fs/cgroup/blkio type cgroup (rw,nosuid,nodev,noexec,relatime,blkio)
''')
        ]
        with self._get_cgroup_configurator(
                mock_commands=command_mocks) as configurator:
            tracked = CGroupsTelemetry._tracked

            self.assertTrue(configurator.enabled(),
                            "Cgroups should be enabled")
            self.assertFalse(
                any(cg for cg in tracked if cg.name == 'walinuxagent.service'
                    and 'memory' in cg.path),
                "The Agent's memory should not be tracked. Tracked: {0}".
                format(tracked))

    def test_initialize_should_not_enable_cgroups_when_the_cpu_and_memory_controllers_are_not_present(
            self):
        command_mocks = [
            MockCommand(
                r"^mount -t cgroup$",
                '''cgroup on /sys/fs/cgroup/systemd type cgroup (rw,nosuid,nodev,noexec,relatime,xattr,name=systemd)
cgroup on /sys/fs/cgroup/rdma type cgroup (rw,nosuid,nodev,noexec,relatime,rdma)
cgroup on /sys/fs/cgroup/cpuset type cgroup (rw,nosuid,nodev,noexec,relatime,cpuset)
cgroup on /sys/fs/cgroup/net_cls,net_prio type cgroup (rw,nosuid,nodev,noexec,relatime,net_cls,net_prio)
cgroup on /sys/fs/cgroup/perf_event type cgroup (rw,nosuid,nodev,noexec,relatime,perf_event)
cgroup on /sys/fs/cgroup/hugetlb type cgroup (rw,nosuid,nodev,noexec,relatime,hugetlb)
cgroup on /sys/fs/cgroup/freezer type cgroup (rw,nosuid,nodev,noexec,relatime,freezer)
cgroup on /sys/fs/cgroup/pids type cgroup (rw,nosuid,nodev,noexec,relatime,pids)
cgroup on /sys/fs/cgroup/devices type cgroup (rw,nosuid,nodev,noexec,relatime,devices)
cgroup on /sys/fs/cgroup/blkio type cgroup (rw,nosuid,nodev,noexec,relatime,blkio)
''')
        ]
        with self._get_cgroup_configurator(
                mock_commands=command_mocks) as configurator:
            tracked = CGroupsTelemetry._tracked

            self.assertFalse(configurator.enabled(),
                             "Cgroups should not be enabled")
            self.assertEqual(
                len(tracked), 0,
                "No cgroups should be tracked. Tracked: {0}".format(tracked))

    def test_initialize_should_not_enable_cgroups_when_the_agent_is_not_in_the_system_slice(
            self):
        command_mocks = [
            MockCommand(
                r"^mount -t cgroup$",
                '''cgroup on /sys/fs/cgroup/systemd type cgroup (rw,nosuid,nodev,noexec,relatime,xattr,name=systemd)
cgroup on /sys/fs/cgroup/rdma type cgroup (rw,nosuid,nodev,noexec,relatime,rdma)
cgroup on /sys/fs/cgroup/cpuset type cgroup (rw,nosuid,nodev,noexec,relatime,cpuset)
cgroup on /sys/fs/cgroup/net_cls,net_prio type cgroup (rw,nosuid,nodev,noexec,relatime,net_cls,net_prio)
cgroup on /sys/fs/cgroup/perf_event type cgroup (rw,nosuid,nodev,noexec,relatime,perf_event)
cgroup on /sys/fs/cgroup/hugetlb type cgroup (rw,nosuid,nodev,noexec,relatime,hugetlb)
cgroup on /sys/fs/cgroup/freezer type cgroup (rw,nosuid,nodev,noexec,relatime,freezer)
cgroup on /sys/fs/cgroup/pids type cgroup (rw,nosuid,nodev,noexec,relatime,pids)
cgroup on /sys/fs/cgroup/devices type cgroup (rw,nosuid,nodev,noexec,relatime,devices)
cgroup on /sys/fs/cgroup/blkio type cgroup (rw,nosuid,nodev,noexec,relatime,blkio)
''')
        ]

        with self._get_cgroup_configurator(
                mock_commands=command_mocks) as configurator:
            tracked = CGroupsTelemetry._tracked
            agent_drop_in_file_cpu_quota = configurator.mocks.get_mapped_path(
                UnitFilePaths.cpu_quota)

            self.assertFalse(configurator.enabled(),
                             "Cgroups should not be enabled")
            self.assertEqual(
                len(tracked), 0,
                "No cgroups should be tracked. Tracked: {0}".format(tracked))
            self.assertFalse(
                os.path.exists(agent_drop_in_file_cpu_quota),
                "{0} should not have been created".format(
                    agent_drop_in_file_cpu_quota))

    def test_initialize_should_not_create_unit_files(self):
        with self._get_cgroup_configurator() as configurator:
            # get the paths to the mocked files
            azure_slice_unit_file = configurator.mocks.get_mapped_path(
                UnitFilePaths.azure)
            extensions_slice_unit_file = configurator.mocks.get_mapped_path(
                UnitFilePaths.vmextensions)
            agent_drop_in_file_slice = configurator.mocks.get_mapped_path(
                UnitFilePaths.slice)
            agent_drop_in_file_cpu_accounting = configurator.mocks.get_mapped_path(
                UnitFilePaths.cpu_accounting)

            # The mock creates the slice unit files; delete them
            os.remove(azure_slice_unit_file)
            os.remove(extensions_slice_unit_file)

            # The service file for the agent includes settings for the slice and cpu accounting, but not for cpu quota; initialize()
            # should not create drop in files for the first 2, but it should create one the cpu quota
            self.assertFalse(
                os.path.exists(azure_slice_unit_file),
                "{0} should not have been created".format(
                    azure_slice_unit_file))
            self.assertFalse(
                os.path.exists(extensions_slice_unit_file),
                "{0} should not have been created".format(
                    extensions_slice_unit_file))
            self.assertFalse(
                os.path.exists(agent_drop_in_file_slice),
                "{0} should not have been created".format(
                    agent_drop_in_file_slice))
            self.assertFalse(
                os.path.exists(agent_drop_in_file_cpu_accounting),
                "{0} should not have been created".format(
                    agent_drop_in_file_cpu_accounting))

    def test_initialize_should_create_unit_files_when_the_agent_service_file_is_not_updated(
            self):
        with self._get_cgroup_configurator(initialize=False) as configurator:
            # get the paths to the mocked files
            azure_slice_unit_file = configurator.mocks.get_mapped_path(
                UnitFilePaths.azure)
            extensions_slice_unit_file = configurator.mocks.get_mapped_path(
                UnitFilePaths.vmextensions)
            agent_drop_in_file_slice = configurator.mocks.get_mapped_path(
                UnitFilePaths.slice)
            agent_drop_in_file_cpu_accounting = configurator.mocks.get_mapped_path(
                UnitFilePaths.cpu_accounting)

            # The mock creates the service and slice unit files; replace the former and delete the latter
            configurator.mocks.add_data_file(
                os.path.join(data_dir, 'init',
                             "walinuxagent.service.previous"),
                UnitFilePaths.walinuxagent)
            os.remove(azure_slice_unit_file)
            os.remove(extensions_slice_unit_file)

            configurator.initialize()

            # The older service file for the agent did not include settings for the slice and cpu parameters; in that case, initialize() should
            # create drop in files to set those properties
            self.assertTrue(
                os.path.exists(azure_slice_unit_file),
                "{0} was not created".format(azure_slice_unit_file))
            self.assertTrue(
                os.path.exists(extensions_slice_unit_file),
                "{0} was not created".format(extensions_slice_unit_file))
            self.assertTrue(
                os.path.exists(agent_drop_in_file_slice),
                "{0} was not created".format(agent_drop_in_file_slice))
            self.assertTrue(
                os.path.exists(agent_drop_in_file_cpu_accounting),
                "{0} was not created".format(
                    agent_drop_in_file_cpu_accounting))

    def test_enable_should_raise_cgroups_exception_when_cgroups_are_not_supported(
            self):
        with self._get_cgroup_configurator(enable=False) as configurator:
            with patch.object(configurator, "supported", return_value=False):
                with self.assertRaises(CGroupsException) as context_manager:
                    configurator.enable()
                self.assertIn(
                    "Attempted to enable cgroups, but they are not supported on the current platform",
                    str(context_manager.exception))

    def test_enable_should_set_agent_cpu_quota_and_track_throttled_time(self):
        with self._get_cgroup_configurator(enable=False) as configurator:
            agent_drop_in_file_cpu_quota = configurator.mocks.get_mapped_path(
                UnitFilePaths.cpu_quota)
            if os.path.exists(agent_drop_in_file_cpu_quota):
                raise Exception(
                    "{0} should not have been created during test setup".
                    format(agent_drop_in_file_cpu_quota))

            configurator.enable()

            expected_quota = "CPUQuota={0}%".format(_AGENT_CPU_QUOTA)
            self.assertTrue(
                os.path.exists(agent_drop_in_file_cpu_quota),
                "{0} was not created".format(agent_drop_in_file_cpu_quota))
            self.assertTrue(
                fileutil.findre_in_file(agent_drop_in_file_cpu_quota,
                                        expected_quota),
                "CPUQuota was not set correctly. Expected: {0}. Got:\n{1}".
                format(expected_quota,
                       fileutil.read_file(agent_drop_in_file_cpu_quota)))
            self.assertTrue(CGroupsTelemetry.get_track_throttled_time(),
                            "Throttle time should be tracked")

    def test_enable_should_not_track_throttled_time_when_setting_the_cpu_quota_fails(
            self):
        with self._get_cgroup_configurator(enable=False) as configurator:
            if CGroupsTelemetry.get_track_throttled_time():
                raise Exception(
                    "Test setup should not start tracking Throttle Time")

            configurator.mocks.add_file(UnitFilePaths.cpu_quota,
                                        Exception("A TEST EXCEPTION"))

            configurator.enable()

            self.assertFalse(CGroupsTelemetry.get_track_throttled_time(),
                             "Throttle time should not be tracked")

    def test_disable_should_reset_cpu_quota_and_tracked_cgroups(self):
        with self._get_cgroup_configurator() as configurator:
            if len(CGroupsTelemetry._tracked) == 0:
                raise Exception(
                    "Test setup should have started tracking at least 1 cgroup (the agent's)"
                )
            if not CGroupsTelemetry._track_throttled_time:
                raise Exception(
                    "Test setup should have started tracking Throttle Time")

            configurator.disable("UNIT TEST")

            agent_drop_in_file_cpu_quota = configurator.mocks.get_mapped_path(
                UnitFilePaths.cpu_quota)
            self.assertTrue(
                os.path.exists(agent_drop_in_file_cpu_quota),
                "{0} was not created".format(agent_drop_in_file_cpu_quota))
            self.assertTrue(
                fileutil.findre_in_file(agent_drop_in_file_cpu_quota,
                                        "^CPUQuota=$"),
                "CPUQuota was not set correctly. Expected an empty value. Got:\n{0}"
                .format(fileutil.read_file(agent_drop_in_file_cpu_quota)))
            self.assertEqual(
                len(CGroupsTelemetry._tracked), 0,
                "No cgroups should be tracked after disable. Tracking: {0}".
                format(CGroupsTelemetry._tracked))
            self.assertFalse(
                CGroupsTelemetry._track_throttled_time,
                "Throttle Time should not be tracked after disable")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_not_use_systemd_when_cgroups_are_not_enabled(
            self, _):
        with self._get_cgroup_configurator() as configurator:
            configurator.disable("UNIT TEST")

            with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                       wraps=subprocess.Popen) as patcher:
                configurator.start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="date",
                    timeout=300,
                    shell=False,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE)

                command_calls = [
                    args[0] for args, _ in patcher.call_args_list
                    if len(args) > 0 and "date" in args[0]
                ]
                self.assertEqual(
                    len(command_calls), 1,
                    "The test command should have been called exactly once [{0}]"
                    .format(command_calls))
                self.assertNotIn(
                    "systemd-run", command_calls[0],
                    "The command should not have been invoked using systemd")
                self.assertEqual(
                    command_calls[0], "date",
                    "The command line should not have been modified")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_use_systemd_run_when_cgroups_are_enabled(
            self, _):
        with self._get_cgroup_configurator() as configurator:
            with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                       wraps=subprocess.Popen) as popen_patch:
                configurator.start_extension_command(
                    extension_name="Microsoft.Compute.TestExtension-1.2.3",
                    command="the-test-extension-command",
                    timeout=300,
                    shell=False,
                    cwd=self.tmp_dir,
                    env={},
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE)

                command_calls = [
                    args[0] for (args, _) in popen_patch.call_args_list
                    if "the-test-extension-command" in args[0]
                ]

                self.assertEqual(
                    len(command_calls), 1,
                    "The test command should have been called exactly once [{0}]"
                    .format(command_calls))
                self.assertIn(
                    "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                    command_calls[0],
                    "The extension should have been invoked using systemd")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_start_tracking_the_extension_cgroups(
            self, _):
        # CPU usage is initialized when we begin tracking a CPU cgroup; since this test does not retrieve the
        # CPU usage, there is no need for initialization
        with self._get_cgroup_configurator() as configurator:
            configurator.start_extension_command(
                extension_name="Microsoft.Compute.TestExtension-1.2.3",
                command="test command",
                timeout=300,
                shell=False,
                cwd=self.tmp_dir,
                env={},
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)

        tracked = CGroupsTelemetry._tracked

        self.assertTrue(
            any(cg for cg in tracked
                if cg.name == 'Microsoft.Compute.TestExtension-1.2.3'
                and 'cpu' in cg.path),
            "The extension's CPU is not being tracked")

    def test_start_extension_command_should_raise_an_exception_when_the_command_cannot_be_started(
            self):
        with self._get_cgroup_configurator() as configurator:
            original_popen = subprocess.Popen

            def mock_popen(command_arg, *args, **kwargs):
                if "test command" in command_arg:
                    raise Exception("A TEST EXCEPTION")
                return original_popen(command_arg, *args, **kwargs)

            with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                       side_effect=mock_popen):
                with self.assertRaises(Exception) as context_manager:
                    configurator.start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="test command",
                        timeout=300,
                        shell=False,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE)

                    self.assertIn("A TEST EXCEPTION",
                                  str(context_manager.exception))

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_disable_cgroups_and_invoke_the_command_directly_if_systemd_fails(
            self, _):
        with self._get_cgroup_configurator() as configurator:
            original_popen = subprocess.Popen

            def mock_popen(command, *args, **kwargs):
                if 'systemd-run' in command:
                    # Inject a syntax error to the call
                    command = command.replace('systemd-run',
                                              'systemd-run syntax_error')
                return original_popen(command, *args, **kwargs)

            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as output_file:
                with patch(
                        "azurelinuxagent.common.cgroupconfigurator.add_event"
                ) as mock_add_event:
                    with patch(
                            "azurelinuxagent.common.cgroupapi.subprocess.Popen",
                            side_effect=mock_popen) as popen_patch:
                        CGroupsTelemetry.reset()

                        command = "echo TEST_OUTPUT"

                        command_output = configurator.start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command=command,
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=output_file,
                            stderr=output_file)

                        self.assertFalse(configurator.enabled(),
                                         "Cgroups should have been disabled")

                        disabled_events = [
                            kwargs
                            for _, kwargs in mock_add_event.call_args_list if
                            kwargs['op'] == WALAEventOperation.CGroupsDisabled
                        ]

                        self.assertTrue(
                            len(disabled_events) == 1,
                            "Exactly one CGroupsDisabled telemetry event should have been issued. Found: {0}"
                            .format(disabled_events))
                        self.assertIn(
                            "Failed to start Microsoft.Compute.TestExtension-1.2.3 using systemd-run",
                            disabled_events[0]['message'],
                            "The systemd-run failure was not included in the telemetry message"
                        )
                        self.assertEqual(
                            False, disabled_events[0]['is_success'],
                            "The telemetry event should indicate a failure")

                        extension_calls = [
                            args[0] for (args, _) in popen_patch.call_args_list
                            if command in args[0]
                        ]

                        self.assertEqual(
                            2, len(extension_calls),
                            "The extension should have been invoked exactly twice"
                        )
                        self.assertIn(
                            "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                            extension_calls[0],
                            "The first call to the extension should have used systemd"
                        )
                        self.assertEqual(
                            command, extension_calls[1],
                            "The second call to the extension should not have used systemd"
                        )

                        self.assertEqual(
                            len(CGroupsTelemetry._tracked), 0,
                            "No cgroups should have been created")

                        self.assertIn("TEST_OUTPUT\n", command_output,
                                      "The test output was not captured")

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_disable_cgroups_and_invoke_the_command_directly_if_systemd_times_out(
            self, _):
        with self._get_cgroup_configurator() as configurator:
            # Systemd has its own internal timeout which is shorter than what we define for extension operation timeout.
            # When systemd times out, it will write a message to stderr and exit with exit code 1.
            # In that case, we will internally recognize the failure due to the non-zero exit code, not as a timeout.
            configurator.mocks.add_command(
                MockCommand(
                    "systemd-run",
                    return_value=1,
                    stdout='',
                    stderr=
                    'Failed to start transient scope unit: Connection timed out'
                ))

            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stdout:
                with tempfile.TemporaryFile(dir=self.tmp_dir,
                                            mode="w+b") as stderr:
                    with patch("subprocess.Popen",
                               wraps=subprocess.Popen) as popen_patch:
                        CGroupsTelemetry.reset()

                        configurator.start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command="echo 'success'",
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                        self.assertFalse(configurator.enabled(),
                                         "Cgroups should have been disabled")

                        extension_calls = [
                            args[0] for (args, _) in popen_patch.call_args_list
                            if "echo 'success'" in args[0]
                        ]
                        self.assertEqual(
                            2, len(extension_calls),
                            "The extension should have been called twice. Got: {0}"
                            .format(extension_calls))
                        self.assertIn(
                            "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                            extension_calls[0],
                            "The first call to the extension should have used systemd"
                        )
                        self.assertNotIn(
                            "systemd-run", extension_calls[1],
                            "The second call to the extension should not have used systemd"
                        )

                        self.assertEqual(
                            len(CGroupsTelemetry._tracked), 0,
                            "No cgroups should have been created")

    @attr('requires_sudo')
    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_not_use_fallback_option_if_extension_fails(
            self, *args):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with self._get_cgroup_configurator() as configurator:
            pass  # release the mocks used to create the test CGroupConfigurator so that they do not conflict the mock Popen below

        command = "ls folder_does_not_exist"

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           wraps=subprocess.Popen) as popen_patch:
                    with self.assertRaises(ExtensionError) as context_manager:
                        configurator.start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command=command,
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                    extension_calls = [
                        args[0] for (args, _) in popen_patch.call_args_list
                        if command in args[0]
                    ]

                    self.assertEqual(
                        1, len(extension_calls),
                        "The extension should have been invoked exactly once")
                    self.assertIn(
                        "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                        extension_calls[0],
                        "The first call to the extension should have used systemd"
                    )

                    self.assertEqual(context_manager.exception.code,
                                     ExtensionErrorCodes.PluginUnknownFailure)
                    self.assertIn("Non-zero exit code",
                                  ustr(context_manager.exception))
                    # The scope name should appear in the process output since systemd-run was invoked and stderr
                    # wasn't truncated.
                    self.assertIn("Microsoft.Compute.TestExtension_1.2.3",
                                  ustr(context_manager.exception))

    @attr('requires_sudo')
    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    @patch(
        "azurelinuxagent.common.utils.extensionprocessutil.TELEMETRY_MESSAGE_MAX_LEN",
        5)
    def test_start_extension_command_should_not_use_fallback_option_if_extension_fails_with_long_output(
            self, *args):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with self._get_cgroup_configurator() as configurator:
            pass  # release the mocks used to create the test CGroupConfigurator so that they do not conflict the mock Popen below

        long_output = "a" * 20  # large enough to ensure both stdout and stderr are truncated
        long_stdout_stderr_command = "echo {0} && echo {0} >&2 && ls folder_does_not_exist".format(
            long_output)

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           wraps=subprocess.Popen) as popen_patch:
                    with self.assertRaises(ExtensionError) as context_manager:
                        configurator.start_extension_command(
                            extension_name=
                            "Microsoft.Compute.TestExtension-1.2.3",
                            command=long_stdout_stderr_command,
                            timeout=300,
                            shell=True,
                            cwd=self.tmp_dir,
                            env={},
                            stdout=stdout,
                            stderr=stderr)

                    extension_calls = [
                        args[0] for (args, _) in popen_patch.call_args_list
                        if long_stdout_stderr_command in args[0]
                    ]

                    self.assertEqual(
                        1, len(extension_calls),
                        "The extension should have been invoked exactly once")
                    self.assertIn(
                        "systemd-run --unit=Microsoft.Compute.TestExtension_1.2.3",
                        extension_calls[0],
                        "The first call to the extension should have used systemd"
                    )

                    self.assertEqual(context_manager.exception.code,
                                     ExtensionErrorCodes.PluginUnknownFailure)
                    self.assertIn("Non-zero exit code",
                                  ustr(context_manager.exception))
                    # stdout and stderr should have been truncated, so the scope name doesn't appear in stderr
                    # even though systemd-run ran
                    self.assertNotIn("Microsoft.Compute.TestExtension_1.2.3",
                                     ustr(context_manager.exception))

    @attr('requires_sudo')
    def test_start_extension_command_should_not_use_fallback_option_if_extension_times_out(
            self, *args):  # pylint: disable=unused-argument
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        with self._get_cgroup_configurator() as configurator:
            pass  # release the mocks used to create the test CGroupConfigurator so that they do not conflict the mock Popen below

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch(
                        "azurelinuxagent.common.utils.extensionprocessutil.wait_for_process_completion_or_timeout",
                        return_value=[True, None]):
                    with patch(
                            "azurelinuxagent.common.cgroupapi.SystemdCgroupsApi._is_systemd_failure",
                            return_value=False):
                        with self.assertRaises(
                                ExtensionError) as context_manager:
                            configurator.start_extension_command(
                                extension_name=
                                "Microsoft.Compute.TestExtension-1.2.3",
                                command="date",
                                timeout=300,
                                shell=True,
                                cwd=self.tmp_dir,
                                env={},
                                stdout=stdout,
                                stderr=stderr)

                        self.assertEqual(
                            context_manager.exception.code,
                            ExtensionErrorCodes.PluginHandlerScriptTimedout)
                        self.assertIn("Timeout",
                                      ustr(context_manager.exception))

    @patch('time.sleep', side_effect=lambda _: mock_sleep())
    def test_start_extension_command_should_capture_only_the_last_subprocess_output(
            self, _):
        with self._get_cgroup_configurator() as configurator:
            pass  # release the mocks used to create the test CGroupConfigurator so that they do not conflict the mock Popen below

        original_popen = subprocess.Popen

        def mock_popen(command, *args, **kwargs):
            # Inject a syntax error to the call
            systemd_command = command.replace('systemd-run',
                                              'systemd-run syntax_error')
            return original_popen(systemd_command, *args, **kwargs)

        expected_output = "[stdout]\n{0}\n\n\n[stderr]\n"

        with tempfile.TemporaryFile(dir=self.tmp_dir, mode="w+b") as stdout:
            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                        mode="w+b") as stderr:
                with patch("azurelinuxagent.common.cgroupapi.subprocess.Popen",
                           side_effect=mock_popen):
                    # We expect this call to fail because of the syntax error
                    process_output = configurator.start_extension_command(
                        extension_name="Microsoft.Compute.TestExtension-1.2.3",
                        command="echo 'very specific test message'",
                        timeout=300,
                        shell=True,
                        cwd=self.tmp_dir,
                        env={},
                        stdout=stdout,
                        stderr=stderr)

                    self.assertEqual(
                        expected_output.format("very specific test message"),
                        process_output)

    def test_check_processes_in_agent_cgroup_should_raise_a_cgroups_exception_when_there_are_unexpected_processes_in_the_agent_cgroup(
            self):
        with self._get_cgroup_configurator() as configurator:
            pass  # release the mocks used to create the test CGroupConfigurator so that they do not conflict the mock Popen below

        # The test script recursively creates a given number of descendant processes, then it blocks until the
        # 'stop_file' exists. It produces an output file containing the PID of each descendant process.
        test_script = os.path.join(self.tmp_dir, "create_processes.sh")
        stop_file = os.path.join(self.tmp_dir, "create_processes.stop")
        AgentTestCase.create_script(
            test_script, """
#!/usr/bin/env bash
set -euo pipefail

if [[ $# != 2 ]]; then
    echo "Usage: $0 <output_file> <count>"
    exit 1
fi

echo $$ >> $1

if [[ $2 > 1 ]]; then
    $0 $1 $(($2 - 1))
else
    timeout 30s /usr/bin/env bash -c "while ! [[ -f {0} ]]; do sleep 0.25s; done"
fi

exit 0
""".format(stop_file))

        number_of_descendants = 3

        def wait_for_processes(processes_file):
            def _all_present():
                if os.path.exists(processes_file):
                    with open(processes_file, "r") as file_stream:
                        _all_present.processes = [
                            int(process)
                            for process in file_stream.read().split()
                        ]
                return len(_all_present.processes) >= number_of_descendants

            _all_present.processes = []

            if not wait_for(_all_present):
                raise Exception(
                    "Timeout waiting for processes. Expected {0}; got: {1}".
                    format(number_of_descendants,
                           format_processes(_all_present.processes)))

            return _all_present.processes

        threads = []

        try:
            #
            # Start the processes that will be used by the test. We use two sets of processes: the first set simulates a command executed by the agent
            # (e.g. iptables) and its child processes, if any. The second set of processes simulates an extension.
            #
            agent_command_output = os.path.join(self.tmp_dir,
                                                "agent_command.pids")
            agent_command = threading.Thread(
                target=lambda: shellutil.run_command([
                    test_script, agent_command_output,
                    str(number_of_descendants)
                ]))
            agent_command.start()
            threads.append(agent_command)
            agent_command_processes = wait_for_processes(agent_command_output)

            extension_output = os.path.join(self.tmp_dir, "extension.pids")

            def start_extension():
                original_sleep = time.sleep
                original_popen = subprocess.Popen

                # Extensions are started using systemd-run; mock Popen to remove the call to systemd-run; the test script creates a couple of
                # child processes, which would simulate the extension's processes.
                def mock_popen(command, *args, **kwargs):
                    match = re.match(
                        r"^systemd-run --unit=[^\s]+ --scope --slice=[^\s]+ (.+)",
                        command)
                    is_systemd_run = match is not None
                    if is_systemd_run:
                        command = match.group(1)
                    process = original_popen(command, *args, **kwargs)
                    if is_systemd_run:
                        start_extension.systemd_run_pid = process.pid
                    return process

                with patch(
                        'time.sleep', side_effect=lambda _: original_sleep(0.1)
                ):  # start_extension_command has a small delay; skip it
                    with patch(
                            "azurelinuxagent.common.cgroupapi.subprocess.Popen",
                            side_effect=mock_popen):
                        with tempfile.TemporaryFile(dir=self.tmp_dir,
                                                    mode="w+b") as stdout:
                            with tempfile.TemporaryFile(dir=self.tmp_dir,
                                                        mode="w+b") as stderr:
                                configurator.start_extension_command(
                                    extension_name="TestExtension",
                                    command="{0} {1} {2}".format(
                                        test_script, extension_output,
                                        number_of_descendants),
                                    timeout=30,
                                    shell=True,
                                    cwd=self.tmp_dir,
                                    env={},
                                    stdout=stdout,
                                    stderr=stderr)

            start_extension.systemd_run_pid = None

            extension = threading.Thread(target=start_extension)
            extension.start()
            threads.append(extension)
            extension_processes = wait_for_processes(extension_output)

            #
            # check_processes_in_agent_cgroup uses shellutil and the cgroups api to get the commands that are currently running;
            # wait for all the processes to show up
            #
            if not wait_for(lambda: len(shellutil.get_running_commands(
            )) > 0 and len(configurator._cgroups_api.get_systemd_run_commands(
            )) > 0):
                raise Exception(
                    "Timeout while attempting to track the child commands")

            #
            # Verify that check_processes_in_agent_cgroup raises when there are unexpected processes in the agent's cgroup.
            #
            # For the agent's processes, we use the current process and its parent (in the actual agent these would be the daemon and the extension
            # handler), and the commands started by the agent.
            #
            # For other processes, we use process 1, a process that already completed, and an extension. Note that extensions are started using
            # systemd-run and the process for that commands belongs to the agent's cgroup but the processes for the extension should be in a
            # different cgroup
            #
            def get_completed_process():
                random.seed()
                completed = random.randint(1000, 10000)
                while os.path.exists(
                        "/proc/{0}".format(completed)
                ):  # ensure we do not use an existing process
                    completed = random.randint(1000, 10000)
                return completed

            agent_processes = [
                os.getppid(), os.getpid()
            ] + agent_command_processes + [start_extension.systemd_run_pid]
            other_processes = [1, get_completed_process()
                               ] + extension_processes

            with patch(
                    "azurelinuxagent.common.cgroupconfigurator.CGroupsApi.get_processes_in_cgroup",
                    return_value=agent_processes + other_processes):
                with self.assertRaises(CGroupsException) as context_manager:
                    configurator._check_processes_in_agent_cgroup()

                # The list of processes in the message is an array of strings: "['foo', ..., 'bar']"
                message = ustr(context_manager.exception)
                search = re.search(
                    r'unexpected processes: \[(?P<processes>.+)\]', message)
                self.assertIsNotNone(
                    search,
                    "The event message is not in the expected format: {0}".
                    format(message))
                reported = search.group('processes').split(',')

                self.assertEqual(
                    len(other_processes), len(reported),
                    "An incorrect number of processes was reported. Expected: {0} Got: {1}"
                    .format(format_processes(other_processes), reported))
                for pid in other_processes:
                    self.assertTrue(
                        any("[PID: {0}]".format(pid) in reported_process
                            for reported_process in reported),
                        "Process {0} was not reported. Got: {1}".format(
                            format_processes([pid]), reported))
        finally:
            # create the file that stops the test processes and wait for them to complete
            open(stop_file, "w").close()
            for thread in threads:
                thread.join(timeout=5)

    def test_check_agent_throttled_time_should_raise_a_cgroups_exception_when_the_threshold_is_exceeded(
            self):
        metrics = [
            MetricValue(MetricsCategory.CPU_CATEGORY,
                        MetricsCounter.THROTTLED_TIME, AGENT_NAME_TELEMETRY,
                        _AGENT_THROTTLED_TIME_THRESHOLD + 1)
        ]

        with self.assertRaises(CGroupsException) as context_manager:
            CGroupConfigurator._Impl._check_agent_throttled_time(metrics)

        self.assertIn("The agent has been throttled",
                      ustr(context_manager.exception),
                      "An incorrect exception was raised")

    def test_check_cgroups_should_disable_cgroups_when_a_check_fails(self):
        with self._get_cgroup_configurator() as configurator:
            checks = [
                "_check_processes_in_agent_cgroup",
                "_check_agent_throttled_time"
            ]
            for method_to_fail in checks:
                patchers = []
                try:
                    # mock 'method_to_fail' to raise an exception and the rest to do nothing
                    for method_to_mock in checks:
                        side_effect = CGroupsException(
                            method_to_fail
                        ) if method_to_mock == method_to_fail else lambda *_: None
                        p = patch.object(configurator,
                                         method_to_mock,
                                         side_effect=side_effect)
                        patchers.append(p)
                        p.start()

                    with patch(
                            "azurelinuxagent.common.cgroupconfigurator.add_event"
                    ) as add_event:
                        configurator.enable()

                        configurator.check_cgroups([])

                        self.assertFalse(
                            configurator.enabled(),
                            "An error in {0} should have disabled cgroups".
                            format(method_to_fail))

                        disable_events = [
                            kwargs for _, kwargs in add_event.call_args_list if
                            kwargs["op"] == WALAEventOperation.CGroupsDisabled
                        ]
                        self.assertTrue(
                            len(disable_events) == 1,
                            "Exactly 1 event should have been emitted when {0} fails. Got: {1}"
                            .format(method_to_fail, disable_events))
                        self.assertIn(
                            "[CGroupsException] {0}".format(method_to_fail),
                            disable_events[0]["message"],
                            "The error message is not correct when {0} failed".
                            format(method_to_fail))
                finally:
                    for p in patchers:
                        p.stop()