Ejemplo n.º 1
0
 def testUpdateSaAddMark(self):
     """Test that when an SA has no mark, it can be updated to add a mark."""
     for version in [4, 6]:
         spi = 0xABCD
         # Test that an SA created with ALLOCSPI can be updated with the mark.
         new_sa = self.xfrm.AllocSpi(net_test.GetWildcardAddress(version),
                                     IPPROTO_ESP, spi, spi)
         mark = xfrm.ExactMatchMark(0xf00d)
         self.xfrm.AddSaInfo(net_test.GetWildcardAddress(version),
                             net_test.GetWildcardAddress(version),
                             spi,
                             xfrm.XFRM_MODE_TUNNEL,
                             0,
                             xfrm_base._ALGO_CBC_AES_256,
                             xfrm_base._ALGO_HMAC_SHA1,
                             None,
                             None,
                             mark,
                             0,
                             is_update=True)
         dump = self.xfrm.DumpSaInfo()
         self.assertEquals(1, len(dump))  # check that update updated
         sainfo, attributes = dump[0]
         self.assertEquals(mark, attributes["XFRMA_MARK"])
         self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version), spi,
                                IPPROTO_ESP, mark)
Ejemplo n.º 2
0
def _CreateReceiveSock(version, port=0):
    # Create a socket to receive packets.
    read_sock = socket(net_test.GetAddressFamily(version), SOCK_DGRAM, 0)
    read_sock.bind((net_test.GetWildcardAddress(version), port))
    # The second parameter of the tuple is the port number regardless of AF.
    local_port = read_sock.getsockname()[1]
    # Guard against the eventuality of the receive failing.
    net_test.SetNonBlocking(read_sock.fileno())

    return read_sock, local_port
Ejemplo n.º 3
0
    def ParamTestSocketPolicySimple(self, params):
        """Test two-way traffic using transport mode and socket policies."""
        def AssertEncrypted(packet):
            # This gives a free pass to ICMP and ICMPv6 packets, which show up
            # nondeterministically in tests.
            self.assertEquals(None, packet.getlayer(scapy.UDP),
                              "UDP packet sent in the clear")
            self.assertEquals(None, packet.getlayer(scapy.TCP),
                              "TCP packet sent in the clear")

        # We create a pair of sockets, "left" and "right", that will talk to each
        # other using transport mode ESP. Because of TapTwister, both sockets
        # perceive each other as owning "remote_addr".
        netid = self.RandomNetid()
        family = net_test.GetAddressFamily(params["version"])
        local_addr = self.MyAddress(params["version"], netid)
        remote_addr = self.GetRemoteSocketAddress(params["version"])
        crypt_left = (xfrm.XfrmAlgo(
            (params["crypt"].name, params["crypt"].key_len)),
                      os.urandom(params["crypt"].key_len /
                                 8)) if params["crypt"] else None
        crypt_right = (xfrm.XfrmAlgo(
            (params["crypt"].name, params["crypt"].key_len)),
                       os.urandom(params["crypt"].key_len /
                                  8)) if params["crypt"] else None
        auth_left = (xfrm.XfrmAlgoAuth(
            (params["auth"].name,
             params["auth"].key_len, params["auth"].trunc_len)),
                     os.urandom(params["auth"].key_len /
                                8)) if params["auth"] else None
        auth_right = (xfrm.XfrmAlgoAuth(
            (params["auth"].name,
             params["auth"].key_len, params["auth"].trunc_len)),
                      os.urandom(params["auth"].key_len /
                                 8)) if params["auth"] else None
        aead_left = (xfrm.XfrmAlgoAead(
            (params["aead"].name,
             params["aead"].key_len, params["aead"].icv_len)),
                     os.urandom(params["aead"].key_len /
                                8)) if params["aead"] else None
        aead_right = (xfrm.XfrmAlgoAead(
            (params["aead"].name,
             params["aead"].key_len, params["aead"].icv_len)),
                      os.urandom(params["aead"].key_len /
                                 8)) if params["aead"] else None
        spi_left = 0xbeefface
        spi_right = 0xcafed00d
        req_ids = [100, 200, 300, 400]  # Used to match templates and SAs.

        # Left outbound SA
        self.xfrm.AddSaInfo(src=local_addr,
                            dst=remote_addr,
                            spi=spi_right,
                            mode=xfrm.XFRM_MODE_TRANSPORT,
                            reqid=req_ids[0],
                            encryption=crypt_right,
                            auth_trunc=auth_right,
                            aead=aead_right,
                            encap=None,
                            mark=None,
                            output_mark=None)
        # Right inbound SA
        self.xfrm.AddSaInfo(src=remote_addr,
                            dst=local_addr,
                            spi=spi_right,
                            mode=xfrm.XFRM_MODE_TRANSPORT,
                            reqid=req_ids[1],
                            encryption=crypt_right,
                            auth_trunc=auth_right,
                            aead=aead_right,
                            encap=None,
                            mark=None,
                            output_mark=None)
        # Right outbound SA
        self.xfrm.AddSaInfo(src=local_addr,
                            dst=remote_addr,
                            spi=spi_left,
                            mode=xfrm.XFRM_MODE_TRANSPORT,
                            reqid=req_ids[2],
                            encryption=crypt_left,
                            auth_trunc=auth_left,
                            aead=aead_left,
                            encap=None,
                            mark=None,
                            output_mark=None)
        # Left inbound SA
        self.xfrm.AddSaInfo(src=remote_addr,
                            dst=local_addr,
                            spi=spi_left,
                            mode=xfrm.XFRM_MODE_TRANSPORT,
                            reqid=req_ids[3],
                            encryption=crypt_left,
                            auth_trunc=auth_left,
                            aead=aead_left,
                            encap=None,
                            mark=None,
                            output_mark=None)

        # Make two sockets.
        sock_left = socket(family, params["proto"], 0)
        sock_left.settimeout(2.0)
        sock_left.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
        self.SelectInterface(sock_left, netid, "mark")
        sock_right = socket(family, params["proto"], 0)
        sock_right.settimeout(2.0)
        sock_right.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
        self.SelectInterface(sock_right, netid, "mark")

        # For UDP, set SO_LINGER to 0, to prevent TCP sockets from hanging around
        # in a TIME_WAIT state.
        if params["proto"] == SOCK_STREAM:
            net_test.DisableFinWait(sock_left)
            net_test.DisableFinWait(sock_right)

        # Apply the left outbound socket policy.
        xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_OUT,
                                    spi_right, req_ids[0], None)
        # Apply right inbound socket policy.
        xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_IN,
                                    spi_right, req_ids[1], None)
        # Apply right outbound socket policy.
        xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_OUT,
                                    spi_left, req_ids[2], None)
        # Apply left inbound socket policy.
        xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_IN,
                                    spi_left, req_ids[3], None)

        server_ready = threading.Event()
        server_error = None  # Save exceptions thrown by the server.

        def TcpServer(sock, client_port):
            try:
                sock.listen(1)
                server_ready.set()
                accepted, peer = sock.accept()
                self.assertEquals(remote_addr, peer[0])
                self.assertEquals(client_port, peer[1])
                data = accepted.recv(2048)
                self.assertEquals("hello request", data)
                accepted.send("hello response")
            except Exception as e:
                server_error = e
            finally:
                sock.close()

        def UdpServer(sock, client_port):
            try:
                server_ready.set()
                data, peer = sock.recvfrom(2048)
                self.assertEquals(remote_addr, peer[0])
                self.assertEquals(client_port, peer[1])
                self.assertEquals("hello request", data)
                sock.sendto("hello response", peer)
            except Exception as e:
                server_error = e
            finally:
                sock.close()

        # Server and client need to know each other's port numbers in advance.
        wildcard_addr = net_test.GetWildcardAddress(params["version"])
        sock_left.bind((wildcard_addr, 0))
        sock_right.bind((wildcard_addr, 0))
        left_port = sock_left.getsockname()[1]
        right_port = sock_right.getsockname()[1]

        # Start the appropriate server type on sock_right.
        target = TcpServer if params["proto"] == SOCK_STREAM else UdpServer
        server = threading.Thread(target=target,
                                  args=(sock_right, left_port),
                                  name="SocketServer")
        server.start()
        # Wait for server to be ready before attempting to connect. TCP retries
        # hide this problem, but UDP will fail outright if the server socket has
        # not bound when we send.
        self.assertTrue(server_ready.wait(2.0),
                        "Timed out waiting for server thread")

        with TapTwister(fd=self.tuns[netid].fileno(),
                        validator=AssertEncrypted):
            sock_left.connect((remote_addr, right_port))
            sock_left.send("hello request")
            data = sock_left.recv(2048)
            self.assertEquals("hello response", data)
            sock_left.close()
            server.join()
        if server_error:
            raise server_error
Ejemplo n.º 4
0
    def _TestSocketPolicy(self, version):
        # Open a UDP socket and connect it.
        family = net_test.GetAddressFamily(version)
        s = socket(family, SOCK_DGRAM, 0)
        netid = self.RandomNetid()
        self.SelectInterface(s, netid, "mark")

        remotesockaddr = self.GetRemoteSocketAddress(version)
        s.connect((remotesockaddr, 53))
        saddr, sport = s.getsockname()[:2]
        daddr, dport = s.getpeername()[:2]
        if version == 5:
            saddr = saddr.replace("::ffff:", "")
            daddr = daddr.replace("::ffff:", "")

        reqid = 0

        desc, pkt = packets.UDP(version, saddr, daddr, sport=sport)
        s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
        self.ExpectPacketOn(netid, "Send after socket, expected %s" % desc,
                            pkt)

        # Using IPv4 XFRM on a dual-stack socket requires setting an AF_INET policy
        # that's written in terms of IPv4 addresses.
        xfrm_version = 4 if version == 5 else version
        xfrm_family = net_test.GetAddressFamily(xfrm_version)
        xfrm_base.ApplySocketPolicy(s, xfrm_family, xfrm.XFRM_POLICY_OUT,
                                    TEST_SPI, reqid, None)

        # Because the policy has level set to "require" (the default), attempting
        # to send a packet results in an error, because there is no SA that
        # matches the socket policy we set.
        self.assertRaisesErrno(EAGAIN, s.sendto, net_test.UDP_PAYLOAD,
                               (remotesockaddr, 53))

        # Adding a matching SA causes the packet to go out encrypted. The SA's
        # SPI must match the one in our template, and the destination address must
        # match the packet's destination address (in tunnel mode, it has to match
        # the tunnel destination).
        self.CreateNewSa(net_test.GetWildcardAddress(xfrm_version),
                         self.GetRemoteAddress(xfrm_version), TEST_SPI, reqid,
                         None)
        s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
        expected_length = xfrm_base.GetEspPacketLength(
            xfrm.XFRM_MODE_TRANSPORT, version, False, net_test.UDP_PAYLOAD,
            xfrm_base._ALGO_HMAC_SHA1, xfrm_base._ALGO_CBC_AES_256)
        self._ExpectEspPacketOn(netid, TEST_SPI, 1, expected_length, None,
                                None)

        # Sending to another destination doesn't work: again, no matching SA.
        remoteaddr2 = self.GetOtherRemoteSocketAddress(version)
        self.assertRaisesErrno(EAGAIN, s.sendto, net_test.UDP_PAYLOAD,
                               (remoteaddr2, 53))

        # Sending on another socket without the policy applied results in an
        # unencrypted packet going out.
        s2 = socket(family, SOCK_DGRAM, 0)
        self.SelectInterface(s2, netid, "mark")
        s2.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
        pkts = self.ReadAllPacketsOn(netid)
        self.assertEquals(1, len(pkts))
        packet = pkts[0]

        protocol = packet.nh if version == 6 else packet.proto
        self.assertEquals(IPPROTO_UDP, protocol)

        # Deleting the SA causes the first socket to return errors again.
        self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI,
                               IPPROTO_ESP)
        self.assertRaisesErrno(EAGAIN, s.sendto, net_test.UDP_PAYLOAD,
                               (remotesockaddr, 53))

        # Clear the socket policy and expect a cleartext packet.
        xfrm_base.SetPolicySockopt(s, family, None)
        s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
        self.ExpectPacketOn(netid, "Send after clear, expected %s" % desc, pkt)

        # Clearing the policy twice is safe.
        xfrm_base.SetPolicySockopt(s, family, None)
        s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
        self.ExpectPacketOn(netid, "Send after clear 2, expected %s" % desc,
                            pkt)

        # Clearing if a policy was never set is safe.
        s = socket(AF_INET6, SOCK_DGRAM, 0)
        xfrm_base.SetPolicySockopt(s, family, None)
Ejemplo n.º 5
0
    def _CheckVtiInputOutput(self, vti, inner_version):
        local_outer = vti.local
        remote_outer = vti.remote

        # Create a socket to receive packets.
        read_sock = socket(net_test.GetAddressFamily(inner_version),
                           SOCK_DGRAM, 0)
        read_sock.bind((net_test.GetWildcardAddress(inner_version), 0))
        # The second parameter of the tuple is the port number regardless of AF.
        port = read_sock.getsockname()[1]
        # Guard against the eventuality of the receive failing.
        csocket.SetSocketTimeout(read_sock, 100)

        # Send a packet out via the vti-backed network, bound for the port number
        # of the input socket.
        write_sock = socket(net_test.GetAddressFamily(inner_version),
                            SOCK_DGRAM, 0)
        self.SelectInterface(write_sock, vti.netid, "mark")
        write_sock.sendto(net_test.UDP_PAYLOAD,
                          (_GetRemoteInnerAddress(inner_version), port))

        # Read a tunneled IP packet on the underlying (outbound) network
        # verifying that it is an ESP packet.
        self.assertSentPacket(vti)
        pkt = self._ExpectEspPacketOn(vti.underlying_netid, vti.out_spi,
                                      vti.tx, None, local_outer, remote_outer)

        # Perform an address switcheroo so that the inner address of the remote
        # end of the tunnel is now the address on the local VTI interface; this
        # way, the twisted inner packet finds a destination via the VTI once
        # decrypted.
        remote = _GetRemoteInnerAddress(inner_version)
        local = vti.addrs[inner_version]
        self._SwapInterfaceAddress(vti.iface, new_addr=remote, old_addr=local)
        try:
            # Swap the packet's IP headers and write it back to the
            # underlying network.
            pkt = TunTwister.TwistPacket(pkt)
            self.ReceivePacketOn(vti.underlying_netid, pkt)
            self.assertReceivedPacket(vti)
            # Receive the decrypted packet on the dest port number.
            read_packet = read_sock.recv(4096)
            self.assertEquals(read_packet, net_test.UDP_PAYLOAD)
        finally:
            # Unwind the switcheroo
            self._SwapInterfaceAddress(vti.iface,
                                       new_addr=local,
                                       old_addr=remote)

        # Now attempt to provoke an ICMP error.
        # TODO: deduplicate with multinetwork_test.py.
        version = net_test.GetAddressVersion(vti.remote)
        dst_prefix, intermediate = {
            4: ("172.19.", "172.16.9.12"),
            6: ("2001:db8::", "2001:db8::1")
        }[version]

        write_sock.sendto(net_test.UDP_PAYLOAD,
                          (_GetRemoteInnerAddress(inner_version), port))
        self.assertSentPacket(vti)
        pkt = self._ExpectEspPacketOn(vti.underlying_netid, vti.out_spi,
                                      vti.tx, None, local_outer, remote_outer)
        myaddr = self.MyAddress(version, vti.underlying_netid)
        _, toobig = packets.ICMPPacketTooBig(version, intermediate, myaddr,
                                             pkt)
        self.ReceivePacketOn(vti.underlying_netid, toobig)

        # Check that the packet too big reduced the MTU.
        routes = self.iproute.GetRoutes(vti.remote, 0, vti.underlying_netid,
                                        None)
        self.assertEquals(1, len(routes))
        rtmsg, attributes = routes[0]
        self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
        self.assertEquals(packets.PTB_MTU,
                          attributes["RTA_METRICS"]["RTAX_MTU"])

        # Clear PMTU information so that future tests don't have to worry about it.
        self.InvalidateDstCache(version, vti.underlying_netid)