def _get_persist_firewall_rules_handler(self, systemd=True): osutil = DefaultOSUtil() osutil.get_firewall_will_wait = MagicMock( return_value=self.__test_wait) osutil.get_agent_bin_path = MagicMock( return_value=self.__agent_bin_dir) osutil.get_systemd_unit_file_install_path = MagicMock( return_value=self.__systemd_dir) self._expected_service_name = PersistFirewallRulesHandler._AGENT_NETWORK_SETUP_NAME_FORMAT.format( osutil.get_service_name()) self._network_service_unit_file = os.path.join( self.__systemd_dir, self._expected_service_name) self._binary_file = os.path.join( conf.get_lib_dir(), PersistFirewallRulesHandler.BINARY_FILE_NAME) # Just for these tests, ignoring the mode of mkdir to allow non-sudo tests orig_mkdir = fileutil.mkdir with patch( "azurelinuxagent.common.persist_firewall_rules.fileutil.mkdir", side_effect=lambda path, **mode: orig_mkdir(path)): with patch( "azurelinuxagent.common.persist_firewall_rules.get_osutil", return_value=osutil): with patch('azurelinuxagent.common.osutil.systemd.is_systemd', return_value=systemd): with patch( "azurelinuxagent.common.utils.shellutil.subprocess.Popen", side_effect=self.__mock_popen): yield PersistFirewallRulesHandler( self.__test_dst_ip, self.__test_uid)
def test_detect_protocol_no_dhcp(self, WireProtocol, mock_get_lib_dir, _): WireProtocol.return_value.detect = Mock() mock_get_lib_dir.return_value = self.tmp_dir protocol_util = get_protocol_util() protocol_util.osutil = MagicMock() protocol_util.osutil.is_dhcp_available.return_value = False protocol_util.dhcp_handler = MagicMock() protocol_util.dhcp_handler.endpoint = None protocol_util.dhcp_handler.run = Mock() endpoint_file = protocol_util._get_wireserver_endpoint_file_path() # Test wire protocol when no endpoint file has been written protocol_util._detect_protocol() self.assertEqual(KNOWN_WIRESERVER_IP, protocol_util.get_wireserver_endpoint()) # Test wire protocol on dhcp failure protocol_util.osutil.is_dhcp_available.return_value = True protocol_util.dhcp_handler.run.side_effect = DhcpError() self.assertRaises(ProtocolError, protocol_util._detect_protocol)
def test_detect_protocol(self, WireProtocol, MetadataProtocol, _): WireProtocol.return_value = MagicMock() MetadataProtocol.return_value = MagicMock() protocol_util = get_protocol_util() protocol_util.dhcp_handler = MagicMock() protocol_util.dhcp_handler.endpoint = "foo.bar" # Test wire protocol is available protocol = protocol_util.get_protocol() self.assertEquals(WireProtocol.return_value, protocol) # Test wire protocol is not available protocol_util.clear_protocol() WireProtocol.return_value.detect.side_effect = ProtocolError() protocol = protocol_util.get_protocol() self.assertEquals(MetadataProtocol.return_value, protocol) # Test no protocol is available protocol_util.clear_protocol() WireProtocol.return_value.detect.side_effect = ProtocolError() MetadataProtocol.return_value.detect.side_effect = ProtocolError() self.assertRaises(ProtocolError, protocol_util.get_protocol)
def test_provision_telemetry_fail(self, mock_util, distro_name, distro_version, distro_full_name, _): """ Assert that the agent issues one telemetry message as part of a failed provisioning. 1. Provision """ ph = get_provision_handler(distro_name, distro_version, distro_full_name) ph.report_event = MagicMock() ph.reg_ssh_host_key = MagicMock(side_effect=ProvisionError( "--unit-test--")) mock_osutil = MagicMock() mock_osutil.decode_customdata = Mock(return_value="") ph.osutil = mock_osutil ph.protocol_util.osutil = mock_osutil ph.protocol_util.get_protocol = MagicMock() conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir) ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME) ovfenv_data = load_data("ovf-env.xml") fileutil.write_file(ovfenv_file, ovfenv_data) ph.run() positional_args, kw_args = ph.report_event.call_args_list[0] self.assertTrue(re.match(r'Provisioning failed: \[ProvisionError\] --unit-test-- \(\d+\.\d+s\)', positional_args[0]) is not None)
def test_get_protocol_new_wireserver_agent_generates_certificates(self, mock_wire_client, mock_enable_firewall, mock_get_lib_dir, _): """ This is for testing that a new WireServer Linux Agent generates appropriate certificates, protocol file, and endpoint file. """ # Setup mocks dir = tempfile.gettempdir() # pylint: disable=redefined-builtin mock_get_lib_dir.return_value = dir mock_enable_firewall.return_value = True protocol_util = get_protocol_util() protocol_util.osutil = MagicMock() mock_wire_client.return_value = MagicMock() protocol_util.dhcp_handler = MagicMock() protocol_util.dhcp_handler.endpoint = KNOWN_WIRESERVER_IP # Run protocol_util.get_protocol() # Check that WireServer Certs exist ws_cert_paths = [os.path.join(dir, ws_cert) for ws_cert in TestProtocolUtil.WIRESERVER_CERTIFICATES] for ws_cert_path in ws_cert_paths: self.assertTrue(os.path.isfile(ws_cert_path)) # Check firewall rules were not reset protocol_util.osutil.remove_firewall.assert_not_called() protocol_util.osutil.enable_firewall.assert_not_called() # Check Protocol File is updated to WireProtocol with open(os.path.join(dir, PROTOCOL_FILE_NAME), "r") as f: self.assertEqual(f.read(), WIRE_PROTOCOL_NAME) # Check Endpoint file is updated to WireServer IP with open(os.path.join(dir, ENDPOINT_FILE_NAME), 'r') as f: self.assertEqual(f.read(), KNOWN_WIRESERVER_IP)
def _provision_test(self, distro_name, distro_version, distro_full_name, ovf_file, provisionMessage, expect_success, patch_write_agent_disabled, patch_get_instance_id): """ Assert that the agent issues two telemetry messages as part of a successful provisioning. 1. Provision 2. GuestState """ ph = get_provision_handler(distro_name, distro_version, distro_full_name) ph.report_event = MagicMock() ph.reg_ssh_host_key = MagicMock(return_value='--thumprint--') mock_osutil = MagicMock() mock_osutil.decode_customdata = Mock(return_value="") ph.osutil = mock_osutil ph.protocol_util.osutil = mock_osutil ph.protocol_util.get_protocol = MagicMock() conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir) ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME) ovfenv_data = load_data(ovf_file) fileutil.write_file(ovfenv_file, ovfenv_data) ph.run() if expect_success: self.assertEqual(2, ph.report_event.call_count) positional_args, kw_args = ph.report_event.call_args_list[0] # [call('Provisioning succeeded (146473.68s)', duration=65, is_success=True)] self.assertTrue(re.match(r'Provisioning succeeded \(\d+\.\d+s\)', positional_args[0]) is not None) self.assertTrue(isinstance(kw_args['duration'], int)) self.assertTrue(kw_args['is_success']) positional_args, kw_args = ph.report_event.call_args_list[1] self.assertTrue(kw_args['operation'] == 'ProvisionGuestAgent') self.assertTrue(kw_args['message'] == provisionMessage) self.assertTrue(kw_args['is_success']) expected_disabled = True if provisionMessage == 'false' else False self.assertTrue(patch_write_agent_disabled.call_count == expected_disabled) else: self.assertEqual(1, ph.report_event.call_count) positional_args, kw_args = ph.report_event.call_args_list[0] # [call(u'[ProtocolError] Failed to validate OVF: ProvisionGuestAgent not found')] self.assertTrue('Failed to validate OVF: ProvisionGuestAgent not found' in positional_args[0]) self.assertFalse(kw_args['is_success'])
def test_get_protocol_metadataserver_to_wireserver_update_removes_metadataserver_artifacts( self, mock_wire_client, mock_enable_firewall, mock_get_lib_dir, _): """ This is for testing that agent upgrade from MetadataServer to WireServer protocol will clean up leftover MDS Certificates and reset firewall rules. Also check that WireServer certificates are present, and protocol/endpoint files are written to appropriately. """ # Setup Protocol file with MetadataProtocol dir = tempfile.gettempdir() protocol_filename = os.path.join(dir, PROTOCOL_FILE_NAME) with open(protocol_filename, "w") as f: f.write(_METADATA_PROTOCOL_NAME) # Setup MDS Certificates mds_cert_paths = [ os.path.join(dir, mds_cert) for mds_cert in TestProtocolUtil.MDS_CERTIFICATES ] for mds_cert_path in mds_cert_paths: open(mds_cert_path, "w").close() # Setup mocks mock_get_lib_dir.return_value = dir mock_enable_firewall.return_value = True protocol_util = get_protocol_util() protocol_util.osutil = MagicMock() mock_wire_client.return_value = MagicMock() protocol_util.dhcp_handler = MagicMock() protocol_util.dhcp_handler.endpoint = KNOWN_WIRESERVER_IP # Run protocol_util.get_protocol() # Check MDS Certs do not exist for mds_cert_path in mds_cert_paths: self.assertFalse(os.path.exists(mds_cert_path)) # Check that WireServer Certs exist ws_cert_paths = [ os.path.join(dir, ws_cert) for ws_cert in TestProtocolUtil.WIRESERVER_CERTIFICATES ] for ws_cert_path in ws_cert_paths: self.assertTrue(os.path.isfile(ws_cert_path)) # Check firewall rules was reset protocol_util.osutil.remove_firewall.assert_called_once() protocol_util.osutil.enable_firewall.assert_called_once() # Check Protocol File is updated to WireProtocol with open(os.path.join(dir, PROTOCOL_FILE_NAME), "r") as f: self.assertEquals(f.read(), WIRE_PROTOCOL_NAME) # Check Endpoint file is updated to WireServer IP with open(os.path.join(dir, ENDPOINT_FILE_NAME), 'r') as f: self.assertEquals(f.read(), KNOWN_WIRESERVER_IP)
def test_get_protocol(self, WireProtocol, _): WireProtocol.return_value = MagicMock() protocol_util = get_protocol_util() protocol_util.get_wireserver_endpoint = Mock() protocol_util._detect_protocol = MagicMock() protocol_util._save_protocol("WireProtocol") protocol = protocol_util.get_protocol() self.assertEquals(WireProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_any_call()
def test_get_protocol(self, WireProtocol, _): # pylint: disable=invalid-name WireProtocol.return_value = MagicMock() protocol_util = get_protocol_util() protocol_util.get_wireserver_endpoint = Mock() protocol_util._detect_protocol = MagicMock() # pylint: disable=protected-access protocol_util._save_protocol("WireProtocol") # pylint: disable=protected-access protocol = protocol_util.get_protocol() self.assertEqual(WireProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_any_call()
def test_cleanup_metadata_server_artifacts_firewall_disabled( self, mock_os_getuid, mock_get_lib_dir, mock_enable_firewall): # Setup Certificate Files dir = tempfile.gettempdir() # pylint: disable=redefined-builtin metadata_server_transport_prv_file = os.path.join( dir, _LEGACY_METADATA_SERVER_TRANSPORT_PRV_FILE_NAME) metadata_server_transport_cert_file = os.path.join( dir, _LEGACY_METADATA_SERVER_TRANSPORT_CERT_FILE_NAME) metadata_server_p7b_file = os.path.join( dir, _LEGACY_METADATA_SERVER_P7B_FILE_NAME) open(metadata_server_transport_prv_file, 'w').close() open(metadata_server_transport_cert_file, 'w').close() open(metadata_server_p7b_file, 'w').close() # Setup Mocks mock_get_lib_dir.return_value = dir mock_enable_firewall.return_value = False fixed_uid = 0 mock_os_getuid.return_value = fixed_uid osutil = MagicMock() # pylint: disable=redefined-outer-name # Run migration_util.cleanup_metadata_server_artifacts(osutil) # Assert files deleted self.assertFalse(os.path.exists(metadata_server_transport_prv_file)) self.assertFalse(os.path.exists(metadata_server_transport_cert_file)) self.assertFalse(os.path.exists(metadata_server_p7b_file)) # Assert Firewall rule calls osutil.remove_firewall.assert_called_once_with( dst_ip=_KNOWN_METADATASERVER_IP, uid=fixed_uid) osutil.enable_firewall.assert_not_called()
def test_logger_should_log_micro_seconds(self, mock_dt): # datetime.isoformat() skips ms if ms=0, this test ensures that ms is always set file_name = "test.log" file_path = os.path.join(self.tmp_dir, file_name) test_logger = logger.Logger() test_logger.add_appender(logger.AppenderType.FILE, logger.LogLevel.INFO, path=file_path) ts_with_no_ms = datetime.utcnow().replace(microsecond=0) mock_dt.utcnow = MagicMock(return_value=ts_with_no_ms) test_logger.info("The time should contain milli-seconds") with open(file_path, "r") as log_file: log = log_file.read() try: time_in_file = datetime.strptime( log.split(logger.LogLevel.STRINGS[logger.LogLevel.INFO]) [0].strip(), logger.Logger.LogTimeFormatInUTC) except ValueError: self.fail( "Ensure timestamp follows ISO-8601 format and has micro seconds in it" ) self.assertEqual(ts_with_no_ms, time_in_file, "Timestamps dont match")
def test_provision(self, mock_util, distro_name, distro_version, distro_full_name): provision_handler = get_provision_handler(distro_name, distro_version, distro_full_name) mock_osutil = MagicMock() mock_osutil.decode_customdata = Mock(return_value="") provision_handler.osutil = mock_osutil provision_handler.protocol_util.osutil = mock_osutil provision_handler.protocol_util.get_protocol = MagicMock() conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir) ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME) ovfenv_data = load_data("ovf-env.xml") fileutil.write_file(ovfenv_file, ovfenv_data) provision_handler.run()
def test_http_request_proxy_secure(self, HTTPConnection, HTTPSConnection): mock_conn = \ MagicMock(getresponse=\ Mock(return_value=\ Mock(read=Mock(return_value="TheResults")))) HTTPSConnection.return_value = mock_conn resp = restutil._http_request("GET", "foo", "/bar", proxy_host="foo.bar", proxy_port=23333, secure=True) HTTPConnection.assert_not_called() HTTPSConnection.assert_has_calls([call("foo.bar", 23333, timeout=10)]) mock_conn.request.assert_has_calls([ call(method="GET", url="https://foo:443/bar", body=None, headers={ 'User-Agent': HTTP_USER_AGENT, 'Connection': 'close' }) ]) self.assertEqual(1, mock_conn.getresponse.call_count) self.assertNotEqual(None, resp) self.assertEqual("TheResults", resp.read())
def test_error_heartbeat_creates_no_signal(self, patch_report_heartbeat, patch_http_get, patch_add_event, *args): monitor_handler = get_monitor_handler() protocol = WireProtocol('endpoint') protocol.update_goal_state = MagicMock() with patch( 'azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol', return_value=protocol): monitor_handler.init_protocols() monitor_handler.last_host_plugin_heartbeat = datetime.datetime.utcnow( ) - timedelta(hours=1) patch_http_get.side_effect = IOError('client error') monitor_handler.send_host_plugin_heartbeat() # health report should not be made self.assertEqual(0, patch_report_heartbeat.call_count) # telemetry with failure details is sent self.assertEqual(1, patch_add_event.call_count) self.assertEqual('HostPluginHeartbeat', patch_add_event.call_args[1]['op']) self.assertTrue( 'client error' in patch_add_event.call_args[1]['message']) self.assertEqual(False, patch_add_event.call_args[1]['is_success']) monitor_handler.stop()
def _create_collect_logs_handler(iterations=1, systemd_present=True): """ Creates an instance of CollectLogsHandler that * Uses a mock_wire_protocol for network requests, * Runs its main loop only the number of times given in the 'iterations' parameter, and * Does not sleep at the end of each iteration The returned CollectLogsHandler is augmented with 2 methods: * get_mock_wire_protocol() - returns the mock protocol * run_and_wait() - invokes run() and wait() on the CollectLogsHandler """ with mock_wire_protocol(DATA_FILE) as protocol: protocol_util = MagicMock() protocol_util.get_protocol = Mock(return_value=protocol) with patch("azurelinuxagent.ga.collect_logs.get_protocol_util", return_value=protocol_util): with patch("azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped", side_effect=[False] * iterations + [True]): with patch("time.sleep"): with patch("azurelinuxagent.common.osutil.systemd.is_systemd", return_value=systemd_present): with patch("azurelinuxagent.ga.collect_logs.conf.get_collect_logs", return_value=True): def run_and_wait(): collect_logs_handler.run() collect_logs_handler.join() collect_logs_handler = get_collect_logs_handler() collect_logs_handler.get_mock_wire_protocol = lambda: protocol collect_logs_handler.run_and_wait = run_and_wait yield collect_logs_handler
def _mock_http_get(self, *_, **kwargs): if "foo.bar" == kwargs[ 'endpoint'] and not self._mock_imds_expect_fallback: raise Exception("Unexpected endpoint called") if self._mock_imds_primary_ioerror and "169.254.169.254" == kwargs[ 'endpoint']: raise HttpError( "[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made" .format(kwargs['endpoint'], kwargs['resource_path'])) if self._mock_imds_secondary_ioerror and "foo.bar" == kwargs[ 'endpoint']: raise HttpError( "[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made" .format(kwargs['endpoint'], kwargs['resource_path'])) if self._mock_imds_gone_error: raise ResourceGoneError("Resource is gone") if self._mock_imds_throttled: raise HttpError( "[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made" .format(kwargs['endpoint'], kwargs['resource_path'])) resp = MagicMock() resp.reason = 'reason' if self._mock_imds_bad_request: resp.status = httpclient.NOT_FOUND resp.read.return_value = 'Mock not found' else: resp.status = httpclient.OK resp.read.return_value = 'Mock success response' return resp
def setUp(self): AgentTestCase.setUp(self) event.init_event_logger(os.path.join(self.tmp_dir, EVENTS_DIRECTORY)) CGroupsTelemetry.reset() clear_singleton_instances(ProtocolUtil) protocol = WireProtocol('endpoint') protocol.update_goal_state = MagicMock() self.get_protocol = patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol', return_value=protocol) self.get_protocol.start()
def test_http_request(self, HTTPConnection, HTTPSConnection): mock_http_conn = MagicMock() mock_http_resp = MagicMock() mock_http_conn.getresponse = Mock(return_value=mock_http_resp) HTTPConnection.return_value = mock_http_conn HTTPSConnection.return_value = mock_http_conn mock_http_resp.read = Mock(return_value="_(:3| <)_") # Test http get resp = restutil._http_request("GET", "foo", "bar") self.assertNotEquals(None, resp) self.assertEquals("_(:3| <)_", resp.read()) # Test https get resp = restutil._http_request("GET", "foo", "bar", secure=True) self.assertNotEquals(None, resp) self.assertEquals("_(:3| <)_", resp.read()) # Test http get with proxy mock_http_resp.read = Mock(return_value="_(:3| <)_") resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", proxy_port=23333) self.assertNotEquals(None, resp) self.assertEquals("_(:3| <)_", resp.read()) # Test https get resp = restutil._http_request("GET", "foo", "bar", secure=True) self.assertNotEquals(None, resp) self.assertEquals("_(:3| <)_", resp.read()) # Test https get with proxy mock_http_resp.read = Mock(return_value="_(:3| <)_") resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", proxy_port=23333, secure=True) self.assertNotEquals(None, resp) self.assertEquals("_(:3| <)_", resp.read())
def _mock_wire_protocol(): # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not # reuse a previous state clear_singleton_instances(ProtocolUtil) with mock_wire_protocol(DATA_FILE) as protocol: protocol_util = MagicMock() protocol_util.get_protocol = Mock(return_value=protocol) with patch("azurelinuxagent.ga.monitor.get_protocol_util", return_value=protocol_util): yield protocol
def test_get_protocol_wireserver_to_wireserver_update_removes_metadataserver_artifacts( self, mock_enable_firewall, mock_get_lib_dir, _): """ This is for testing that agent upgrade from WireServer to WireServer protocol will clean up leftover MDS Certificates (from a previous Metadata Server to Wireserver update, intermediate updated agent does not clean up MDS certificates) and reset firewall rules. We don't test that WireServer certificates, protocol file, or endpoint file were created because we already expect them to be created since we are updating from a WireServer agent. """ # Setup Protocol file with WireProtocol dir = tempfile.gettempdir() filename = os.path.join(dir, PROTOCOL_FILE_NAME) with open(filename, "w") as f: f.write(WIRE_PROTOCOL_NAME) # Setup MDS Certificates mds_cert_paths = [ os.path.join(dir, mds_cert) for mds_cert in TestProtocolUtil.MDS_CERTIFICATES ] for mds_cert_path in mds_cert_paths: open(mds_cert_path, "w").close() # Setup mocks mock_get_lib_dir.return_value = dir mock_enable_firewall.return_value = True protocol_util = get_protocol_util() protocol_util.osutil = MagicMock() protocol_util.dhcp_handler = MagicMock() protocol_util.dhcp_handler.endpoint = KNOWN_WIRESERVER_IP # Run protocol_util.get_protocol() # Check MDS Certs do not exist for mds_cert_path in mds_cert_paths: self.assertFalse(os.path.exists(mds_cert_path)) # Check firewall rules was reset protocol_util.osutil.remove_firewall.assert_called_once() protocol_util.osutil.enable_firewall.assert_called_once()
def test_get_protocol(self, WireProtocol, MetadataProtocol, _): WireProtocol.return_value = MagicMock() MetadataProtocol.return_value = MagicMock() protocol_util = get_protocol_util() protocol_util.get_wireserver_endpoint = Mock() protocol_util._detect_protocol = MagicMock() # Test for wire protocol protocol_util._save_protocol("WireProtocol") protocol = protocol_util.get_protocol() self.assertEquals(WireProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_any_call() # Test to ensure protocol persists protocol_util.get_wireserver_endpoint.reset_mock() protocol_util._save_protocol("MetadataProtocol") protocol = protocol_util.get_protocol() self.assertEquals(WireProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_not_called() # Test for metadata protocol protocol_util.clear_protocol() protocol_util._save_protocol("MetadataProtocol") protocol = protocol_util.get_protocol() self.assertEquals(MetadataProtocol.return_value, protocol) protocol_util.get_wireserver_endpoint.assert_not_called() # Test for unknown protocol protocol_util.clear_protocol() protocol_util._save_protocol("Not_a_Protocol") protocol_util._detect_protocol.side_effect = NotImplementedError() self.assertRaises(NotImplementedError, protocol_util.get_protocol) protocol_util.get_wireserver_endpoint.assert_not_called()
def mock_http_post(self, url, *args, **kwargs): content = None resp = MagicMock() resp.status = httpclient.OK if url.endswith('/HealthService'): self.call_counts['/HealthService'] += 1 content = '' else: raise Exception("Bad url {0}".format(url)) resp.read = Mock(return_value=content.encode("utf-8")) return resp
def mock_http_put(self, url, *args, **kwargs): # pylint: disable=unused-argument content = None resp = MagicMock() resp.status = httpclient.OK if url.endswith('/vmAgentLog'): self.call_counts['/vmAgentLog'] += 1 content = '' else: raise Exception("Bad url {0}".format(url)) resp.read = Mock(return_value=content.encode("utf-8")) return resp
def mock_http_put(self, url, data, **_): content = '' resp = MagicMock() resp.status = httpclient.OK if url.endswith('/vmAgentLog'): self.call_counts['/vmAgentLog'] += 1 elif url.endswith('/StatusBlob'): self.call_counts['/StatusBlob'] += 1 self.status_blobs.append(data) else: raise Exception("Bad url {0}".format(url)) resp.read = Mock(return_value=content.encode("utf-8")) return resp
def test_endpoint_file_states(self, mock_get_lib_dir, mock_fileutil, _): mock_get_lib_dir.return_value = self.tmp_dir mock_fileutil = MagicMock() protocol_util = get_protocol_util() endpoint_file = protocol_util._get_wireserver_endpoint_file_path() # Test get endpoint for io error mock_fileutil.read_file.side_effect = IOError() ep = protocol_util.get_wireserver_endpoint() self.assertEquals(ep, KNOWN_WIRESERVER_IP) # Test get endpoint when file not found mock_fileutil.read_file.side_effect = IOError(ENOENT, 'File not found') ep = protocol_util.get_wireserver_endpoint() self.assertEquals(ep, KNOWN_WIRESERVER_IP) # Test get endpoint for empty file mock_fileutil.read_file.return_value = "" ep = protocol_util.get_wireserver_endpoint() self.assertEquals(ep, KNOWN_WIRESERVER_IP) # Test set endpoint for io error mock_fileutil.write_file.side_effect = IOError() ep = protocol_util.get_wireserver_endpoint() self.assertRaises(OSUtilError, protocol_util._set_wireserver_endpoint('abc')) # Test clear endpoint for io error with open(endpoint_file, "w+") as ep_fd: ep_fd.write("") with patch('os.remove') as mock_remove: protocol_util._clear_wireserver_endpoint() self.assertEqual(1, mock_remove.call_count) self.assertEqual(endpoint_file, mock_remove.call_args_list[0][0][0]) # Test clear endpoint when file not found with patch('os.remove') as mock_remove: mock_remove = Mock(side_effect=IOError(ENOENT, 'File not found')) protocol_util._clear_wireserver_endpoint() mock_remove.assert_not_called()
def setUp(self): AgentTestCase.setUp(self) # Override for mocking Popen, should be of the form - (True/False, cmd-to-execute-if-True) self.__replace_popen_cmd = lambda *_: (False, "") self.__executed_commands = [] self.__test_dst_ip = "1.2.3.4" self.__test_uid = 9999 self.__test_wait = "-w" self.__systemd_dir = os.path.join(self.tmp_dir, "system") fileutil.mkdir(self.__systemd_dir) self.__agent_bin_dir = os.path.join(self.tmp_dir, "bin") fileutil.mkdir(self.__agent_bin_dir) self.__tmp_conf_lib = os.path.join(self.tmp_dir, "waagent") fileutil.mkdir(self.__tmp_conf_lib) conf.get_lib_dir = MagicMock(return_value=self.__tmp_conf_lib)
def test_read_response_error(self): """ Validate the read_response_error method handles encoding correctly """ responses = ['message', b'message', '\x80message\x80'] response = MagicMock() response.status = 'status' response.reason = 'reason' with patch.object(response, 'read') as patch_response: for s in responses: patch_response.return_value = s result = restutil.read_response_error(response) print("RESPONSE: {0}".format(s)) print("RESULT: {0}".format(result)) print("PRESENT: {0}".format('[status: reason]' in result)) self.assertTrue('[status: reason]' in result) self.assertTrue('message' in result)
def test_http_request_proxy_with_no_proxy_check(self, _http_request, sleep, mock_get_http_proxy): # pylint: disable=unused-argument mock_http_resp = MagicMock() mock_http_resp.read = Mock(return_value="hehe") _http_request.return_value = mock_http_resp mock_get_http_proxy.return_value = "host", 1234 # Return a host/port combination no_proxy_list = ["foo.com", "www.google.com", "168.63.129.16"] with patch.dict(os.environ, {'no_proxy': ",".join(no_proxy_list)}): # Test http get resp = restutil.http_get("http://foo.com", use_proxy=True) self.assertEqual("hehe", resp.read()) self.assertEqual(0, mock_get_http_proxy.call_count) # Test http get resp = restutil.http_get("http://bar.com", use_proxy=True) self.assertEqual("hehe", resp.read()) self.assertEqual(1, mock_get_http_proxy.call_count)
def _create_collect_logs_handler(iterations=1, cgroups_enabled=True, collect_logs_conf=True): """ Creates an instance of CollectLogsHandler that * Uses a mock_wire_protocol for network requests, * Runs its main loop only the number of times given in the 'iterations' parameter, and * Does not sleep at the end of each iteration The returned CollectLogsHandler is augmented with 2 methods: * get_mock_wire_protocol() - returns the mock protocol * run_and_wait() - invokes run() and wait() on the CollectLogsHandler """ with mock_wire_protocol(DATA_FILE) as protocol: protocol_util = MagicMock() protocol_util.get_protocol = Mock(return_value=protocol) with patch("azurelinuxagent.ga.collect_logs.get_protocol_util", return_value=protocol_util): with patch( "azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped", side_effect=[False] * iterations + [True]): with patch("time.sleep"): # Grab the singleton to patch it cgroups_configurator_singleton = CGroupConfigurator.get_instance( ) with patch.object(cgroups_configurator_singleton, "enabled", return_value=cgroups_enabled): with patch( "azurelinuxagent.ga.collect_logs.conf.get_collect_logs", return_value=collect_logs_conf): def run_and_wait(): collect_logs_handler.run() collect_logs_handler.join() collect_logs_handler = get_collect_logs_handler() collect_logs_handler.get_mock_wire_protocol = lambda: protocol collect_logs_handler.run_and_wait = run_and_wait yield collect_logs_handler
def test_http_request_with_retry(self, _http_request, sleep): mock_httpresp = MagicMock() mock_httpresp.read = Mock(return_value="hehe") _http_request.return_value = mock_httpresp #Test http get resp = restutil.http_get("http://foo.bar") self.assertEquals("hehe", resp.read()) #Test https get resp = restutil.http_get("https://foo.bar") self.assertEquals("hehe", resp.read()) #Test http failure _http_request.side_effect = httpclient.HTTPException("Http failure") self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar") #Test http failure _http_request.side_effect = IOError("IO failure") self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar")