예제 #1
0
    def setDeviceRedirectionComponent(self, deviceRedirection: DeviceRedirectionMITM):
        if self.deviceRedirection:
            self.deviceRedirection.removeObserver(self)

        if deviceRedirection:
            deviceRedirection.addObserver(self)

        self.deviceRedirection = deviceRedirection
예제 #2
0
 def setUp(self):
     self.client = Mock()
     self.server = Mock()
     self.log = Mock()
     self.statCounter = Mock()
     self.state = Mock()
     self.state.config = MagicMock()
     self.state.config.outDir = Path("/tmp")
     self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log,
                                       self.statCounter, self.state)
예제 #3
0
    def buildDeviceChannel(self, client: MCSServerChannel, server: MCSClientChannel):
        """
        Build the MITM component for the device redirection channel.
        :param client: MCS channel for the client side
        :param server: MCS channel for the server side
        """

        clientSecurity = self.state.createSecurityLayer(ParserMode.SERVER, True)
        clientVirtualChannel = VirtualChannelLayer(activateShowProtocolFlag=False)
        clientLayer = DeviceRedirectionLayer()
        serverSecurity = self.state.createSecurityLayer(ParserMode.CLIENT, True)
        serverVirtualChannel = VirtualChannelLayer(activateShowProtocolFlag=False)
        serverLayer = DeviceRedirectionLayer()

        clientLayer.addObserver(LayerLogger(self.getClientLog(MCSChannelName.DEVICE_REDIRECTION)))
        serverLayer.addObserver(LayerLogger(self.getServerLog(MCSChannelName.DEVICE_REDIRECTION)))

        LayerChainItem.chain(client, clientSecurity, clientVirtualChannel, clientLayer)
        LayerChainItem.chain(server, serverSecurity, serverVirtualChannel, serverLayer)

        deviceRedirection = DeviceRedirectionMITM(clientLayer, serverLayer, self.getLog(MCSChannelName.DEVICE_REDIRECTION), self.config, self.statCounter, self.state)
        self.channelMITMs[client.channelID] = deviceRedirection

        if self.config.enableCrawler:
            self.crawler.setDeviceRedirectionComponent(deviceRedirection)

        if self.attacker:
            self.attacker.setDeviceRedirectionComponent(deviceRedirection)
예제 #4
0
    def setDeviceRedirectionComponent(
            self, deviceRedirection: DeviceRedirectionMITM):
        """
        Sets a reference to the class we are currently observing. Can only observe one class.
        If uninitialized, load the patterns from the pattern files.
        :param deviceRedirection: Reference to the observed class.
        """
        if self.deviceRedirection:
            self.deviceRedirection.removeObserver(self)

        if deviceRedirection:
            deviceRedirection.addObserver(self)

        self.deviceRedirection = deviceRedirection
        if not self.matchPatterns and not self.ignorePatterns:
            self.preparePatterns()
예제 #5
0
 def setUp(self):
     self.request = DeviceRedirectionMITM.ForgedDirectoryListingRequest(
         0, 0, Mock(), "directory")
예제 #6
0
 def setUp(self):
     self.request = DeviceRedirectionMITM.ForgedFileReadRequest(
         0, 0, Mock(), "file")
예제 #7
0
 def setUp(self):
     self.request = DeviceRedirectionMITM.ForgedRequest(0, 0, Mock())
예제 #8
0
class DeviceRedirectionMITMTest(unittest.TestCase):
    def setUp(self):
        self.client = Mock()
        self.server = Mock()
        self.log = Mock()
        self.statCounter = Mock()
        self.state = Mock()
        self.state.config = MagicMock()
        self.state.config.outDir = Path("/tmp")
        self.mitm = DeviceRedirectionMITM(self.client, self.server, self.log,
                                          self.statCounter, self.state)

    @patch("pyrdp.mitm.FileMapping.FileMapping.generate")
    def sendCreateResponse(self, request, response, generate):
        self.mitm.handleCreateResponse(request, response)
        return generate

    def test_stats(self):
        self.mitm.handlePDU = Mock()
        self.mitm.statCounter = StatCounter()

        self.mitm.onClientPDUReceived(Mock())
        self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION],
                         1)
        self.assertEqual(
            self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_CLIENT], 1)

        self.mitm.onServerPDUReceived(Mock())
        self.assertEqual(self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION],
                         2)
        self.assertEqual(
            self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_SERVER], 1)

        self.mitm.handleIORequest(Mock())
        self.assertEqual(
            self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IOREQUEST], 1)

        self.mitm.handleIOResponse(Mock())
        self.assertEqual(
            self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IORESPONSE], 1)

        error = MockIOError()
        self.mitm.handleIORequest(error)
        self.mitm.handleIOResponse(error)
        self.assertEqual(
            self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_IOERROR], 1)

        self.mitm.handleCloseResponse(Mock(), Mock())
        self.assertEqual(
            self.mitm.statCounter.stats[STAT.DEVICE_REDIRECTION_FILE_CLOSE], 1)

        self.mitm.sendForgedFileRead(Mock(), Mock())
        self.assertEqual(
            self.mitm.statCounter.stats[
                STAT.DEVICE_REDIRECTION_FORGED_FILE_READ], 1)

        self.mitm.sendForgedDirectoryListing(Mock(), MagicMock())
        self.assertEqual(
            self.mitm.statCounter.stats[
                STAT.DEVICE_REDIRECTION_FORGED_DIRECTORY_LISTING], 1)

    def test_ioError_showsWarning(self):
        self.log.warning = Mock()
        error = MockIOError()

        self.mitm.handleIORequest(error)
        self.mitm.handleIOResponse(error)
        self.log.warning.assert_called_once()

    def test_deviceListAnnounce_logsDevices(self):
        pdu = Mock()
        pdu.deviceList = [Mock(), Mock(), Mock()]

        self.mitm.observer = Mock()
        self.mitm.handleDeviceListAnnounceRequest(pdu)

        self.assertEqual(self.log.info.call_count, len(pdu.deviceList))
        self.assertEqual(self.mitm.observer.onDeviceAnnounce.call_count,
                         len(pdu.deviceList))

    def test_handleClientLogin_logsCredentials(self):
        creds = "PASSWORD"
        self.log.info = Mock()

        self.state.credentialsCandidate = creds
        self.state.inputBuffer = ""
        self.mitm.handleClientLogin()
        self.log.info.assert_called_once()
        self.assertTrue(creds in self.log.info.call_args[0][1].values())

        self.log.info.reset_mock()
        self.state.credentialsCandidate = ""
        self.state.inputBuffer = creds
        self.mitm.handleClientLogin()
        self.log.info.assert_called_once()
        self.assertTrue(creds in self.log.info.call_args[0][1].values())

        self.mitm.handleClientLogin = Mock()
        pdu = Mock(packetID=DeviceRedirectionPacketID.PAKID_CORE_USER_LOGGEDON)
        pdu.__class__ = DeviceRedirectionPDU

        self.mitm.handlePDU(pdu, self.client)
        self.mitm.handleClientLogin.assert_called_once()

    def test_handleIOResponse_uniqueResponse(self):
        handler = Mock()
        self.mitm.responseHandlers[1234] = handler

        pdu = Mock(deviceID=0, completionID=0, majorFunction=1234, ioStatus=0)
        self.mitm.handleIORequest(pdu)
        self.mitm.handleIOResponse(pdu)
        handler.assert_called_once()

        # Second response should not go through
        self.mitm.handleIOResponse(pdu)
        handler.assert_called_once()

    def test_handleIOResponse_matchingOnly(self):
        handler = Mock()
        self.mitm.responseHandlers[1234] = handler

        request = Mock(deviceID=0, completionID=0)
        matching_response = Mock(deviceID=0,
                                 completionID=0,
                                 majorFunction=1234,
                                 ioStatus=0)
        bad_completionID = Mock(deviceID=0,
                                completionID=1,
                                majorFunction=1234,
                                ioStatus=0)
        bad_deviceID = Mock(deviceID=1,
                            completionID=0,
                            majorFunction=1234,
                            ioStatus=0)

        self.mitm.handleIORequest(request)
        self.mitm.handleIOResponse(matching_response)
        handler.assert_called_once()

        self.mitm.handleIORequest(request)

        self.mitm.handleIOResponse(bad_completionID)
        handler.assert_called_once()
        self.log.error.assert_called_once()
        self.log.error.reset_mock()

        self.mitm.handleIOResponse(bad_deviceID)
        handler.assert_called_once()
        self.log.error.assert_called_once()
        self.log.error.reset_mock()

    def test_handlePDU_hidesForgedResponses(self):
        majorFunction = MajorFunction.IRP_MJ_CREATE
        handler = Mock()
        completionID = self.mitm.sendForgedFileRead(0, "forged")
        request = self.mitm.forgedRequests[(0, completionID)]
        request.handlers[majorFunction] = handler

        self.assertEqual(len(self.mitm.forgedRequests), 1)
        response = Mock(deviceID=0,
                        completionID=completionID,
                        majorFunction=majorFunction,
                        ioStatus=0)
        response.__class__ = DeviceIOResponsePDU
        self.mitm.handlePDU(response, self.mitm.server)
        handler.assert_called_once()
        self.mitm.server.sendPDU.assert_not_called()

    def test_handleCreateResponse_createsMapping(self):
        createRequest = Mock(
            deviceID=0,
            completionID=0,
            desiredAccess=(FileAccessMask.GENERIC_READ
                           | FileAccessMask.FILE_READ_DATA),
            createOptions=CreateOption.FILE_NON_DIRECTORY_FILE,
            path="file",
        )
        createResponse = Mock(deviceID=0, completionID=0, fileID=0)

        generate = self.sendCreateResponse(createRequest, createResponse)
        self.assertEqual(len(self.mitm.mappings), 1)
        generate.assert_called_once()

    def test_handleReadResponse_writesData(self):
        request = Mock(
            deviceID=0,
            completionID=0,
            fileID=0,
            desiredAccess=(FileAccessMask.GENERIC_READ
                           | FileAccessMask.FILE_READ_DATA),
            createOptions=CreateOption.FILE_NON_DIRECTORY_FILE,
            path="file",
        )
        response = Mock(deviceID=0,
                        completionID=0,
                        fileID=0,
                        payload="test payload")
        self.mitm.saveMapping = Mock()

        self.sendCreateResponse(request, response)
        mapping = list(self.mitm.mappings.values())[0]
        mapping.write = Mock()

        self.mitm.handleReadResponse(request, response)
        mapping.write.assert_called_once()

        # Make sure it checks the file ID
        request.fileID, response.fileID = 1, 1
        self.mitm.handleReadResponse(request, response)
        mapping.write.assert_called_once()

    def test_handleCloseResponse_finalizesMapping(self):
        request = Mock(
            deviceID=0,
            completionID=0,
            fileID=0,
            desiredAccess=(FileAccessMask.GENERIC_READ
                           | FileAccessMask.FILE_READ_DATA),
            createOptions=CreateOption.FILE_NON_DIRECTORY_FILE,
            path="file",
        )
        response = Mock(deviceID=0,
                        completionID=0,
                        fileID=0,
                        payload="test payload")
        self.mitm.saveMapping = Mock()

        self.sendCreateResponse(request, response)
        mapping = list(self.mitm.mappings.values())[0]
        mapping.finalize = Mock()

        self.mitm.handleCloseResponse(request, response)

        mapping.finalize.assert_called_once()

    def test_findNextRequestID_incrementsRequestID(self):
        baseID = self.mitm.findNextRequestID()
        self.mitm.sendForgedFileRead(0, Mock())
        self.assertEqual(self.mitm.findNextRequestID(), baseID + 1)
        self.mitm.sendForgedFileRead(1, Mock())
        self.assertEqual(self.mitm.findNextRequestID(), baseID + 2)

    def test_sendForgedFileRead_failsWhenDisabled(self):
        self.mitm.config.extractFiles = False
        self.assertFalse(self.mitm.sendForgedFileRead(1, "/test"))

    def test_sendForgedDirectoryListing_failsWhenDisabled(self):
        self.mitm.config.extractFiles = False
        self.assertFalse(self.mitm.sendForgedDirectoryListing(1, "/"))