def test_fallback_channel_503(self, patch_update, patch_put, patch_upload, _): """ When host plugin returns a 503, we should fall back to the direct channel """ test_goal_state = wire.GoalState( WireProtocolData(DATA_FILE).goal_state) status = restapi.VMStatus(status="Ready", message="Guest Agent is running") wire_protocol_client = wire.WireProtocol(wireserver_url).client wire_protocol_client.get_goal_state = Mock( return_value=test_goal_state) wire_protocol_client.ext_conf = wire.ExtensionsConfig(None) wire_protocol_client.ext_conf.status_upload_blob = sas_url wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type wire_protocol_client.status_blob.set_vm_status(status) # act wire_protocol_client.upload_status_blob() # assert direct route is called self.assertEqual(1, patch_upload.call_count, "Direct channel was not used") # assert host plugin route is called self.assertEqual(1, patch_put.call_count, "Host plugin was not used") # assert update goal state is only called once, non-forced self.assertEqual(1, patch_update.call_count, "Update goal state unexpected call count") self.assertEqual(0, len(patch_update.call_args[1]), "Update goal state unexpected call count") # ensure the correct url is used self.assertEqual(sas_url, patch_put.call_args[0][0]) # ensure host plugin is not set as default self.assertFalse(wire.HostPluginProtocol.is_default_channel())
def test_get_protocol(self, WireProtocol, MetadataProtocol, _): WireProtocol.return_value = MagicMock() MetadataProtocol.return_value = MagicMock() protocol_util = get_protocol_util() protocol_util.get_wireserver_endpoint = Mock() protocol_util._detect_protocol = MagicMock() # Test for wire protocol protocol_util._save_protocol("WireProtocol") protocol = protocol_util.get_protocol() self.assertEquals(WireProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_any_call() # Test to ensure protocol persists protocol_util.get_wireserver_endpoint.reset_mock() protocol_util._save_protocol("MetadataProtocol") protocol = protocol_util.get_protocol() self.assertEquals(WireProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_not_called() # Test for metadata protocol protocol_util.clear_protocol() protocol_util._save_protocol("MetadataProtocol") protocol = protocol_util.get_protocol() self.assertEquals(MetadataProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_not_called() # Test for unknown protocol protocol_util.clear_protocol() protocol_util._save_protocol("Not_a_Protocol") protocol_util._detect_protocol.side_effect = NotImplementedError() self.assertRaises(NotImplementedError, protocol_util.get_protocol) protocol_util.get_wireserver_endpoint.assert_not_called()
def test_default_channel(self, patch_update, patch_put, patch_upload, _): """ Status now defaults to HostPlugin. Validate that any errors on the public channel are ignored. Validate that the default channel is never changed as part of status upload. """ test_goal_state = wire.GoalState( WireProtocolData(DATA_FILE).goal_state) status = restapi.VMStatus(status="Ready", message="Guest Agent is running") wire_protocol_client = wire.WireProtocol(wireserver_url).client wire_protocol_client.get_goal_state = Mock( return_value=test_goal_state) wire_protocol_client.ext_conf = wire.ExtensionsConfig(None) wire_protocol_client.ext_conf.status_upload_blob = sas_url wire_protocol_client.ext_conf.status_upload_blob_type = page_blob_type wire_protocol_client.status_blob.set_vm_status(status) # act wire_protocol_client.upload_status_blob() # assert direct route is not called self.assertEqual(0, patch_upload.call_count, "Direct channel was used") # assert host plugin route is called self.assertEqual(1, patch_put.call_count, "Host plugin was not used") # assert update goal state is only called once, non-forced self.assertEqual(1, patch_update.call_count, "Unexpected call count") self.assertEqual(0, len(patch_update.call_args[1]), "Unexpected parameters") # ensure the correct url is used self.assertEqual(sas_url, patch_put.call_args[0][0]) # ensure host plugin is not set as default self.assertFalse(wire.HostPluginProtocol.is_default_channel())
def test_validate_http_request_when_uploading_status(self): """Validate correct set of data is sent to HostGAPlugin when reporting VM status""" with mock_wire_protocol(DATA_FILE) as protocol: test_goal_state = protocol.client._goal_state # pylint: disable=protected-access plugin = protocol.client.get_host_plugin() status_blob = protocol.client.status_blob status_blob.data = faux_status status_blob.vm_status = restapi.VMStatus(message="Ready", status="Ready") exp_method = 'PUT' exp_url = hostplugin_status_url exp_data = self._hostplugin_data( status_blob.get_block_blob_headers(len(faux_status)), bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: patch_http.return_value = Mock(status=httpclient.OK) with patch.object(plugin, 'get_api_versions') as patch_api: patch_api.return_value = API_VERSION plugin.put_vm_status(status_blob, sas_url, block_blob_type) self.assertTrue(patch_http.call_count == 2) # first call is to host plugin self._validate_hostplugin_args( patch_http.call_args_list[0], test_goal_state, exp_method, exp_url, exp_data) # second call is to health service self.assertEqual('POST', patch_http.call_args_list[1][0][0]) self.assertEqual(health_service_url, patch_http.call_args_list[1][0][1])
def test_protocol_file_states(self, _): protocol_util = get_protocol_util() protocol_util._clear_wireserver_endpoint = Mock() protocol_file = protocol_util._get_protocol_file_path() # Test clear protocol for io error with open(protocol_file, "w+") as proto_fd: proto_fd.write("") with patch('os.remove') as mock_remove: protocol_util.clear_protocol() self.assertEqual(1, protocol_util._clear_wireserver_endpoint.call_count) self.assertEqual(1, mock_remove.call_count) self.assertEqual(protocol_file, mock_remove.call_args_list[0][0][0]) # Test clear protocol when file not found protocol_util._clear_wireserver_endpoint.reset_mock() with patch('os.remove') as mock_remove: protocol_util.clear_protocol() self.assertEqual(1, protocol_util._clear_wireserver_endpoint.call_count) self.assertEqual(1, mock_remove.call_count) self.assertEqual(protocol_file, mock_remove.call_args_list[0][0][0])
def _create_collect_logs_handler(iterations=1, systemd_present=True): """ Creates an instance of CollectLogsHandler that * Uses a mock_wire_protocol for network requests, * Runs its main loop only the number of times given in the 'iterations' parameter, and * Does not sleep at the end of each iteration The returned CollectLogsHandler is augmented with 2 methods: * get_mock_wire_protocol() - returns the mock protocol * run_and_wait() - invokes run() and wait() on the CollectLogsHandler """ original_file_exists = os.path.exists def mock_file_exists(filepath): if filepath == SYSTEMD_RUN_PATH: return systemd_present return original_file_exists(filepath) with mock_wire_protocol(DATA_FILE) as protocol: protocol_util = MagicMock() protocol_util.get_protocol = Mock(return_value=protocol) with patch("azurelinuxagent.ga.collect_logs.get_protocol_util", return_value=protocol_util): with patch("azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped", side_effect=[False] * iterations + [True]): with patch("time.sleep"): with patch("azurelinuxagent.ga.collect_logs.os.path.exists", side_effect=mock_file_exists): with patch("azurelinuxagent.ga.collect_logs.conf.get_collect_logs", return_value=True): def run_and_wait(): collect_logs_handler.run() collect_logs_handler.join() collect_logs_handler = get_collect_logs_handler() collect_logs_handler.get_mock_wire_protocol = lambda: protocol collect_logs_handler.run_and_wait = run_and_wait yield collect_logs_handler
def test_endpoint_fallback(self): # http error status codes are tested in test_response_validation, none of which # should trigger a fallback. This is confirmed as _assert_validation will count # http GET calls and enforces a single GET call (fallback would cause 2) and # checks the url called. test_subject = imds.ImdsClient("foo.bar") # ensure user-agent gets set correctly for is_health, expected_useragent in [(False, restutil.HTTP_USER_AGENT), (True, restutil.HTTP_USER_AGENT_HEALTH)]: # set a different resource path for health query to make debugging unit test easier resource_path = 'something/health' if is_health else 'something' for has_primary_ioerror in (False, True): # secondary endpoint unreachable test_subject._http_get = Mock(side_effect=self._mock_http_get) self._mock_imds_setup(primary_ioerror=has_primary_ioerror, secondary_ioerror=True) result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) self.assertFalse(result.success) if has_primary_ioerror else self.assertTrue(result.success) self.assertFalse(result.service_error) if has_primary_ioerror: self.assertEqual('IMDS error in /metadata/{0}: Unable to connect to endpoint'.format(resource_path), result.response) else: self.assertEqual('Mock success response', result.response) for _, kwargs in test_subject._http_get.call_args_list: self.assertTrue('User-Agent' in kwargs['headers']) self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count) # IMDS success test_subject._http_get = Mock(side_effect=self._mock_http_get) self._mock_imds_setup(primary_ioerror=has_primary_ioerror) result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) self.assertTrue(result.success) self.assertFalse(result.service_error) self.assertEqual('Mock success response', result.response) for _, kwargs in test_subject._http_get.call_args_list: self.assertTrue('User-Agent' in kwargs['headers']) self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count) # IMDS throttled test_subject._http_get = Mock(side_effect=self._mock_http_get) self._mock_imds_setup(primary_ioerror=has_primary_ioerror, throttled=True) result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) self.assertFalse(result.success) self.assertFalse(result.service_error) self.assertEqual('IMDS error in /metadata/{0}: Throttled'.format(resource_path), result.response) for _, kwargs in test_subject._http_get.call_args_list: self.assertTrue('User-Agent' in kwargs['headers']) self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count) # IMDS gone error test_subject._http_get = Mock(side_effect=self._mock_http_get) self._mock_imds_setup(primary_ioerror=has_primary_ioerror, gone_error=True) result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) self.assertFalse(result.success) self.assertTrue(result.service_error) self.assertEqual('IMDS error in /metadata/{0}: HTTP Failed with Status Code 410: Gone'.format(resource_path), result.response) for _, kwargs in test_subject._http_get.call_args_list: self.assertTrue('User-Agent' in kwargs['headers']) self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count) # IMDS bad request test_subject._http_get = Mock(side_effect=self._mock_http_get) self._mock_imds_setup(primary_ioerror=has_primary_ioerror, bad_request=True) result = test_subject.get_metadata(resource_path=resource_path, is_health=is_health) self.assertFalse(result.success) self.assertFalse(result.service_error) self.assertEqual('IMDS error in /metadata/{0}: [HTTP Failed] [404: reason] Mock not found'.format(resource_path), result.response) for _, kwargs in test_subject._http_get.call_args_list: self.assertTrue('User-Agent' in kwargs['headers']) self.assertEqual(expected_useragent, kwargs['headers']['User-Agent']) self.assertEqual(2 if has_primary_ioerror else 1, test_subject._http_get.call_count)
def mock_crypt_util(self, *args, **kw): #Partially patch instance method of class CryptUtil cryptutil = CryptUtil(*args, **kw) cryptutil.gen_transport_cert = Mock(side_effect=self.mock_gen_trans_cert) return cryptutil
def mock_http_get(self, url, *args, **kwargs): content = None resp = MagicMock() resp.status = httpclient.OK if "comp=versions" in url: # wire server versions content = self.version_info self.call_counts["comp=versions"] += 1 elif "/versions" in url: # HostPlugin versions content = '["2015-09-01"]' self.call_counts["/versions"] += 1 elif url.endswith("/health"): # HostPlugin health content = '' self.call_counts["/health"] += 1 elif "goalstate" in url: content = self.goal_state self.call_counts["goalstate"] += 1 elif "hostingenvuri" in url: content = self.hosting_env self.call_counts["hostingenvuri"] += 1 elif "sharedconfiguri" in url: content = self.shared_config self.call_counts["sharedconfiguri"] += 1 elif "certificatesuri" in url: content = self.certs self.call_counts["certificatesuri"] += 1 elif "extensionsconfiguri" in url: content = self.ext_conf self.call_counts["extensionsconfiguri"] += 1 elif "remoteaccessinfouri" in url: content = self.remote_access self.call_counts["remoteaccessinfouri"] += 1 elif ".vmSettings" in url or ".settings" in url: content = self.in_vm_artifacts_profile self.call_counts["in_vm_artifacts_profile"] += 1 else: # A stale GoalState results in a 400 from the HostPlugin # for which the HTTP handler in restutil raises ResourceGoneError if self.emulate_stale_goal_state: if "extensionArtifact" in url: self.emulate_stale_goal_state = False self.call_counts["extensionArtifact"] += 1 raise ResourceGoneError() else: raise HttpError() # For HostPlugin requests, replace the URL with that passed # via the x-ms-artifact-location header if "extensionArtifact" in url: self.call_counts["extensionArtifact"] += 1 if "headers" not in kwargs: raise ValueError("HostPlugin request is missing the HTTP headers: {0}", kwargs) if "x-ms-artifact-location" not in kwargs["headers"]: raise ValueError("HostPlugin request is missing the x-ms-artifact-location header: {0}", kwargs) url = kwargs["headers"]["x-ms-artifact-location"] if "manifest.xml" in url: content = self.manifest self.call_counts["manifest.xml"] += 1 elif "manifest_of_ga.xml" in url: content = self.ga_manifest self.call_counts["manifest_of_ga.xml"] += 1 elif "ExampleHandlerLinux" in url: content = self.ext self.call_counts["ExampleHandlerLinux"] += 1 resp.read = Mock(return_value=content) return resp elif ".vmSettings" in url or ".settings" in url: content = self.in_vm_artifacts_profile self.call_counts["in_vm_artifacts_profile"] += 1 else: raise Exception("Bad url {0}".format(url)) resp.read = Mock(return_value=content.encode("utf-8")) return resp
def _provision_test( self, # pylint: disable=invalid-name,too-many-arguments distro_name, distro_version, distro_full_name, ovf_file, provisionMessage, expect_success, patch_write_agent_disabled, patch_get_instance_id): # pylint: disable=unused-argument """ Assert that the agent issues two telemetry messages as part of a successful provisioning. 1. Provision 2. GuestState """ ph = get_provision_handler( distro_name, # pylint: disable=invalid-name distro_version, distro_full_name) ph.report_event = MagicMock() ph.reg_ssh_host_key = MagicMock(return_value='--thumprint--') mock_osutil = MagicMock() mock_osutil.decode_customdata = Mock(return_value="") ph.osutil = mock_osutil ph.protocol_util.osutil = mock_osutil ph.protocol_util.get_protocol = MagicMock() conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir) ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME) ovfenv_data = load_data(ovf_file) fileutil.write_file(ovfenv_file, ovfenv_data) ph.run() if expect_success: self.assertEqual(2, ph.report_event.call_count) positional_args, kw_args = ph.report_event.call_args_list[0] # [call('Provisioning succeeded (146473.68s)', duration=65, is_success=True)] self.assertTrue( re.match(r'Provisioning succeeded \(\d+\.\d+s\)', positional_args[0]) is not None) self.assertTrue(isinstance(kw_args['duration'], int)) self.assertTrue(kw_args['is_success']) positional_args, kw_args = ph.report_event.call_args_list[1] self.assertTrue(kw_args['operation'] == 'ProvisionGuestAgent') self.assertTrue(kw_args['message'] == provisionMessage) self.assertTrue(kw_args['is_success']) expected_disabled = True if provisionMessage == 'false' else False # pylint: disable=simplifiable-if-expression self.assertTrue( patch_write_agent_disabled.call_count == expected_disabled) else: self.assertEqual(1, ph.report_event.call_count) positional_args, kw_args = ph.report_event.call_args_list[0] # [call(u'[ProtocolError] Failed to validate OVF: ProvisionGuestAgent not found')] self.assertTrue( 'Failed to validate OVF: ProvisionGuestAgent not found' in positional_args[0]) self.assertFalse(kw_args['is_success'])