def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
        ext_handler.properties = ext_handler_properties

        protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

        self.pkg = ExtHandlerPackage()
        self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ]
        self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol)
        self.ext_handler_instance.pkg = self.pkg

        self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir)
        self.mock_get_base_dir.start()

        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir)
        self.mock_get_log_dir.start()

        self.agent_dir = self.tmp_dir
        self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir)
        self.mock_get_lib_dir.start()
Exemple #2
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        self.base_cgroups = os.path.join(self.tmp_dir, "cgroup")
        os.mkdir(self.base_cgroups)
        os.mkdir(os.path.join(self.base_cgroups, "cpu"))
        os.mkdir(os.path.join(self.base_cgroups, "memory"))

        self.mock__base_cgroups = patch(
            "azurelinuxagent.common.cgroups.BASE_CGROUPS", self.base_cgroups)
        self.mock__base_cgroups.start()

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()

        self.cgroups_enabled = CGroupConfigurator.get_instance().enabled()
        CGroupConfigurator.get_instance().disable()
Exemple #4
0
    def test_it_should_contain_all_helper_environment_variables(self):

        wire_ip = str(uuid.uuid4())
        ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol(wire_ip))

        helper_env_vars = {ExtCommandEnvVariable.ExtensionSeqNumber: _DEFAULT_SEQ_NO,
                           ExtCommandEnvVariable.ExtensionPath: self.tmp_dir,
                           ExtCommandEnvVariable.ExtensionVersion: ext_handler_instance.ext_handler.properties.version,
                           ExtCommandEnvVariable.WireProtocolAddress: wire_ip}

        command = """
            printenv | grep -E '(%s)'
        """ % '|'.join(helper_env_vars.keys())

        test_file = 'printHelperEnvironments.sh'
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), test_file), command)

        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
            # Returning empty list for get_agent_supported_features_list_for_extensions as we have a separate test for it
            with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
                       return_value={}):
                output = ext_handler_instance.launch_command(test_file)

            args, kwagrs = patch_popen.call_args  # pylint: disable=unused-variable
            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items() if k not in os.environ)

            # This check will fail if any helper environment variables are added/removed later on
            self.assertEqual(helper_env_vars, without_os_env)

            # This check is checking if the expected values are set for the extension commands
            for helper_var in helper_env_vars:
                self.assertIn("%s=%s" % (helper_var, helper_env_vars[helper_var]), output)
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)

        self.extension_directory = os.path.join(self.tmp_dir, "extension_directory")
        self.mock_get_base_dir = patch.object(self.ext_handler_instance, "get_base_dir", return_value=self.extension_directory)
        self.mock_get_base_dir.start()
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
        ext_handler.properties = ext_handler_properties

        protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

        self.pkg = ExtHandlerPackage()
        self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ]
        self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol)
        self.ext_handler_instance.pkg = self.pkg

        self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir)
        self.mock_get_base_dir.start()

        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir)
        self.mock_get_log_dir.start()

        self.agent_dir = self.tmp_dir
        self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir)
        self.mock_get_lib_dir.start()
Exemple #7
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4"))

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()
Exemple #8
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)

        pkg_uri = ExtHandlerVersionUri()
        pkg_uri.uri = "http://bar/foo__1.2.3"
        self.ext_handler_instance.pkg = ExtHandlerPackage(ext_handler_properties.version)
        self.ext_handler_instance.pkg.uris.append(pkg_uri)

        self.base_dir = self.tmp_dir
        self.extension_directory = os.path.join(self.tmp_dir, "extension_directory")
        self.mock_get_base_dir = patch.object(self.ext_handler_instance, "get_base_dir", return_value=self.extension_directory)
        self.mock_get_base_dir.start()
Exemple #9
0
    def assert_extension_sequence_number(
            self,  # pylint: disable=too-many-arguments
            patch_get_largest_seq,
            patch_add_event,
            goal_state_sequence_number,
            disk_sequence_number,
            expected_sequence_number):
        ext = Extension()
        ext.sequenceNumber = goal_state_sequence_number
        patch_get_largest_seq.return_value = disk_sequence_number

        ext_handler_props = ExtHandlerProperties()
        ext_handler_props.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_props

        instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)
        seq, path = instance.get_status_file_path(ext)

        try:
            gs_seq_int = int(goal_state_sequence_number)
            gs_int = True
        except ValueError:
            gs_int = False

        if gs_int and gs_seq_int != disk_sequence_number:
            self.assertEqual(1, patch_add_event.call_count)
            args, kw_args = patch_add_event.call_args  # pylint: disable=unused-variable
            self.assertEqual('SequenceNumberMismatch', kw_args['op'])
            self.assertEqual(False, kw_args['is_success'])
            self.assertEqual(
                'Goal state: {0}, disk: {1}'.format(gs_seq_int,
                                                    disk_sequence_number),
                kw_args['message'])
        else:
            self.assertEqual(0, patch_add_event.call_count)

        self.assertEqual(expected_sequence_number, seq)
        if seq > -1:
            self.assertTrue(
                path.endswith('/foo-1.2.3/status/{0}.status'.format(
                    expected_sequence_number)))
        else:
            self.assertIsNone(path)
Exemple #10
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler,
                                                       protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()
Exemple #11
0
    def assert_extension_sequence_number(self, patch_get_largest_seq=None,
                                         goal_state_sequence_number=None,
                                         disk_sequence_number=None,
                                         expected_sequence_number=None):
        ext = Extension()
        ext.sequenceNumber = goal_state_sequence_number
        patch_get_largest_seq.return_value = disk_sequence_number

        ext_handler_props = ExtHandlerProperties()
        ext_handler_props.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_props

        instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)
        seq, path = instance.get_status_file_path(ext)

        self.assertEqual(expected_sequence_number, seq)
        if seq > -1:
            self.assertTrue(path.endswith('/foo-1.2.3/status/{0}.status'.format(expected_sequence_number)))
        else:
            self.assertIsNone(path)
    def assert_extension_sequence_number(self,
                                         patch_get_largest_seq,
                                         patch_add_event,
                                         goal_state_sequence_number,
                                         disk_sequence_number,
                                         expected_sequence_number):
        ext = Extension()
        ext.sequenceNumber = goal_state_sequence_number
        patch_get_largest_seq.return_value = disk_sequence_number

        ext_handler_props = ExtHandlerProperties()
        ext_handler_props.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_props

        instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)
        seq, path = instance.get_status_file_path(ext)

        try:
            gs_seq_int = int(goal_state_sequence_number)
            gs_int = True
        except ValueError:
            gs_int = False

        if gs_int and gs_seq_int != disk_sequence_number:
            self.assertEqual(1, patch_add_event.call_count)
            args, kw_args = patch_add_event.call_args
            self.assertEqual('SequenceNumberMismatch', kw_args['op'])
            self.assertEqual(False, kw_args['is_success'])
            self.assertEqual('Goal state: {0}, disk: {1}'
                             .format(gs_seq_int, disk_sequence_number),
                             kw_args['message'])
        else:
            self.assertEqual(0, patch_add_event.call_count)

        self.assertEqual(expected_sequence_number, seq)
        if seq > -1:
            self.assertTrue(path.endswith('/foo-1.2.3/status/{0}.status'.format(expected_sequence_number)))
        else:
            self.assertIsNone(path)
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=None)

        self.base_cgroups = os.path.join(self.tmp_dir, "cgroup")
        os.mkdir(self.base_cgroups)
        os.mkdir(os.path.join(self.base_cgroups, "cpu"))
        os.mkdir(os.path.join(self.base_cgroups, "memory"))

        self.mock__base_cgroups = patch("azurelinuxagent.common.cgroups.BASE_CGROUPS", self.base_cgroups)
        self.mock__base_cgroups.start()

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()
Exemple #15
0
    def test_command_extension_log_truncates_correctly(self, mock_log_dir):
        log_dir_path = os.path.join(self.tmp_dir, "log_directory")
        mock_log_dir.return_value = log_dir_path

        ext_handler_props = ExtHandlerProperties()
        ext_handler_props.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_props

        first_line = "This is the first line!"
        second_line = "This is the second line."
        old_logfile_contents = "{first_line}\n{second_line}\n".format(first_line=first_line, second_line=second_line)

        log_file_path = os.path.join(log_dir_path, "foo", "CommandExecution.log")

        fileutil.mkdir(os.path.join(log_dir_path, "foo"), mode=0o755)
        with open(log_file_path, "a") as log_file:
            log_file.write(old_logfile_contents)

        _ = ExtHandlerInstance(ext_handler=ext_handler, protocol=None,
            execution_log_max_size=(len(first_line)+len(second_line)//2))

        with open(log_file_path) as truncated_log_file:
            self.assertEqual(truncated_log_file.read(), "{second_line}\n".format(second_line=second_line))
class ExtHandlerInstanceTestCase(AgentTestCase):
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler,
                                                       protocol=None)

        pkg_uri = ExtHandlerVersionUri()
        pkg_uri.uri = "http://bar/foo__1.2.3"
        self.ext_handler_instance.pkg = ExtHandlerPackage(
            ext_handler_properties.version)
        self.ext_handler_instance.pkg.uris.append(pkg_uri)

        self.base_dir = self.tmp_dir
        self.extension_directory = os.path.join(self.tmp_dir,
                                                "extension_directory")
        self.mock_get_base_dir = patch.object(
            self.ext_handler_instance,
            "get_base_dir",
            return_value=self.extension_directory)
        self.mock_get_base_dir.start()

    def tearDown(self):
        self.mock_get_base_dir.stop()

    def test_rm_ext_handler_dir_should_remove_the_extension_packages(self):
        os.mkdir(self.extension_directory)
        open(os.path.join(self.extension_directory, "extension_file1"),
             'w').close()
        open(os.path.join(self.extension_directory, "extension_file2"),
             'w').close()
        open(os.path.join(self.extension_directory, "extension_file3"),
             'w').close()
        open(os.path.join(self.base_dir, "foo__1.2.3.zip"), 'w').close()

        self.ext_handler_instance.remove_ext_handler()

        self.assertFalse(os.path.exists(self.extension_directory))
        self.assertFalse(
            os.path.exists(os.path.join(self.base_dir, "foo__1.2.3.zip")))

    def test_rm_ext_handler_dir_should_remove_the_extension_directory(self):
        os.mkdir(self.extension_directory)
        os.mknod(os.path.join(self.extension_directory, "extension_file1"))
        os.mknod(os.path.join(self.extension_directory, "extension_file2"))
        os.mknod(os.path.join(self.extension_directory, "extension_file3"))

        self.ext_handler_instance.remove_ext_handler()

        self.assertFalse(os.path.exists(self.extension_directory))

    def test_rm_ext_handler_dir_should_not_report_an_event_if_the_extension_directory_does_not_exist(
            self):
        if os.path.exists(self.extension_directory):
            os.rmdir(self.extension_directory)

        with patch.object(self.ext_handler_instance,
                          "report_event") as mock_report_event:
            self.ext_handler_instance.remove_ext_handler()

        mock_report_event.assert_not_called()

    def test_rm_ext_handler_dir_should_not_report_an_event_if_a_child_is_removed_asynchronously_while_deleting_the_extension_directory(
            self):
        os.mkdir(self.extension_directory)
        os.mknod(os.path.join(self.extension_directory, "extension_file1"))
        os.mknod(os.path.join(self.extension_directory, "extension_file2"))
        os.mknod(os.path.join(self.extension_directory, "extension_file3"))

        #
        # Some extensions uninstall asynchronously and the files we are trying to remove may be removed
        # while shutil.rmtree is traversing the extension's directory. Mock this by deleting a file
        # twice (the second call will produce "[Errno 2] No such file or directory", which should not be
        # reported as a telemetry event.
        # In order to mock this, we need to know that remove_ext_handler invokes Pyhon's shutil.rmtree,
        # which in turn invokes os.unlink (Python 3) or os.remove (Python 2)
        #
        remove_api_name = "unlink" if sys.version_info >= (3, 0) else "remove"

        original_remove_api = getattr(shutil.os, remove_api_name)

        extension_directory = self.extension_directory

        def mock_remove(path, dir_fd=None):
            if dir_fd is not None:  # path is relative, make it absolute
                path = os.path.join(extension_directory, path)

            if path.endswith("extension_file2"):
                original_remove_api(path)
                mock_remove.file_deleted_asynchronously = True
            original_remove_api(path)

        mock_remove.file_deleted_asynchronously = False

        with patch.object(shutil.os, remove_api_name, mock_remove):
            with patch.object(self.ext_handler_instance,
                              "report_event") as mock_report_event:
                self.ext_handler_instance.remove_ext_handler()

        mock_report_event.assert_not_called()

        # The next 2 asserts are checks on the mock itself, in case the implementation of remove_ext_handler changes (mocks may need to be updated then)
        self.assertTrue(mock_remove.file_deleted_asynchronously
                        )  # verify the mock was actually called
        self.assertFalse(
            os.path.exists(self.extension_directory)
        )  # verify the error produced by the mock did not prevent the deletion

    def test_rm_ext_handler_dir_should_report_an_event_if_an_error_occurs_while_deleting_the_extension_directory(
            self):
        os.mkdir(self.extension_directory)
        os.mknod(os.path.join(self.extension_directory, "extension_file1"))
        os.mknod(os.path.join(self.extension_directory, "extension_file2"))
        os.mknod(os.path.join(self.extension_directory, "extension_file3"))

        # The mock below relies on the knowledge that remove_ext_handler invokes Pyhon's shutil.rmtree,
        # which in turn invokes os.unlink (Python 3) or os.remove (Python 2)
        remove_api_name = "unlink" if sys.version_info >= (3, 0) else "remove"

        original_remove_api = getattr(shutil.os, remove_api_name)

        def mock_remove(path, dir_fd=None):  # pylint: disable=unused-argument
            if path.endswith("extension_file2"):
                raise IOError("A mocked error")
            original_remove_api(path)

        with patch.object(shutil.os, remove_api_name, mock_remove):
            with patch.object(self.ext_handler_instance,
                              "report_event") as mock_report_event:
                self.ext_handler_instance.remove_ext_handler()

        args, kwargs = mock_report_event.call_args  # pylint: disable=unused-variable
        self.assertTrue("A mocked error" in kwargs["message"])
Exemple #17
0
    def test_telemetry_with_tracked_cgroup(self):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        # This test has some timing issues when systemd is managing cgroups, so we force the file system API
        # by creating a new instance of the CGroupConfigurator
        with patch("azurelinuxagent.common.cgroupapi.CGroupsApi._is_systemd", return_value=False):
            cgroup_configurator_instance = CGroupConfigurator._instance
            CGroupConfigurator._instance = None

            try:
                max_num_polls = 30
                time_to_wait = 3
                extn_name = "foobar-1.0.0"
                num_summarization_values = 7

                cgs = make_new_cgroup(extn_name)
                self.assertEqual(len(cgs), 2)

                ext_handler_properties = ExtHandlerProperties()
                ext_handler_properties.version = "1.0.0"
                self.ext_handler = ExtHandler(name='foobar')
                self.ext_handler.properties = ext_handler_properties
                self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=None)

                command = self.create_script("keep_cpu_busy_and_consume_memory_for_5_seconds", '''
nohup python -c "import time

for i in range(5):
    x = [1, 2, 3, 4, 5] * (i * 1000)
    time.sleep({0})
    x *= 0
    print('Test loop')" &
'''.format(time_to_wait))

                self.log_dir = os.path.join(self.tmp_dir, "log")

                with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir) as \
                        patch_get_base_dir:
                    with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir) as \
                            patch_get_log_dir:
                        self.ext_handler_instance.launch_command(command)

                #
                # If the test is made to run using the systemd API, then the paths of the cgroups need to be checked differently:
                #
                #     self.assertEquals(len(CGroupsTelemetry._tracked), 2)
                #     cpu = os.path.join(BASE_CGROUPS, "cpu", "system.slice", r"foobar_1.0.0_.*\.scope")
                #     self.assertTrue(any(re.match(cpu, tracked.path) for tracked in CGroupsTelemetry._tracked))
                #     memory = os.path.join(BASE_CGROUPS, "memory", "system.slice", r"foobar_1.0.0_.*\.scope")
                #     self.assertTrue(any(re.match(memory, tracked.path) for tracked in CGroupsTelemetry._tracked))
                #
                self.assertTrue(CGroupsTelemetry.is_tracked(os.path.join(
                    BASE_CGROUPS, "cpu", "walinuxagent.extensions", "foobar_1.0.0")))
                self.assertTrue(CGroupsTelemetry.is_tracked(os.path.join(
                    BASE_CGROUPS, "memory", "walinuxagent.extensions", "foobar_1.0.0")))

                for i in range(max_num_polls):
                    CGroupsTelemetry.poll_all_tracked()
                    time.sleep(0.5)

                collected_metrics = CGroupsTelemetry.report_all_tracked()

                self.assertIn("memory", collected_metrics[extn_name])
                self.assertIn("cur_mem", collected_metrics[extn_name]["memory"])
                self.assertIn("max_mem", collected_metrics[extn_name]["memory"])
                self.assertEqual(len(collected_metrics[extn_name]["memory"]["cur_mem"]), num_summarization_values)
                self.assertEqual(len(collected_metrics[extn_name]["memory"]["max_mem"]), num_summarization_values)

                self.assertIsInstance(collected_metrics[extn_name]["memory"]["cur_mem"][5], str)
                self.assertIsInstance(collected_metrics[extn_name]["memory"]["cur_mem"][6], str)
                self.assertIsInstance(collected_metrics[extn_name]["memory"]["max_mem"][5], str)
                self.assertIsInstance(collected_metrics[extn_name]["memory"]["max_mem"][6], str)

                self.assertIn("cpu", collected_metrics[extn_name])
                self.assertIn("cur_cpu", collected_metrics[extn_name]["cpu"])
                self.assertEqual(len(collected_metrics[extn_name]["cpu"]["cur_cpu"]), num_summarization_values)

                self.assertIsInstance(collected_metrics[extn_name]["cpu"]["cur_cpu"][5], str)
                self.assertIsInstance(collected_metrics[extn_name]["cpu"]["cur_cpu"][6], str)

                for i in range(5):
                    self.assertGreater(collected_metrics[extn_name]["memory"]["cur_mem"][i], 0)
                    self.assertGreater(collected_metrics[extn_name]["memory"]["max_mem"][i], 0)
                    self.assertGreaterEqual(collected_metrics[extn_name]["cpu"]["cur_cpu"][i], 0)
                    # Equal because CPU could be zero for minimum value.
            finally:
                CGroupConfigurator._instance = cgroup_configurator_instance
Exemple #18
0
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
        cls.mock_cgroups.start()

        cls.mock_cgroups_telemetry = patch(
            "azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
        cls.mock_cgroups_telemetry.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups_telemetry.stop()
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler,
                                                       protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

    def tearDown(self):
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    def _create_script(self, file_name, contents):
        """
        Creates an executable script with the given contents.
        If file_name ends with ".py", it creates a Python3 script, otherwise it creates a bash script
        """
        file_path = os.path.join(self.ext_handler_instance.get_base_dir(),
                                 file_name)

        with open(file_path, "w") as script:
            if file_name.endswith(".py"):
                script.write("#!/usr/bin/env python3\n")
            else:
                script.write("#!/usr/bin/env bash\n")
            script.write(contents)

        os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)

        return file_name

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'),
                          'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = self._create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir)
                    if not i.endswith(".tld")]  # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(
            files_before, files_after,
            "Not all temporary files were deleted. File list: {0}".format(
                files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = 1234
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = self._create_script(
            "produce_output_then_hang.py", '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        original_sleep = time.sleep

        def sleep(seconds):
            if not os.path.exists(signal_file):
                original_sleep(seconds)

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep,
                   autospec=True) as mock_sleep:

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(
                    command,
                    timeout=timeout,
                    extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            self.assertRegex(
                message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(
                    command,
                    LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEquals(context_manager.exception.code,
                              extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # the command should have been terminated
            self.assertFalse(LaunchCommandTestCase._find_process(command),
                             "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = self._create_script(
            "fail.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(
                command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(
            message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(
                exit_code, command,
                LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEquals(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = self._create_script(
            "start_child_process.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script(
            "start_child_with_output.py", '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr,
           more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(
            self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = self._create_script(
            "start_child_that_fails.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script(
            "create_file.py", '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(
            self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = self._create_script(
            "produce_output", '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEquals(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = self._create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr,
           int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = self._create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch('azurelinuxagent.ga.exthandlers.format_stdout_stderr',
                   side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = self._create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to _capture_process_output, which will
        # call file.read() and we force stdout/stderr to be None; this will produce an exception when trying to use these files.
        original_capture_process_output = ExtHandlerInstance._capture_process_output

        def capture_process_output(process, stdout_file, stderr_file, cmd,
                                   timeout, code):
            return original_capture_process_output(process, None, None, cmd,
                                                   timeout, code)

        with patch(
                'azurelinuxagent.ga.exthandlers.ExtHandlerInstance._capture_process_output',
                side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)
Exemple #19
0
class TestCGroupsTelemetry(AgentTestCase):
    def setUp(self):
        AgentTestCase.setUp(self)
        CGroupsTelemetry.reset()

    def tearDown(self):
        AgentTestCase.tearDown(self)
        CGroupsTelemetry.reset()

    def _assert_cgroup_metrics_equal(self, cpu_usage, memory_usage, max_memory_usage):
        for _, cgroup_metric in CGroupsTelemetry._cgroup_metrics.items():
            self.assertListEqual(cgroup_metric.get_memory_usage()._data, memory_usage)
            self.assertListEqual(cgroup_metric.get_max_memory_usage()._data, max_memory_usage)
            self.assertListEqual(cgroup_metric.get_cpu_usage()._data, cpu_usage)

    @patch("azurelinuxagent.common.cgroup.CpuCgroup._get_current_cpu_total")
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    def test_telemetry_polling_with_active_cgroups(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage") as patch_get_memory_max_usage:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CpuCgroup._get_cpu_percent") as patch_get_cpu_percent:
                    with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:
                        patch_is_active.return_value = True

                        current_cpu = 30
                        current_memory = 209715200
                        current_max_memory = 471859200

                        patch_get_cpu_percent.return_value = current_cpu
                        patch_get_memory_usage.return_value = current_memory  # example 200 MB
                        patch_get_memory_max_usage.return_value = current_max_memory  # example 450 MB

                        poll_count = 1

                        for data_count in range(poll_count, 10):
                            CGroupsTelemetry.poll_all_tracked()
                            self.assertEqual(len(CGroupsTelemetry._cgroup_metrics), num_extensions)
                            self._assert_cgroup_metrics_equal(
                                cpu_usage=[current_cpu] * data_count,
                                memory_usage=[current_memory] * data_count,
                                max_memory_usage=[current_max_memory] * data_count)

                        CGroupsTelemetry.report_all_tracked()

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)
                        self._assert_cgroup_metrics_equal([], [], [])

    @patch("azurelinuxagent.common.cgroup.CpuCgroup._get_current_cpu_total")
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    def test_telemetry_polling_with_inactive_cgroups(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage") as patch_get_memory_max_usage:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CpuCgroup._get_cpu_percent") as patch_get_cpu_percent:
                    with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:
                        patch_is_active.return_value = False

                        no_extensions_expected = 0
                        data_count = 1
                        current_cpu = 30
                        current_memory = 209715200
                        current_max_memory = 471859200

                        patch_get_cpu_percent.return_value = current_cpu
                        patch_get_memory_usage.return_value = current_memory  # example 200 MB
                        patch_get_memory_max_usage.return_value = current_max_memory  # example 450 MB

                        for i in range(num_extensions):
                            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
                            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

                        CGroupsTelemetry.poll_all_tracked()

                        for i in range(num_extensions):
                            self.assertFalse(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
                            self.assertFalse(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)
                        self._assert_cgroup_metrics_equal(
                            cpu_usage=[current_cpu] * data_count,
                            memory_usage=[current_memory] * data_count,
                            max_memory_usage=[current_max_memory] * data_count)

                        CGroupsTelemetry.report_all_tracked()

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), no_extensions_expected)
                        self._assert_cgroup_metrics_equal([], [], [])

    @patch("azurelinuxagent.common.cgroup.CpuCgroup._get_current_cpu_total")
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    def test_telemetry_polling_with_changing_cgroups_state(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage") as patch_get_memory_max_usage:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CpuCgroup._get_cpu_percent") as patch_get_cpu_percent:
                    with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:
                        patch_is_active.return_value = True

                        no_extensions_expected = 0
                        expected_data_count = 2

                        current_cpu = 30
                        current_memory = 209715200
                        current_max_memory = 471859200

                        patch_get_cpu_percent.return_value = current_cpu
                        patch_get_memory_usage.return_value = current_memory  # example 200 MB
                        patch_get_memory_max_usage.return_value = current_max_memory  # example 450 MB

                        for i in range(num_extensions):
                            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
                            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

                        CGroupsTelemetry.poll_all_tracked()

                        for i in range(num_extensions):
                            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
                            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)

                        patch_is_active.return_value = False
                        CGroupsTelemetry.poll_all_tracked()

                        for i in range(num_extensions):
                            self.assertFalse(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
                            self.assertFalse(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)
                        self._assert_cgroup_metrics_equal(
                            cpu_usage=[current_cpu] * expected_data_count,
                            memory_usage=[current_memory] * expected_data_count,
                            max_memory_usage=[current_max_memory] * expected_data_count)

                        CGroupsTelemetry.report_all_tracked()

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), no_extensions_expected)
                        self._assert_cgroup_metrics_equal([], [], [])

    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    @patch("azurelinuxagent.common.logger.periodic_warn")
    @patch("azurelinuxagent.common.utils.fileutil.read_file")
    def test_telemetry_polling_to_not_generate_transient_logs_ioerror_file_not_found(self, mock_read_file,
                                                                                     patch_periodic_warn, *args):
        num_extensions = 1
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        self.assertEqual(0, patch_periodic_warn.call_count)

        # Not expecting logs present for io_error with errno=errno.ENOENT
        io_error_2 = IOError()
        io_error_2.errno = errno.ENOENT
        mock_read_file.side_effect = io_error_2

        poll_count = 1
        for data_count in range(poll_count, 10):
            CGroupsTelemetry.poll_all_tracked()
            self.assertEqual(0, patch_periodic_warn.call_count)

    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    @patch("azurelinuxagent.common.logger.periodic_warn")
    @patch("azurelinuxagent.common.utils.fileutil.read_file")
    def test_telemetry_polling_to_generate_transient_logs_ioerror_permission_denied(self, mock_read_file,
                                                                                    patch_periodic_warn, *args):
        num_extensions = 1
        num_controllers = 2
        is_active_check_per_controller = 2

        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        self.assertEqual(0, patch_periodic_warn.call_count)

        # Expecting logs to be present for different kind of errors
        io_error_3 = IOError()
        io_error_3.errno = errno.EPERM
        mock_read_file.side_effect = io_error_3

        poll_count = 1
        expected_count_per_call = num_controllers + is_active_check_per_controller
        # each collect per controller would generate a log statement, and each cgroup would invoke a
        # is active check raising an exception

        for data_count in range(poll_count, 10):
            CGroupsTelemetry.poll_all_tracked()
            self.assertEqual(poll_count * expected_count_per_call, patch_periodic_warn.call_count)

    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    @patch("azurelinuxagent.common.utils.fileutil.read_file")
    def test_telemetry_polling_to_generate_transient_logs_index_error(self, mock_read_file, *args):
        num_extensions = 1
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        # Generating a different kind of error (non-IOError) to check the logging.
        # Trying to invoke IndexError during the getParameter call
        mock_read_file.return_value = ''

        with patch("azurelinuxagent.common.logger.periodic_warn") as patch_periodic_warn:
            expected_call_count = 1  # called only once at start, and then gets removed from the tracked data.
            for data_count in range(1, 10):
                CGroupsTelemetry.poll_all_tracked()
                self.assertEqual(expected_call_count, patch_periodic_warn.call_count)

    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    @patch("azurelinuxagent.common.cgroup.CpuCgroup._get_current_cpu_total")
    @patch("azurelinuxagent.common.cgroup.CpuCgroup._update_cpu_data")
    def test_telemetry_calculations(self, *args):
        num_polls = 10
        num_extensions = 1
        num_summarization_values = 7

        cpu_percent_values = [random.randint(0, 100) for _ in range(num_polls)]

        # only verifying calculations and not validity of the values.
        memory_usage_values = [random.randint(0, 8 * 1024 ** 3) for _ in range(num_polls)]
        max_memory_usage_values = [random.randint(0, 8 * 1024 ** 3) for _ in range(num_polls)]

        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        self.assertEqual(2 * num_extensions, len(CGroupsTelemetry._tracked))

        with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage") as patch_get_memory_max_usage:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CpuCgroup._get_cpu_percent") as patch_get_cpu_percent:
                    with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:
                        for i in range(num_polls):
                            patch_is_active.return_value = True
                            patch_get_cpu_percent.return_value = cpu_percent_values[i]
                            patch_get_memory_usage.return_value = memory_usage_values[i]  # example 200 MB
                            patch_get_memory_max_usage.return_value = max_memory_usage_values[i]  # example 450 MB
                            CGroupsTelemetry.poll_all_tracked()

        collected_metrics = CGroupsTelemetry.report_all_tracked()
        for i in range(num_extensions):
            name = "dummy_extension_{0}".format(i)

            self.assertIn(name, collected_metrics)
            self.assertIn("memory", collected_metrics[name])
            self.assertIn("cur_mem", collected_metrics[name]["memory"])
            self.assertIn("max_mem", collected_metrics[name]["memory"])
            self.assertEqual(num_summarization_values, len(collected_metrics[name]["memory"]["cur_mem"]))
            self.assertEqual(num_summarization_values, len(collected_metrics[name]["memory"]["max_mem"]))

            self.assertListEqual(generate_metric_list(memory_usage_values),
                                 collected_metrics[name]["memory"]["cur_mem"][0:5])
            self.assertListEqual(generate_metric_list(max_memory_usage_values),
                                 collected_metrics[name]["memory"]["max_mem"][0:5])

            self.assertIn("cpu", collected_metrics[name])
            self.assertIn("cur_cpu", collected_metrics[name]["cpu"])
            self.assertEqual(num_summarization_values, len(collected_metrics[name]["cpu"]["cur_cpu"]))
            self.assertListEqual(generate_metric_list(cpu_percent_values),
                                 collected_metrics[name]["cpu"]["cur_cpu"][0:5])

    # mocking get_proc_stat to make it run on Mac and other systems
    # this test does not need to read the values of the /proc/stat file
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    def test_cgroup_tracking(self, *args):
        num_extensions = 5
        num_controllers = 2
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        for i in range(num_extensions):
            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertEqual(num_extensions * num_controllers, len(CGroupsTelemetry._tracked))

    # mocking get_proc_stat to make it run on Mac and other systems
    # this test does not need to read the values of the /proc/stat file
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    def test_cgroup_pruning(self, *args):
        num_extensions = 5
        num_controllers = 2
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        for i in range(num_extensions):
            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertEqual(num_extensions * num_controllers, len(CGroupsTelemetry._tracked))

        CGroupsTelemetry.prune_all_tracked()

        for i in range(num_extensions):
            self.assertFalse(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertFalse(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertEqual(0, len(CGroupsTelemetry._tracked))

    # mocking get_proc_stat to make it run on Mac and other systems
    # this test does not need to read the values of the /proc/stat file
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    def test_cgroup_is_tracked(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory", "dummy_extension_{0}".
                                                format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        for i in range(num_extensions):
            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertTrue(CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertFalse(CGroupsTelemetry.is_tracked("not_present_cpu_dummy_path"))
        self.assertFalse(CGroupsTelemetry.is_tracked("not_present_memory_dummy_path"))

    # mocking get_proc_stat to make it run on Mac and other systems
    # this test does not need to read the values of the /proc/stat file
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil._get_proc_stat")
    def test_process_cgroup_metric_with_incorrect_cgroups_mounted(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage") as patch_get_cpu_usage:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                patch_get_cpu_usage.side_effect = Exception("File not found")
                patch_get_memory_usage.side_effect = Exception("File not found")

                for data_count in range(1, 10):
                    CGroupsTelemetry.poll_all_tracked()

                self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)

                collected_metrics = {}
                for name, cgroup_metrics in CGroupsTelemetry._cgroup_metrics.items():
                    collected_metrics[name] = CGroupsTelemetry._process_cgroup_metric(cgroup_metrics)
                    self.assertEqual(collected_metrics[name], {})  # empty

    @patch("azurelinuxagent.common.cgroup.CpuCgroup._get_current_cpu_total")
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    def test_process_cgroup_metric_with_no_memory_cgroup_mounted(self, *args):
        num_extensions = 5

        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroup.CpuCgroup._get_cpu_percent") as patch_get_cpu_percent:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:
                    patch_is_active.return_value = True
                    patch_get_memory_usage.side_effect = Exception("File not found")

                    current_cpu = 30
                    patch_get_cpu_percent.return_value = current_cpu

                    poll_count = 1

                    for data_count in range(poll_count, 10):
                        CGroupsTelemetry.poll_all_tracked()

                        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)
                        self._assert_cgroup_metrics_equal(cpu_usage=[current_cpu] * data_count, memory_usage=[], max_memory_usage=[])

                    CGroupsTelemetry.report_all_tracked()

                    self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(), num_extensions)
                    self._assert_cgroup_metrics_equal([], [], [])

    @patch("azurelinuxagent.common.cgroup.CpuCgroup._get_current_cpu_total")
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    def test_process_cgroup_metric_with_no_cpu_cgroup_mounted(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage") as patch_get_memory_max_usage:
            with patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage") as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage") as patch_get_cpu_usage:
                    with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:
                        patch_is_active.return_value = True

                        patch_get_cpu_usage.side_effect = Exception("File not found")

                        current_memory = 209715200
                        current_max_memory = 471859200

                        patch_get_memory_usage.return_value = current_memory  # example 200 MB
                        patch_get_memory_max_usage.return_value = current_max_memory  # example 450 MB

                        poll_count = 1

                        for data_count in range(poll_count, 10):
                            CGroupsTelemetry.poll_all_tracked()
                            self.assertEqual(len(CGroupsTelemetry._cgroup_metrics), num_extensions)
                            self._assert_cgroup_metrics_equal(
                                cpu_usage=[],
                                memory_usage=[current_memory] * data_count,
                                max_memory_usage=[current_max_memory] * data_count)

                        CGroupsTelemetry.report_all_tracked()

                        self.assertEqual(len(CGroupsTelemetry._cgroup_metrics), num_extensions)
                        self._assert_cgroup_metrics_equal([], [], [])

    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage")
    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage")
    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage")
    @patch("azurelinuxagent.common.osutil.default.DefaultOSUtil.get_total_cpu_ticks_since_boot")
    def test_extension_temetry_not_sent_for_empty_perf_metrics(self, *args):
        num_extensions = 5
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i), "cpu", "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create("dummy_memory_path_{0}".format(i), "memory",
                                                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

        with patch("azurelinuxagent.common.cgroupstelemetry.CGroupsTelemetry._process_cgroup_metric") as \
                patch_process_cgroup_metric:
            with patch("azurelinuxagent.common.cgroup.CGroup.is_active") as patch_is_active:

                patch_is_active.return_value = False
                patch_process_cgroup_metric.return_value = {}
                poll_count = 1

                for data_count in range(poll_count, 10):
                    CGroupsTelemetry.poll_all_tracked()

                collected_metrics = CGroupsTelemetry.report_all_tracked()
                self.assertEqual(0, len(collected_metrics))

    @skip_if_predicate_false(are_cgroups_enabled, "Does not run when Cgroups are not enabled")
    @skip_if_predicate_true(is_trusty_in_travis, "Does not run on Trusty in Travis")
    @attr('requires_sudo')
    def test_telemetry_with_tracked_cgroup(self):
        self.assertTrue(i_am_root(), "Test does not run when non-root")

        # This test has some timing issues when systemd is managing cgroups, so we force the file system API
        # by creating a new instance of the CGroupConfigurator
        with patch("azurelinuxagent.common.cgroupapi.CGroupsApi._is_systemd", return_value=False):
            cgroup_configurator_instance = CGroupConfigurator._instance
            CGroupConfigurator._instance = None

            try:
                max_num_polls = 30
                time_to_wait = 3
                extn_name = "foobar-1.0.0"
                num_summarization_values = 7

                cgs = make_new_cgroup(extn_name)
                self.assertEqual(len(cgs), 2)

                ext_handler_properties = ExtHandlerProperties()
                ext_handler_properties.version = "1.0.0"
                self.ext_handler = ExtHandler(name='foobar')
                self.ext_handler.properties = ext_handler_properties
                self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=None)

                command = self.create_script("keep_cpu_busy_and_consume_memory_for_5_seconds", '''
nohup python -c "import time

for i in range(5):
    x = [1, 2, 3, 4, 5] * (i * 1000)
    time.sleep({0})
    x *= 0
    print('Test loop')" &
'''.format(time_to_wait))

                self.log_dir = os.path.join(self.tmp_dir, "log")

                with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir) as \
                        patch_get_base_dir:
                    with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir) as \
                            patch_get_log_dir:
                        self.ext_handler_instance.launch_command(command)

                #
                # If the test is made to run using the systemd API, then the paths of the cgroups need to be checked differently:
                #
                #     self.assertEquals(len(CGroupsTelemetry._tracked), 2)
                #     cpu = os.path.join(BASE_CGROUPS, "cpu", "system.slice", r"foobar_1.0.0_.*\.scope")
                #     self.assertTrue(any(re.match(cpu, tracked.path) for tracked in CGroupsTelemetry._tracked))
                #     memory = os.path.join(BASE_CGROUPS, "memory", "system.slice", r"foobar_1.0.0_.*\.scope")
                #     self.assertTrue(any(re.match(memory, tracked.path) for tracked in CGroupsTelemetry._tracked))
                #
                self.assertTrue(CGroupsTelemetry.is_tracked(os.path.join(
                    BASE_CGROUPS, "cpu", "walinuxagent.extensions", "foobar_1.0.0")))
                self.assertTrue(CGroupsTelemetry.is_tracked(os.path.join(
                    BASE_CGROUPS, "memory", "walinuxagent.extensions", "foobar_1.0.0")))

                for i in range(max_num_polls):
                    CGroupsTelemetry.poll_all_tracked()
                    time.sleep(0.5)

                collected_metrics = CGroupsTelemetry.report_all_tracked()

                self.assertIn("memory", collected_metrics[extn_name])
                self.assertIn("cur_mem", collected_metrics[extn_name]["memory"])
                self.assertIn("max_mem", collected_metrics[extn_name]["memory"])
                self.assertEqual(len(collected_metrics[extn_name]["memory"]["cur_mem"]), num_summarization_values)
                self.assertEqual(len(collected_metrics[extn_name]["memory"]["max_mem"]), num_summarization_values)

                self.assertIsInstance(collected_metrics[extn_name]["memory"]["cur_mem"][5], str)
                self.assertIsInstance(collected_metrics[extn_name]["memory"]["cur_mem"][6], str)
                self.assertIsInstance(collected_metrics[extn_name]["memory"]["max_mem"][5], str)
                self.assertIsInstance(collected_metrics[extn_name]["memory"]["max_mem"][6], str)

                self.assertIn("cpu", collected_metrics[extn_name])
                self.assertIn("cur_cpu", collected_metrics[extn_name]["cpu"])
                self.assertEqual(len(collected_metrics[extn_name]["cpu"]["cur_cpu"]), num_summarization_values)

                self.assertIsInstance(collected_metrics[extn_name]["cpu"]["cur_cpu"][5], str)
                self.assertIsInstance(collected_metrics[extn_name]["cpu"]["cur_cpu"][6], str)

                for i in range(5):
                    self.assertGreater(collected_metrics[extn_name]["memory"]["cur_mem"][i], 0)
                    self.assertGreater(collected_metrics[extn_name]["memory"]["max_mem"][i], 0)
                    self.assertGreaterEqual(collected_metrics[extn_name]["cpu"]["cur_cpu"][i], 0)
                    # Equal because CPU could be zero for minimum value.
            finally:
                CGroupConfigurator._instance = cgroup_configurator_instance
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()

        self.cgroups_enabled = CGroupConfigurator.get_instance().enabled()
        CGroupConfigurator.get_instance().disable()

    def tearDown(self):
        if self.cgroups_enabled:
            CGroupConfigurator.get_instance().enable()
        else:
            CGroupConfigurator.get_instance().disable()

        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()
        self.mock_sleep.stop()

        AgentTestCase.tearDown(self)

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'),
                          'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = self.create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir)
                    if not i.endswith(".tld")]  # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(
            files_before, files_after,
            "Not all temporary files were deleted. File list: {0}".format(
                files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = ExtensionErrorCodes.PluginHandlerScriptTimedout
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = self.create_script(
            "produce_output_then_hang.py", '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        original_sleep = time.sleep

        def sleep(seconds):
            if not os.path.exists(signal_file):
                original_sleep(seconds)

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep,
                   autospec=True) as mock_sleep:

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(
                    command,
                    timeout=timeout,
                    extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            command_full_path = os.path.join(self.tmp_dir,
                                             command.lstrip(os.path.sep))
            self.assertRegex(
                message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(
                    command_full_path,
                    LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEquals(context_manager.exception.code,
                              extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # the command should have been terminated
            self.assertFalse(LaunchCommandTestCase._find_process(command),
                             "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = self.create_script(
            "fail.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(
                command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(
            message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(
                exit_code, command,
                LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEquals(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = self.create_script(
            "start_child_process.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self.create_script(
            "start_child_with_output.py", '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr,
           more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(
            self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = self.create_script(
            "start_child_that_fails.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self.create_script(
            "create_file.py", '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(
            self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = self.create_script(
            "produce_output", '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEquals(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = self.create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr,
           int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = self.create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch(
                'azurelinuxagent.common.utils.extensionprocessutil.format_stdout_stderr',
                side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = self.create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')
        # Mocking the call to file.read() is difficult, so instead we mock the call to_capture_process_output,
        # which will call file.read() and we force stdout/stderr to be None; this will produce an exception when
        # trying to use these files.
        original_capture_process_output = read_output

        def capture_process_output(stdout_file, stderr_file):
            return original_capture_process_output(None, None)

        with patch(
                'azurelinuxagent.common.utils.extensionprocessutil.read_output',
                side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)

    def test_it_should_contain_all_helper_environment_variables(self):

        helper_env_vars = {
            ExtCommandEnvVariable.ExtensionSeqNumber:
            self.ext_handler_instance.get_seq_no(),
            ExtCommandEnvVariable.ExtensionPath:
            self.tmp_dir,
            ExtCommandEnvVariable.ExtensionVersion:
            self.ext_handler_instance.ext_handler.properties.version
        }

        command = """
            printenv | grep -E '(%s)'
        """ % '|'.join(helper_env_vars.keys())

        test_file = self.create_script('printHelperEnvironments.sh', command)

        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
            output = self.ext_handler_instance.launch_command(test_file)

            args, kwagrs = patch_popen.call_args
            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items()
                                  if k not in os.environ)

            # This check will fail if any helper environment variables are added/removed later on
            self.assertEqual(helper_env_vars, without_os_env)

            # This check is checking if the expected values are set for the extension commands
            for helper_var in helper_env_vars:
                self.assertIn(
                    "%s=%s" % (helper_var, helper_env_vars[helper_var]),
                    output)
    def test_telemetry_with_tracked_cgroup(self, *_):
        self.assertTrue(i_am_root(), "Test does not run when non-root")
        CGroupConfigurator._instance = None

        max_num_polls = 30
        time_to_wait = 3
        extn_name = "foobar-1.0.0"
        num_summarization_values = 7

        cgs = make_new_cgroup(extn_name)
        self.assertEqual(len(cgs), 2)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        self.ext_handler = ExtHandler(name='foobar')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        command = self.create_script(
            "keep_cpu_busy_and_consume_memory_for_5_seconds", '''
nohup python -c "import time

for i in range(5):
    x = [1, 2, 3, 4, 5] * (i * 1000)
    time.sleep({0})
    x *= 0
    print('Test loop')" &
'''.format(time_to_wait))

        self.log_dir = os.path.join(self.tmp_dir, "log")

        with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir) as \
                patch_get_base_dir:
            with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir) as \
                    patch_get_log_dir:
                self.ext_handler_instance.launch_command(command)

        self.assertTrue(
            CGroupsTelemetry.is_tracked(
                os.path.join(BASE_CGROUPS, "cpu", "walinuxagent.extensions",
                             "foobar_1.0.0")))
        self.assertTrue(
            CGroupsTelemetry.is_tracked(
                os.path.join(BASE_CGROUPS, "memory", "walinuxagent.extensions",
                             "foobar_1.0.0")))

        for i in range(max_num_polls):
            CGroupsTelemetry.poll_all_tracked()
            time.sleep(0.5)

        collected_metrics = CGroupsTelemetry.report_all_tracked()

        self.assertIn("memory", collected_metrics[extn_name])
        self.assertIn("cur_mem", collected_metrics[extn_name]["memory"])
        self.assertIn("max_mem", collected_metrics[extn_name]["memory"])
        self.assertEqual(
            len(collected_metrics[extn_name]["memory"]["cur_mem"]),
            num_summarization_values)
        self.assertEqual(
            len(collected_metrics[extn_name]["memory"]["max_mem"]),
            num_summarization_values)

        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["cur_mem"][5], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["cur_mem"][6], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["max_mem"][5], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["max_mem"][6], str)

        self.assertIn("cpu", collected_metrics[extn_name])
        self.assertIn("cur_cpu", collected_metrics[extn_name]["cpu"])
        self.assertEqual(len(collected_metrics[extn_name]["cpu"]["cur_cpu"]),
                         num_summarization_values)

        self.assertIsInstance(
            collected_metrics[extn_name]["cpu"]["cur_cpu"][5], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["cpu"]["cur_cpu"][6], str)

        for i in range(5):
            self.assertGreater(
                collected_metrics[extn_name]["memory"]["cur_mem"][i], 0)
            self.assertGreater(
                collected_metrics[extn_name]["memory"]["max_mem"][i], 0)
            self.assertGreaterEqual(
                collected_metrics[extn_name]["cpu"]["cur_cpu"][i], 0)
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
        cls.mock_cgroups.start()

        cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
        cls.mock_cgroups_telemetry.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups_telemetry.stop()
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

    def tearDown(self):
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    def _create_script(self, file_name, contents):
        """
        Creates an executable script with the given contents.
        If file_name ends with ".py", it creates a Python3 script, otherwise it creates a bash script
        """
        file_path = os.path.join(self.ext_handler_instance.get_base_dir(), file_name)

        with open(file_path, "w") as script:
            if file_name.endswith(".py"):
                script.write("#!/usr/bin/env python3\n")
            else:
                script.write("#!/usr/bin/env bash\n")
            script.write(contents)

        os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)

        return file_name

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'), 'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = self._create_script("produce_output.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir) if not i.endswith(".tld")] # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(files_before, files_after, "Not all temporary files were deleted. File list: {0}".format(files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = 1234
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = self._create_script("produce_output_then_hang.py", '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        original_sleep = time.sleep

        def sleep(seconds):
            if not os.path.exists(signal_file):
                original_sleep(seconds)

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep, autospec=True) as mock_sleep:

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(command, timeout=timeout, extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            self.assertRegex(message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(command, LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEquals(context_manager.exception.code, extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # the command should have been terminated
            self.assertFalse(LaunchCommandTestCase._find_process(command), "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = self._create_script("fail.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(exit_code, command, LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEquals(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = self._create_script("start_child_process.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script("start_child_with_output.py", '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr, more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = self._create_script("start_child_that_fails.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script("create_file.py", '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = self._create_script("produce_output", '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEquals(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = self._create_script("produce_long_output.py", '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr, int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = self._create_script("produce_long_output.py", '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch('azurelinuxagent.ga.exthandlers.format_stdout_stderr', side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = self._create_script("produce_output.py", '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to _capture_process_output, which will
        # call file.read() and we force stdout/stderr to be None; this will produce an exception when trying to use these files.
        original_capture_process_output = ExtHandlerInstance._capture_process_output

        def capture_process_output(process, stdout_file, stderr_file, cmd, timeout, code):
            return original_capture_process_output(process, None, None, cmd, timeout, code)

        with patch('azurelinuxagent.ga.exthandlers.ExtHandlerInstance._capture_process_output', side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)
class TestCGroupsTelemetry(AgentTestCase):
    TestProcessIds = ["1000", "1001", "1002"]
    TestProcStatmMemoryUsed = 1234
    TestProcComm = "python"
    TestProcCommandLine = "python -u bin/WALinuxAgent-2.2.45-py2.7.egg -run-exthandlers"
    NumSummarizationValues = 7

    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()

        # Use the default value for memory used from proc_statm
        cls.mock_get_memory_usage_from_proc_statm = patch(
            "azurelinuxagent.common.resourceusage.MemoryResourceUsage."
            "get_memory_usage_from_proc_statm",
            return_value=TestCGroupsTelemetry.TestProcStatmMemoryUsed)
        cls.mock_get_memory_usage_from_proc_statm.start()

        # Use the default value for memory used from proc_statm
        cls.mock_get_tracked_processes = patch(
            "azurelinuxagent.common.cgroup.CGroup.get_tracked_processes",
            return_value=TestCGroupsTelemetry.TestProcessIds)
        cls.mock_get_tracked_processes.start()

        cls.mock_get_proc_name = patch(
            "azurelinuxagent.common.resourceusage.ProcessInfo.get_proc_name",
            return_value=TestCGroupsTelemetry.TestProcComm)
        cls.mock_get_proc_name.start()

        cls.mock_get_proc_cmdline = patch(
            "azurelinuxagent.common.resourceusage.ProcessInfo.get_proc_cmdline",
            return_value=TestCGroupsTelemetry.TestProcCommandLine)
        cls.mock_get_proc_cmdline.start()

        # CPU Cgroups compute usage based on /proc/stat and /sys/fs/cgroup/.../cpuacct.stat; use mock data for those
        # files
        original_read_file = fileutil.read_file

        def mock_read_file(filepath, **args):
            if filepath == "/proc/stat":
                filepath = os.path.join(data_dir, "cgroups", "proc_stat_t0")
            elif filepath.endswith("/cpuacct.stat"):
                filepath = os.path.join(data_dir, "cgroups", "cpuacct.stat_t0")
            return original_read_file(filepath, **args)

        cls._mock_read_cpu_cgroup_file = patch(
            "azurelinuxagent.common.utils.fileutil.read_file",
            side_effect=mock_read_file)
        cls._mock_read_cpu_cgroup_file.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_get_memory_usage_from_proc_statm.stop()
        cls.mock_get_tracked_processes.stop()
        cls.mock_get_proc_name.stop()
        cls.mock_get_proc_cmdline.stop()
        cls._mock_read_cpu_cgroup_file.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)
        CGroupsTelemetry.reset()

    def tearDown(self):
        AgentTestCase.tearDown(self)
        CGroupsTelemetry.reset()

    @staticmethod
    def _track_new_extension_cgroups(num_extensions):
        for i in range(num_extensions):
            dummy_cpu_cgroup = CGroup.create("dummy_cpu_path_{0}".format(i),
                                             "cpu",
                                             "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_cpu_cgroup)

            dummy_memory_cgroup = CGroup.create(
                "dummy_memory_path_{0}".format(i), "memory",
                "dummy_extension_{0}".format(i))
            CGroupsTelemetry.track_cgroup(dummy_memory_cgroup)

    def _assert_cgroups_are_tracked(self, num_extensions):
        for i in range(num_extensions):
            self.assertTrue(
                CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertTrue(
                CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

    def _assert_calculated_resource_metrics_equal(self,
                                                  cpu_usage,
                                                  memory_usage,
                                                  max_memory_usage,
                                                  memory_statm_memory_usage,
                                                  proc_ids=None):
        if not proc_ids:
            proc_ids = TestCGroupsTelemetry.TestProcessIds

        processes_instances = [
            CGroupsTelemetry.get_process_info_summary(pid) for pid in proc_ids
        ]
        for _, cgroup_metric in CGroupsTelemetry._cgroup_metrics.items():
            self.assertListEqual(cgroup_metric.get_memory_metrics()._data,
                                 memory_usage)
            self.assertListEqual(cgroup_metric.get_max_memory_metrics()._data,
                                 max_memory_usage)
            self.assertListEqual(cgroup_metric.get_cpu_metrics()._data,
                                 cpu_usage)
            for kv_pair in cgroup_metric.get_proc_statm_memory_metrics():
                self.assertIn(kv_pair.pid_name_cmdline, processes_instances)
                self.assertListEqual(kv_pair.resource_metric._data,
                                     memory_statm_memory_usage)

    def _assert_polled_metrics_equal(self,
                                     metrics,
                                     cpu_metric_value,
                                     memory_metric_value,
                                     max_memory_metric_value,
                                     proc_stat_memory_usage_value,
                                     pids=None):
        for metric in metrics:
            self.assertIn(metric.category, ["Process", "Memory"])
            if metric.category == "Process":
                self.assertEqual(metric.counter, "% Processor Time")
                self.assertEqual(metric.value, cpu_metric_value)
            if metric.category == "Memory":
                self.assertIn(metric.counter, [
                    "Total Memory Usage", "Max Memory Usage",
                    "Memory Used by Process"
                ])
                if metric.counter == "Total Memory Usage":
                    self.assertEqual(metric.value, memory_metric_value)
                elif metric.counter == "Max Memory Usage":
                    self.assertEqual(metric.value, max_memory_metric_value)
                elif metric.counter == "Memory Used by Process":
                    if pids:
                        processes_instances = [
                            CGroupsTelemetry.get_process_info_summary(pid)
                            for pid in pids
                        ]
                    else:
                        processes_instances = [
                            CGroupsTelemetry.get_process_info_summary(pid)
                            for pid in TestCGroupsTelemetry.TestProcessIds
                        ]
                    self.assertIn(metric.instance, processes_instances)
                    self.assertEqual(metric.value,
                                     proc_stat_memory_usage_value)

    def _assert_extension_metrics_data(self,
                                       collected_metrics,
                                       num_extensions,
                                       cpu_percent_values,
                                       proc_stat_memory_usage_values,
                                       memory_usage_values,
                                       max_memory_usage_values,
                                       is_cpu_present=True,
                                       is_memory_present=True):
        num_summarization_values = TestCGroupsTelemetry.NumSummarizationValues

        if not (is_cpu_present or is_memory_present):
            self.assertEquals(collected_metrics, {})
            return
        else:
            for i in range(num_extensions):
                name = "dummy_extension_{0}".format(i)

                if is_memory_present:
                    self.assertIn(name, collected_metrics)
                    self.assertIn("memory", collected_metrics[name])
                    self.assertIn("cur_mem", collected_metrics[name]["memory"])
                    self.assertIn("max_mem", collected_metrics[name]["memory"])
                    self.assertEqual(
                        num_summarization_values,
                        len(collected_metrics[name]["memory"]["cur_mem"]))
                    self.assertEqual(
                        num_summarization_values,
                        len(collected_metrics[name]["memory"]["max_mem"]))

                    self.assertIn("proc_statm_memory", collected_metrics[name])
                    self.assertEqual(
                        3, len(collected_metrics[name]["proc_statm_memory"]
                               ))  # number of processes added
                    for tracked_process in collected_metrics[name][
                            "proc_statm_memory"]:
                        self.assertEqual(
                            num_summarization_values,
                            len(collected_metrics[name]["proc_statm_memory"]
                                [tracked_process]))
                        self.assertListEqual(
                            generate_metric_list(
                                proc_stat_memory_usage_values),
                            collected_metrics[name]["proc_statm_memory"]
                            [tracked_process][0:5])

                    self.assertListEqual(
                        generate_metric_list(memory_usage_values),
                        collected_metrics[name]["memory"]["cur_mem"][0:5])
                    self.assertListEqual(
                        generate_metric_list(max_memory_usage_values),
                        collected_metrics[name]["memory"]["max_mem"][0:5])

                if is_cpu_present:
                    self.assertIn("cpu", collected_metrics[name])
                    self.assertIn("cur_cpu", collected_metrics[name]["cpu"])
                    self.assertEqual(
                        num_summarization_values,
                        len(collected_metrics[name]["cpu"]["cur_cpu"]))
                    self.assertListEqual(
                        generate_metric_list(cpu_percent_values),
                        collected_metrics[name]["cpu"]["cur_cpu"][0:5])

    def test_telemetry_polling_with_active_cgroups(self, *args):
        num_extensions = 3

        self._track_new_extension_cgroups(num_extensions)

        with patch(
                "azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage"
        ) as patch_get_memory_max_usage:
            with patch(
                    "azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage"
            ) as patch_get_memory_usage:
                with patch(
                        "azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage"
                ) as patch_get_cpu_usage:
                    with patch("azurelinuxagent.common.cgroup.CGroup.is_active"
                               ) as patch_is_active:
                        patch_is_active.return_value = True

                        current_cpu = 30
                        current_memory = 209715200
                        current_max_memory = 471859200
                        current_proc_statm = TestCGroupsTelemetry.TestProcStatmMemoryUsed

                        # 1 CPU metric + 1 Current Memory + 1 Max memor + num_processes * memory from statm
                        num_of_metrics_per_extn_expected = 1 + 1 + 1 + 3 * 1
                        patch_get_cpu_usage.return_value = current_cpu
                        patch_get_memory_usage.return_value = current_memory  # example 200 MB
                        patch_get_memory_max_usage.return_value = current_max_memory  # example 450 MB
                        num_polls = 10

                        for data_count in range(1, num_polls + 1):
                            metrics = CGroupsTelemetry.poll_all_tracked()

                            self.assertEqual(
                                len(CGroupsTelemetry._cgroup_metrics),
                                num_extensions)
                            self._assert_calculated_resource_metrics_equal(
                                cpu_usage=[current_cpu] * data_count,
                                memory_usage=[current_memory] * data_count,
                                max_memory_usage=[current_max_memory] *
                                data_count,
                                proc_ids=TestCGroupsTelemetry.TestProcessIds,
                                memory_statm_memory_usage=[current_proc_statm
                                                           ] * data_count)
                            self.assertEqual(
                                len(metrics), num_extensions *
                                num_of_metrics_per_extn_expected)
                            self._assert_polled_metrics_equal(
                                metrics, current_cpu, current_memory,
                                current_max_memory, current_proc_statm)

        collected_metrics = CGroupsTelemetry.report_all_tracked()

        self._assert_extension_metrics_data(
            collected_metrics,
            num_extensions, [current_cpu] * num_polls,
            [TestCGroupsTelemetry.TestProcStatmMemoryUsed] * num_polls,
            [current_memory] * num_polls, [current_max_memory] * num_polls,
            is_cpu_present=False)

        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                         num_extensions)
        self._assert_calculated_resource_metrics_equal([], [], [], [], [])

    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage",
           side_effect=raise_ioerror)
    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage",
           side_effect=raise_ioerror)
    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage",
           side_effect=raise_ioerror)
    @patch("azurelinuxagent.common.cgroup.CGroup.is_active",
           return_value=False)
    def test_telemetry_polling_with_inactive_cgroups(self, *_):
        num_extensions = 5
        no_extensions_expected = 0

        self._track_new_extension_cgroups(num_extensions)
        self._assert_cgroups_are_tracked(num_extensions)

        metrics = CGroupsTelemetry.poll_all_tracked()

        for i in range(num_extensions):
            self.assertFalse(
                CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertFalse(
                CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                         num_extensions)
        self._assert_calculated_resource_metrics_equal([], [], [], [],
                                                       proc_ids=None)
        self.assertEqual(len(metrics), 0)

        collected_metrics = CGroupsTelemetry.report_all_tracked()
        self._assert_extension_metrics_data(collected_metrics,
                                            num_extensions, [], [], [], [],
                                            is_cpu_present=False,
                                            is_memory_present=False)
        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                         no_extensions_expected)
        self._assert_calculated_resource_metrics_equal([], [], [], [], [])

    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage")
    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage")
    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage")
    @patch("azurelinuxagent.common.cgroup.CGroup.is_active")
    @patch(
        "azurelinuxagent.common.resourceusage.MemoryResourceUsage.get_memory_usage_from_proc_statm"
    )
    def test_telemetry_polling_with_changing_cgroups_state(
            self, patch_get_statm, patch_is_active, patch_get_cpu_usage,
            patch_get_mem, patch_get_max_mem, *args):
        num_extensions = 5
        self._track_new_extension_cgroups(num_extensions)

        patch_is_active.return_value = True

        no_extensions_expected = 0
        expected_data_count = 1

        current_cpu = 30
        current_memory = 209715200
        current_max_memory = 471859200
        current_proc_statm = 20000000

        patch_get_cpu_usage.return_value = current_cpu
        patch_get_mem.return_value = current_memory  # example 200 MB
        patch_get_max_mem.return_value = current_max_memory  # example 450 MB
        patch_get_statm.return_value = current_proc_statm

        self._assert_cgroups_are_tracked(num_extensions)
        CGroupsTelemetry.poll_all_tracked()

        self._assert_cgroups_are_tracked(num_extensions)

        patch_is_active.return_value = False
        patch_get_cpu_usage.side_effect = raise_ioerror
        patch_get_mem.side_effect = raise_ioerror
        patch_get_max_mem.side_effect = raise_ioerror
        patch_get_statm.side_effect = raise_ioerror

        CGroupsTelemetry.poll_all_tracked()

        for i in range(num_extensions):
            self.assertFalse(
                CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertFalse(
                CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                         num_extensions)
        self._assert_calculated_resource_metrics_equal(
            cpu_usage=[current_cpu] * expected_data_count,
            memory_usage=[current_memory] * expected_data_count,
            max_memory_usage=[current_max_memory] * expected_data_count,
            proc_ids=TestCGroupsTelemetry.TestProcessIds,
            memory_statm_memory_usage=[current_proc_statm] *
            expected_data_count)

        CGroupsTelemetry.report_all_tracked()

        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                         no_extensions_expected)
        self._assert_calculated_resource_metrics_equal([], [], [], [], [])

    # mocking get_proc_stat to make it run on Mac and other systems. This test does not need to read the values of the
    # /proc/stat file on the filesystem.
    @patch("azurelinuxagent.common.logger.periodic_warn")
    def test_telemetry_polling_to_not_generate_transient_logs_ioerror_file_not_found(
            self, patch_periodic_warn):
        num_extensions = 1
        self._track_new_extension_cgroups(num_extensions)
        self.assertEqual(0, patch_periodic_warn.call_count)

        # Not expecting logs present for io_error with errno=errno.ENOENT
        io_error_2 = IOError()
        io_error_2.errno = errno.ENOENT

        with patch("azurelinuxagent.common.utils.fileutil.read_file",
                   side_effect=io_error_2):
            poll_count = 1
            for data_count in range(poll_count, 10):
                CGroupsTelemetry.poll_all_tracked()
                self.assertEqual(0, patch_periodic_warn.call_count)

    @patch("azurelinuxagent.common.logger.periodic_warn")
    def test_telemetry_polling_to_generate_transient_logs_ioerror_permission_denied(
            self, patch_periodic_warn):
        num_extensions = 1
        num_controllers = 2
        is_active_check_per_controller = 2
        self._track_new_extension_cgroups(num_extensions)

        self.assertEqual(0, patch_periodic_warn.call_count)

        # Expecting logs to be present for different kind of errors
        io_error_3 = IOError()
        io_error_3.errno = errno.EPERM

        with patch("azurelinuxagent.common.utils.fileutil.read_file",
                   side_effect=io_error_3):
            poll_count = 1
            expected_count_per_call = num_controllers + is_active_check_per_controller
            # each collect per controller would generate a log statement, and each cgroup would invoke a
            # is active check raising an exception

            for data_count in range(poll_count, 10):
                CGroupsTelemetry.poll_all_tracked()
                self.assertEqual(poll_count * expected_count_per_call,
                                 patch_periodic_warn.call_count)

    def test_telemetry_polling_to_generate_transient_logs_index_error(self):
        num_extensions = 1
        self._track_new_extension_cgroups(num_extensions)

        # Generating a different kind of error (non-IOError) to check the logging.
        # Trying to invoke IndexError during the getParameter call
        with patch("azurelinuxagent.common.utils.fileutil.read_file",
                   return_value=''):
            with patch("azurelinuxagent.common.logger.periodic_warn"
                       ) as patch_periodic_warn:
                expected_call_count = 2  # 1 periodic warning for the cpu cgroups, and 1 for memory
                for data_count in range(1, 10):
                    CGroupsTelemetry.poll_all_tracked()
                    self.assertEqual(expected_call_count,
                                     patch_periodic_warn.call_count)

    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage")
    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage")
    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage")
    @patch("azurelinuxagent.common.cgroup.CGroup.is_active")
    @patch(
        "azurelinuxagent.common.resourceusage.MemoryResourceUsage.get_memory_usage_from_proc_statm"
    )
    def test_telemetry_calculations(self, patch_get_statm, patch_is_active,
                                    patch_get_cpu_usage,
                                    patch_get_memory_usage,
                                    patch_get_memory_max_usage, *args):
        num_polls = 10
        num_extensions = 1

        cpu_percent_values = [random.randint(0, 100) for _ in range(num_polls)]

        # only verifying calculations and not validity of the values.
        memory_usage_values = [
            random.randint(0, 8 * 1024**3) for _ in range(num_polls)
        ]
        max_memory_usage_values = [
            random.randint(0, 8 * 1024**3) for _ in range(num_polls)
        ]
        proc_stat_memory_usage_values = [
            random.randint(0, 8 * 1024**3) for _ in range(num_polls)
        ]

        self._track_new_extension_cgroups(num_extensions)
        self.assertEqual(2 * num_extensions, len(CGroupsTelemetry._tracked))

        for i in range(num_polls):
            patch_is_active.return_value = True
            patch_get_cpu_usage.return_value = cpu_percent_values[i]
            patch_get_memory_usage.return_value = memory_usage_values[
                i]  # example 200 MB
            patch_get_memory_max_usage.return_value = max_memory_usage_values[
                i]  # example 450 MB
            patch_get_statm.return_value = proc_stat_memory_usage_values[i]

            metrics = CGroupsTelemetry.poll_all_tracked()

            # 1 CPU metric + 1 Current Memory + 1 Max memory + num_processes (3) * memory from statm
            self.assertEqual(len(metrics), 6 * num_extensions)
            self._assert_polled_metrics_equal(metrics, cpu_percent_values[i],
                                              memory_usage_values[i],
                                              max_memory_usage_values[i],
                                              proc_stat_memory_usage_values[i])

        collected_metrics = CGroupsTelemetry.report_all_tracked()
        self._assert_extension_metrics_data(collected_metrics, num_extensions,
                                            cpu_percent_values,
                                            proc_stat_memory_usage_values,
                                            memory_usage_values,
                                            max_memory_usage_values)

    def test_cgroup_tracking(self, *args):
        num_extensions = 5
        num_controllers = 2
        self._track_new_extension_cgroups(num_extensions)
        self._assert_cgroups_are_tracked(num_extensions)
        self.assertEqual(num_extensions * num_controllers,
                         len(CGroupsTelemetry._tracked))

    def test_cgroup_pruning(self, *args):
        num_extensions = 5
        num_controllers = 2
        self._track_new_extension_cgroups(num_extensions)
        self._assert_cgroups_are_tracked(num_extensions)
        self.assertEqual(num_extensions * num_controllers,
                         len(CGroupsTelemetry._tracked))

        CGroupsTelemetry.prune_all_tracked()
        for i in range(num_extensions):
            self.assertFalse(
                CGroupsTelemetry.is_tracked("dummy_cpu_path_{0}".format(i)))
            self.assertFalse(
                CGroupsTelemetry.is_tracked("dummy_memory_path_{0}".format(i)))

        self.assertEqual(0, len(CGroupsTelemetry._tracked))

    def test_cgroup_is_tracked(self, *args):
        num_extensions = 5
        self._track_new_extension_cgroups(num_extensions)
        self._assert_cgroups_are_tracked(num_extensions)
        self.assertFalse(
            CGroupsTelemetry.is_tracked("not_present_cpu_dummy_path"))
        self.assertFalse(
            CGroupsTelemetry.is_tracked("not_present_memory_dummy_path"))

    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage",
           side_effect=raise_ioerror)
    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage",
           side_effect=raise_ioerror)
    def test_process_cgroup_metric_with_incorrect_cgroups_mounted(self, *args):
        num_extensions = 5
        self._track_new_extension_cgroups(num_extensions)

        for data_count in range(1, 10):
            metrics = CGroupsTelemetry.poll_all_tracked()
            self.assertEqual(len(metrics), 0)

        self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                         num_extensions)

        collected_metrics = {}
        for name, cgroup_metrics in CGroupsTelemetry._cgroup_metrics.items():
            collected_metrics[name] = CGroupsTelemetry._process_cgroup_metric(
                cgroup_metrics)
            self.assertEqual(collected_metrics[name], {})  # empty

    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage",
           side_effect=raise_ioerror)
    def test_process_cgroup_metric_with_no_memory_cgroup_mounted(self, *args):
        num_extensions = 5
        self._track_new_extension_cgroups(num_extensions)

        with patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage"
                   ) as patch_get_cpu_usage:
            with patch("azurelinuxagent.common.cgroup.CGroup.is_active"
                       ) as patch_is_active:
                patch_is_active.return_value = True

                current_cpu = 30
                patch_get_cpu_usage.return_value = current_cpu

                poll_count = 1

                for data_count in range(poll_count, 10):
                    metrics = CGroupsTelemetry.poll_all_tracked()

                    self.assertEqual(
                        CGroupsTelemetry._cgroup_metrics.__len__(),
                        num_extensions)
                    self._assert_calculated_resource_metrics_equal(
                        cpu_usage=[current_cpu] * data_count,
                        memory_usage=[],
                        max_memory_usage=[],
                        proc_ids=[],
                        memory_statm_memory_usage=[])
                    self.assertEqual(len(metrics),
                                     num_extensions * 1)  # Only CPU populated
                    self._assert_polled_metrics_equal(metrics, current_cpu, 0,
                                                      0, 0)

                CGroupsTelemetry.report_all_tracked()

                self.assertEqual(CGroupsTelemetry._cgroup_metrics.__len__(),
                                 num_extensions)
                self._assert_calculated_resource_metrics_equal([], [], [], [],
                                                               [])

    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage",
           side_effect=raise_ioerror)
    def test_process_cgroup_metric_with_no_cpu_cgroup_mounted(self, *args):
        num_extensions = 5

        self._track_new_extension_cgroups(num_extensions)

        with patch(
                "azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage"
        ) as patch_get_memory_max_usage:
            with patch(
                    "azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage"
            ) as patch_get_memory_usage:
                with patch("azurelinuxagent.common.cgroup.CGroup.is_active"
                           ) as patch_is_active:
                    patch_is_active.return_value = True

                    current_memory = 209715200
                    current_max_memory = 471859200

                    patch_get_memory_usage.return_value = current_memory  # example 200 MB
                    patch_get_memory_max_usage.return_value = current_max_memory  # example 450 MB
                    num_polls = 10
                    for data_count in range(1, num_polls + 1):
                        metrics = CGroupsTelemetry.poll_all_tracked()
                        self.assertEqual(len(CGroupsTelemetry._cgroup_metrics),
                                         num_extensions)
                        self._assert_calculated_resource_metrics_equal(
                            cpu_usage=[],
                            memory_usage=[current_memory] * data_count,
                            max_memory_usage=[current_max_memory] * data_count,
                            memory_statm_memory_usage=[
                                TestCGroupsTelemetry.TestProcStatmMemoryUsed
                            ] * data_count,
                            proc_ids=TestCGroupsTelemetry.TestProcessIds)
                        # Memory is only populated, CPU is not. Thus 5 metrics per cgroup.
                        self.assertEqual(len(metrics), num_extensions * 5)
                        self._assert_polled_metrics_equal(
                            metrics, 0, current_memory, current_max_memory,
                            TestCGroupsTelemetry.TestProcStatmMemoryUsed)

                    collected_metrics = CGroupsTelemetry.report_all_tracked()
                    self._assert_extension_metrics_data(
                        collected_metrics,
                        num_extensions, [],
                        [TestCGroupsTelemetry.TestProcStatmMemoryUsed] *
                        num_polls, [current_memory] * num_polls,
                        [current_max_memory] * num_polls,
                        is_cpu_present=False)

                    self.assertEqual(len(CGroupsTelemetry._cgroup_metrics),
                                     num_extensions)
                    self._assert_calculated_resource_metrics_equal([], [], [],
                                                                   [], [])

    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_memory_usage",
           side_effect=raise_ioerror)
    @patch("azurelinuxagent.common.cgroup.MemoryCgroup.get_max_memory_usage",
           side_effect=raise_ioerror)
    @patch("azurelinuxagent.common.cgroup.CpuCgroup.get_cpu_usage",
           side_effect=raise_ioerror)
    def test_extension_telemetry_not_sent_for_empty_perf_metrics(self, *args):
        num_extensions = 5
        self._track_new_extension_cgroups(num_extensions)

        with patch("azurelinuxagent.common.cgroupstelemetry.CGroupsTelemetry._process_cgroup_metric") as \
                patch_process_cgroup_metric:
            with patch("azurelinuxagent.common.cgroup.CGroup.is_active"
                       ) as patch_is_active:

                patch_is_active.return_value = False
                patch_process_cgroup_metric.return_value = {}
                poll_count = 1

                for data_count in range(poll_count, 10):
                    metrics = CGroupsTelemetry.poll_all_tracked()
                    self.assertEqual(0, len(metrics))

                collected_metrics = CGroupsTelemetry.report_all_tracked()
                self.assertEqual(0, len(collected_metrics))

    @skip_if_predicate_true(
        lambda: True,
        "Skipping this test currently: We need two different tests - one for "
        "FileSystemCgroupAPI based test and one for SystemDCgroupAPI based test. @vrdmr will "
        "be splitting this test in subsequent PRs")
    @skip_if_predicate_false(are_cgroups_enabled,
                             "Does not run when Cgroups are not enabled")
    @skip_if_predicate_true(is_trusty_in_travis,
                            "Does not run on Trusty in Travis")
    @attr('requires_sudo')
    @patch("azurelinuxagent.common.cgroupconfigurator.get_osutil",
           return_value=DefaultOSUtil())
    @patch("azurelinuxagent.common.cgroupapi.CGroupsApi._is_systemd",
           return_value=False)
    def test_telemetry_with_tracked_cgroup(self, *_):
        self.assertTrue(i_am_root(), "Test does not run when non-root")
        CGroupConfigurator._instance = None

        max_num_polls = 30
        time_to_wait = 3
        extn_name = "foobar-1.0.0"
        num_summarization_values = 7

        cgs = make_new_cgroup(extn_name)
        self.assertEqual(len(cgs), 2)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        self.ext_handler = ExtHandler(name='foobar')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        command = self.create_script(
            "keep_cpu_busy_and_consume_memory_for_5_seconds", '''
nohup python -c "import time

for i in range(5):
    x = [1, 2, 3, 4, 5] * (i * 1000)
    time.sleep({0})
    x *= 0
    print('Test loop')" &
'''.format(time_to_wait))

        self.log_dir = os.path.join(self.tmp_dir, "log")

        with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir) as \
                patch_get_base_dir:
            with patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir) as \
                    patch_get_log_dir:
                self.ext_handler_instance.launch_command(command)

        self.assertTrue(
            CGroupsTelemetry.is_tracked(
                os.path.join(BASE_CGROUPS, "cpu", "walinuxagent.extensions",
                             "foobar_1.0.0")))
        self.assertTrue(
            CGroupsTelemetry.is_tracked(
                os.path.join(BASE_CGROUPS, "memory", "walinuxagent.extensions",
                             "foobar_1.0.0")))

        for i in range(max_num_polls):
            CGroupsTelemetry.poll_all_tracked()
            time.sleep(0.5)

        collected_metrics = CGroupsTelemetry.report_all_tracked()

        self.assertIn("memory", collected_metrics[extn_name])
        self.assertIn("cur_mem", collected_metrics[extn_name]["memory"])
        self.assertIn("max_mem", collected_metrics[extn_name]["memory"])
        self.assertEqual(
            len(collected_metrics[extn_name]["memory"]["cur_mem"]),
            num_summarization_values)
        self.assertEqual(
            len(collected_metrics[extn_name]["memory"]["max_mem"]),
            num_summarization_values)

        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["cur_mem"][5], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["cur_mem"][6], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["max_mem"][5], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["memory"]["max_mem"][6], str)

        self.assertIn("cpu", collected_metrics[extn_name])
        self.assertIn("cur_cpu", collected_metrics[extn_name]["cpu"])
        self.assertEqual(len(collected_metrics[extn_name]["cpu"]["cur_cpu"]),
                         num_summarization_values)

        self.assertIsInstance(
            collected_metrics[extn_name]["cpu"]["cur_cpu"][5], str)
        self.assertIsInstance(
            collected_metrics[extn_name]["cpu"]["cur_cpu"][6], str)

        for i in range(5):
            self.assertGreater(
                collected_metrics[extn_name]["memory"]["cur_mem"][i], 0)
            self.assertGreater(
                collected_metrics[extn_name]["memory"]["max_mem"][i], 0)
            self.assertGreaterEqual(
                collected_metrics[extn_name]["cpu"]["cur_cpu"][i], 0)
class DownloadExtensionTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
        cls.mock_cgroups.start()

        cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
        cls.mock_cgroups_telemetry.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups_telemetry.stop()
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
        ext_handler.properties = ext_handler_properties

        protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

        self.pkg = ExtHandlerPackage()
        self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ]
        self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol)
        self.ext_handler_instance.pkg = self.pkg

        self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir)
        self.mock_get_base_dir.start()

        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir)
        self.mock_get_log_dir.start()

        self.agent_dir = self.tmp_dir
        self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir)
        self.mock_get_lib_dir.start()

    def tearDown(self):
        self.mock_get_lib_dir.stop()
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    _extension_command = "RunCommandLinux.sh"

    @staticmethod
    def _create_zip_file(filename):
        file = None
        try:
            file = zipfile.ZipFile(filename, "w")
            info = zipfile.ZipInfo(DownloadExtensionTestCase._extension_command)
            info.date_time = time.localtime(time.time())[:6]
            info.compress_type = zipfile.ZIP_DEFLATED
            file.writestr(info, "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n")
        finally:
            if file is not None:
                file.close()

    @staticmethod
    def _create_invalid_zip_file(filename):
        with open(filename, "w") as file:
            file.write("An invalid ZIP file\n")

    def _get_extension_package_file(self):
        return os.path.join(self.agent_dir, os.path.basename(self.pkg.uris[0].uri) + ".zip")

    def _get_extension_command_file(self):
        return os.path.join(self.extension_dir, DownloadExtensionTestCase._extension_command)

    def _assert_download_and_expand_succeeded(self):
        self.assertTrue(os.path.exists(self._get_extension_package_file()), "The extension package was not downloaded to the expected location")
        self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location")

    def test_it_should_download_and_expand_extension_package(self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        # first download attempt should succeed
        mock_download_ext_handler_pkg.assert_called_once()

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_existing_extension_package_when_already_downloaded(self):
        DownloadExtensionTestCase._create_zip_file(self._get_extension_package_file())

        with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg") as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_not_called()

        self.assertTrue(os.path.exists(self._get_extension_command_file()), "The extension package was not expanded to the expected location")

    def test_it_should_ignore_existing_extension_package_when_it_is_invalid(self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file())

        with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_called_once()

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_download_fails(self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                return False
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_download_raises_an_exception(self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                raise Exception("Download failed")
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                DownloadExtensionTestCase._create_invalid_zip_file(destination)
            else:
                DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count, self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_raise_an_exception_when_all_downloads_fail(self):
        def download_ext_handler_pkg(_uri, _destination):
            DownloadExtensionTestCase._create_invalid_zip_file(self._get_extension_package_file())
            return True

        with patch("time.sleep", lambda *_: None):
            with patch("azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg", side_effect=download_ext_handler_pkg) as mock_download_ext_handler_pkg:
                with self.assertRaises(ExtensionDownloadError) as context_manager:
                    self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count, NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris))

        self.assertRegex(str(context_manager.exception), "Failed to download extension")
        self.assertEquals(context_manager.exception.code, ExtensionErrorCodes.PluginManifestDownloadError)

        self.assertFalse(os.path.exists(self.extension_dir), "The extension directory was not removed")
        self.assertFalse(os.path.exists(self._get_extension_package_file()), "The extension package was not removed")
class DownloadExtensionTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch(
            "azurelinuxagent.ga.exthandlers.CGroupConfigurator")
        cls.mock_cgroups.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
        ext_handler.properties = ext_handler_properties

        protocol = WireProtocol(
            "http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

        self.pkg = ExtHandlerPackage()
        self.pkg.uris = [
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri(),
            ExtHandlerVersionUri()
        ]
        self.pkg.uris[
            0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[
            4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler,
                                                       protocol=protocol)
        self.ext_handler_instance.pkg = self.pkg

        self.extension_dir = os.path.join(
            self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            return_value=self.extension_dir)
        self.mock_get_base_dir.start()

        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            return_value=self.tmp_dir)
        self.mock_get_log_dir.start()

        self.agent_dir = self.tmp_dir
        self.mock_get_lib_dir = patch(
            "azurelinuxagent.ga.exthandlers.conf.get_lib_dir",
            return_value=self.agent_dir)
        self.mock_get_lib_dir.start()

    def tearDown(self):
        self.mock_get_lib_dir.stop()
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    _extension_command = "RunCommandLinux.sh"

    @staticmethod
    def _create_zip_file(filename):
        file = None
        try:
            file = zipfile.ZipFile(filename, "w")
            info = zipfile.ZipInfo(
                DownloadExtensionTestCase._extension_command)
            info.date_time = time.localtime(time.time())[:6]
            info.compress_type = zipfile.ZIP_DEFLATED
            file.writestr(
                info,
                "#!/bin/sh\necho 'RunCommandLinux executed successfully'\n")
        finally:
            if file is not None:
                file.close()

    @staticmethod
    def _create_invalid_zip_file(filename):
        with open(filename, "w") as file:
            file.write("An invalid ZIP file\n")

    def _get_extension_package_file(self):
        return os.path.join(
            self.agent_dir,
            self.ext_handler_instance.get_extension_package_zipfile_name())

    def _get_extension_command_file(self):
        return os.path.join(self.extension_dir,
                            DownloadExtensionTestCase._extension_command)

    def _assert_download_and_expand_succeeded(self):
        self.assertTrue(
            os.path.exists(self._get_extension_package_file()),
            "The extension package was not downloaded to the expected location"
        )
        self.assertTrue(
            os.path.exists(self._get_extension_command_file()),
            "The extension package was not expanded to the expected location")

    def test_it_should_download_and_expand_extension_package(self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            with patch(
                    "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.report_event"
            ) as mock_report_event:
                self.ext_handler_instance.download()

        # first download attempt should succeed
        mock_download_ext_handler_pkg.assert_called_once()
        mock_report_event.assert_called_once()

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_existing_extension_package_when_already_downloaded(
            self):
        DownloadExtensionTestCase._create_zip_file(
            self._get_extension_package_file())

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg"
        ) as mock_download_ext_handler_pkg:
            with patch(
                    "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.report_event"
            ) as mock_report_event:
                self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_not_called()
        mock_report_event.assert_not_called()

        self.assertTrue(
            os.path.exists(self._get_extension_command_file()),
            "The extension package was not expanded to the expected location")

    def test_it_should_ignore_existing_extension_package_when_it_is_invalid(
            self):
        def download_ext_handler_pkg(_uri, destination):
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        DownloadExtensionTestCase._create_invalid_zip_file(
            self._get_extension_package_file())

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        mock_download_ext_handler_pkg.assert_called_once()

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_download_fails(self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                return False
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_download_raises_an_exception(
            self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                raise Exception("Download failed")
            DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_use_alternate_uris_when_it_downloads_an_invalid_package(
            self):
        self.download_failures = 0

        def download_ext_handler_pkg(_uri, destination):
            # fail a few times, then succeed
            if self.download_failures < 3:
                self.download_failures += 1
                DownloadExtensionTestCase._create_invalid_zip_file(destination)
            else:
                DownloadExtensionTestCase._create_zip_file(destination)
            return True

        with patch(
                "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                side_effect=download_ext_handler_pkg
        ) as mock_download_ext_handler_pkg:
            self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          self.download_failures + 1)

        self._assert_download_and_expand_succeeded()

    def test_it_should_raise_an_exception_when_all_downloads_fail(self):
        def download_ext_handler_pkg(_uri, _destination):
            DownloadExtensionTestCase._create_invalid_zip_file(
                self._get_extension_package_file())
            return True

        with patch("time.sleep", lambda *_: None):
            with patch(
                    "azurelinuxagent.common.protocol.wire.WireProtocol.download_ext_handler_pkg",
                    side_effect=download_ext_handler_pkg
            ) as mock_download_ext_handler_pkg:
                with self.assertRaises(
                        ExtensionDownloadError) as context_manager:
                    self.ext_handler_instance.download()

        self.assertEquals(mock_download_ext_handler_pkg.call_count,
                          NUMBER_OF_DOWNLOAD_RETRIES * len(self.pkg.uris))

        self.assertRegex(str(context_manager.exception),
                         "Failed to download extension")
        self.assertEquals(context_manager.exception.code,
                          ExtensionErrorCodes.PluginManifestDownloadError)

        self.assertFalse(os.path.exists(self.extension_dir),
                         "The extension directory was not removed")
        self.assertFalse(os.path.exists(self._get_extension_package_file()),
                         "The extension package was not removed")
Exemple #26
0
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4"))

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()

    def tearDown(self):
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()
        self.mock_sleep.stop()

        AgentTestCase.tearDown(self)

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'), 'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = "produce_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir) if not i.endswith(AGENT_EVENT_FILE_EXTENSION)] # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(files_before, files_after, "Not all temporary files were deleted. File list: {0}".format(files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = ExtensionErrorCodes.PluginHandlerScriptTimedout
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = "produce_output_then_hang.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        def sleep(seconds):
            if not os.path.exists(signal_file):
                sleep.original_sleep(seconds)
        sleep.original_sleep = time.sleep

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep, autospec=True) as mock_sleep:  # pylint: disable=redefined-outer-name

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(command, timeout=timeout, extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            command_full_path = os.path.join(self.tmp_dir, command.lstrip(os.path.sep))
            self.assertRegex(message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(command_full_path, LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEqual(context_manager.exception.code, extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # The command should have been terminated.
            # The /proc file system may still include the process when we do this check so we try a few times after a short delay; note that we
            # are mocking sleep, so we need to use the original implementation.
            terminated = False
            i = 0
            while not terminated and i < 4:
                if not LaunchCommandTestCase._find_process(command):
                    terminated = True
                else:
                    sleep.original_sleep(0.25)
                i += 1

            self.assertTrue(terminated, "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = "fail.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(exit_code, command, LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEqual(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = "start_child_process.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = "start_child_with_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr, more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = "start_child_that_fails.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = "create_file.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = "produce_output"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEqual(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = "produce_long_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr, int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = "produce_long_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch('azurelinuxagent.common.utils.extensionprocessutil.format_stdout_stderr', side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args  # pylint: disable=unused-variable
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = "produce_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')
        # Mocking the call to file.read() is difficult, so instead we mock the call to_capture_process_output,
        # which will call file.read() and we force stdout/stderr to be None; this will produce an exception when
        # trying to use these files.
        original_capture_process_output = read_output

        def capture_process_output(stdout_file, stderr_file):  # pylint: disable=unused-argument
            return original_capture_process_output(None, None)

        with patch('azurelinuxagent.common.utils.extensionprocessutil.read_output', side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)

    def test_it_should_contain_all_helper_environment_variables(self):

        wire_ip = str(uuid.uuid4())
        ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol(wire_ip))

        helper_env_vars = {ExtCommandEnvVariable.ExtensionSeqNumber: _DEFAULT_SEQ_NO,
                           ExtCommandEnvVariable.ExtensionPath: self.tmp_dir,
                           ExtCommandEnvVariable.ExtensionVersion: ext_handler_instance.ext_handler.properties.version,
                           ExtCommandEnvVariable.WireProtocolAddress: wire_ip}

        command = """
            printenv | grep -E '(%s)'
        """ % '|'.join(helper_env_vars.keys())

        test_file = 'printHelperEnvironments.sh'
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), test_file), command)

        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
            # Returning empty list for get_agent_supported_features_list_for_extensions as we have a separate test for it
            with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
                       return_value={}):
                output = ext_handler_instance.launch_command(test_file)

            args, kwagrs = patch_popen.call_args  # pylint: disable=unused-variable
            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items() if k not in os.environ)

            # This check will fail if any helper environment variables are added/removed later on
            self.assertEqual(helper_env_vars, without_os_env)

            # This check is checking if the expected values are set for the extension commands
            for helper_var in helper_env_vars:
                self.assertIn("%s=%s" % (helper_var, helper_env_vars[helper_var]), output)

    def test_it_should_pass_supported_features_list_as_environment_variables(self):

        class TestFeature(AgentSupportedFeature):

            def __init__(self, name, version, supported):
                super(TestFeature, self).__init__(name=name,
                                                  version=version,
                                                  supported=supported)

        test_name = str(uuid.uuid4())
        test_version = str(uuid.uuid4())

        command = "check_env.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import json
import sys

features = os.getenv("{0}")
if not features:
    print("{0} not found in environment")
    sys.exit(0)
l = json.loads(features)
found = False
for feature in l:
    if feature['Key'] == "{1}" and feature['Value'] == "{2}":
        found = True
        break
    
print("Found Feature %s: %s" % ("{1}", found))
'''.format(ExtCommandEnvVariable.ExtensionSupportedFeatures, test_name, test_version))

        # It should include all supported features and pass it as Environment Variable to extensions
        test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=True)}
        with patch("azurelinuxagent.common.agent_supported_feature.__EXTENSION_ADVERTISED_FEATURES",
                   test_supported_features):
            output = self.ext_handler_instance.launch_command(command)

            self.assertIn("[stdout]\nFound Feature {0}: True".format(test_name), output, "Feature not found")

        # It should not include the feature if feature not supported
        test_supported_features = {
            test_name: TestFeature(name=test_name, version=test_version, supported=False),
            "testFeature": TestFeature(name="testFeature", version="1.2.1", supported=True)
        }
        with patch("azurelinuxagent.common.agent_supported_feature.__EXTENSION_ADVERTISED_FEATURES",
                   test_supported_features):
            output = self.ext_handler_instance.launch_command(command)

            self.assertIn("[stdout]\nFound Feature {0}: False".format(test_name), output, "Feature wrongfully found")

        # It should not include the SupportedFeatures Key in Environment variables if no features supported
        test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=False)}
        with patch("azurelinuxagent.common.agent_supported_feature.__EXTENSION_ADVERTISED_FEATURES",
                   test_supported_features):
            output = self.ext_handler_instance.launch_command(command)

            self.assertIn(
                "[stdout]\n{0} not found in environment".format(ExtCommandEnvVariable.ExtensionSupportedFeatures),
                output, "Environment variable should not be found")