예제 #1
0
    def test_validate_page_blobs(self):
        """Validate correct set of data is sent for page blobs"""
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)

            self.assertFalse(host_client.is_initialized)
            self.assertTrue(host_client.api_versions is None)

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.type = page_blob_type
            status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                     status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url

            page_status = bytearray(status_blob.data, encoding='utf-8')
            page_size = int((len(page_status) + 511) / 512) * 512
            page_status = bytearray(status_blob.data.ljust(page_size),
                                    encoding='utf-8')
            page = bytearray(page_size)
            page[0:page_size] = page_status[0:len(page_status)]
            mock_response = MockResponse('', httpclient.OK)

            with patch.object(restutil,
                              "http_request",
                              return_value=mock_response) as patch_http:
                with patch.object(wire.HostPluginProtocol,
                                  "get_api_versions") as patch_get:
                    patch_get.return_value = api_versions
                    host_client.put_vm_status(status_blob, sas_url)

                    self.assertTrue(patch_http.call_count == 3)

                    # first call is to host plugin
                    exp_data = self._hostplugin_data(
                        status_blob.get_page_blob_create_headers(page_size))
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0], test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST',
                                     patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url,
                                     patch_http.call_args_list[1][0][1])

                    # last call is to host plugin
                    exp_data = self._hostplugin_data(
                        status_blob.get_page_blob_page_headers(0, page_size),
                        page)
                    exp_data['requestUri'] += "?comp=page"
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[2], test_goal_state,
                        exp_method, exp_url, exp_data)
예제 #2
0
    def test_validate_http_request_when_uploading_status(self):
        """Validate correct set of data is sent to HostGAPlugin when reporting VM status"""

        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state
            plugin = protocol.client.get_host_plugin()

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.vm_status = restapi.VMStatus(message="Ready", status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url
            exp_data = self._hostplugin_data(
                status_blob.get_block_blob_headers(len(faux_status)),
                bytearray(faux_status, encoding='utf-8'))

            with patch.object(restutil, "http_request") as patch_http:
                patch_http.return_value = Mock(status=httpclient.OK)

                with patch.object(plugin, 'get_api_versions') as patch_api:
                    patch_api.return_value = API_VERSION
                    plugin.put_vm_status(status_blob, sas_url, block_blob_type)

                    self.assertTrue(patch_http.call_count == 2)

                    # first call is to host plugin
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0],
                        test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST', patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url, patch_http.call_args_list[1][0][1])
예제 #3
0
    def test_put_vm_log_should_raise_an_exception_when_request_fails(self):
        def http_put_handler(url, *args, **kwargs):  # pylint: disable=inconsistent-return-statements
            if self.is_host_plugin_put_logs_request(url):
                http_put_handler.args, http_put_handler.kwargs = args, kwargs
                return MockResponse(body=ustr('Gone'), status_code=410)

        http_put_handler.args, http_put_handler.kwargs = [], {}

        with mock_wire_protocol(DATA_FILE,
                                http_put_handler=http_put_handler) as protocol:
            test_goal_state = protocol.client.get_goal_state()

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)

            self.assertFalse(host_client.is_initialized,
                             "Host plugin should not be initialized!")

            with self.assertRaises(HttpError) as context_manager:
                content = b"test"
                host_client.put_vm_log(content)

            self.assertIsInstance(context_manager.exception, HttpError)
            self.assertIn("410", ustr(context_manager.exception))
            self.assertIn("Gone", ustr(context_manager.exception))
예제 #4
0
    def test_validate_get_extension_artifacts(self):
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access

            expected_url = hostplugin.URI_FORMAT_GET_EXTENSION_ARTIFACT.format(
                wireserver_url, hostplugin.HOST_PLUGIN_PORT)
            expected_headers = {
                'x-ms-version': '2015-09-01',
                "x-ms-containerid": test_goal_state.container_id,
                "x-ms-host-config-name": test_goal_state.role_config_name,
                "x-ms-artifact-location": sas_url
            }

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)
            self.assertFalse(host_client.is_initialized)
            self.assertTrue(host_client.api_versions is None)
            self.assertTrue(host_client.health_service is not None)

            with patch.object(wire.HostPluginProtocol,
                              "get_api_versions",
                              return_value=api_versions) as patch_get:  # pylint: disable=unused-variable
                actual_url, actual_headers = host_client.get_artifact_request(
                    sas_url)
                self.assertTrue(host_client.is_initialized)
                self.assertFalse(host_client.api_versions is None)
                self.assertEqual(expected_url, actual_url)
                for k in expected_headers:
                    self.assertTrue(k in actual_headers)
                    self.assertEqual(expected_headers[k], actual_headers[k])
예제 #5
0
 def test_it_should_report_error_if_plugin_settings_version_mismatch(self):
     with mock_wire_protocol(
             mockwiredata.DATA_FILE_PLUGIN_SETTINGS_MISMATCH) as protocol:
         with patch("azurelinuxagent.common.protocol.goal_state.add_event"
                    ) as mock_add_event:
             # Forcing update of GoalState to allow the ExtConfig to report an event
             protocol.mock_wire_data.set_incarnation(2)
             protocol.client.update_goal_state()
             plugin_setting_mismatch_calls = [
                 kw for _, kw in mock_add_event.call_args_list if kw['op']
                 == WALAEventOperation.PluginSettingsVersionMismatch
             ]
             self.assertEqual(
                 1, len(plugin_setting_mismatch_calls),
                 "PluginSettingsMismatch event should be reported once")
             self.assertIn(
                 'ExtHandler PluginSettings Version Mismatch! Expected PluginSettings version: 1.0.0 for Handler: OSTCExtensions.ExampleHandlerLinux',
                 plugin_setting_mismatch_calls[0]['message'],
                 "Invalid error message with incomplete data detected for PluginSettingsVersionMismatch"
             )
             self.assertTrue(
                 "1.0.2" in plugin_setting_mismatch_calls[0]['message']
                 and "1.0.1" in plugin_setting_mismatch_calls[0]['message'],
                 "Error message should contain the incorrect versions")
             self.assertFalse(
                 plugin_setting_mismatch_calls[0]['is_success'],
                 "The event should be false")
예제 #6
0
def _create_collect_logs_handler(iterations=1, systemd_present=True):
    """
    Creates an instance of CollectLogsHandler that
        * Uses a mock_wire_protocol for network requests,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned CollectLogsHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the CollectLogsHandler

    """
    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.collect_logs.get_protocol_util", return_value=protocol_util):
            with patch("azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped", side_effect=[False] * iterations + [True]):
                with patch("time.sleep"):
                    with patch("azurelinuxagent.common.osutil.systemd.is_systemd", return_value=systemd_present):
                        with patch("azurelinuxagent.ga.collect_logs.conf.get_collect_logs", return_value=True):
                            def run_and_wait():
                                collect_logs_handler.run()
                                collect_logs_handler.join()

                            collect_logs_handler = get_collect_logs_handler()
                            collect_logs_handler.get_mock_wire_protocol = lambda: protocol
                            collect_logs_handler.run_and_wait = run_and_wait
                            yield collect_logs_handler
예제 #7
0
    def _create_send_telemetry_events_handler(self,
                                              timeout=0.5,
                                              start_thread=True,
                                              batching_queue_limit=1):
        def http_post_handler(url, body, **__):
            if self.is_telemetry_request(url):
                send_telemetry_events_handler.event_calls.append(
                    (datetime.now(), body))
                return MockHttpResponse(status=200)
            return None

        with mock_wire_protocol(
                DATA_FILE, http_post_handler=http_post_handler) as protocol:
            protocol_util = MagicMock()
            protocol_util.get_protocol = Mock(return_value=protocol)
            send_telemetry_events_handler = get_send_telemetry_events_handler(
                protocol_util)
            send_telemetry_events_handler.event_calls = []
            with patch(
                    "azurelinuxagent.ga.send_telemetry_events.SendTelemetryEventsHandler._MIN_EVENTS_TO_BATCH",
                    batching_queue_limit):
                with patch(
                        "azurelinuxagent.ga.send_telemetry_events.SendTelemetryEventsHandler._MAX_TIMEOUT",
                        timeout):

                    send_telemetry_events_handler.get_mock_wire_protocol = lambda: protocol
                    if start_thread:
                        send_telemetry_events_handler.start()
                        self.assertTrue(
                            send_telemetry_events_handler.is_alive(),
                            "Thread didn't start properly!")
                    yield send_telemetry_events_handler
예제 #8
0
 def _init_host(self):
     with mock_wire_protocol(DATA_FILE) as protocol:
         test_goal_state = protocol.client.get_goal_state()
         host_plugin = wire.HostPluginProtocol(
             wireserver_url, test_goal_state.container_id,
             test_goal_state.role_config_name)
         self.assertTrue(host_plugin.health_service is not None)
         return host_plugin
예제 #9
0
    def test_remote_access_handler_should_retrieve_users_when_it_is_invoked_the_first_time(self):
        mock_os_util = MagicMock()
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=mock_os_util):
            with mock_wire_protocol(DATA_FILE) as mock_protocol:
                rah = RemoteAccessHandler(mock_protocol)
                rah.run()

                self.assertTrue(len(mock_os_util.get_users.call_args_list) == 1, "The first invocation of remote access should have retrieved the current users")
예제 #10
0
    def test_remote_access_handler_should_retrieve_users_when_goal_state_contains_jit_users(self):
        mock_os_util = MagicMock()
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=mock_os_util):
            with mock_wire_protocol(DATA_FILE_REMOTE_ACCESS) as mock_protocol:
                rah = RemoteAccessHandler(mock_protocol)
                rah.run()

                self.assertTrue(len(mock_os_util.get_users.call_args_list) > 0, "A goal state with jit users did not retrieve the current users")
예제 #11
0
    def test_validate_http_request_for_put_vm_log(self):
        def http_put_handler(url, *args, **kwargs):  # pylint: disable=inconsistent-return-statements
            if self.is_host_plugin_put_logs_request(url):
                http_put_handler.args, http_put_handler.kwargs = args, kwargs
                return MockResponse(body=b'', status_code=200)

        http_put_handler.args, http_put_handler.kwargs = [], {}

        with mock_wire_protocol(DATA_FILE,
                                http_put_handler=http_put_handler) as protocol:
            test_goal_state = protocol.client.get_goal_state()

            expected_url = hostplugin.URI_FORMAT_PUT_LOG.format(
                wireserver_url, hostplugin.HOST_PLUGIN_PORT)
            expected_headers = {
                'x-ms-version':
                '2015-09-01',
                "x-ms-containerid":
                test_goal_state.container_id,
                "x-ms-vmagentlog-deploymentid":
                test_goal_state.role_config_name.split(".")[0],
                "x-ms-client-name":
                AGENT_NAME,
                "x-ms-client-version":
                AGENT_VERSION
            }

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)

            self.assertFalse(host_client.is_initialized,
                             "Host plugin should not be initialized!")

            content = b"test"
            host_client.put_vm_log(content)
            self.assertTrue(host_client.is_initialized,
                            "Host plugin is not initialized!")

            urls = protocol.get_tracked_urls()

            self.assertEqual(expected_url, urls[0], "Unexpected request URL!")
            self.assertEqual(content, http_put_handler.args[0],
                             "Unexpected content for HTTP PUT request!")

            headers = http_put_handler.kwargs['headers']
            for k in expected_headers:
                self.assertTrue(k in headers,
                                "Header {0} not found in headers!".format(k))
                self.assertEqual(expected_headers[k], headers[k],
                                 "Request headers don't match!")

            # Special check for correlation id header value, check for pattern, not exact value
            self.assertTrue("x-ms-client-correlationid" in headers.keys(),
                            "Correlation id not found in headers!")
            self.assertTrue(
                UUID_PATTERN.match(headers["x-ms-client-correlationid"]),
                "Correlation id is not in GUID form!")
예제 #12
0
    def test_fetch_goal_state_should_raise_on_incomplete_goal_state(self):
        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
            protocol.mock_wire_data.data_files = mockwiredata.DATA_FILE_NOOP_GS
            protocol.mock_wire_data.reload()
            protocol.mock_wire_data.set_incarnation(2)

            with patch('time.sleep') as mock_sleep:
                with self.assertRaises(IncompleteGoalStateError):
                    GoalState(protocol.client)
                self.assertEqual(_NUM_GS_FETCH_RETRIES, mock_sleep.call_count, "Unexpected number of retries")
예제 #13
0
def _mock_wire_protocol():
    # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not
    # reuse a previous state
    clear_singleton_instances(ProtocolUtil)

    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.monitor.get_protocol_util", return_value=protocol_util):
            yield protocol
예제 #14
0
    def test_remote_access_handler_should_not_retrieve_users_when_goal_state_does_not_contain_jit_users(self):
        mock_os_util = MagicMock()
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=mock_os_util):
            with mock_wire_protocol(DATA_FILE) as mock_protocol:
                rah = RemoteAccessHandler(mock_protocol)
                rah.run()  # this will trigger one call to retrieve the users

                mock_protocol.mock_wire_data.set_incarnation(123)  # mock a new goal state; the data file does not include any jit users
                rah.run()
                self.assertTrue(len(mock_os_util.get_users.call_args_list) == 1, "A goal state without jit users retrieved the current users")
예제 #15
0
 def test_update_remote_access_conf_remote_access(self):
     with mock_wire_protocol(
             mockwiredata.DATA_FILE_REMOTE_ACCESS) as protocol:
         self.assertIsNotNone(protocol.client.get_remote_access())
         self.assertEqual(
             1, len(protocol.client.get_remote_access().user_list.users))
         self.assertEqual(
             'testAccount',
             protocol.client.get_remote_access().user_list.users[0].name)
         self.assertEqual(
             'encryptedPasswordString',
             protocol.client.get_remote_access().user_list.users[0].
             encrypted_password)
예제 #16
0
    def create_mock_protocol():
        with mock_wire_protocol(DATA_FILE_NO_EXT) as protocol:
            # These tests use mock wire data that don't have any extensions (extension config will be empty).
            # Populate the upload blob and set an initial empty status before returning the protocol.
            ext_conf = protocol.client._goal_state.ext_conf
            ext_conf.status_upload_blob = sas_url
            ext_conf.status_upload_blob_type = page_blob_type

            status = restapi.VMStatus(status="Ready", message="Guest Agent is running")
            protocol.client.status_blob.set_vm_status(status)

            # Also, they mock WireClient.update_goal_state() to verify how it is called
            protocol.client.update_goal_state = Mock()

            yield protocol
예제 #17
0
    def test_validate_block_blob(self):
        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access

            host_client = wire.HostPluginProtocol(
                wireserver_url, test_goal_state.container_id,
                test_goal_state.role_config_name)
            self.assertFalse(host_client.is_initialized)
            self.assertTrue(host_client.api_versions is None)
            self.assertTrue(host_client.health_service is not None)

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.type = block_blob_type
            status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                     status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url
            exp_data = self._hostplugin_data(
                status_blob.get_block_blob_headers(len(faux_status)),
                bytearray(faux_status, encoding='utf-8'))

            with patch.object(restutil, "http_request") as patch_http:
                patch_http.return_value = Mock(status=httpclient.OK)

                with patch.object(wire.HostPluginProtocol,
                                  "get_api_versions") as patch_get:
                    patch_get.return_value = api_versions
                    host_client.put_vm_status(status_blob, sas_url)

                    self.assertTrue(patch_http.call_count == 2)

                    # first call is to host plugin
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0], test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST',
                                     patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url,
                                     patch_http.call_args_list[1][0][1])
예제 #18
0
def _create_collect_logs_handler(iterations=1,
                                 cgroups_enabled=True,
                                 collect_logs_conf=True):
    """
    Creates an instance of CollectLogsHandler that
        * Uses a mock_wire_protocol for network requests,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned CollectLogsHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the CollectLogsHandler

    """
    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.collect_logs.get_protocol_util",
                   return_value=protocol_util):
            with patch(
                    "azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped",
                    side_effect=[False] * iterations + [True]):
                with patch("time.sleep"):

                    # Grab the singleton to patch it
                    cgroups_configurator_singleton = CGroupConfigurator.get_instance(
                    )
                    with patch.object(cgroups_configurator_singleton,
                                      "enabled",
                                      return_value=cgroups_enabled):
                        with patch(
                                "azurelinuxagent.ga.collect_logs.conf.get_collect_logs",
                                return_value=collect_logs_conf):

                            def run_and_wait():
                                collect_logs_handler.run()
                                collect_logs_handler.join()

                            collect_logs_handler = get_collect_logs_handler()
                            collect_logs_handler.get_mock_wire_protocol = lambda: protocol
                            collect_logs_handler.run_and_wait = run_and_wait
                            yield collect_logs_handler
예제 #19
0
def _create_monitor_handler(enabled_operations=[], iterations=1):
    """
    Creates an instance of MonitorHandler that
        * Uses a mock_wire_protocol for network requests,
        * Executes only the operations given in the 'enabled_operations' parameter,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned MonitorHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the MonitorHandler

    """
    def run(self):
        if len(enabled_operations) == 0 or self._name in enabled_operations:
            run.original_definition(self)

    run.original_definition = PeriodicOperation.run

    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.monitor.get_protocol_util",
                   return_value=protocol_util):
            with patch.object(PeriodicOperation,
                              "run",
                              side_effect=run,
                              autospec=True):
                with patch("azurelinuxagent.ga.monitor.MonitorHandler.stopped",
                           side_effect=[False] * iterations + [True]):
                    with patch("time.sleep"):

                        def run_and_wait():
                            monitor_handler.run()
                            monitor_handler.join()

                        monitor_handler = get_monitor_handler()
                        monitor_handler.get_mock_wire_protocol = lambda: protocol
                        monitor_handler.run_and_wait = run_and_wait
                        yield monitor_handler
예제 #20
0
    def initialize_event_logger(event_dir):
        """
        Initializes the event logger using mock data for the common parameters; the goal state fields are taken
        from mockwiredata.DATA_FILE and the IMDS fields from mock_imds_data.
        """
        if not os.path.exists(event_dir):
            os.mkdir(event_dir)
        event.init_event_logger(event_dir)

        mock_imds_info = tools.Mock()
        mock_imds_info.location = EventLoggerTools.mock_imds_data['location']
        mock_imds_info.subscriptionId = EventLoggerTools.mock_imds_data['subscriptionId']
        mock_imds_info.resourceGroupName = EventLoggerTools.mock_imds_data['resourceGroupName']
        mock_imds_info.vmId = EventLoggerTools.mock_imds_data['vmId']
        mock_imds_info.image_origin = EventLoggerTools.mock_imds_data['image_origin']

        mock_imds_client = tools.Mock()
        mock_imds_client.get_compute = tools.Mock(return_value=mock_imds_info)

        with mock_wire_protocol(mockwiredata.DATA_FILE) as mock_protocol:
            with tools.patch("azurelinuxagent.common.event.get_imds_client", return_value=mock_imds_client):
                event.initialize_event_logger_vminfo_common_parameters(mock_protocol)
예제 #21
0
    def test_add_event_should_use_the_container_id_from_the_most_recent_goal_state(
            self):
        def create_event_and_return_container_id():
            event.add_event(name='Event')
            event_list = event.collect_events()
            self.assertEquals(len(event_list.events), 1,
                              "Could not find the event created by add_event")

            for p in event_list.events[0].parameters:
                if p.name == 'ContainerId':
                    return p.value

            self.fail("Could not find Contained ID on event")

        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
            contained_id = create_event_and_return_container_id()
            # The expect value comes from DATA_FILE
            self.assertEquals(contained_id,
                              'c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2',
                              "Incorrect container ID")

            protocol.mock_wire_data.set_container_id(
                'AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE')
            protocol.update_goal_state()
            contained_id = create_event_and_return_container_id()
            self.assertEquals(contained_id,
                              'AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE',
                              "Incorrect container ID")

            protocol.mock_wire_data.set_container_id(
                '11111111-2222-3333-4444-555555555555')
            protocol.update_goal_state()
            contained_id = create_event_and_return_container_id()
            self.assertEquals(contained_id,
                              '11111111-2222-3333-4444-555555555555',
                              "Incorrect container ID")
예제 #22
0
    def test_add_event_should_use_the_container_id_from_the_most_recent_goal_state(
            self):
        def create_event_and_return_container_id():  # pylint: disable=inconsistent-return-statements
            event.add_event(name='Event')
            event_list = self._collect_events()
            self.assertEqual(len(event_list), 1,
                             "Could not find the event created by add_event")

            for p in event_list[0].parameters:  # pylint: disable=invalid-name
                if p.name == CommonTelemetryEventSchema.ContainerId:
                    return p.value

            self.fail("Could not find Contained ID on event")

        with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
            contained_id = create_event_and_return_container_id()
            # The expect value comes from DATA_FILE
            self.assertEqual(contained_id,
                             'c6d5526c-5ac2-4200-b6e2-56f2b70c5ab2',
                             "Incorrect container ID")

            protocol.mock_wire_data.set_container_id(
                'AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE')
            protocol.update_goal_state()
            contained_id = create_event_and_return_container_id()
            self.assertEqual(contained_id,
                             'AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE',
                             "Incorrect container ID")

            protocol.mock_wire_data.set_container_id(
                '11111111-2222-3333-4444-555555555555')
            protocol.update_goal_state()
            contained_id = create_event_and_return_container_id()
            self.assertEqual(contained_id,
                             '11111111-2222-3333-4444-555555555555',
                             "Incorrect container ID")
예제 #23
0
 def test_goal_state_with_no_remote_access(self):
     with mock_wire_protocol(mockwiredata.DATA_FILE) as protocol:
         self.assertIsNone(protocol.client.get_remote_access())
예제 #24
0
    def test_report_event_should_encode_call_stack_correctly(self):
        """
        The Message in some telemetry events that include call stacks are being truncated in Kusto. While the issue doesn't seem
        to be in the agent itself, this test verifies that the Message of the event we send in the HTTP request matches the
        Message we read from the event's file.
        """
        def get_event_message_from_event_file(event_file):
            with open(event_file, "rb") as fd:  # pylint: disable=invalid-name
                event_data = fd.read().decode(
                    "utf-8")  # event files are UTF-8 encoded
            telemetry_event = json.loads(event_data)

            for p in telemetry_event['parameters']:  # pylint: disable=invalid-name
                if p['name'] == GuestAgentExtensionEventsSchema.Message:
                    return p['value']

            raise ValueError(
                'Could not find the Message for the telemetry event in {0}'.
                format(event_file))

        def get_event_message_from_http_request_body(http_request_body):
            # The XML for the event is sent over as a CDATA element ("Event") in the request's body
            request_body_xml_doc = textutil.parse_doc(http_request_body)

            event_node = textutil.find(request_body_xml_doc, "Event")
            if event_node is None:
                raise ValueError(
                    'Could not find the Event node in the XML document')
            if len(event_node.childNodes) != 1:
                raise ValueError(
                    'The Event node in the XML document should have exactly 1 child'
                )

            event_node_first_child = event_node.childNodes[0]
            if event_node_first_child.nodeType != xml.dom.Node.CDATA_SECTION_NODE:
                raise ValueError('The Event node contents should be CDATA')

            event_node_cdata = event_node_first_child.nodeValue

            # The CDATA will contain a sequence of "<Param Name='foo' Value='bar'/>" nodes, which
            # correspond to the parameters of the telemetry event.  Wrap those into a "Helper" node
            # and extract the "Message"
            event_xml_text = '<?xml version="1.0"?><Helper>{0}</Helper>'.format(
                event_node_cdata)
            event_xml_doc = textutil.parse_doc(event_xml_text)
            helper_node = textutil.find(event_xml_doc, "Helper")

            for child in helper_node.childNodes:
                if child.getAttribute(
                        'Name') == GuestAgentExtensionEventsSchema.Message:
                    return child.getAttribute('Value')

            raise ValueError(
                'Could not find the Message for the telemetry event. Request body: {0}'
                .format(http_request_body))

        def http_post_handler(url, body, **__):
            if self.is_telemetry_request(url):
                http_post_handler.request_body = body
                return MockHttpResponse(status=200)
            return None

        http_post_handler.request_body = None

        with mock_wire_protocol(
                mockwiredata.DATA_FILE,
                http_post_handler=http_post_handler) as protocol:
            event_file_path = self._create_test_event_file(
                "event_with_callstack.waagent.tld")
            expected_message = get_event_message_from_event_file(
                event_file_path)

            event_list = self._collect_events()
            self._report_events(protocol, event_list)

            event_message = get_event_message_from_http_request_body(
                http_post_handler.request_body)

            self.assertEqual(
                event_message, expected_message,
                "The Message in the HTTP request does not match the Message in the event's *.tld file"
            )