Пример #1
0
    def test_fallback_channel_503(self, patch_update, patch_put, patch_upload,
                                  _):
        """
        When host plugin returns a 503, we should fall back to the direct channel
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        wire_protocol_client.get_goal_state = Mock(
            return_value=test_goal_state)
        wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
        wire_protocol_client.ext_conf.status_upload_blob = sas_url
        wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type
        wire_protocol_client.status_blob.set_vm_status(status)

        # act
        wire_protocol_client.upload_status_blob()

        # assert direct route is called
        self.assertEqual(1, patch_upload.call_count,
                         "Direct channel was not used")

        # assert host plugin route is called
        self.assertEqual(1, patch_put.call_count, "Host plugin was not used")

        # assert update goal state is only called once, non-forced
        self.assertEqual(1, patch_update.call_count,
                         "Update goal state unexpected call count")
        self.assertEqual(0, len(patch_update.call_args[1]),
                         "Update goal state unexpected call count")

        # ensure the correct url is used
        self.assertEqual(sas_url, patch_put.call_args[0][0])

        # ensure host plugin is not set as default
        self.assertFalse(wire.HostPluginProtocol.is_default_channel())
    def test_get_protocol(self, WireProtocol, MetadataProtocol, _):
        WireProtocol.return_value = MagicMock()
        MetadataProtocol.return_value = MagicMock()

        protocol_util = get_protocol_util()
        protocol_util.get_wireserver_endpoint = Mock()
        protocol_util._detect_protocol = MagicMock()

        # Test for wire protocol
        protocol_util._save_protocol("WireProtocol")

        protocol = protocol_util.get_protocol()
        self.assertEquals(WireProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_any_call()

        # Test to ensure protocol persists
        protocol_util.get_wireserver_endpoint.reset_mock()
        protocol_util._save_protocol("MetadataProtocol")

        protocol = protocol_util.get_protocol()
        self.assertEquals(WireProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_not_called()

        # Test for metadata protocol
        protocol_util.clear_protocol()
        protocol_util._save_protocol("MetadataProtocol")

        protocol = protocol_util.get_protocol()
        self.assertEquals(MetadataProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_not_called()

        # Test for unknown protocol
        protocol_util.clear_protocol()
        protocol_util._save_protocol("Not_a_Protocol")
        protocol_util._detect_protocol.side_effect = NotImplementedError()

        self.assertRaises(NotImplementedError, protocol_util.get_protocol)
        protocol_util.get_wireserver_endpoint.assert_not_called()
Пример #3
0
    def test_default_channel(self, patch_update, patch_put, patch_upload, _):
        """
        Status now defaults to HostPlugin. Validate that any errors on the public
        channel are ignored.  Validate that the default channel is never changed
        as part of status upload.
        """
        test_goal_state = wire.GoalState(
            WireProtocolData(DATA_FILE).goal_state)
        status = restapi.VMStatus(status="Ready",
                                  message="Guest Agent is running")

        wire_protocol_client = wire.WireProtocol(wireserver_url).client
        wire_protocol_client.get_goal_state = Mock(
            return_value=test_goal_state)
        wire_protocol_client.ext_conf = wire.ExtensionsConfig(None)
        wire_protocol_client.ext_conf.status_upload_blob = sas_url
        wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type
        wire_protocol_client.status_blob.set_vm_status(status)

        # act
        wire_protocol_client.upload_status_blob()

        # assert direct route is not called
        self.assertEqual(0, patch_upload.call_count, "Direct channel was used")

        # assert host plugin route is called
        self.assertEqual(1, patch_put.call_count, "Host plugin was not used")

        # assert update goal state is only called once, non-forced
        self.assertEqual(1, patch_update.call_count, "Unexpected call count")
        self.assertEqual(0, len(patch_update.call_args[1]),
                         "Unexpected parameters")

        # ensure the correct url is used
        self.assertEqual(sas_url, patch_put.call_args[0][0])

        # ensure host plugin is not set as default
        self.assertFalse(wire.HostPluginProtocol.is_default_channel())
Пример #4
0
    def test_validate_http_request_when_uploading_status(self):
        """Validate correct set of data is sent to HostGAPlugin when reporting VM status"""

        with mock_wire_protocol(DATA_FILE) as protocol:
            test_goal_state = protocol.client._goal_state  # pylint: disable=protected-access
            plugin = protocol.client.get_host_plugin()

            status_blob = protocol.client.status_blob
            status_blob.data = faux_status
            status_blob.vm_status = restapi.VMStatus(message="Ready",
                                                     status="Ready")

            exp_method = 'PUT'
            exp_url = hostplugin_status_url
            exp_data = self._hostplugin_data(
                status_blob.get_block_blob_headers(len(faux_status)),
                bytearray(faux_status, encoding='utf-8'))

            with patch.object(restutil, "http_request") as patch_http:
                patch_http.return_value = Mock(status=httpclient.OK)

                with patch.object(plugin, 'get_api_versions') as patch_api:
                    patch_api.return_value = API_VERSION
                    plugin.put_vm_status(status_blob, sas_url, block_blob_type)

                    self.assertTrue(patch_http.call_count == 2)

                    # first call is to host plugin
                    self._validate_hostplugin_args(
                        patch_http.call_args_list[0], test_goal_state,
                        exp_method, exp_url, exp_data)

                    # second call is to health service
                    self.assertEqual('POST',
                                     patch_http.call_args_list[1][0][0])
                    self.assertEqual(health_service_url,
                                     patch_http.call_args_list[1][0][1])
Пример #5
0
    def test_protocol_file_states(self, _):
        protocol_util = get_protocol_util()
        protocol_util._clear_wireserver_endpoint = Mock()

        protocol_file = protocol_util._get_protocol_file_path()

        # Test clear protocol for io error
        with open(protocol_file, "w+") as proto_fd:
            proto_fd.write("")

        with patch('os.remove') as mock_remove:
            protocol_util.clear_protocol()
            self.assertEqual(1, protocol_util._clear_wireserver_endpoint.call_count)
            self.assertEqual(1, mock_remove.call_count)
            self.assertEqual(protocol_file, mock_remove.call_args_list[0][0][0])

        # Test clear protocol when file not found
        protocol_util._clear_wireserver_endpoint.reset_mock()

        with patch('os.remove') as mock_remove:
            protocol_util.clear_protocol()
            self.assertEqual(1, protocol_util._clear_wireserver_endpoint.call_count)
            self.assertEqual(1, mock_remove.call_count)
            self.assertEqual(protocol_file, mock_remove.call_args_list[0][0][0])
Пример #6
0
def _create_collect_logs_handler(iterations=1, systemd_present=True):
    """
    Creates an instance of CollectLogsHandler that
        * Uses a mock_wire_protocol for network requests,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned CollectLogsHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the CollectLogsHandler

    """

    original_file_exists = os.path.exists

    def mock_file_exists(filepath):
        if filepath == SYSTEMD_RUN_PATH:
            return systemd_present
        return original_file_exists(filepath)

    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.collect_logs.get_protocol_util", return_value=protocol_util):
            with patch("azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped", side_effect=[False] * iterations + [True]):
                with patch("time.sleep"):
                    with patch("azurelinuxagent.ga.collect_logs.os.path.exists", side_effect=mock_file_exists):
                        with patch("azurelinuxagent.ga.collect_logs.conf.get_collect_logs", return_value=True):
                            def run_and_wait():
                                collect_logs_handler.run()
                                collect_logs_handler.join()

                            collect_logs_handler = get_collect_logs_handler()
                            collect_logs_handler.get_mock_wire_protocol = lambda: protocol
                            collect_logs_handler.run_and_wait = run_and_wait
                            yield collect_logs_handler
Пример #7
0
    def test_endpoint_fallback(self):
        # http error status codes are tested in test_response_validation, none of which
        # should trigger a fallback. This is confirmed as _assert_validation will count
        # http GET calls and enforces a single GET call (fallback would cause 2) and
        # checks the url called.

        test_subject = imds.ImdsClient("foo.bar")

        # ensure user-agent gets set correctly
        for is_health, expected_useragent in [(False, restutil.HTTP_USER_AGENT), (True, restutil.HTTP_USER_AGENT_HEALTH)]:
            # set a different resource path for health query to make debugging unit test easier
            resource_path = 'something/health' if is_health else 'something'

            for has_primary_ioerror in (False, True):
                # secondary endpoint unreachable
                test_subject._http_get = Mock(side_effect=self._mock_http_get)
                self._mock_imds_setup(primary_ioerror=has_primary_ioerror, secondary_ioerror=True)
                result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
                self.assertFalse(result.success) if has_primary_ioerror else self.assertTrue(result.success)
                self.assertFalse(result.service_error)
                if has_primary_ioerror:
                    self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), result.response)
                else:
                    self.assertEqual('Mock success response', result.response)
                for _, kwargs in test_subject._http_get.call_args_list:
                    self.assertTrue('User-Agent' in kwargs['headers'])
                    self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
                self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

                # IMDS success
                test_subject._http_get = Mock(side_effect=self._mock_http_get)
                self._mock_imds_setup(primary_ioerror=has_primary_ioerror)
                result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
                self.assertTrue(result.success)
                self.assertFalse(result.service_error)
                self.assertEqual('Mock success response', result.response)
                for _, kwargs in test_subject._http_get.call_args_list:
                    self.assertTrue('User-Agent' in kwargs['headers'])
                    self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
                self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

                # IMDS throttled
                test_subject._http_get = Mock(side_effect=self._mock_http_get)
                self._mock_imds_setup(primary_ioerror=has_primary_ioerror, throttled=True)
                result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
                self.assertFalse(result.success)
                self.assertFalse(result.service_error)
                self.assertEqual('IMDS error in /metadata/{0}: Throttled'.format(resource_path), result.response)
                for _, kwargs in test_subject._http_get.call_args_list:
                    self.assertTrue('User-Agent' in kwargs['headers'])
                    self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
                self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

                # IMDS gone error
                test_subject._http_get = Mock(side_effect=self._mock_http_get)
                self._mock_imds_setup(primary_ioerror=has_primary_ioerror, gone_error=True)
                result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
                self.assertFalse(result.success)
                self.assertTrue(result.service_error)
                self.assertEqual('IMDS error in /metadata/{0}: HTTP Failed with Status Code 410: Gone'.format(resource_path), result.response)
                for _, kwargs in test_subject._http_get.call_args_list:
                    self.assertTrue('User-Agent' in kwargs['headers'])
                    self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
                self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)

                # IMDS bad request
                test_subject._http_get = Mock(side_effect=self._mock_http_get)
                self._mock_imds_setup(primary_ioerror=has_primary_ioerror, bad_request=True)
                result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health)
                self.assertFalse(result.success)
                self.assertFalse(result.service_error)
                self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), result.response)
                for _, kwargs in test_subject._http_get.call_args_list:
                    self.assertTrue('User-Agent' in kwargs['headers'])
                    self.assertEqual(expected_useragent, kwargs['headers']['User-Agent'])
                self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)
Пример #8
0
 def mock_crypt_util(self, *args, **kw):
     #Partially patch instance method of class CryptUtil
     cryptutil = CryptUtil(*args, **kw)
     cryptutil.gen_transport_cert = Mock(side_effect=self.mock_gen_trans_cert)
     return cryptutil
Пример #9
0
    def mock_http_get(self, url, *args, **kwargs):
        content = None

        resp = MagicMock()
        resp.status = httpclient.OK

        if "comp=versions" in url:  # wire server versions
            content = self.version_info
            self.call_counts["comp=versions"] += 1
        elif "/versions" in url:  # HostPlugin versions
            content = '["2015-09-01"]'
            self.call_counts["/versions"] += 1
        elif url.endswith("/health"):  # HostPlugin health
            content = ''
            self.call_counts["/health"] += 1
        elif "goalstate" in url:
            content = self.goal_state
            self.call_counts["goalstate"] += 1
        elif "hostingenvuri" in url:
            content = self.hosting_env
            self.call_counts["hostingenvuri"] += 1
        elif "sharedconfiguri" in url:
            content = self.shared_config
            self.call_counts["sharedconfiguri"] += 1
        elif "certificatesuri" in url:
            content = self.certs
            self.call_counts["certificatesuri"] += 1
        elif "extensionsconfiguri" in url:
            content = self.ext_conf
            self.call_counts["extensionsconfiguri"] += 1
        elif "remoteaccessinfouri" in url:
            content = self.remote_access
            self.call_counts["remoteaccessinfouri"] += 1
        elif ".vmSettings" in url or ".settings" in url:
            content = self.in_vm_artifacts_profile
            self.call_counts["in_vm_artifacts_profile"] += 1

        else:
            # A stale GoalState results in a 400 from the HostPlugin
            # for which the HTTP handler in restutil raises ResourceGoneError
            if self.emulate_stale_goal_state:
                if "extensionArtifact" in url:
                    self.emulate_stale_goal_state = False
                    self.call_counts["extensionArtifact"] += 1
                    raise ResourceGoneError()
                else:
                    raise HttpError()

            # For HostPlugin requests, replace the URL with that passed
            # via the x-ms-artifact-location header
            if "extensionArtifact" in url:
                self.call_counts["extensionArtifact"] += 1
                if "headers" not in kwargs:
                    raise ValueError("HostPlugin request is missing the HTTP headers: {0}", kwargs)
                if "x-ms-artifact-location" not in kwargs["headers"]:
                    raise ValueError("HostPlugin request is missing the x-ms-artifact-location header: {0}", kwargs)
                url = kwargs["headers"]["x-ms-artifact-location"]

            if "manifest.xml" in url:
                content = self.manifest
                self.call_counts["manifest.xml"] += 1
            elif "manifest_of_ga.xml" in url:
                content = self.ga_manifest
                self.call_counts["manifest_of_ga.xml"] += 1
            elif "ExampleHandlerLinux" in url:
                content = self.ext
                self.call_counts["ExampleHandlerLinux"] += 1
                resp.read = Mock(return_value=content)
                return resp
            elif ".vmSettings" in url or ".settings" in url:
                content = self.in_vm_artifacts_profile
                self.call_counts["in_vm_artifacts_profile"] += 1
            else:
                raise Exception("Bad url {0}".format(url))

        resp.read = Mock(return_value=content.encode("utf-8"))
        return resp
Пример #10
0
    def _provision_test(
            self,  # pylint: disable=invalid-name,too-many-arguments
            distro_name,
            distro_version,
            distro_full_name,
            ovf_file,
            provisionMessage,
            expect_success,
            patch_write_agent_disabled,
            patch_get_instance_id):  # pylint: disable=unused-argument
        """
        Assert that the agent issues two telemetry messages as part of a
        successful provisioning.

         1. Provision
         2. GuestState
        """
        ph = get_provision_handler(
            distro_name,  # pylint: disable=invalid-name
            distro_version,
            distro_full_name)
        ph.report_event = MagicMock()
        ph.reg_ssh_host_key = MagicMock(return_value='--thumprint--')

        mock_osutil = MagicMock()
        mock_osutil.decode_customdata = Mock(return_value="")

        ph.osutil = mock_osutil
        ph.protocol_util.osutil = mock_osutil
        ph.protocol_util.get_protocol = MagicMock()

        conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir)
        ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME)
        ovfenv_data = load_data(ovf_file)
        fileutil.write_file(ovfenv_file, ovfenv_data)

        ph.run()

        if expect_success:
            self.assertEqual(2, ph.report_event.call_count)
            positional_args, kw_args = ph.report_event.call_args_list[0]
            # [call('Provisioning succeeded (146473.68s)', duration=65, is_success=True)]
            self.assertTrue(
                re.match(r'Provisioning succeeded \(\d+\.\d+s\)',
                         positional_args[0]) is not None)
            self.assertTrue(isinstance(kw_args['duration'], int))
            self.assertTrue(kw_args['is_success'])

            positional_args, kw_args = ph.report_event.call_args_list[1]
            self.assertTrue(kw_args['operation'] == 'ProvisionGuestAgent')
            self.assertTrue(kw_args['message'] == provisionMessage)
            self.assertTrue(kw_args['is_success'])

            expected_disabled = True if provisionMessage == 'false' else False  # pylint: disable=simplifiable-if-expression
            self.assertTrue(
                patch_write_agent_disabled.call_count == expected_disabled)

        else:
            self.assertEqual(1, ph.report_event.call_count)
            positional_args, kw_args = ph.report_event.call_args_list[0]
            # [call(u'[ProtocolError] Failed to validate OVF: ProvisionGuestAgent not found')]
            self.assertTrue(
                'Failed to validate OVF: ProvisionGuestAgent not found' in
                positional_args[0])
            self.assertFalse(kw_args['is_success'])