Example #1
0
    def test_error_heartbeat_creates_no_signal(self, patch_report_heartbeat,
                                               patch_http_get, patch_add_event,
                                               *args):

        monitor_handler = get_monitor_handler()
        protocol = WireProtocol('endpoint')
        protocol.update_goal_state = MagicMock()
        with patch(
                'azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol',
                return_value=protocol):
            monitor_handler.init_protocols()
            monitor_handler.last_host_plugin_heartbeat = datetime.datetime.utcnow(
            ) - timedelta(hours=1)

            patch_http_get.side_effect = IOError('client error')
            monitor_handler.send_host_plugin_heartbeat()

            # health report should not be made
            self.assertEqual(0, patch_report_heartbeat.call_count)

            # telemetry with failure details is sent
            self.assertEqual(1, patch_add_event.call_count)
            self.assertEqual('HostPluginHeartbeat',
                             patch_add_event.call_args[1]['op'])
            self.assertTrue(
                'client error' in patch_add_event.call_args[1]['message'])

            self.assertEqual(False, patch_add_event.call_args[1]['is_success'])
            monitor_handler.stop()
Example #2
0
    def _detect_wire_protocol(self):
        endpoint = self.dhcp_handler.endpoint
        if endpoint is None:
            '''
            Check if DHCP can be used to get the wire protocol endpoint
            '''
            dhcp_available = self.osutil.is_dhcp_available()
            if dhcp_available:
                logger.info(
                    "WireServer endpoint is not found. Rerun dhcp handler")
                try:
                    self.dhcp_handler.run()
                except DhcpError as e:
                    raise ProtocolError(ustr(e))
                endpoint = self.dhcp_handler.endpoint
            else:
                logger.info("_detect_wire_protocol: DHCP not available")
                endpoint = self.get_wireserver_endpoint()

        try:
            protocol = WireProtocol(endpoint)
            protocol.detect()
            self._set_wireserver_endpoint(endpoint)
            return protocol
        except ProtocolError as e:
            logger.info("WireServer is not responding. Reset dhcp endpoint")
            self.dhcp_handler.endpoint = None
            self.dhcp_handler.skip_cache = True
            raise e
Example #3
0
    def _detect_wire_protocol(self):
        endpoint = self.dhcp_handler.endpoint
        if endpoint is None:
            '''
            Check if DHCP can be used to get the wire protocol endpoint
            '''
            (dhcp_available, conf_endpoint) =  self.osutil.is_dhcp_available()
            if dhcp_available:
                logger.info("WireServer endpoint is not found. Rerun dhcp handler")
                try:
                    self.dhcp_handler.run()
                except DhcpError as e:
                    raise ProtocolError(ustr(e))
                endpoint = self.dhcp_handler.endpoint
            else:
                logger.info("_detect_wire_protocol: DHCP not available")
                endpoint = self._get_wireserver_endpoint()
                if endpoint == None:
                    endpoint = conf_endpoint
                    logger.info("Using hardcoded WireServer endpoint {0}", endpoint)
                else:
                    logger.info("WireServer endpoint {0} read from file", endpoint)

        try:
            protocol = WireProtocol(endpoint)
            protocol.detect()
            self._set_wireserver_endpoint(endpoint)
            return protocol
        except ProtocolError as e:
            logger.info("WireServer is not responding. Reset endpoint")
            self.dhcp_handler.endpoint = None
            self.dhcp_handler.skip_cache = True
            raise e
Example #4
0
 def setUp(self):
     AgentTestCase.setUp(self)
     event.init_event_logger(os.path.join(self.tmp_dir, EVENTS_DIRECTORY))
     CGroupsTelemetry.reset()
     clear_singleton_instances(ProtocolUtil)
     protocol = WireProtocol('endpoint')
     protocol.update_goal_state = MagicMock()
     self.get_protocol = patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol', return_value=protocol)
     self.get_protocol.start()
Example #5
0
    def test_remote_access_handler_run_error(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            mock_protocol = WireProtocol("foo.bar")
            mock_protocol.get_incarnation = MagicMock(side_effect=Exception("foobar!"))

            rah = RemoteAccessHandler(mock_protocol)
            rah.run()
            print(TestRemoteAccessHandler.eventing_data)
            check_message = "foobar!"
            self.assertTrue(check_message in TestRemoteAccessHandler.eventing_data[4],
                            "expected message {0} not found in {1}"
                            .format(check_message, TestRemoteAccessHandler.eventing_data[4]))
            self.assertEqual(False, TestRemoteAccessHandler.eventing_data[2], "is_success is true")
Example #6
0
    def _detect_protocol(self):
        """
        Probe protocol endpoints in turn.
        """
        self.clear_protocol()

        for retry in range(0, MAX_RETRY):
            try:
                endpoint = self.dhcp_handler.endpoint
                if endpoint is None:
                    # pylint: disable=W0105
                    '''
                    Check if DHCP can be used to get the wire protocol endpoint
                    '''
                    # pylint: enable=W0105
                    dhcp_available = self.osutil.is_dhcp_available()
                    if dhcp_available:
                        logger.info(
                            "WireServer endpoint is not found. Rerun dhcp handler"
                        )
                        try:
                            self.dhcp_handler.run()
                        except DhcpError as e:
                            raise ProtocolError(ustr(e))
                        endpoint = self.dhcp_handler.endpoint
                    else:
                        logger.info("_detect_protocol: DHCP not available")
                        endpoint = self.get_wireserver_endpoint()

                try:
                    protocol = WireProtocol(endpoint)
                    protocol.detect()
                    self._set_wireserver_endpoint(endpoint)
                    return protocol

                except ProtocolError as e:
                    logger.info(
                        "WireServer is not responding. Reset dhcp endpoint")
                    self.dhcp_handler.endpoint = None
                    self.dhcp_handler.skip_cache = True
                    raise e

            except ProtocolError as e:
                logger.info("Protocol endpoint not found: {0}", e)

            if retry < MAX_RETRY - 1:
                logger.info("Retry detect protocol: retry={0}", retry)
                time.sleep(PROBE_INTERVAL)
        raise ProtocolNotFoundError("No protocol found.")
    def _create_mock(self, test_data, mock_http_get, MockCryptUtil):
        """Test enable/disable/uninstall of an extension"""
        handler = get_exthandlers_handler()

        #Mock protocol to return test data
        mock_http_get.side_effect = test_data.mock_http_get
        MockCryptUtil.side_effect = test_data.mock_crypt_util

        protocol = WireProtocol("foo.bar")
        protocol.detect()
        protocol.report_ext_status = MagicMock()
        protocol.report_vm_status = MagicMock()

        handler.protocol_util.get_protocol = Mock(return_value=protocol)
        return handler, protocol
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.0.0"
        ext_handler = ExtHandler(name='Microsoft.CPlat.Core.RunCommandLinux')
        ext_handler.properties = ext_handler_properties

        protocol = WireProtocol("http://Microsoft.CPlat.Core.RunCommandLinux/foo-bar")

        self.pkg = ExtHandlerPackage()
        self.pkg.uris = [ ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri(), ExtHandlerVersionUri() ]
        self.pkg.uris[0].uri = 'https://zrdfepirv2cy4prdstr00a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[1].uri = 'https://zrdfepirv2cy4prdstr01a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[2].uri = 'https://zrdfepirv2cy4prdstr02a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[3].uri = 'https://zrdfepirv2cy4prdstr03a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'
        self.pkg.uris[4].uri = 'https://zrdfepirv2cy4prdstr04a.blob.core.windows.net/f72653efd9e349ed9842c8b99e4c1712-foobar/Microsoft.CPlat.Core__RunCommandLinux__1.0.0'

        self.ext_handler_instance = ExtHandlerInstance(ext_handler=ext_handler, protocol=protocol)
        self.ext_handler_instance.pkg = self.pkg

        self.extension_dir = os.path.join(self.tmp_dir, "Microsoft.CPlat.Core.RunCommandLinux-1.0.0")
        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", return_value=self.extension_dir)
        self.mock_get_base_dir.start()

        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", return_value=self.tmp_dir)
        self.mock_get_log_dir.start()

        self.agent_dir = self.tmp_dir
        self.mock_get_lib_dir = patch("azurelinuxagent.ga.exthandlers.conf.get_lib_dir", return_value=self.agent_dir)
        self.mock_get_lib_dir.start()
    def test_remote_access_handler_run_error(self, _):
        mock_protocol = WireProtocol("foo.bar")
        mock_protocol.get_incarnation = MagicMock(
            side_effect=RemoteAccessError("foobar!"))

        rah = RemoteAccessHandler(mock_protocol)
        rah.os_util = MockOSUtil()
        rah.run()
        print(TestRemoteAccessHandler.eventing_data)
        check_message = "foobar!"
        self.assertTrue(
            check_message in TestRemoteAccessHandler.eventing_data[4],
            "expected message {0} not found in {1}".format(
                check_message, TestRemoteAccessHandler.eventing_data[4]))
        self.assertEqual(False, TestRemoteAccessHandler.eventing_data[2],
                         "is_success is true")
Example #10
0
 def _detect_wire_protocol(self):
     endpoint = self.dhcp_handler.endpoint
     if endpoint is None:
         logger.info("WireServer endpoint is not found. Rerun dhcp handler")
         try:
             self.dhcp_handler.run()
         except DhcpError as e:
             raise ProtocolError(ustr(e))
         endpoint = self.dhcp_handler.endpoint
     
     try:
         protocol = WireProtocol(endpoint)
         protocol.detect()
         self._set_wireserver_endpoint(endpoint)
         self.save_protocol("WireProtocol")
         return protocol
     except ProtocolError as e:
         logger.info("WireServer is not responding. Reset endpoint")
         self.dhcp_handler.endpoint = None
         self.dhcp_handler.skip_cache = True
         raise e
Example #11
0
    def _detect_wire_protocol(self):
        endpoint = self.dhcp_handler.endpoint
        if endpoint is None:
            logger.info("WireServer endpoint is not found. Rerun dhcp handler")
            try:
                self.dhcp_handler.run()
            except DhcpError as e:
                raise ProtocolError(ustr(e))
            endpoint = self.dhcp_handler.endpoint

        try:
            protocol = WireProtocol(endpoint)
            protocol.detect()
            self._set_wireserver_endpoint(endpoint)
            self.save_protocol("WireProtocol")
            return protocol
        except ProtocolError as e:
            logger.info("WireServer is not responding. Reset endpoint")
            self.dhcp_handler.endpoint = None
            self.dhcp_handler.skip_cache = True
            raise e
Example #12
0
    def _test_getters(self, test_data, mock_restutil, MockCryptUtil, _):
        mock_restutil.http_get.side_effect = test_data.mock_http_get
        MockCryptUtil.side_effect = test_data.mock_crypt_util

        protocol = WireProtocol("foo.bar")
        protocol.detect()
        protocol.get_vminfo()
        protocol.get_certs()
        ext_handlers, etag = protocol.get_ext_handlers()
        for ext_handler in ext_handlers.extHandlers:
            protocol.get_ext_handler_pkgs(ext_handler)

        crt1 = os.path.join(self.tmp_dir, 
                           '33B0ABCE4673538650971C10F7D7397E71561F35.crt')
        crt2 = os.path.join(self.tmp_dir, 
                            '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt')
        prv2 = os.path.join(self.tmp_dir,
                            '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv')

        self.assertTrue(os.path.isfile(crt1))
        self.assertTrue(os.path.isfile(crt2))
        self.assertTrue(os.path.isfile(prv2))
Example #13
0
    def _detect_wire_protocol(self):
        endpoint = self.dhcp_handler.endpoint
        if endpoint is None:
            '''
            Check if DHCP can be used to get the wire protocol endpoint
            '''
            (dhcp_available, conf_endpoint) = self.osutil.is_dhcp_available()
            if dhcp_available:
                logger.info(
                    "WireServer endpoint is not found. Rerun dhcp handler")
                try:
                    self.dhcp_handler.run()
                except DhcpError as e:
                    raise ProtocolError(ustr(e))
                endpoint = self.dhcp_handler.endpoint
            else:
                logger.info("_detect_wire_protocol: DHCP not available")
                endpoint = self._get_wireserver_endpoint()
                if endpoint == None:
                    endpoint = conf_endpoint
                    logger.info("Using hardcoded WireServer endpoint {0}",
                                endpoint)
                else:
                    logger.info("WireServer endpoint {0} read from file",
                                endpoint)

        try:
            protocol = WireProtocol(endpoint)
            protocol.detect()
            self._set_wireserver_endpoint(endpoint)
            self.save_protocol("WireProtocol")
            return protocol
        except ProtocolError as e:
            logger.info("WireServer is not responding. Reset endpoint")
            self.dhcp_handler.endpoint = None
            self.dhcp_handler.skip_cache = True
            raise e
    def test_it_should_report_an_error_if_the_wireserver_cannot_be_reached(
            self, patch_is_triggered, patch_add_event):
        test_message = "TEST MESSAGE"

        patch_is_triggered.return_value = True  # protocol errors are reported only after a delay; force the error to be reported now

        protocol = WireProtocol("foo.bar")
        protocol.get_ext_handlers = MagicMock(
            side_effect=ProtocolError(test_message))

        get_exthandlers_handler(protocol).run()

        self.assertEquals(patch_add_event.call_count, 2)

        _, first_call_args = patch_add_event.call_args_list[0]
        self.assertEquals(first_call_args['op'],
                          WALAEventOperation.GetArtifactExtended)
        self.assertEquals(first_call_args['is_success'], False)

        _, second_call_args = patch_add_event.call_args_list[1]
        self.assertEquals(second_call_args['op'],
                          WALAEventOperation.ExtensionProcessing)
        self.assertEquals(second_call_args['is_success'], False)
        self.assertIn(test_message, second_call_args['message'])
Example #15
0
    def get_protocol(self):
        """
        Detect protocol by endpoint.
        :returns: protocol instance
        """
        self._lock.acquire()
        try:
            if self._protocol is not None:
                return self._protocol

            # If the protocol file contains MetadataProtocol we need to fall through to
            # _detect_protocol so that we can generate the WireServer transport certificates.
            protocol_file_path = self._get_protocol_file_path()
            if os.path.isfile(protocol_file_path) and fileutil.read_file(
                    protocol_file_path) == WIRE_PROTOCOL_NAME:
                endpoint = self.get_wireserver_endpoint()
                self._protocol = WireProtocol(endpoint)

                # If metadataserver certificates are present we clean certificates
                # and remove MetadataServer firewall rule. It is possible
                # there was a previous intermediate upgrade before 2.2.48 but metadata artifacts
                # were not cleaned up (intermediate updated agent does not have cleanup
                # logic but we transitioned from Metadata to Wire protocol)
                if is_metadata_server_artifact_present():
                    cleanup_metadata_server_artifacts(self.osutil)
                return self._protocol

            logger.info("Detect protocol endpoint")

            protocol = self._detect_protocol()

            IOErrorCounter.set_protocol_endpoint(
                endpoint=protocol.get_endpoint())
            self._save_protocol(WIRE_PROTOCOL_NAME)

            self._protocol = protocol

            # Need to clean up MDS artifacts only after _detect_protocol so that we don't
            # delete MDS certificates if we can't reach WireServer and have to roll back
            # the update
            if is_metadata_server_artifact_present():
                cleanup_metadata_server_artifacts(self.osutil)

            return self._protocol
        finally:
            self._lock.release()
    def _get_protocol(self):
        """
        Get protocol instance based on previous detecting result.
        """
        protocol_file_path = self._get_protocol_file_path()
        if not os.path.isfile(protocol_file_path):
            raise ProtocolNotFoundError("No protocol found")

        protocol_name = fileutil.read_file(protocol_file_path)
        if protocol_name == prots.WireProtocol:
            endpoint = self._get_wireserver_endpoint()
            return WireProtocol(endpoint)
        elif protocol_name == prots.MetadataProtocol:
            return MetadataProtocol()
        else:
            raise ProtocolNotFoundError(("Unknown protocol: {0}"
                                         "").format(protocol_name))
Example #17
0
    def _get_protocol(self):
        """
        Get protocol instance based on previous detecting result.
        """
        protocol_file_path = os.path.join(conf.get_lib_dir(),
                                          PROTOCOL_FILE_NAME)
        if not os.path.isfile(protocol_file_path):
            raise ProtocolNotFoundError("No protocol found")

        protocol_name = fileutil.read_file(protocol_file_path)
        if protocol_name == "WireProtocol":
            endpoint = self._get_wireserver_endpoint()
            return WireProtocol(endpoint)
        elif protocol_name == "MetadataProtocol":
            return MetadataProtocol()
        else:
            raise ProtocolNotFoundError(("Unknown protocol: {0}"
                                         "").format(protocol_name))
Example #18
0
    def _create_mock(self, test_data, mock_http_get, MockCryptUtil):
        """Test enable/disable/uninstall of an extension"""
        handler = get_exthandlers_handler()

        #Mock protocol to return test data
        mock_http_get.side_effect = test_data.mock_http_get
        MockCryptUtil.side_effect = test_data.mock_crypt_util

        protocol = WireProtocol("foo.bar")
        protocol.detect()
        protocol.report_ext_status = MagicMock()
        protocol.report_vm_status = MagicMock()

        handler.protocol_util.get_protocol = Mock(return_value=protocol)
        return handler, protocol
        ) - timedelta(hours=1)
        monitor_handler.send_host_plugin_heartbeat()
        self.assertEqual(1, patch_report_heartbeat.call_count)
        self.assertEqual(1, args[5].call_count)
        self.assertEqual('HostPluginHeartbeatExtended',
                         args[5].call_args[1]['op'])
        self.assertEqual(False, args[5].call_args[1]['is_success'])
        monitor_handler.stop()


@patch('azurelinuxagent.common.event.EventLogger.add_event')
@patch("azurelinuxagent.common.utils.restutil.http_post")
@patch("azurelinuxagent.common.utils.restutil.http_get")
@patch('azurelinuxagent.common.protocol.wire.WireClient.get_goal_state')
@patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol',
       return_value=WireProtocol('endpoint'))
class TestMonitorFailure(AgentTestCase):
    @patch(
        "azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_heartbeat"
    )
    def test_error_heartbeat_creates_no_signal(self, patch_report_heartbeat,
                                               *args):
        patch_http_get = args[2]
        patch_add_event = args[4]

        monitor_handler = get_monitor_handler()
        monitor_handler.init_protocols()
        monitor_handler.last_host_plugin_heartbeat = datetime.datetime.utcnow(
        ) - timedelta(hours=1)

        patch_http_get.side_effect = IOError('client error')
Example #20
0
class TestRemoteAccessHandler(AgentTestCase): # pylint: disable=too-many-public-methods
    eventing_data = [()]

    def setUp(self):
        super(TestRemoteAccessHandler, self).setUp()
        # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not
        # reuse a previous state
        clear_singleton_instances(ProtocolUtil)
        for data in TestRemoteAccessHandler.eventing_data:
            del data

    # add_user tests
    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_add_user(self, *_):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstpassword = "******"
            tstuser = "******"
            expiration_date = datetime.utcnow() + timedelta(days=1)
            pwd = tstpassword
            rah._add_user(tstuser, pwd, expiration_date) # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))
            actual_user = users[tstuser]
            expected_expiration = (expiration_date + timedelta(days=1)).strftime("%Y-%m-%d")
            self.assertEqual(actual_user[7], expected_expiration)
            self.assertEqual(actual_user[4], "JIT_Account")

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_add_user_bad_creation_data(self, *_):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstpassword = "******"
            tstuser = ""
            expiration = datetime.utcnow() + timedelta(days=1)
            pwd = tstpassword
            error = "test exception for bad username"
            self.assertRaisesRegex(Exception, error, rah._add_user, tstuser, pwd, expiration) # pylint: disable=protected-access
            self.assertEqual(0, len(rah._os_util.get_users())) # pylint: disable=protected-access

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="")
    def test_add_user_bad_password_data(self, *_):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstpassword = ""
            tstuser = "******"
            expiration = datetime.utcnow() + timedelta(days=1)
            pwd = tstpassword
            error = "test exception for bad password"
            self.assertRaisesRegex(Exception, error, rah._add_user, tstuser, pwd, expiration) # pylint: disable=protected-access
            self.assertEqual(0, len(rah._os_util.get_users())) # pylint: disable=protected-access

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_add_user_already_existing(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstpassword = "******"
            tstuser = "******"
            expiration_date = datetime.utcnow() + timedelta(days=1)
            pwd = tstpassword
            rah._add_user(tstuser, pwd, expiration_date) # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))
            self.assertEqual(1, len(users.keys()))
            actual_user = users[tstuser]
            self.assertEqual(actual_user[7], (expiration_date + timedelta(days=1)).strftime("%Y-%m-%d"))
            # add the new duplicate user, ensure it's not created and does not overwrite the existing user.
            # this does not test the user add function as that's mocked, it tests processing skips the remaining
            # calls after the initial failure
            new_user_expiration = datetime.utcnow() + timedelta(days=5)
            self.assertRaises(Exception, rah._add_user, tstuser, pwd, new_user_expiration) # pylint: disable=protected-access
            # refresh users
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users after dup user attempted".format(tstuser))
            self.assertEqual(1, len(users.keys()))
            actual_user = users[tstuser]
            self.assertEqual(actual_user[7], (expiration_date + timedelta(days=1)).strftime("%Y-%m-%d"))

    # delete_user tests
    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_delete_user(self, *_):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstpassword = "******"
            tstuser = "******"
            expiration_date = datetime.utcnow() + timedelta(days=1)
            expected_expiration = (expiration_date + timedelta(days=1)).strftime("%Y-%m-%d") # pylint: disable=unused-variable
            pwd = tstpassword
            rah._add_user(tstuser, pwd, expiration_date) # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))
            rah._remove_user(tstuser) # pylint: disable=protected-access
            # refresh users
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertFalse(tstuser in users)

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_new_user(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_single_account.xml')
            remote_access = RemoteAccess(data_str)
            tstuser = remote_access.user_list.users[0].name
            expiration_date = datetime.utcnow() + timedelta(days=1)
            expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            remote_access.user_list.users[0].expiration = expiration
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))
            actual_user = users[tstuser]
            expected_expiration = (expiration_date + timedelta(days=1)).strftime("%Y-%m-%d")
            self.assertEqual(actual_user[7], expected_expiration)
            self.assertEqual(actual_user[4], "JIT_Account")

    def test_do_not_add_expired_user(self):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_single_account.xml')
            remote_access = RemoteAccess(data_str)
            expiration = (datetime.utcnow() - timedelta(days=2)).strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            remote_access.user_list.users[0].expiration = expiration
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertFalse("testAccount" in users)

    def test_error_add_user(self):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstuser = "******"
            expiration = datetime.utcnow() + timedelta(days=1)
            pwd = "bad password"
            error = r"\[CryptError\] Error decoding secret\nInner error: Incorrect padding"
            self.assertRaisesRegex(Exception, error, rah._add_user, tstuser, pwd, expiration) # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertEqual(0, len(users))

    def test_handle_remote_access_no_users(self):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_no_accounts.xml')
            remote_access = RemoteAccess(data_str)
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertEqual(0, len(users.keys()))

    def test_handle_remote_access_validate_jit_user_valid(self):
        rah = RemoteAccessHandler(Mock())
        comment = "JIT_Account"
        result = rah._is_jit_user(comment) # pylint: disable=protected-access
        self.assertTrue(result, "Did not identify '{0}' as a JIT_Account".format(comment))

    def test_handle_remote_access_validate_jit_user_invalid(self):
        rah = RemoteAccessHandler(Mock())
        test_users = ["John Doe", None, "", " "]
        failed_results = ""
        for user in test_users:
            if rah._is_jit_user(user): # pylint: disable=protected-access
                failed_results += "incorrectly identified '{0} as a JIT_Account'.  ".format(user)
        if len(failed_results) > 0: # pylint: disable=len-as-condition
            self.fail(failed_results)

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_two_accounts.xml')
            remote_access = RemoteAccess(data_str)
            testusers = []
            count = 0
            while count < 2:
                user = remote_access.user_list.users[count].name
                expiration_date = datetime.utcnow() + timedelta(days=count + 1)
                expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
                remote_access.user_list.users[count].expiration = expiration
                testusers.append(user)
                count += 1
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(testusers[0] in users, "{0} missing from users".format(testusers[0]))
            self.assertTrue(testusers[1] in users, "{0} missing from users".format(testusers[1]))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    # max fabric supports in the Goal State
    def test_handle_remote_access_ten_users(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertEqual(10, len(users.keys()))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_user_removed(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertEqual(10, len(users.keys()))
            del rah._remote_access.user_list.users[:] # pylint: disable=protected-access
            self.assertEqual(10, len(users.keys()))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_bad_data_and_good_data(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                if count == 2:
                    user.name = ""
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertEqual(9, len(users.keys()))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_deleted_user_readded(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_single_account.xml')
            remote_access = RemoteAccess(data_str)
            tstuser = remote_access.user_list.users[0].name
            expiration_date = datetime.utcnow() + timedelta(days=1)
            expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            remote_access.user_list.users[0].expiration = expiration
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))
            os_util = rah._os_util # pylint: disable=protected-access
            os_util.__class__ = MockOSUtil
            os_util.all_users.clear() # pylint: disable=no-member
            # refresh users
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser not in users)
            rah._handle_remote_access() # pylint: disable=protected-access
            # refresh users
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    @patch('azurelinuxagent.common.osutil.get_osutil', return_value=MockOSUtil())
    @patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol', return_value=WireProtocol("12.34.56.78"))
    @patch('azurelinuxagent.common.protocol.wire.WireProtocol.get_incarnation', return_value="1")
    @patch('azurelinuxagent.common.protocol.wire.WireClient.get_remote_access', return_value="asdf")
    def test_remote_access_handler_run_bad_data(self, _1, _2, _3, _4, _5):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            tstpassword = "******"
            tstuser = "******"
            expiration_date = datetime.utcnow() + timedelta(days=1)
            pwd = tstpassword
            rah._add_user(tstuser, pwd, expiration_date) # pylint: disable=protected-access
            users = get_user_dictionary(rah._os_util.get_users()) # pylint: disable=protected-access
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))
            rah.run()
            self.assertTrue(tstuser in users, "{0} missing from users".format(tstuser))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users_one_removed(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertEqual(10, len(users))
            # now remove the user from RemoteAccess
            deleted_user = rah._remote_access.user_list.users[3] # pylint: disable=protected-access
            del rah._remote_access.user_list.users[3] # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertTrue(deleted_user not in users, "{0} still in users".format(deleted_user))
            self.assertEqual(9, len(users))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users_null_remote_access(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertEqual(10, len(users))
            # now remove the user from RemoteAccess
            rah._remote_access = None # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertEqual(0, len(users))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users_error_with_null_remote_access(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertEqual(10, len(users))
            # now remove the user from RemoteAccess
            rah._remote_access = None # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertEqual(0, len(users))

    def test_remove_user_error(self):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            error = "test exception, bad data"
            self.assertRaisesRegex(Exception, error, rah._remove_user, "") # pylint: disable=protected-access

    def test_remove_user_not_exists(self):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            user = "******"
            error = "test exception, user does not exist to delete"
            self.assertRaisesRegex(Exception, error, rah._remove_user, user) # pylint: disable=protected-access

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret', return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_remove_and_add(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            rah = RemoteAccessHandler(Mock())
            data_str = load_data('wire/remote_access_10_accounts.xml')
            remote_access = RemoteAccess(data_str)
            count = 0
            for user in remote_access.user_list.users:
                count += 1
                user.name = "tstuser{0}".format(count)
                expiration_date = datetime.utcnow() + timedelta(days=count)
                user.expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
            rah._remote_access = remote_access # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertEqual(10, len(users))
            # now remove the user from RemoteAccess
            new_user = "******"
            deleted_user = rah._remote_access.user_list.users[3] # pylint: disable=protected-access
            rah._remote_access.user_list.users[3].name = new_user # pylint: disable=protected-access
            rah._handle_remote_access() # pylint: disable=protected-access
            users = rah._os_util.get_users() # pylint: disable=protected-access
            self.assertTrue(deleted_user not in users, "{0} still in users".format(deleted_user))
            self.assertTrue(new_user in [u[0] for u in users], "user {0} not in users".format(new_user))
            self.assertEqual(10, len(users))

    @patch('azurelinuxagent.ga.remoteaccess.add_event', side_effect=mock_add_event)
    def test_remote_access_handler_run_error(self, _):
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=MockOSUtil()):
            mock_protocol = WireProtocol("foo.bar")
            mock_protocol.get_incarnation = MagicMock(side_effect=Exception("foobar!"))

            rah = RemoteAccessHandler(mock_protocol)
            rah.run()
            print(TestRemoteAccessHandler.eventing_data)
            check_message = "foobar!"
            self.assertTrue(check_message in TestRemoteAccessHandler.eventing_data[4],
                            "expected message {0} not found in {1}"
                            .format(check_message, TestRemoteAccessHandler.eventing_data[4]))
            self.assertEqual(False, TestRemoteAccessHandler.eventing_data[2], "is_success is true")

    def test_remote_access_handler_should_retrieve_users_when_it_is_invoked_the_first_time(self):
        mock_os_util = MagicMock()
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=mock_os_util):
            with mock_wire_protocol(DATA_FILE) as mock_protocol:
                rah = RemoteAccessHandler(mock_protocol)
                rah.run()

                self.assertTrue(len(mock_os_util.get_users.call_args_list) == 1, "The first invocation of remote access should have retrieved the current users")

    def test_remote_access_handler_should_retrieve_users_when_goal_state_contains_jit_users(self):
        mock_os_util = MagicMock()
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=mock_os_util):
            with mock_wire_protocol(DATA_FILE_REMOTE_ACCESS) as mock_protocol:
                rah = RemoteAccessHandler(mock_protocol)
                rah.run()

                self.assertTrue(len(mock_os_util.get_users.call_args_list) > 0, "A goal state with jit users did not retrieve the current users")

    def test_remote_access_handler_should_not_retrieve_users_when_goal_state_does_not_contain_jit_users(self):
        mock_os_util = MagicMock()
        with patch("azurelinuxagent.ga.remoteaccess.get_osutil", return_value=mock_os_util):
            with mock_wire_protocol(DATA_FILE) as mock_protocol:
                rah = RemoteAccessHandler(mock_protocol)
                rah.run()  # this will trigger one call to retrieve the users

                mock_protocol.mock_wire_data.set_incarnation(123)  # mock a new goal state; the data file does not include any jit users
                rah.run()
                self.assertTrue(len(mock_os_util.get_users.call_args_list) == 1, "A goal state without jit users retrieved the current users")
class TestRemoteAccessHandler(AgentTestCase):
    eventing_data = [()]

    def setUp(self):
        super(TestRemoteAccessHandler, self).setUp()
        # Since ProtocolUtil is a singleton per thread, we need to clear it to ensure that the test cases do not
        # reuse a previous state
        clear_singleton_instances(ProtocolUtil)
        del info_messages[:]
        del error_messages[:]
        for data in TestRemoteAccessHandler.eventing_data:
            del data

    # add_user tests
    @patch('azurelinuxagent.common.logger.Logger.info', side_effect=log_info)
    @patch('azurelinuxagent.common.logger.Logger.error', side_effect=log_error)
    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_add_user(self, _1, _2, _3):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstpassword = "******"
        tstuser = "******"
        expiration_date = datetime.utcnow() + timedelta(days=1)
        pwd = tstpassword
        rah.add_user(tstuser, pwd, expiration_date)
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))
        actual_user = users[tstuser]
        expected_expiration = (expiration_date +
                               timedelta(days=1)).strftime("%Y-%m-%d")
        self.assertEqual(actual_user[7], expected_expiration)
        self.assertEqual(actual_user[4], "JIT_Account")
        self.assertEqual(0, len(error_messages))
        self.assertEqual(1, len(info_messages))
        self.assertEqual(
            info_messages[0],
            "User '{0}' added successfully with expiration in {1}".format(
                tstuser, expected_expiration))

    @patch('azurelinuxagent.common.logger.Logger.info', side_effect=log_info)
    @patch('azurelinuxagent.common.logger.Logger.error', side_effect=log_error)
    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_add_user_bad_creation_data(self, _1, _2, _3):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstpassword = "******"
        tstuser = ""
        expiration = datetime.utcnow() + timedelta(days=1)
        pwd = tstpassword
        error = "Error adding user {0}. test exception for bad username".format(
            tstuser)
        self.assertRaisesRegex(RemoteAccessError, error, rah.add_user, tstuser,
                               pwd, expiration)
        self.assertEqual(0, len(rah.os_util.get_users()))
        self.assertEqual(0, len(error_messages))
        self.assertEqual(0, len(info_messages))

    @patch('azurelinuxagent.common.logger.Logger.info', side_effect=log_info)
    @patch('azurelinuxagent.common.logger.Logger.error', side_effect=log_error)
    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="")
    def test_add_user_bad_password_data(self, _1, _2, _3):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstpassword = ""
        tstuser = "******"
        expiration = datetime.utcnow() + timedelta(days=1)
        pwd = tstpassword
        error = "Error adding user {0} cleanup successful\nInner error: test exception for bad password".format(
            tstuser)
        self.assertRaisesRegex(RemoteAccessError, error, rah.add_user, tstuser,
                               pwd, expiration)
        self.assertEqual(0, len(rah.os_util.get_users()))
        self.assertEqual(0, len(error_messages))
        self.assertEqual(1, len(info_messages))
        self.assertEqual("User deleted {0}".format(tstuser), info_messages[0])

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_add_user_already_existing(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstpassword = "******"
        tstuser = "******"
        expiration_date = datetime.utcnow() + timedelta(days=1)
        pwd = tstpassword
        rah.add_user(tstuser, pwd, expiration_date)
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))
        self.assertEqual(1, len(users.keys()))
        actual_user = users[tstuser]
        self.assertEqual(actual_user[7],
                         (expiration_date +
                          timedelta(days=1)).strftime("%Y-%m-%d"))
        # add the new duplicate user, ensure it's not created and does not overwrite the existing user.
        # this does not test the user add function as that's mocked, it tests processing skips the remaining
        # calls after the initial failure
        new_user_expiration = datetime.utcnow() + timedelta(days=5)
        self.assertRaises(RemoteAccessError, rah.add_user, tstuser, pwd,
                          new_user_expiration)
        # refresh users
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(
            tstuser in users,
            "{0} missing from users after dup user attempted".format(tstuser))
        self.assertEqual(1, len(users.keys()))
        actual_user = users[tstuser]
        self.assertEqual(actual_user[7],
                         (expiration_date +
                          timedelta(days=1)).strftime("%Y-%m-%d"))

    # delete_user tests
    @patch('azurelinuxagent.common.logger.Logger.info', side_effect=log_info)
    @patch('azurelinuxagent.common.logger.Logger.error', side_effect=log_error)
    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_delete_user(self, _1, _2, _3):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstpassword = "******"
        tstuser = "******"
        expiration_date = datetime.utcnow() + timedelta(days=1)
        expected_expiration = (expiration_date +
                               timedelta(days=1)).strftime("%Y-%m-%d")
        pwd = tstpassword
        rah.add_user(tstuser, pwd, expiration_date)
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))
        rah.delete_user(tstuser)
        # refresh users
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertFalse(tstuser in users)
        self.assertEqual(0, len(error_messages))
        self.assertEqual(2, len(info_messages))
        self.assertEqual(
            "User '{0}' added successfully with expiration in {1}".format(
                tstuser, expected_expiration), info_messages[0])
        self.assertEqual("User deleted {0}".format(tstuser), info_messages[1])

    def test_handle_failed_create_with_bad_data(self):
        mock_os_util = MockOSUtil()
        testusr = "******"
        mock_os_util.all_users[testusr] = (testusr, None, None, None, None,
                                           None, None, None)
        rah = RemoteAccessHandler(Mock())
        rah.os_util = mock_os_util
        self.assertRaises(RemoteAccessError, rah.handle_failed_create, "")
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(1, len(users.keys()))
        self.assertTrue(testusr in users,
                        "Expected user {0} missing".format(testusr))

    @patch('azurelinuxagent.common.logger.Logger.info', side_effect=log_info)
    @patch('azurelinuxagent.common.logger.Logger.error', side_effect=log_error)
    def test_delete_user_does_not_exist(self, _1, _2):
        mock_os_util = MockOSUtil()
        testusr = "******"
        mock_os_util.all_users[testusr] = (testusr, None, None, None, None,
                                           None, None, None)
        rah = RemoteAccessHandler(Mock())
        rah.os_util = mock_os_util
        testuser = "******"
        error = "Failed to clean up after account creation for {0}.\n" \
                "Inner error: test exception, user does not exist to delete".format(testuser)
        self.assertRaisesRegex(RemoteAccessError, error,
                               rah.handle_failed_create, testuser)
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(1, len(users.keys()))
        self.assertTrue(testusr in users,
                        "Expected user {0} missing".format(testusr))
        self.assertEqual(0, len(error_messages))
        self.assertEqual(0, len(info_messages))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_new_user(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_single_account.xml')
        remote_access = RemoteAccess(data_str)
        tstuser = remote_access.user_list.users[0].name
        expiration_date = datetime.utcnow() + timedelta(days=1)
        expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
        remote_access.user_list.users[0].expiration = expiration
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))
        actual_user = users[tstuser]
        expected_expiration = (expiration_date +
                               timedelta(days=1)).strftime("%Y-%m-%d")
        self.assertEqual(actual_user[7], expected_expiration)
        self.assertEqual(actual_user[4], "JIT_Account")

    def test_do_not_add_expired_user(self):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_single_account.xml')
        remote_access = RemoteAccess(data_str)
        expiration = (datetime.utcnow() - timedelta(days=2)
                      ).strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
        remote_access.user_list.users[0].expiration = expiration
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertFalse("testAccount" in users)

    @patch('azurelinuxagent.common.logger.Logger.info', side_effect=log_info)
    @patch('azurelinuxagent.common.logger.Logger.error', side_effect=log_error)
    def test_error_add_user(self, _1, _2):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstuser = "******"
        expiration = datetime.utcnow() + timedelta(days=1)
        pwd = "bad password"
        error = "Error adding user foobar cleanup successful\n" \
                "Inner error: \[CryptError\] Error decoding secret\n" \
                "Inner error: Incorrect padding".format(tstuser)
        self.assertRaisesRegex(RemoteAccessError, error, rah.add_user, tstuser,
                               pwd, expiration)
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(0, len(users))
        self.assertEqual(0, len(error_messages))
        self.assertEqual(1, len(info_messages))
        self.assertEqual("User deleted {0}".format(tstuser), info_messages[0])

    def test_handle_remote_access_no_users(self):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_no_accounts.xml')
        remote_access = RemoteAccess(data_str)
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(0, len(users.keys()))

    def test_handle_remote_access_validate_jit_user_valid(self):
        rah = RemoteAccessHandler(Mock())
        comment = "JIT_Account"
        result = rah.validate_jit_user(comment)
        self.assertTrue(
            result, "Did not identify '{0}' as a JIT_Account".format(comment))

    def test_handle_remote_access_validate_jit_user_invalid(self):
        rah = RemoteAccessHandler(Mock())
        test_users = ["John Doe", None, "", " "]
        failed_results = ""
        for user in test_users:
            if rah.validate_jit_user(user):
                failed_results += "incorrectly identified '{0} as a JIT_Account'.  ".format(
                    user)
        if len(failed_results) > 0:
            self.fail(failed_results)

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_two_accounts.xml')
        remote_access = RemoteAccess(data_str)
        testusers = []
        count = 0
        while count < 2:
            user = remote_access.user_list.users[count].name
            expiration_date = datetime.utcnow() + timedelta(days=count + 1)
            expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
            remote_access.user_list.users[count].expiration = expiration
            testusers.append(user)
            count += 1
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(testusers[0] in users,
                        "{0} missing from users".format(testusers[0]))
        self.assertTrue(testusers[1] in users,
                        "{0} missing from users".format(testusers[1]))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    # max fabric supports in the Goal State
    def test_handle_remote_access_ten_users(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(10, len(users.keys()))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_user_removed(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(10, len(users.keys()))
        del rah.remote_access.user_list.users[:]
        self.assertEqual(10, len(users.keys()))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_bad_data_and_good_data(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            if count == 2:
                user.name = ""
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertEqual(9, len(users.keys()))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_deleted_user_readded(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_single_account.xml')
        remote_access = RemoteAccess(data_str)
        tstuser = remote_access.user_list.users[0].name
        expiration_date = datetime.utcnow() + timedelta(days=1)
        expiration = expiration_date.strftime("%a, %d %b %Y %H:%M:%S ") + "UTC"
        remote_access.user_list.users[0].expiration = expiration
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))
        os_util = rah.os_util
        os_util.__class__ = MockOSUtil
        os_util.all_users.clear()
        # refresh users
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser not in users)
        rah.handle_remote_access()
        # refresh users
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    @patch('azurelinuxagent.common.osutil.get_osutil',
           return_value=MockOSUtil())
    @patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol',
           return_value=WireProtocol("12.34.56.78"))
    @patch('azurelinuxagent.common.protocol.wire.WireProtocol.get_incarnation',
           return_value="1")
    @patch('azurelinuxagent.common.protocol.wire.WireClient.get_remote_access',
           return_value="asdf")
    def test_remote_access_handler_run_bad_data(self, _1, _2, _3, _4, _5):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        tstpassword = "******"
        tstuser = "******"
        expiration_date = datetime.utcnow() + timedelta(days=1)
        pwd = tstpassword
        rah.add_user(tstuser, pwd, expiration_date)
        users = get_user_dictionary(rah.os_util.get_users())
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))
        rah.run()
        self.assertTrue(tstuser in users,
                        "{0} missing from users".format(tstuser))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users_one_removed(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertEqual(10, len(users))
        # now remove the user from RemoteAccess
        deleted_user = rah.remote_access.user_list.users[3]
        del rah.remote_access.user_list.users[3]
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertTrue(deleted_user not in users,
                        "{0} still in users".format(deleted_user))
        self.assertEqual(9, len(users))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users_null_remote_access(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertEqual(10, len(users))
        # now remove the user from RemoteAccess
        rah.remote_access = None
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertEqual(0, len(users))

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_multiple_users_error_with_null_remote_access(
            self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertEqual(10, len(users))
        # now remove the user from RemoteAccess
        rah.remote_access = None
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertEqual(0, len(users))

    def test_remove_user_error(self):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        error = "Failed to delete user {0}\nInner error: test exception, bad data".format(
            "")
        self.assertRaisesRegex(RemoteAccessError, error, rah.remove_user, "")

    def test_remove_user_not_exists(self):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        user = "******"
        error = "Failed to delete user {0}\n" \
                "Inner error: test exception, user does not exist to delete".format(user)
        self.assertRaisesRegex(RemoteAccessError, error, rah.remove_user, user)

    @patch('azurelinuxagent.common.utils.cryptutil.CryptUtil.decrypt_secret',
           return_value="]aPPEv}uNg1FPnl?")
    def test_handle_remote_access_remove_and_add(self, _):
        rah = RemoteAccessHandler(Mock())
        rah.os_util = MockOSUtil()
        data_str = load_data('wire/remote_access_10_accounts.xml')
        remote_access = RemoteAccess(data_str)
        count = 0
        for user in remote_access.user_list.users:
            count += 1
            user.name = "tstuser{0}".format(count)
            expiration_date = datetime.utcnow() + timedelta(days=count)
            user.expiration = expiration_date.strftime(
                "%a, %d %b %Y %H:%M:%S ") + "UTC"
        rah.remote_access = remote_access
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertEqual(10, len(users))
        # now remove the user from RemoteAccess
        new_user = "******"
        deleted_user = rah.remote_access.user_list.users[3]
        rah.remote_access.user_list.users[3].name = new_user
        rah.handle_remote_access()
        users = rah.os_util.get_users()
        self.assertTrue(deleted_user not in users,
                        "{0} still in users".format(deleted_user))
        self.assertTrue(new_user in [u[0] for u in users],
                        "user {0} not in users".format(new_user))
        self.assertEqual(10, len(users))

    @patch('azurelinuxagent.ga.remoteaccess.add_event',
           side_effect=mock_add_event)
    def test_remote_access_handler_run_error(self, _):
        mock_protocol = WireProtocol("foo.bar")
        mock_protocol.get_incarnation = MagicMock(
            side_effect=RemoteAccessError("foobar!"))

        rah = RemoteAccessHandler(mock_protocol)
        rah.os_util = MockOSUtil()
        rah.run()
        print(TestRemoteAccessHandler.eventing_data)
        check_message = "foobar!"
        self.assertTrue(
            check_message in TestRemoteAccessHandler.eventing_data[4],
            "expected message {0} not found in {1}".format(
                check_message, TestRemoteAccessHandler.eventing_data[4]))
        self.assertEqual(False, TestRemoteAccessHandler.eventing_data[2],
                         "is_success is true")
Example #22
0
    def setUp(self):
        AgentTestCase.setUp(self)

        ext_handler_properties = ExtHandlerProperties()
        ext_handler_properties.version = "1.2.3"
        self.ext_handler = ExtHandler(name='foo')
        self.ext_handler.properties = ext_handler_properties
        self.ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol("1.2.3.4"))

        self.mock_get_base_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_base_dir", lambda *_: self.tmp_dir)
        self.mock_get_base_dir.start()

        self.log_dir = os.path.join(self.tmp_dir, "log")
        self.mock_get_log_dir = patch("azurelinuxagent.ga.exthandlers.ExtHandlerInstance.get_log_dir", lambda *_: self.log_dir)
        self.mock_get_log_dir.start()

        self.mock_sleep = patch("time.sleep", lambda *_: mock_sleep(0.01))
        self.mock_sleep.start()
Example #23
0
    def test_it_should_contain_all_helper_environment_variables(self):

        wire_ip = str(uuid.uuid4())
        ext_handler_instance = ExtHandlerInstance(ext_handler=self.ext_handler, protocol=WireProtocol(wire_ip))

        helper_env_vars = {ExtCommandEnvVariable.ExtensionSeqNumber: _DEFAULT_SEQ_NO,
                           ExtCommandEnvVariable.ExtensionPath: self.tmp_dir,
                           ExtCommandEnvVariable.ExtensionVersion: ext_handler_instance.ext_handler.properties.version,
                           ExtCommandEnvVariable.WireProtocolAddress: wire_ip}

        command = """
            printenv | grep -E '(%s)'
        """ % '|'.join(helper_env_vars.keys())

        test_file = 'printHelperEnvironments.sh'
        self.create_script(os.path.join(self.ext_handler_instance.get_base_dir(), test_file), command)

        with patch("subprocess.Popen", wraps=subprocess.Popen) as patch_popen:
            # Returning empty list for get_agent_supported_features_list_for_extensions as we have a separate test for it
            with patch("azurelinuxagent.ga.exthandlers.get_agent_supported_features_list_for_extensions",
                       return_value={}):
                output = ext_handler_instance.launch_command(test_file)

            args, kwagrs = patch_popen.call_args  # pylint: disable=unused-variable
            without_os_env = dict((k, v) for (k, v) in kwagrs['env'].items() if k not in os.environ)

            # This check will fail if any helper environment variables are added/removed later on
            self.assertEqual(helper_env_vars, without_os_env)

            # This check is checking if the expected values are set for the extension commands
            for helper_var in helper_env_vars:
                self.assertIn("%s=%s" % (helper_var, helper_env_vars[helper_var]), output)
Example #24
0
def mock_wire_protocol(mock_wire_data_file, http_get_handler=None, http_post_handler=None, http_put_handler=None, fail_on_unknown_request=True):
    """
    Creates a WireProtocol object that handles requests to the WireServer and the Host GA Plugin (i.e requests on the WireServer endpoint), plus
    some requests to storage (requests on the fake server 'mock-goal-state').

    The data returned by those requests is read from the files specified by 'mock_wire_data_file' (which must follow the structure of the data
    files defined in tests/protocol/mockwiredata.py).

    The caller can also provide handler functions for specific HTTP methods using the http_*_handler arguments. The return value of the handler
    function is interpreted similarly to the "return_value" argument of patch(): if it is an exception the exception is raised or, if it is
    any object other than None, the value is returned by the mock. If the handler function returns None the call is handled using the mock
    wireserver data or passed to the original to restutil.http_request.

    The returned protocol object maintains a list of "tracked" urls. When a handler function returns a value than is not None the url for the
    request is automatically added to the tracked list. The handler function can add other items to this list using the track_url() method on
    the mock.

    The return value of this function is an instance of WireProtocol augmented with these properties/methods:

        * mock_wire_data - the WireProtocolData constructed from the mock_wire_data_file parameter.
        * start() - starts the patchers for http_request and CryptUtil
        * stop() - stops the patchers
        * track_url(url) - adds the given item to the list of tracked urls.
        * get_tracked_urls() - returns the list of tracked urls.

    NOTE: This function patches common.utils.restutil.http_request and common.protocol.wire.CryptUtil; you need to be aware of this if your
          tests patch those methods or others in the call stack (e.g. restutil.get, resutil._http_request, etc)

    """
    tracked_urls = []

    # use a helper function to keep the HTTP handlers (they need to be modified by set_http_handlers() and
    # Python 2.* does not support nonlocal declarations)
    def http_handlers(get, post, put):
        http_handlers.get = get
        http_handlers.post = post
        http_handlers.put = put
        del tracked_urls[:]
    http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler)

    #
    # function used to patch restutil.http_request
    #
    original_http_request = restutil.http_request

    def http_request(method, url, data, **kwargs): # pylint: disable=too-many-branches
        # if there is a handler for the request, use it
        handler = None
        if method == 'GET':
            handler = http_handlers.get
        elif method == 'POST':
            handler = http_handlers.post
        elif method == 'PUT':
            handler = http_handlers.put

        if handler is not None:
            if method == 'GET':
                return_value = handler(url, **kwargs)
            else:
                return_value = handler(url, data, **kwargs)
            if return_value is not None:
                tracked_urls.append(url)
                if isinstance(return_value, Exception):
                    raise return_value
                return return_value

        # if the request was not handled try to use the mock wireserver data
        if _USE_MOCK_WIRE_DATA_RE.match(url) is not None:
            if method == 'GET':
                return protocol.mock_wire_data.mock_http_get(url, **kwargs)
            if method == 'POST':
                return protocol.mock_wire_data.mock_http_post(url, data, **kwargs)
            if method == 'PUT':
                return protocol.mock_wire_data.mock_http_put(url, data, **kwargs)

        # the request was not handled; fail or call the original resutil.http_request
        if fail_on_unknown_request:
            raise ValueError('Unknown HTTP request: {0} [{1}]'.format(url, method))
        return original_http_request(method, url, data, **kwargs)

    #
    # functions to start/stop the mocks
    #
    def start():
        patched = patch("azurelinuxagent.common.utils.restutil.http_request", side_effect=http_request)
        patched.start()
        start.http_request_patch = patched

        patched = patch("azurelinuxagent.common.protocol.wire.CryptUtil", side_effect=protocol.mock_wire_data.mock_crypt_util)
        patched.start()
        start.crypt_util_patch = patched
    start.http_request_patch = None
    start.crypt_util_patch = None

    def stop():
        if start.crypt_util_patch is not None:
            start.crypt_util_patch.stop()
        if start.http_request_patch is not None:
            start.http_request_patch.stop()

    #
    # create the protocol object
    #
    protocol = WireProtocol(restutil.KNOWN_WIRESERVER_IP)
    protocol.mock_wire_data = mockwiredata.WireProtocolData(mock_wire_data_file)
    protocol.start = start
    protocol.stop = stop
    protocol.track_url = lambda url: tracked_urls.append(url) # pylint: disable=unnecessary-lambda
    protocol.get_tracked_urls = lambda: tracked_urls
    protocol.set_http_handlers = lambda http_get_handler=None, http_post_handler=None, http_put_handler=None:\
        http_handlers(get=http_get_handler, post=http_post_handler, put=http_put_handler)

    # go do it
    try:
        protocol.start()
        protocol.detect()
        yield protocol
    finally:
        protocol.stop()
Example #25
0
        self.assertNotIn("PerfMetrics", message_json)

        message_json = generate_extension_metrics_telemetry_dictionary(schema_version=2.0,
                                                                       performance_metrics=None)
        self.assertEqual(message_json, None)

        message_json = generate_extension_metrics_telemetry_dictionary(schema_version="z",
                                                                       performance_metrics=None)
        self.assertEqual(message_json, None)


@patch('azurelinuxagent.common.event.EventLogger.add_event')
@patch("azurelinuxagent.common.utils.restutil.http_post")
@patch("azurelinuxagent.common.utils.restutil.http_get")
@patch('azurelinuxagent.common.protocol.wire.WireClient.get_goal_state')
@patch('azurelinuxagent.common.protocol.util.ProtocolUtil.get_protocol', return_value=WireProtocol('endpoint'))
class TestMonitorFailure(AgentTestCase):

    @patch("azurelinuxagent.common.protocol.healthservice.HealthService.report_host_plugin_heartbeat")
    def test_error_heartbeat_creates_no_signal(self, patch_report_heartbeat, *args):
        patch_http_get = args[2]
        patch_add_event = args[4]

        monitor_handler = get_monitor_handler()
        monitor_handler.init_protocols()
        monitor_handler.last_host_plugin_heartbeat = datetime.datetime.utcnow() - timedelta(hours=1)

        patch_http_get.side_effect = IOError('client error')
        monitor_handler.send_host_plugin_heartbeat()

        # health report should not be made
Example #26
0
    def _test_getters(self, test_data, mock_restutil, MockCryptUtil, _):
        mock_restutil.http_get.side_effect = test_data.mock_http_get
        MockCryptUtil.side_effect = test_data.mock_crypt_util

        protocol = WireProtocol("foo.bar")
        protocol.detect()
        protocol.get_vminfo()
        protocol.get_certs()
        ext_handlers, etag = protocol.get_ext_handlers()
        for ext_handler in ext_handlers.extHandlers:
            protocol.get_ext_handler_pkgs(ext_handler)

        crt1 = os.path.join(self.tmp_dir,
                            '33B0ABCE4673538650971C10F7D7397E71561F35.crt')
        crt2 = os.path.join(self.tmp_dir,
                            '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt')
        prv2 = os.path.join(self.tmp_dir,
                            '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv')

        self.assertTrue(os.path.isfile(crt1))
        self.assertTrue(os.path.isfile(crt2))
        self.assertTrue(os.path.isfile(prv2))