Esempio n. 1
0
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
        cls.mock_cgroups.start()

        cls.mock_cgroups_telemetry = patch(
            "azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
        cls.mock_cgroups_telemetry.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups_telemetry.stop()
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler,
                                                       protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

    def tearDown(self):
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    def _create_script(self, file_name, contents):
        """
        Creates an executable script with the given contents.
        If file_name ends with ".py", it creates a Python3 script, otherwise it creates a bash script
        """
        file_path = os.path.join(self.ext_handler_instance.get_base_dir(),
                                 file_name)

        with open(file_path, "w") as script:
            if file_name.endswith(".py"):
                script.write("#!/usr/bin/env python3\n")
            else:
                script.write("#!/usr/bin/env bash\n")
            script.write(contents)

        os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)

        return file_name

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'),
                          'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = self._create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir)
                    if not i.endswith(".tld")]  # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(
            files_before, files_after,
            "Not all temporary files were deleted. File list: {0}".format(
                files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = 1234
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = self._create_script(
            "produce_output_then_hang.py", '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        original_sleep = time.sleep

        def sleep(seconds):
            if not os.path.exists(signal_file):
                original_sleep(seconds)

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep,
                   autospec=True) as mock_sleep:

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(
                    command,
                    timeout=timeout,
                    extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            self.assertRegex(
                message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(
                    command,
                    LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEquals(context_manager.exception.code,
                              extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # the command should have been terminated
            self.assertFalse(LaunchCommandTestCase._find_process(command),
                             "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = self._create_script(
            "fail.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(
                command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(
            message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(
                exit_code, command,
                LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEquals(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = self._create_script(
            "start_child_process.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script(
            "start_child_with_output.py", '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr,
           more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(
            self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = self._create_script(
            "start_child_that_fails.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script(
            "create_file.py", '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(
            self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = self._create_script(
            "produce_output", '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEquals(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = self._create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr,
           int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = self._create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch('azurelinuxagent.ga.exthandlers.format_stdout_stderr',
                   side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = self._create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to _capture_process_output, which will
        # call file.read() and we force stdout/stderr to be None; this will produce an exception when trying to use these files.
        original_capture_process_output = ExtHandlerInstance._capture_process_output

        def capture_process_output(process, stdout_file, stderr_file, cmd,
                                   timeout, code):
            return original_capture_process_output(process, None, None, cmd,
                                                   timeout, code)

        with patch(
                'azurelinuxagent.ga.exthandlers.ExtHandlerInstance._capture_process_output',
                side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)
Esempio n. 2
0
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    @classmethod
    def setUpClass(cls):
        AgentTestCase.setUpClass()
        cls.mock_cgroups = patch("azurelinuxagent.ga.exthandlers.CGroups")
        cls.mock_cgroups.start()

        cls.mock_cgroups_telemetry = patch("azurelinuxagent.ga.exthandlers.CGroupsTelemetry")
        cls.mock_cgroups_telemetry.start()

    @classmethod
    def tearDownClass(cls):
        cls.mock_cgroups_telemetry.stop()
        cls.mock_cgroups.stop()

        AgentTestCase.tearDownClass()

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        ext_handler = ExtHandler(name='foo')
        ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=None)

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

    def tearDown(self):
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()

        AgentTestCase.tearDown(self)

    def _create_script(self, file_name, contents):
        """
        Creates an executable script with the given contents.
        If file_name ends with ".py", it creates a Python3 script, otherwise it creates a bash script
        """
        file_path = os.path.join(self.ext_handler_instance.get_base_dir(), file_name)

        with open(file_path, "w") as script:
            if file_name.endswith(".py"):
                script.write("#!/usr/bin/env python3\n")
            else:
                script.write("#!/usr/bin/env bash\n")
            script.write(contents)

        os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)

        return file_name

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'), 'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = self._create_script("produce_output.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir) if not i.endswith(".tld")] # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(files_before, files_after, "Not all temporary files were deleted. File list: {0}".format(files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = 1234
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = self._create_script("produce_output_then_hang.py", '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        original_sleep = time.sleep

        def sleep(seconds):
            if not os.path.exists(signal_file):
                original_sleep(seconds)

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep, autospec=True) as mock_sleep:

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(command, timeout=timeout, extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            self.assertRegex(message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(command, LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEquals(context_manager.exception.code, extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # the command should have been terminated
            self.assertFalse(LaunchCommandTestCase._find_process(command), "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = self._create_script("fail.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(exit_code, command, LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEquals(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = self._create_script("start_child_process.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script("start_child_with_output.py", '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr, more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = self._create_script("start_child_that_fails.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self._create_script("create_file.py", '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = self._create_script("produce_output", '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEquals(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = self._create_script("produce_long_output.py", '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr, int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = self._create_script("produce_long_output.py", '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch('azurelinuxagent.ga.exthandlers.format_stdout_stderr', side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = self._create_script("produce_output.py", '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to _capture_process_output, which will
        # call file.read() and we force stdout/stderr to be None; this will produce an exception when trying to use these files.
        original_capture_process_output = ExtHandlerInstance._capture_process_output

        def capture_process_output(process, stdout_file, stderr_file, cmd, timeout, code):
            return original_capture_process_output(process, None, None, cmd, timeout, code)

        with patch('azurelinuxagent.ga.exthandlers.ExtHandlerInstance._capture_process_output', side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(
            ext_handler=self.ext_handler, protocol=None)

        self.mock_get_base_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir",
            lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch(
            "azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir",
            lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()

        self.cgroups_enabled = CGroupConfigurator.get_instance().enabled()
        CGroupConfigurator.get_instance().disable()

    def tearDown(self):
        if self.cgroups_enabled:
            CGroupConfigurator.get_instance().enable()
        else:
            CGroupConfigurator.get_instance().disable()

        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()
        self.mock_sleep.stop()

        AgentTestCase.tearDown(self)

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'),
                          'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = self.create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir)
                    if not i.endswith(".tld")]  # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(
            files_before, files_after,
            "Not all temporary files were deleted. File list: {0}".format(
                files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = ExtensionErrorCodes.PluginHandlerScriptTimedout
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = self.create_script(
            "produce_output_then_hang.py", '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        original_sleep = time.sleep

        def sleep(seconds):
            if not os.path.exists(signal_file):
                original_sleep(seconds)

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep,
                   autospec=True) as mock_sleep:

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(
                    command,
                    timeout=timeout,
                    extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            command_full_path = os.path.join(self.tmp_dir,
                                             command.lstrip(os.path.sep))
            self.assertRegex(
                message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(
                    command_full_path,
                    LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEquals(context_manager.exception.code,
                              extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # the command should have been terminated
            self.assertFalse(LaunchCommandTestCase._find_process(command),
                             "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = self.create_script(
            "fail.py", '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(
                command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(
            message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(
                exit_code, command,
                LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEquals(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = self.create_script(
            "start_child_process.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output,
                         LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self.create_script(
            "start_child_with_output.py", '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr,
           more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(
            self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = self.create_script(
            "start_child_that_fails.py", '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = self.create_script(
            "create_file.py", '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(
            self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = self.create_script(
            "produce_output", '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEquals(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = self.create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr,
           int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = self.create_script(
            "produce_long_output.py", '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch(
                'azurelinuxagent.common.utils.extensionprocessutil.format_stdout_stderr',
                side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = self.create_script(
            "produce_output.py", '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')
        # Mocking the call to file.read() is difficult, so instead we mock the call to_capture_process_output,
        # which will call file.read() and we force stdout/stderr to be None; this will produce an exception when
        # trying to use these files.
        original_capture_process_output = read_output

        def capture_process_output(stdout_file, stderr_file):
            return original_capture_process_output(None, None)

        with patch(
                'azurelinuxagent.common.utils.extensionprocessutil.read_output',
                side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)

    def test_it_should_contain_all_helper_environment_variables(self):

        helper_env_vars = {
            ExtCommandEnvVariable.ExtensionSeqNumber:
            self.ext_handler_instance.get_seq_no(),
            ExtCommandEnvVariable.ExtensionPath:
            self.tmp_dir,
            ExtCommandEnvVariable.ExtensionVersion:
            self.ext_handler_instance.ext_handler.properties.version
        }

        command = """
            printenv | grep -E '(%s)'
        """ % '|'.join(helper_env_vars.keys())

        test_file = self.create_script('printHelperEnvironments.sh', command)

        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
            output = self.ext_handler_instance.launch_command(test_file)

            args, kwagrs = patch_popen.call_args
            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items()
                                  if k not in os.environ)

            # This check will fail if any helper environment variables are added/removed later on
            self.assertEqual(helper_env_vars, without_os_env)

            # This check is checking if the expected values are set for the extension commands
            for helper_var in helper_env_vars:
                self.assertIn(
                    "%s=%s" % (helper_var, helper_env_vars[helper_var]),
                    output)
Esempio n. 4
0
class LaunchCommandTestCase(AgentTestCase):
    """
    Test cases for launch_command
    """

    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4"))

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()

    def tearDown(self):
        self.mock_get_log_dir.stop()
        self.mock_get_base_dir.stop()
        self.mock_sleep.stop()

        AgentTestCase.tearDown(self)

    @staticmethod
    def _output_regex(stdout, stderr):
        return r"\[stdout\]\s+{0}\s+\[stderr\]\s+{1}".format(stdout, stderr)

    @staticmethod
    def _find_process(command):
        for pid in [pid for pid in os.listdir('/proc') if pid.isdigit()]:
            try:
                with open(os.path.join('/proc', pid, 'cmdline'), 'r') as cmdline:
                    for line in cmdline.readlines():
                        if command in line:
                            return True
            except IOError:  # proc has already terminated
                continue
        return False

    def test_it_should_capture_the_output_of_the_command(self):
        stdout = "stdout" * 5
        stderr = "stderr" * 5

        command = "produce_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")

'''.format(stdout, stderr))

        def list_directory():
            base_dir = self.ext_handler_instance.get_base_dir()
            return [i for i in os.listdir(base_dir) if not i.endswith(AGENT_EVENT_FILE_EXTENSION)] # ignore telemetry files

        files_before = list_directory()

        output = self.ext_handler_instance.launch_command(command)

        files_after = list_directory()

        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

        self.assertListEqual(files_before, files_after, "Not all temporary files were deleted. File list: {0}".format(files_after))

    def test_it_should_raise_an_exception_when_the_command_times_out(self):
        extension_error_code = ExtensionErrorCodes.PluginHandlerScriptTimedout
        stdout = "stdout" * 7
        stderr = "stderr" * 7

        # the signal file is used by the test command to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        # the test command produces some output then goes into an infinite loop
        command = "produce_output_then_hang.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys
import time

sys.stdout.write("{0}")
sys.stdout.flush()

sys.stderr.write("{1}")
sys.stderr.flush()

with open("{2}", "w") as file:
    while True:
        file.write(".")
        time.sleep(1)

'''.format(stdout, stderr, signal_file))

        # mock time.sleep to wait for the signal file (launch_command implements the time out using polling and sleep)
        def sleep(seconds):
            if not os.path.exists(signal_file):
                sleep.original_sleep(seconds)
        sleep.original_sleep = time.sleep

        timeout = 60

        start_time = time.time()

        with patch("time.sleep", side_effect=sleep, autospec=True) as mock_sleep:  # pylint: disable=redefined-outer-name

            with self.assertRaises(ExtensionError) as context_manager:
                self.ext_handler_instance.launch_command(command, timeout=timeout, extension_error_code=extension_error_code)

            # the command name and its output should be part of the message
            message = str(context_manager.exception)
            command_full_path = os.path.join(self.tmp_dir, command.lstrip(os.path.sep))
            self.assertRegex(message, r"Timeout\(\d+\):\s+{0}\s+{1}".format(command_full_path, LaunchCommandTestCase._output_regex(stdout, stderr)))

            # the exception code should be as specified in the call to launch_command
            self.assertEqual(context_manager.exception.code, extension_error_code)

            # the timeout period should have elapsed
            self.assertGreaterEqual(mock_sleep.call_count, timeout)

            # The command should have been terminated.
            # The /proc file system may still include the process when we do this check so we try a few times after a short delay; note that we
            # are mocking sleep, so we need to use the original implementation.
            terminated = False
            i = 0
            while not terminated and i < 4:
                if not LaunchCommandTestCase._find_process(command):
                    terminated = True
                else:
                    sleep.original_sleep(0.25)
                i += 1

            self.assertTrue(terminated, "The command was not terminated")

        # as a check for the test itself, verify it completed in just a few seconds
        self.assertLessEqual(time.time() - start_time, 5)

    def test_it_should_raise_an_exception_when_the_command_fails(self):
        extension_error_code = 2345
        stdout = "stdout" * 3
        stderr = "stderr" * 3
        exit_code = 99

        command = "fail.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("{0}")
sys.stderr.write("{1}")
exit({2})

'''.format(stdout, stderr, exit_code))

        # the output is captured as part of the exception message
        with self.assertRaises(ExtensionError) as context_manager:
            self.ext_handler_instance.launch_command(command, extension_error_code=extension_error_code)

        message = str(context_manager.exception)
        self.assertRegex(message, r"Non-zero exit code: {0}.+{1}\s+{2}".format(exit_code, command, LaunchCommandTestCase._output_regex(stdout, stderr)))

        self.assertEqual(context_manager.exception.code, extension_error_code)

    def test_it_should_not_wait_for_child_process(self):
        stdout = "stdout"
        stderr = "stderr"

        command = "start_child_process.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    time.sleep(60)
else:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    
'''.format(stdout, stderr))

        start_time = time.time()

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(time.time() - start_time, 5)

        # Also check that we capture the parent's output
        self.assertRegex(output, LaunchCommandTestCase._output_regex(stdout, stderr))

    def test_it_should_capture_the_output_of_child_process(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"
        more_parent_stdout = "MORE PARENT STDOUT"
        more_parent_stderr = "MORE PARENT STDERR"

        # the child process uses the signal file to indicate it has produced output
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = "start_child_with_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import sys
import time

sys.stdout.write("{0}")
sys.stderr.write("{1}")

pid = os.fork()

if pid == 0:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")
    
    open("{6}", "w").close()
else:
    sys.stdout.write("{4}")
    sys.stderr.write("{5}")
    
    while not os.path.exists("{6}"):
        time.sleep(0.5)
    
'''.format(parent_stdout, parent_stderr, child_stdout, child_stderr, more_parent_stdout, more_parent_stderr, signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

        self.assertIn(more_parent_stdout, output)
        self.assertIn(more_parent_stderr, output)

    def test_it_should_capture_the_output_of_child_process_that_fails_to_start(self):
        parent_stdout = "PARENT STDOUT"
        parent_stderr = "PARENT STDERR"
        child_stdout = "CHILD STDOUT"
        child_stderr = "CHILD STDERR"

        command = "start_child_that_fails.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import sys
import time

pid = os.fork()

if pid == 0:
    sys.stdout.write("{0}")
    sys.stderr.write("{1}")
    exit(1)
else:
    sys.stdout.write("{2}")
    sys.stderr.write("{3}")

'''.format(child_stdout, child_stderr, parent_stdout, parent_stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertIn(parent_stdout, output)
        self.assertIn(parent_stderr, output)

        self.assertIn(child_stdout, output)
        self.assertIn(child_stderr, output)

    def test_it_should_execute_commands_with_no_output(self):
        # file used to verify the command completed successfully
        signal_file = os.path.join(self.tmp_dir, "signal_file.txt")

        command = "create_file.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
open("{0}", "w").close()

'''.format(signal_file))

        output = self.ext_handler_instance.launch_command(command)

        self.assertTrue(os.path.exists(signal_file))
        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

    def test_it_should_not_capture_the_output_of_commands_that_do_their_own_redirection(self):
        # the test script redirects its output to this file
        command_output_file = os.path.join(self.tmp_dir, "command_output.txt")
        stdout = "STDOUT"
        stderr = "STDERR"

        # the test script mimics the redirection done by the Custom Script extension
        command = "produce_output"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
exec &> {0}
echo {1}
>&2 echo {2}

'''.format(command_output_file, stdout, stderr))

        output = self.ext_handler_instance.launch_command(command)

        self.assertRegex(output, LaunchCommandTestCase._output_regex('', ''))

        with open(command_output_file, "r") as command_output:
            output = command_output.read()
            self.assertEqual(output, "{0}\n{1}\n".format(stdout, stderr))

    def test_it_should_truncate_the_command_output(self):
        stdout = "STDOUT"
        stderr = "STDERR"

        command = "produce_long_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write( "{0}" * {1})
sys.stderr.write( "{2}" * {3})
'''.format(stdout, int(TELEMETRY_MESSAGE_MAX_LEN / len(stdout)), stderr, int(TELEMETRY_MESSAGE_MAX_LEN / len(stderr))))

        output = self.ext_handler_instance.launch_command(command)

        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)
        self.assertIn(stdout, output)
        self.assertIn(stderr, output)

    def test_it_should_read_only_the_head_of_large_outputs(self):
        command = "produce_long_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("O" * 5 * 1024 * 1024)
sys.stderr.write("E" * 5 * 1024 * 1024)
''')

        # Mocking the call to file.read() is difficult, so instead we mock the call to format_stdout_stderr, which takes the
        # return value of the calls to file.read(). The intention of the test is to verify we never read (and load in memory)
        # more than a few KB of data from the files used to capture stdout/stderr
        with patch('azurelinuxagent.common.utils.extensionprocessutil.format_stdout_stderr', side_effect=format_stdout_stderr) as mock_format:
            output = self.ext_handler_instance.launch_command(command)

        self.assertGreaterEqual(len(output), 1024)
        self.assertLessEqual(len(output), TELEMETRY_MESSAGE_MAX_LEN)

        mock_format.assert_called_once()

        args, kwargs = mock_format.call_args  # pylint: disable=unused-variable
        stdout, stderr = args

        self.assertGreaterEqual(len(stdout), 1024)
        self.assertLessEqual(len(stdout), TELEMETRY_MESSAGE_MAX_LEN)

        self.assertGreaterEqual(len(stderr), 1024)
        self.assertLessEqual(len(stderr), TELEMETRY_MESSAGE_MAX_LEN)

    def test_it_should_handle_errors_while_reading_the_command_output(self):
        command = "produce_output.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import sys

sys.stdout.write("STDOUT")
sys.stderr.write("STDERR")
''')
        # Mocking the call to file.read() is difficult, so instead we mock the call to_capture_process_output,
        # which will call file.read() and we force stdout/stderr to be None; this will produce an exception when
        # trying to use these files.
        original_capture_process_output = read_output

        def capture_process_output(stdout_file, stderr_file):  # pylint: disable=unused-argument
            return original_capture_process_output(None, None)

        with patch('azurelinuxagent.common.utils.extensionprocessutil.read_output', side_effect=capture_process_output):
            output = self.ext_handler_instance.launch_command(command)

        self.assertIn("[stderr]\nCannot read stdout/stderr:", output)

    def test_it_should_contain_all_helper_environment_variables(self):

        wire_ip = str(uuid.uuid4())
        ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol(wire_ip))

        helper_env_vars = {ExtCommandEnvVariable.ExtensionSeqNumber: _DEFAULT_SEQ_NO,
                           ExtCommandEnvVariable.ExtensionPath: self.tmp_dir,
                           ExtCommandEnvVariable.ExtensionVersion: ext_handler_instance.ext_handler.properties.version,
                           ExtCommandEnvVariable.WireProtocolAddress: wire_ip}

        command = """
            printenv | grep -E '(%s)'
        """ % '|'.join(helper_env_vars.keys())

        test_file = 'printHelperEnvironments.sh'
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), test_file), command)

        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
            # Returning empty list for get_agent_supported_features_list_for_extensions as we have a separate test for it
            with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
                       return_value={}):
                output = ext_handler_instance.launch_command(test_file)

            args, kwagrs = patch_popen.call_args  # pylint: disable=unused-variable
            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items() if k not in os.environ)

            # This check will fail if any helper environment variables are added/removed later on
            self.assertEqual(helper_env_vars, without_os_env)

            # This check is checking if the expected values are set for the extension commands
            for helper_var in helper_env_vars:
                self.assertIn("%s=%s" % (helper_var, helper_env_vars[helper_var]), output)

    def test_it_should_pass_supported_features_list_as_environment_variables(self):

        class TestFeature(AgentSupportedFeature):

            def __init__(self, name, version, supported):
                super(TestFeature, self).__init__(name=name,
                                                  version=version,
                                                  supported=supported)

        test_name = str(uuid.uuid4())
        test_version = str(uuid.uuid4())

        command = "check_env.py"
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), command), '''
import os
import json
import sys

features = os.getenv("{0}")
if not features:
    print("{0} not found in environment")
    sys.exit(0)
l = json.loads(features)
found = False
for feature in l:
    if feature['Key'] == "{1}" and feature['Value'] == "{2}":
        found = True
        break
    
print("Found Feature %s: %s" % ("{1}", found))
'''.format(ExtCommandEnvVariable.ExtensionSupportedFeatures, test_name, test_version))

        # It should include all supported features and pass it as Environment Variable to extensions
        test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=True)}
        with patch("azurelinuxagent.common.agent_supported_feature.__EXTENSION_ADVERTISED_FEATURES",
                   test_supported_features):
            output = self.ext_handler_instance.launch_command(command)

            self.assertIn("[stdout]\nFound Feature {0}: True".format(test_name), output, "Feature not found")

        # It should not include the feature if feature not supported
        test_supported_features = {
            test_name: TestFeature(name=test_name, version=test_version, supported=False),
            "testFeature": TestFeature(name="testFeature", version="1.2.1", supported=True)
        }
        with patch("azurelinuxagent.common.agent_supported_feature.__EXTENSION_ADVERTISED_FEATURES",
                   test_supported_features):
            output = self.ext_handler_instance.launch_command(command)

            self.assertIn("[stdout]\nFound Feature {0}: False".format(test_name), output, "Feature wrongfully found")

        # It should not include the SupportedFeatures Key in Environment variables if no features supported
        test_supported_features = {test_name: TestFeature(name=test_name, version=test_version, supported=False)}
        with patch("azurelinuxagent.common.agent_supported_feature.__EXTENSION_ADVERTISED_FEATURES",
                   test_supported_features):
            output = self.ext_handler_instance.launch_command(command)

            self.assertIn(
                "[stdout]\n{0} not found in environment".format(ExtCommandEnvVariable.ExtensionSupportedFeatures),
                output, "Environment variable should not be found")