Ejemplo n.º 1
0
    def _mock_http_get(self, *_, **kwargs):
        if "foo.bar" == kwargs[
                'endpoint'] and not self._mock_imds_expect_fallback:
            raise Exception("Unexpected endpoint called")
        if self._mock_imds_primary_ioerror and "169.254.169.254" == kwargs[
                'endpoint']:
            raise HttpError(
                "[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
                .format(kwargs['endpoint'], kwargs['resource_path']))
        if self._mock_imds_secondary_ioerror and "foo.bar" == kwargs[
                'endpoint']:
            raise HttpError(
                "[HTTP Failed] GET http://{0}/metadata/{1} -- IOError timed out -- 6 attempts made"
                .format(kwargs['endpoint'], kwargs['resource_path']))
        if self._mock_imds_gone_error:
            raise ResourceGoneError("Resource is gone")
        if self._mock_imds_throttled:
            raise HttpError(
                "[HTTP Retry] GET http://{0}/metadata/{1} -- Status Code 429 -- 25 attempts made"
                .format(kwargs['endpoint'], kwargs['resource_path']))

        resp = MagicMock()
        resp.reason = 'reason'
        if self._mock_imds_bad_request:
            resp.status = httpclient.NOT_FOUND
            resp.read.return_value = 'Mock not found'
        else:
            resp.status = httpclient.OK
            resp.read.return_value = 'Mock success response'
        return resp
Ejemplo n.º 2
0
    def _get_persist_firewall_rules_handler(self, systemd=True):

        osutil = DefaultOSUtil()
        osutil.get_firewall_will_wait = MagicMock(
            return_value=self.__test_wait)
        osutil.get_agent_bin_path = MagicMock(
            return_value=self.__agent_bin_dir)
        osutil.get_systemd_unit_file_install_path = MagicMock(
            return_value=self.__systemd_dir)

        self._expected_service_name = PersistFirewallRulesHandler._AGENT_NETWORK_SETUP_NAME_FORMAT.format(
            osutil.get_service_name())

        self._network_service_unit_file = os.path.join(
            self.__systemd_dir, self._expected_service_name)
        self._binary_file = os.path.join(
            conf.get_lib_dir(), PersistFirewallRulesHandler.BINARY_FILE_NAME)

        # Just for these tests, ignoring the mode of mkdir to allow non-sudo tests
        orig_mkdir = fileutil.mkdir
        with patch(
                "azurelinuxagent.common.persist_firewall_rules.fileutil.mkdir",
                side_effect=lambda path, **mode: orig_mkdir(path)):
            with patch(
                    "azurelinuxagent.common.persist_firewall_rules.get_osutil",
                    return_value=osutil):
                with patch('azurelinuxagent.common.osutil.systemd.is_systemd',
                           return_value=systemd):
                    with patch(
                            "azurelinuxagent.common.utils.shellutil.subprocess.Popen",
                            side_effect=self.__mock_popen):
                        yield PersistFirewallRulesHandler(
                            self.__test_dst_ip, self.__test_uid)
Ejemplo n.º 3
0
    def test_get_protocol_new_wireserver_agent_generates_certificates(self, mock_wire_client, mock_enable_firewall, mock_get_lib_dir, _):
        """
        This is for testing that a new WireServer Linux Agent generates appropriate certificates,
        protocol file, and endpoint file.
        """
        # Setup mocks
        dir = tempfile.gettempdir()  # pylint: disable=redefined-builtin
        mock_get_lib_dir.return_value = dir
        mock_enable_firewall.return_value = True
        protocol_util = get_protocol_util()
        protocol_util.osutil = MagicMock()
        mock_wire_client.return_value = MagicMock()
        protocol_util.dhcp_handler = MagicMock()
        protocol_util.dhcp_handler.endpoint = KNOWN_WIRESERVER_IP

        # Run
        protocol_util.get_protocol()

        # Check that WireServer Certs exist
        ws_cert_paths = [os.path.join(dir, ws_cert) for ws_cert in TestProtocolUtil.WIRESERVER_CERTIFICATES]
        for ws_cert_path in ws_cert_paths:
            self.assertTrue(os.path.isfile(ws_cert_path))

        # Check firewall rules were not reset
        protocol_util.osutil.remove_firewall.assert_not_called()
        protocol_util.osutil.enable_firewall.assert_not_called()

        # Check Protocol File is updated to WireProtocol
        with open(os.path.join(dir, PROTOCOL_FILE_NAME), "r") as f:
            self.assertEqual(f.read(), WIRE_PROTOCOL_NAME)
        
        # Check Endpoint file is updated to WireServer IP
        with open(os.path.join(dir, ENDPOINT_FILE_NAME), 'r') as f:
            self.assertEqual(f.read(), KNOWN_WIRESERVER_IP)
Ejemplo n.º 4
0
    def test_detect_protocol(self, WireProtocol, MetadataProtocol, _):
        WireProtocol.return_value = MagicMock()
        MetadataProtocol.return_value = MagicMock()

        protocol_util = get_protocol_util()

        protocol_util.dhcp_handler = MagicMock()
        protocol_util.dhcp_handler.endpoint = "foo.bar"

        # Test wire protocol is available
        protocol = protocol_util.get_protocol()
        self.assertEquals(WireProtocol.return_value, protocol)

        # Test wire protocol is not available
        protocol_util.clear_protocol()
        WireProtocol.return_value.detect.side_effect = ProtocolError()

        protocol = protocol_util.get_protocol()
        self.assertEquals(MetadataProtocol.return_value, protocol)

        # Test no protocol is available
        protocol_util.clear_protocol()
        WireProtocol.return_value.detect.side_effect = ProtocolError()
        MetadataProtocol.return_value.detect.side_effect = ProtocolError()

        self.assertRaises(ProtocolError, protocol_util.get_protocol)
Ejemplo n.º 5
0
    def test_detect_protocol_no_dhcp(self, WireProtocol, mock_get_lib_dir, _):
        WireProtocol.return_value.detect = Mock()
        mock_get_lib_dir.return_value = self.tmp_dir

        protocol_util = get_protocol_util()

        protocol_util.osutil = MagicMock()
        protocol_util.osutil.is_dhcp_available.return_value = False

        protocol_util.dhcp_handler = MagicMock()
        protocol_util.dhcp_handler.endpoint = None
        protocol_util.dhcp_handler.run = Mock()

        endpoint_file = protocol_util._get_wireserver_endpoint_file_path()

        # Test wire protocol when no endpoint file has been written
        protocol_util._detect_protocol()
        self.assertEqual(KNOWN_WIRESERVER_IP,
                         protocol_util.get_wireserver_endpoint())

        # Test wire protocol on dhcp failure
        protocol_util.osutil.is_dhcp_available.return_value = True
        protocol_util.dhcp_handler.run.side_effect = DhcpError()

        self.assertRaises(ProtocolError, protocol_util._detect_protocol)
Ejemplo n.º 6
0
def _create_collect_logs_handler(iterations=1, systemd_present=True):
    """
    Creates an instance of CollectLogsHandler that
        * Uses a mock_wire_protocol for network requests,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned CollectLogsHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the CollectLogsHandler

    """
    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.collect_logs.get_protocol_util", return_value=protocol_util):
            with patch("azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped", side_effect=[False] * iterations + [True]):
                with patch("time.sleep"):
                    with patch("azurelinuxagent.common.osutil.systemd.is_systemd", return_value=systemd_present):
                        with patch("azurelinuxagent.ga.collect_logs.conf.get_collect_logs", return_value=True):
                            def run_and_wait():
                                collect_logs_handler.run()
                                collect_logs_handler.join()

                            collect_logs_handler = get_collect_logs_handler()
                            collect_logs_handler.get_mock_wire_protocol = lambda: protocol
                            collect_logs_handler.run_and_wait = run_and_wait
                            yield collect_logs_handler
Ejemplo n.º 7
0
    def test_get_protocol_metadataserver_to_wireserver_update_removes_metadataserver_artifacts(
            self, mock_wire_client, mock_enable_firewall, mock_get_lib_dir, _):
        """
        This is for testing that agent upgrade from MetadataServer to WireServer protocol
        will clean up leftover MDS Certificates and reset firewall rules. Also check that
        WireServer certificates are present, and protocol/endpoint files are written to appropriately.
        """
        # Setup Protocol file with MetadataProtocol
        dir = tempfile.gettempdir()
        protocol_filename = os.path.join(dir, PROTOCOL_FILE_NAME)
        with open(protocol_filename, "w") as f:
            f.write(_METADATA_PROTOCOL_NAME)

        # Setup MDS Certificates
        mds_cert_paths = [
            os.path.join(dir, mds_cert)
            for mds_cert in TestProtocolUtil.MDS_CERTIFICATES
        ]
        for mds_cert_path in mds_cert_paths:
            open(mds_cert_path, "w").close()

        # Setup mocks
        mock_get_lib_dir.return_value = dir
        mock_enable_firewall.return_value = True
        protocol_util = get_protocol_util()
        protocol_util.osutil = MagicMock()
        mock_wire_client.return_value = MagicMock()
        protocol_util.dhcp_handler = MagicMock()
        protocol_util.dhcp_handler.endpoint = KNOWN_WIRESERVER_IP

        # Run
        protocol_util.get_protocol()

        # Check MDS Certs do not exist
        for mds_cert_path in mds_cert_paths:
            self.assertFalse(os.path.exists(mds_cert_path))

        # Check that WireServer Certs exist
        ws_cert_paths = [
            os.path.join(dir, ws_cert)
            for ws_cert in TestProtocolUtil.WIRESERVER_CERTIFICATES
        ]
        for ws_cert_path in ws_cert_paths:
            self.assertTrue(os.path.isfile(ws_cert_path))

        # Check firewall rules was reset
        protocol_util.osutil.remove_firewall.assert_called_once()
        protocol_util.osutil.enable_firewall.assert_called_once()

        # Check Protocol File is updated to WireProtocol
        with open(os.path.join(dir, PROTOCOL_FILE_NAME), "r") as f:
            self.assertEquals(f.read(), WIRE_PROTOCOL_NAME)

        # Check Endpoint file is updated to WireServer IP
        with open(os.path.join(dir, ENDPOINT_FILE_NAME), 'r') as f:
            self.assertEquals(f.read(), KNOWN_WIRESERVER_IP)
Ejemplo n.º 8
0
def _mock_wire_protocol():
    # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not
    # reuse a previous state
    clear_singleton_instances(ProtocolUtil)

    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.monitor.get_protocol_util", return_value=protocol_util):
            yield protocol
Ejemplo n.º 9
0
    def test_get_protocol(self, WireProtocol, _):
        WireProtocol.return_value = MagicMock()

        protocol_util = get_protocol_util()
        protocol_util.get_wireserver_endpoint = Mock()
        protocol_util._detect_protocol = MagicMock()
        protocol_util._save_protocol("WireProtocol")

        protocol = protocol_util.get_protocol()

        self.assertEquals(WireProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_any_call()
Ejemplo n.º 10
0
    def test_get_protocol(self, WireProtocol, _):  # pylint: disable=invalid-name
        WireProtocol.return_value = MagicMock()

        protocol_util = get_protocol_util()
        protocol_util.get_wireserver_endpoint = Mock()
        protocol_util._detect_protocol = MagicMock()  # pylint: disable=protected-access
        protocol_util._save_protocol("WireProtocol")  # pylint: disable=protected-access

        protocol = protocol_util.get_protocol()

        self.assertEqual(WireProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_any_call()
Ejemplo n.º 11
0
    def mock_http_post(self, url, *args, **kwargs):
        content = None

        resp = MagicMock()
        resp.status = httpclient.OK

        if url.endswith('/HealthService'):
            self.call_counts['/HealthService'] += 1
            content = ''
        else:
            raise Exception("Bad url {0}".format(url))

        resp.read = Mock(return_value=content.encode("utf-8"))
        return resp
Ejemplo n.º 12
0
    def mock_http_put(self, url, *args, **kwargs):  # pylint: disable=unused-argument
        content = None

        resp = MagicMock()
        resp.status = httpclient.OK

        if url.endswith('/vmAgentLog'):
            self.call_counts['/vmAgentLog'] += 1
            content = ''
        else:
            raise Exception("Bad url {0}".format(url))

        resp.read = Mock(return_value=content.encode("utf-8"))
        return resp
Ejemplo n.º 13
0
    def test_cleanup_metadata_server_artifacts_firewall_disabled(
            self, mock_os_getuid, mock_get_lib_dir, mock_enable_firewall):
        # Setup Certificate Files
        dir = tempfile.gettempdir()  # pylint: disable=redefined-builtin
        metadata_server_transport_prv_file = os.path.join(
            dir, _LEGACY_METADATA_SERVER_TRANSPORT_PRV_FILE_NAME)
        metadata_server_transport_cert_file = os.path.join(
            dir, _LEGACY_METADATA_SERVER_TRANSPORT_CERT_FILE_NAME)
        metadata_server_p7b_file = os.path.join(
            dir, _LEGACY_METADATA_SERVER_P7B_FILE_NAME)
        open(metadata_server_transport_prv_file, 'w').close()
        open(metadata_server_transport_cert_file, 'w').close()
        open(metadata_server_p7b_file, 'w').close()

        # Setup Mocks
        mock_get_lib_dir.return_value = dir
        mock_enable_firewall.return_value = False
        fixed_uid = 0
        mock_os_getuid.return_value = fixed_uid
        osutil = MagicMock()  # pylint: disable=redefined-outer-name

        # Run
        migration_util.cleanup_metadata_server_artifacts(osutil)

        # Assert files deleted
        self.assertFalse(os.path.exists(metadata_server_transport_prv_file))
        self.assertFalse(os.path.exists(metadata_server_transport_cert_file))
        self.assertFalse(os.path.exists(metadata_server_p7b_file))

        # Assert Firewall rule calls
        osutil.remove_firewall.assert_called_once_with(
            dst_ip=_KNOWN_METADATASERVER_IP, uid=fixed_uid)
        osutil.enable_firewall.assert_not_called()
Ejemplo n.º 14
0
    def test_logger_should_log_micro_seconds(self, mock_dt):
        # datetime.isoformat() skips ms if ms=0, this test ensures that ms is always set

        file_name = "test.log"
        file_path = os.path.join(self.tmp_dir, file_name)
        test_logger = logger.Logger()
        test_logger.add_appender(logger.AppenderType.FILE,
                                 logger.LogLevel.INFO,
                                 path=file_path)

        ts_with_no_ms = datetime.utcnow().replace(microsecond=0)
        mock_dt.utcnow = MagicMock(return_value=ts_with_no_ms)

        test_logger.info("The time should contain milli-seconds")

        with open(file_path, "r") as log_file:
            log = log_file.read()
            try:
                time_in_file = datetime.strptime(
                    log.split(logger.LogLevel.STRINGS[logger.LogLevel.INFO])
                    [0].strip(), logger.Logger.LogTimeFormatInUTC)
            except ValueError:
                self.fail(
                    "Ensure timestamp follows ISO-8601 format and has micro seconds in it"
                )

            self.assertEqual(ts_with_no_ms, time_in_file,
                             "Timestamps dont match")
Ejemplo n.º 15
0
    def test_error_heartbeat_creates_no_signal(self, patch_report_heartbeat,
                                               patch_http_get, patch_add_event,
                                               *args):

        monitor_handler = get_monitor_handler()
        protocol = WireProtocol('endpoint')
        protocol.update_goal_state = MagicMock()
        with patch(
                'azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol',
                return_value=protocol):
            monitor_handler.init_protocols()
            monitor_handler.last_host_plugin_heartbeat = datetime.datetime.utcnow(
            ) - timedelta(hours=1)

            patch_http_get.side_effect = IOError('client error')
            monitor_handler.send_host_plugin_heartbeat()

            # health report should not be made
            self.assertEqual(0, patch_report_heartbeat.call_count)

            # telemetry with failure details is sent
            self.assertEqual(1, patch_add_event.call_count)
            self.assertEqual('HostPluginHeartbeat',
                             patch_add_event.call_args[1]['op'])
            self.assertTrue(
                'client error' in patch_add_event.call_args[1]['message'])

            self.assertEqual(False, patch_add_event.call_args[1]['is_success'])
            monitor_handler.stop()
Ejemplo n.º 16
0
    def mock_http_put(self, url, data, **_):
        content = ''

        resp = MagicMock()
        resp.status = httpclient.OK

        if url.endswith('/vmAgentLog'):
            self.call_counts['/vmAgentLog'] += 1
        elif url.endswith('/StatusBlob'):
            self.call_counts['/StatusBlob'] += 1
            self.status_blobs.append(data)
        else:
            raise Exception("Bad url {0}".format(url))

        resp.read = Mock(return_value=content.encode("utf-8"))
        return resp
Ejemplo n.º 17
0
 def test_provision(self, mock_util, distro_name, distro_version, distro_full_name):
     provision_handler = get_provision_handler(distro_name, distro_version,
                                               distro_full_name)
     mock_osutil = MagicMock()
     mock_osutil.decode_customdata = Mock(return_value="")
     
     provision_handler.osutil = mock_osutil
     provision_handler.protocol_util.osutil = mock_osutil
     provision_handler.protocol_util.get_protocol = MagicMock()
    
     conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir)
     ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME)
     ovfenv_data = load_data("ovf-env.xml")
     fileutil.write_file(ovfenv_file, ovfenv_data)
      
     provision_handler.run()
Ejemplo n.º 18
0
    def test_http_request_proxy_secure(self, HTTPConnection, HTTPSConnection):
        mock_conn = \
            MagicMock(getresponse=\
                Mock(return_value=\
                    Mock(read=Mock(return_value="TheResults"))))

        HTTPSConnection.return_value = mock_conn

        resp = restutil._http_request("GET",
                                      "foo",
                                      "/bar",
                                      proxy_host="foo.bar",
                                      proxy_port=23333,
                                      secure=True)

        HTTPConnection.assert_not_called()
        HTTPSConnection.assert_has_calls([call("foo.bar", 23333, timeout=10)])
        mock_conn.request.assert_has_calls([
            call(method="GET",
                 url="https://foo:443/bar",
                 body=None,
                 headers={
                     'User-Agent': HTTP_USER_AGENT,
                     'Connection': 'close'
                 })
        ])
        self.assertEqual(1, mock_conn.getresponse.call_count)
        self.assertNotEqual(None, resp)
        self.assertEqual("TheResults", resp.read())
Ejemplo n.º 19
0
 def test_read_response_error(self):
     """
     Validate the read_response_error method handles encoding correctly
     """
     responses = ['message', b'message', '\x80message\x80']
     response = MagicMock()
     response.status = 'status'
     response.reason = 'reason'
     with patch.object(response, 'read') as patch_response:
         for s in responses:
             patch_response.return_value = s
             result = restutil.read_response_error(response)
             print("RESPONSE: {0}".format(s))
             print("RESULT: {0}".format(result))
             print("PRESENT: {0}".format('[status: reason]' in result))
             self.assertTrue('[status: reason]' in result)
             self.assertTrue('message' in result)
Ejemplo n.º 20
0
 def setUp(self):
     AgentTestCase.setUp(self)
     event.init_event_logger(os.path.join(self.tmp_dir, EVENTS_DIRECTORY))
     CGroupsTelemetry.reset()
     clear_singleton_instances(ProtocolUtil)
     protocol = WireProtocol('endpoint')
     protocol.update_goal_state = MagicMock()
     self.get_protocol = patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol', return_value=protocol)
     self.get_protocol.start()
Ejemplo n.º 21
0
    def test_http_request_proxy_with_no_proxy_check(self, _http_request, sleep,
                                                    mock_get_http_proxy):  # pylint: disable=unused-argument
        mock_http_resp = MagicMock()
        mock_http_resp.read = Mock(return_value="hehe")
        _http_request.return_value = mock_http_resp
        mock_get_http_proxy.return_value = "host", 1234  # Return a host/port combination

        no_proxy_list = ["foo.com", "www.google.com", "168.63.129.16"]
        with patch.dict(os.environ, {'no_proxy': ",".join(no_proxy_list)}):
            # Test http get
            resp = restutil.http_get("http://foo.com", use_proxy=True)
            self.assertEqual("hehe", resp.read())
            self.assertEqual(0, mock_get_http_proxy.call_count)

            # Test http get
            resp = restutil.http_get("http://bar.com", use_proxy=True)
            self.assertEqual("hehe", resp.read())
            self.assertEqual(1, mock_get_http_proxy.call_count)
Ejemplo n.º 22
0
def _create_collect_logs_handler(iterations=1,
                                 cgroups_enabled=True,
                                 collect_logs_conf=True):
    """
    Creates an instance of CollectLogsHandler that
        * Uses a mock_wire_protocol for network requests,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned CollectLogsHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the CollectLogsHandler

    """
    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.collect_logs.get_protocol_util",
                   return_value=protocol_util):
            with patch(
                    "azurelinuxagent.ga.collect_logs.CollectLogsHandler.stopped",
                    side_effect=[False] * iterations + [True]):
                with patch("time.sleep"):

                    # Grab the singleton to patch it
                    cgroups_configurator_singleton = CGroupConfigurator.get_instance(
                    )
                    with patch.object(cgroups_configurator_singleton,
                                      "enabled",
                                      return_value=cgroups_enabled):
                        with patch(
                                "azurelinuxagent.ga.collect_logs.conf.get_collect_logs",
                                return_value=collect_logs_conf):

                            def run_and_wait():
                                collect_logs_handler.run()
                                collect_logs_handler.join()

                            collect_logs_handler = get_collect_logs_handler()
                            collect_logs_handler.get_mock_wire_protocol = lambda: protocol
                            collect_logs_handler.run_and_wait = run_and_wait
                            yield collect_logs_handler
Ejemplo n.º 23
0
    def test_http_request_with_retry(self, _http_request, sleep):
        mock_httpresp = MagicMock()
        mock_httpresp.read = Mock(return_value="hehe")
        _http_request.return_value = mock_httpresp

        #Test http get
        resp = restutil.http_get("http://foo.bar") 
        self.assertEquals("hehe", resp.read())

        #Test https get
        resp = restutil.http_get("https://foo.bar") 
        self.assertEquals("hehe", resp.read())
        
        #Test http failure
        _http_request.side_effect = httpclient.HTTPException("Http failure")
        self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar")

        #Test http failure
        _http_request.side_effect = IOError("IO failure")
        self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar")
Ejemplo n.º 24
0
    def test_http_request_with_retry(self, _http_request, sleep):
        mock_httpresp = MagicMock()
        mock_httpresp.read = Mock(return_value="hehe")
        _http_request.return_value = mock_httpresp

        #Test http get
        resp = restutil.http_get("http://foo.bar") 
        self.assertEquals("hehe", resp.read())

        #Test https get
        resp = restutil.http_get("https://foo.bar") 
        self.assertEquals("hehe", resp.read())
        
        #Test http failure
        _http_request.side_effect = httpclient.HTTPException("Http failure")
        self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar")

        #Test http failure
        _http_request.side_effect = IOError("IO failure")
        self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar")
Ejemplo n.º 25
0
    def test_get_protocol_wireserver_to_wireserver_update_removes_metadataserver_artifacts(
            self, mock_enable_firewall, mock_get_lib_dir, _):
        """
        This is for testing that agent upgrade from WireServer to WireServer protocol
        will clean up leftover MDS Certificates (from a previous Metadata Server to Wireserver
        update, intermediate updated agent does not clean up MDS certificates) and reset firewall rules.
        We don't test that WireServer certificates, protocol file, or endpoint file were created
        because we already expect them to be created since we are updating from a WireServer agent.
        """
        # Setup Protocol file with WireProtocol
        dir = tempfile.gettempdir()
        filename = os.path.join(dir, PROTOCOL_FILE_NAME)
        with open(filename, "w") as f:
            f.write(WIRE_PROTOCOL_NAME)

        # Setup MDS Certificates
        mds_cert_paths = [
            os.path.join(dir, mds_cert)
            for mds_cert in TestProtocolUtil.MDS_CERTIFICATES
        ]
        for mds_cert_path in mds_cert_paths:
            open(mds_cert_path, "w").close()

        # Setup mocks
        mock_get_lib_dir.return_value = dir
        mock_enable_firewall.return_value = True
        protocol_util = get_protocol_util()
        protocol_util.osutil = MagicMock()
        protocol_util.dhcp_handler = MagicMock()
        protocol_util.dhcp_handler.endpoint = KNOWN_WIRESERVER_IP

        # Run
        protocol_util.get_protocol()

        # Check MDS Certs do not exist
        for mds_cert_path in mds_cert_paths:
            self.assertFalse(os.path.exists(mds_cert_path))

        # Check firewall rules was reset
        protocol_util.osutil.remove_firewall.assert_called_once()
        protocol_util.osutil.enable_firewall.assert_called_once()
Ejemplo n.º 26
0
 def mock_http_get(self, url, *args, **kwargs):
     content = None
     if url.count(u"identity?") > 0:
         content = self.identity
     elif url.count(u"certificates") > 0:
         content = self.certificates
     elif url.count(u"certificates_data") > 0:
         content = self.certificates_data
     elif url.count(u"extensionHandlers") > 0:
         content = self.ext_handlers
     elif url.count(u"versionUri") > 0:
         content = self.ext_handler_pkgs
     else:
         raise Exception("Bad url {0}".format(url))
     resp = MagicMock()
     resp.status = httpclient.OK
     if content is None:
         resp.read = Mock(return_value=None)
     else:
         resp.read = Mock(return_value=content.encode("utf-8"))
     return resp
Ejemplo n.º 27
0
    def test_http_request(self, HTTPConnection, HTTPSConnection):
        mock_http_conn = MagicMock()
        mock_http_resp = MagicMock()
        mock_http_conn.getresponse = Mock(return_value=mock_http_resp)
        HTTPConnection.return_value = mock_http_conn
        HTTPSConnection.return_value = mock_http_conn

        mock_http_resp.read = Mock(return_value="_(:3| <)_")

        # Test http get
        resp = restutil._http_request("GET", "foo", "bar")
        self.assertNotEquals(None, resp)
        self.assertEquals("_(:3| <)_", resp.read())

        # Test https get
        resp = restutil._http_request("GET", "foo", "bar", secure=True)
        self.assertNotEquals(None, resp)
        self.assertEquals("_(:3| <)_", resp.read())

        # Test http get with proxy
        mock_http_resp.read = Mock(return_value="_(:3| <)_")
        resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", proxy_port=23333)
        self.assertNotEquals(None, resp)
        self.assertEquals("_(:3| <)_", resp.read())

        # Test https get
        resp = restutil._http_request("GET", "foo", "bar", secure=True)
        self.assertNotEquals(None, resp)
        self.assertEquals("_(:3| <)_", resp.read())

        # Test https get with proxy
        mock_http_resp.read = Mock(return_value="_(:3| <)_")
        resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", proxy_port=23333, secure=True)
        self.assertNotEquals(None, resp)
        self.assertEquals("_(:3| <)_", resp.read())
Ejemplo n.º 28
0
def _create_monitor_handler(enabled_operations=[], iterations=1):
    """
    Creates an instance of MonitorHandler that
        * Uses a mock_wire_protocol for network requests,
        * Executes only the operations given in the 'enabled_operations' parameter,
        * Runs its main loop only the number of times given in the 'iterations' parameter, and
        * Does not sleep at the end of each iteration

    The returned MonitorHandler is augmented with 2 methods:
        * get_mock_wire_protocol() - returns the mock protocol
        * run_and_wait() - invokes run() and wait() on the MonitorHandler

    """
    def run(self):
        if len(enabled_operations) == 0 or self._name in enabled_operations:
            run.original_definition(self)

    run.original_definition = PeriodicOperation.run

    with mock_wire_protocol(DATA_FILE) as protocol:
        protocol_util = MagicMock()
        protocol_util.get_protocol = Mock(return_value=protocol)
        with patch("azurelinuxagent.ga.monitor.get_protocol_util",
                   return_value=protocol_util):
            with patch.object(PeriodicOperation,
                              "run",
                              side_effect=run,
                              autospec=True):
                with patch("azurelinuxagent.ga.monitor.MonitorHandler.stopped",
                           side_effect=[False] * iterations + [True]):
                    with patch("time.sleep"):

                        def run_and_wait():
                            monitor_handler.run()
                            monitor_handler.join()

                        monitor_handler = get_monitor_handler()
                        monitor_handler.get_mock_wire_protocol = lambda: protocol
                        monitor_handler.run_and_wait = run_and_wait
                        yield monitor_handler
Ejemplo n.º 29
0
    def test_provision_telemetry_fail(self,
                                      mock_util,
                                      distro_name,
                                      distro_version,
                                      distro_full_name, _):
        """
        Assert that the agent issues one telemetry message as part of a
        failed provisioning.

         1. Provision
        """
        ph = get_provision_handler(distro_name, distro_version,
                                   distro_full_name)
        ph.report_event = MagicMock()
        ph.reg_ssh_host_key = MagicMock(side_effect=ProvisionError(
            "--unit-test--"))

        mock_osutil = MagicMock()
        mock_osutil.decode_customdata = Mock(return_value="")

        ph.osutil = mock_osutil
        ph.protocol_util.osutil = mock_osutil
        ph.protocol_util.get_protocol = MagicMock()

        conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir)
        ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME)
        ovfenv_data = load_data("ovf-env.xml")
        fileutil.write_file(ovfenv_file, ovfenv_data)

        ph.run()
        positional_args, kw_args = ph.report_event.call_args_list[0]
        self.assertTrue(re.match(r'Provisioning failed: \[ProvisionError\] --unit-test-- \(\d+\.\d+s\)', positional_args[0]) is not None)
Ejemplo n.º 30
0
    def test_read_response_bytes(self):
        response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \
                         '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \
                         '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \
                         '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \
                         '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \
                         'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \
                         '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \
                         '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \
                         '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \
                         '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \
                         '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \
                         '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \
                         '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \
                         '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \
                         '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \
                         '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \
                         '6c:73:22:3a:20:22:22:0a:7d'.split(':')
        expected_response = '[HTTP Failed] [status: reason] {\n    "errorCode": "The blob ' \
                            'type is invalid for this operation.",\n    ' \
                            '"message": "<?xml version="1.0" ' \
                            'encoding="utf-8"?>' \
                            '<Error><Code>InvalidBlobType</Code><Message>The ' \
                            'blob type is invalid for this operation.\n' \
                            'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \
                            '\n    "details": ""\n}'

        response_string = ''.join(chr(int(b, 16)) for b in response_bytes)
        response = MagicMock()
        response.status = 'status'
        response.reason = 'reason'
        with patch.object(response, 'read') as patch_response:
            patch_response.return_value = response_string
            result = restutil.read_response_error(response)
            self.assertEqual(result, expected_response)
            try:
                raise HttpError("{0}".format(result))
            except HttpError as e:
                self.assertTrue(result in ustr(e))
Ejemplo n.º 31
0
    def test_get_protocol(self, WireProtocol, MetadataProtocol, _):
        WireProtocol.return_value = MagicMock()
        MetadataProtocol.return_value = MagicMock()

        protocol_util = get_protocol_util()
        protocol_util.get_wireserver_endpoint = Mock()
        protocol_util._detect_protocol = MagicMock()

        # Test for wire protocol
        protocol_util._save_protocol("WireProtocol")

        protocol = protocol_util.get_protocol()
        self.assertEquals(WireProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_any_call()

        # Test to ensure protocol persists
        protocol_util.get_wireserver_endpoint.reset_mock()
        protocol_util._save_protocol("MetadataProtocol")

        protocol = protocol_util.get_protocol()
        self.assertEquals(WireProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_not_called()

        # Test for metadata protocol
        protocol_util.clear_protocol()
        protocol_util._save_protocol("MetadataProtocol")

        protocol = protocol_util.get_protocol()
        self.assertEquals(MetadataProtocol.return_value, protocol)
        protocol_util.get_wireserver_endpoint.assert_not_called()

        # Test for unknown protocol
        protocol_util.clear_protocol()
        protocol_util._save_protocol("Not_a_Protocol")
        protocol_util._detect_protocol.side_effect = NotImplementedError()

        self.assertRaises(NotImplementedError, protocol_util.get_protocol)
        protocol_util.get_wireserver_endpoint.assert_not_called()
Ejemplo n.º 32
0
    def _provision_test(self,
                        distro_name,
                        distro_version,
                        distro_full_name,
                        ovf_file,
                        provisionMessage,
                        expect_success,
                        patch_write_agent_disabled,
                        patch_get_instance_id):
        """
        Assert that the agent issues two telemetry messages as part of a
        successful provisioning.

         1. Provision
         2. GuestState
        """
        ph = get_provision_handler(distro_name,
                                   distro_version,
                                   distro_full_name)
        ph.report_event = MagicMock()
        ph.reg_ssh_host_key = MagicMock(return_value='--thumprint--')

        mock_osutil = MagicMock()
        mock_osutil.decode_customdata = Mock(return_value="")

        ph.osutil = mock_osutil
        ph.protocol_util.osutil = mock_osutil
        ph.protocol_util.get_protocol = MagicMock()

        conf.get_dvd_mount_point = Mock(return_value=self.tmp_dir)
        ovfenv_file = os.path.join(self.tmp_dir, OVF_FILE_NAME)
        ovfenv_data = load_data(ovf_file)
        fileutil.write_file(ovfenv_file, ovfenv_data)

        ph.run()

        if expect_success:
            self.assertEqual(2, ph.report_event.call_count)
            positional_args, kw_args = ph.report_event.call_args_list[0]
            # [call('Provisioning succeeded (146473.68s)', duration=65, is_success=True)]
            self.assertTrue(re.match(r'Provisioning succeeded \(\d+\.\d+s\)', positional_args[0]) is not None)
            self.assertTrue(isinstance(kw_args['duration'], int))
            self.assertTrue(kw_args['is_success'])

            positional_args, kw_args = ph.report_event.call_args_list[1]
            self.assertTrue(kw_args['operation'] == 'ProvisionGuestAgent')
            self.assertTrue(kw_args['message'] == provisionMessage)
            self.assertTrue(kw_args['is_success'])

            expected_disabled = True if provisionMessage == 'false' else False
            self.assertTrue(patch_write_agent_disabled.call_count == expected_disabled)

        else:
            self.assertEqual(1, ph.report_event.call_count)
            positional_args, kw_args = ph.report_event.call_args_list[0]
            # [call(u'[ProtocolError] Failed to validate OVF: ProvisionGuestAgent not found')]
            self.assertTrue('Failed to validate OVF: ProvisionGuestAgent not found' in positional_args[0])
            self.assertFalse(kw_args['is_success'])