Ejemplo n.º 1
0
    def test_ber_readLength(self):
        """
        @summary: test readLength function in ber module
        """
        s1 = type.Stream()
        s1.writeType(type.UInt8(0x1a))
        s1.pos = 0

        l1 = ber.readLength(s1)

        self.assertTrue(l1 == 0x1a, "readLength fail in small format")

        s2 = type.Stream()
        s2.writeType((type.UInt8(0x81), type.UInt8(0xab)))
        s2.pos = 0

        l2 = ber.readLength(s2)

        self.assertTrue(l2 == 0xab, "readLength fail in big format of size 1")

        s3 = type.Stream()
        s3.writeType((type.UInt8(0x82), type.UInt16Be(0xabab)))
        s3.pos = 0

        l3 = ber.readLength(s3)

        self.assertTrue(l3 == 0xabab,
                        "readLength fail in big format of size 2")
Ejemplo n.º 2
0
    def test_x224_server_recvConnectionRequest_client_accept_ssl(self):
        """
        @summary:  unit test for X224Server.recvConnectionRequest function
                    test client doesn't support TLS case
        """
        class Transport(object):
            def send(self, data):
                if not isinstance(data, x224.ServerConnectionConfirm):
                    raise X224Test.X224_FAIL()
                if data.protocolNeg.code.value != x224.NegociationType.TYPE_RDP_NEG_FAILURE or data.protocolNeg.failureCode.value != x224.NegotiationFailureCode.SSL_REQUIRED_BY_SERVER:
                    raise X224Test.X224_FAIL()

            def close(self):
                raise X224Test.X224_PASS()

        message = x224.ClientConnectionRequestPDU()
        message.protocolNeg.selectedProtocol.value = x224.Protocols.PROTOCOL_HYBRID
        s = type.Stream()
        s.writeType(message)
        s.pos = 0

        layer = x224.Server(None, "key", "cert", True)
        layer._transport = Transport()
        layer.connect()

        self.assertRaises(X224Test.X224_PASS, layer.recv, s)
Ejemplo n.º 3
0
 def send(self, data):
     s = type.Stream()
     s.writeType(data)
     s.pos = 0
     s.readType(x224.X224DataHeader())
     s.readType(type.String('test_x224_layer_send', constant=True))
     raise X224Test.X224_PASS()
    def test_tpkt_layer_recv_fastpath_ext_length(self):
        """
        @summary: test receive in fastpath case with extended length
        """
        class FastPathLayer(tpkt.IFastPathListener):
            def setFastPathSender(self, fastPathSender):
                pass

            def recvFastPath(self, secflag, fastPathS):
                fastPathS.readType(
                    type.String("test_tpkt_layer_recv_fastpath_ext_length",
                                constant=True))
                raise TPKTTest.TPKT_PASS()

        message = type.String("test_tpkt_layer_recv_fastpath_ext_length")

        s = type.Stream()
        s.writeType(
            (type.UInt8(tpkt.Action.FASTPATH_ACTION_FASTPATH),
             type.UInt16Be((type.sizeof(message) + 3) | 0x8000), message))

        layer = tpkt.TPKT(None)
        layer.initFastPath(FastPathLayer())
        layer.connect()
        self.assertRaises(TPKTTest.TPKT_PASS, layer.dataReceived, s.getvalue())
    def test_valid_client_licensing_error_message(self):
        l = lic.LicenseManager(None)
        s = type.Stream()
        s.writeType(lic.createValidClientLicensingErrorMessage())
        #reinit position
        s.pos = 0

        self.assertTrue(l.recv(s), "Manager can retrieve valid case")
 def sendFlagged(self, flag, message):
     if flag != sec.SecurityFlag.SEC_LICENSE_PKT:
         return
     s = type.Stream()
     s.writeType(message)
     s.pos = 0
     s.readType(lic.LicPacket(lic.ClientNewLicenseRequest()))
     self._state = True
Ejemplo n.º 7
0
            def send(self, data):
                s = type.Stream()
                s.writeType(data)
                s.pos = 0
                t = x224.ClientConnectionRequestPDU()
                s.readType(t)

                if t.protocolNeg.code != x224.NegociationType.TYPE_RDP_NEG_REQ:
                    raise X224Test.X224_FAIL()
Ejemplo n.º 8
0
    def test_per_readLength(self):
        """
        @summary: test readLength function in per module
        """
        s1 = type.Stream()
        s1.writeType(type.UInt8(0x1a))
        s1.pos = 0

        l1 = per.readLength(s1)

        self.assertTrue(l1 == 0x1a, "readLength fail in small format")

        s2 = type.Stream()
        s2.writeType(type.UInt16Be(0x1abc | 0x8000))
        s2.pos = 0

        l2 = per.readLength(s2)

        self.assertTrue(l2 == 0x1abc, "readLength fail in big format")
Ejemplo n.º 9
0
    def test_per_readInteger(self):
        """
        @summary: test readInteger function in per module
        """
        for t in [type.UInt8, type.UInt16Be, type.UInt32Be]:
            v = t(3)
            s = type.Stream()
            s.writeType((per.writeLength(type.sizeof(v)), v))
            s.pos = 0

            self.assertTrue(
                per.readInteger(s) == 3, "invalid readLength for type %s" % t)

        #error case
        for l in [0, 3, 5]:
            s = type.Stream()
            s.writeType(per.writeLength(l))
            s.pos = 0

            self.assertRaises(error.InvalidValue, per.readInteger, s)
Ejemplo n.º 10
0
 def test_x224_client_recvConnectionConfirm_negotiation_failure(self):
     """
     @summary: unit test for X224Client.recvConnectionConfirm and sendConnectionRequest function
                 check negotiation failure
     """
     message = x224.ServerConnectionConfirm()
     message.protocolNeg.code.value = x224.NegociationType.TYPE_RDP_NEG_FAILURE
     s = type.Stream()
     s.writeType(message)
     s.pos = 0
     layer = x224.Client(None)
     self.assertRaises(error.RDPSecurityNegoFail, layer.recvConnectionConfirm, s)
Ejemplo n.º 11
0
 def test_x224_client_recvConnectionConfirm_negotiation_bad_protocol(self):
     """
     @summary:  unit test for X224Client.recvConnectionConfirm and sendConnectionRequest function
                 Server ask another protocol than SSL or RDP
     """
     message = x224.ServerConnectionConfirm()
     message.protocolNeg.selectedProtocol.value = x224.Protocols.PROTOCOL_HYBRID
     s = type.Stream()
     s.writeType(message)
     s.pos = 0
     layer = x224.Client(None)
     self.assertRaises(error.InvalidExpectedDataException,
                       layer.recvConnectionConfirm, s)
Ejemplo n.º 12
0
 def test_x224_layer_recvData(self):
     """
     @summary: unit test for X224Layer.recvData function
     """
     class Presentation(object):
         def recv(self, data):
             data.readType(type.String('test_x224_layer_recvData', constant = True))
             raise X224Test.X224_PASS()
             
     layer = x224.X224Layer(Presentation())
     s = type.Stream()
     s.writeType((x224.X224DataHeader(), type.String('test_x224_layer_recvData')))
     #reinit position
     s.pos = 0
     
     self.assertRaises(X224Test.X224_PASS, layer.recvData, s)
Ejemplo n.º 13
0
    def test_x224_server_recvConnectionRequest_valid(self):
        """
        @summary:  unit test for X224Server.recvConnectionRequest function
        """
        global tls, connect_event
        tls = False
        connect_event = False

        class ServerTLSContext(object):
            def __init__(self, key, cert):
                pass

        x224.ServerTLSContext = ServerTLSContext

        class Transport(object):
            def __init__(self):
                class TLS(object):
                    def startTLS(self, context):
                        global tls
                        tls = True

                self.transport = TLS()

            def send(self, data):
                if not isinstance(data, x224.ServerConnectionConfirm):
                    raise X224Test.X224_FAIL()
                if data.protocolNeg.code.value != x224.NegociationType.TYPE_RDP_NEG_RSP or data.protocolNeg.selectedProtocol.value != x224.Protocols.PROTOCOL_SSL:
                    raise X224Test.X224_FAIL()

        class Presentation(object):
            def connect(self):
                global connect_event
                connect_event = True

        message = x224.ClientConnectionRequestPDU()
        message.protocolNeg.selectedProtocol.value = x224.Protocols.PROTOCOL_SSL | x224.Protocols.PROTOCOL_RDP
        s = type.Stream()
        s.writeType(message)
        s.pos = 0

        layer = x224.Server(Presentation(), "key", "cert")
        layer._transport = Transport()
        layer.connect()
        layer.recvConnectionRequest(s)

        self.assertTrue(tls, "TLS not started")
        self.assertTrue(connect_event, "connect event not forwarded")
Ejemplo n.º 14
0
    def test_x224_client_recvConnectionConfirm_ok(self):
        """
        @summary: nominal case of protocol negotiation
        """
        global tls_begin, presentation_connect
        tls_begin = False
        presentation_connect = False

        class Transport(object):
            def __init__(self):
                class TLSTransport(object):
                    def startTLS(self, context):
                        global tls_begin
                        tls_begin = True

                self.transport = TLSTransport()

        class Presentation(object):
            def connect(self):
                global presentation_connect
                presentation_connect = True

        def recvData(data):
            raise X224Test.X224_PASS()

        message = x224.ServerConnectionConfirm()
        message.protocolNeg.selectedProtocol.value = x224.Protocols.PROTOCOL_SSL

        s = type.Stream()
        s.writeType(message)
        s.pos = 0
        layer = x224.Client(Presentation())
        layer._transport = Transport()
        layer.recvData = recvData

        layer.recvConnectionConfirm(s)

        self.assertTrue(tls_begin, "TLS is not started")
        self.assertTrue(presentation_connect, "connect event is not forwarded")
        self.assertRaises(X224Test.X224_PASS, layer.recv,
                          type.String('\x01\x02'))
Ejemplo n.º 15
0
    def test_new_license(self):
        class Transport(object):
            def __init__(self):
                self._state = False

            def sendFlagged(self, flag, message):
                if flag != sec.SecurityFlag.SEC_LICENSE_PKT:
                    return
                s = type.Stream()
                s.writeType(message)
                s.pos = 0
                s.readType(lic.LicPacket(lic.ClientNewLicenseRequest()))
                self._state = True

        t = Transport()
        l = lic.LicenseManager(t)

        s = type.Stream(SERVERREQUEST.decode("base64"))

        self.assertFalse(
            l.recv(s) and t._state, "Bad message after license request")
    def test_tpkt_layer_recv(self):
        """
        @summary: test receive in classic case
        """
        class Presentation(object):
            def connect(self):
                pass

            def recv(self, data):
                data.readType(
                    type.String("test_tpkt_layer_recv", constant=True))
                raise TPKTTest.TPKT_PASS()

        message = type.String("test_tpkt_layer_recv")

        s = type.Stream()
        s.writeType(
            (type.UInt8(tpkt.Action.FASTPATH_ACTION_X224), type.UInt8(),
             type.UInt16Be(type.sizeof(message) + 4), message))

        layer = tpkt.TPKT(Presentation())
        layer.connect()
        self.assertRaises(TPKTTest.TPKT_PASS, layer.dataReceived, s.getvalue())
    def test_new_license(self):
        class Transport(object):
            def __init__(self):
                self._state = False

            def sendFlagged(self, flag, message):
                if flag != sec.SecurityFlag.SEC_LICENSE_PKT:
                    return
                s = type.Stream()
                s.writeType(message)
                s.pos = 0
                s.readType(lic.LicPacket(lic.ClientNewLicenseRequest()))
                self._state = True

            def getGCCServerSettings(self):
                class A:
                    def __init__(self):
                        self._is_readed = False

                class B:
                    def __init__(self):
                        self.serverCertificate = A()

                class C:
                    def __init__(self):
                        self.SC_SECURITY = B()

                return C()

        t = Transport()
        l = lic.LicenseManager(t)

        s = type.Stream(SERVERREQUEST.decode("base64"))

        self.assertFalse(
            l.recv(s) and t._state, "Bad message after license request")