Ejemplo n.º 1
0
class RpcTests(unittest.TestCase):
    """
    Tests for the SubscriberDB rpc servicer and stub
    """
    def setUp(self):
        # Create an in-memory store
        self._tmpfile = tempfile.TemporaryDirectory()
        store = SqliteStore(self._tmpfile.name + '/')

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10), )
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')

        # Add the servicer
        self._servicer = SubscriberDBRpcServicer(
            store=store, lte_processor=self._create_lte_processor_mock())
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()

        # Create a rpc stub
        channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))
        self._stub = SubscriberDBStub(channel)

    def _create_lte_processor_mock(self):
        lte_processor_mock = MagicMock()
        sub_profile_mock = MagicMock()
        lte_processor_mock.get_sub_profile.return_value = sub_profile_mock
        sub_profile_mock.max_ul_bit_rate = 23
        sub_profile_mock.max_dl_bit_rate = 42
        return lte_processor_mock

    def tearDown(self):
        self._tmpfile.cleanup()
        self._rpc_server.stop(0)

    def test_get_invalid_subscriber(self):
        """
        Test if the rpc call returns NOT_FOUND
        """
        with self.assertRaises(grpc.RpcError) as err:
            self._stub.GetSubscriberData(SIDUtils.to_pb('IMSI123'))
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)

    def test_add_delete_subscriber(self):
        """
        Test if AddSubscriber and DeleteSubscriber rpc call works
        """
        sid = SIDUtils.to_pb('IMSI1')
        data = SubscriberData(sid=sid)

        # Add subscriber
        self._stub.AddSubscriber(data)

        # Add subscriber again
        with self.assertRaises(grpc.RpcError) as err:
            self._stub.AddSubscriber(data)
        self.assertEqual(err.exception.code(), grpc.StatusCode.ALREADY_EXISTS)

        # See if we can get the data for the subscriber
        self.assertEqual(self._stub.GetSubscriberData(sid).sid, data.sid)
        self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 1)
        self.assertEqual(self._stub.ListSubscribers(Void()).sids[0], sid)

        # Delete the subscriber
        self._stub.DeleteSubscriber(sid)
        self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 0)

    def test_update_subscriber(self):
        """
        Test if UpdateSubscriber rpc call works
        """
        sid = SIDUtils.to_pb('IMSI1')
        data = SubscriberData(sid=sid)

        # Add subscriber
        self._stub.AddSubscriber(data)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.lte.auth_key, b'')
        self.assertEqual(sub.state.lte_auth_next_seq, 0)

        # Update subscriber
        update = SubscriberUpdate()
        update.data.sid.CopyFrom(sid)
        update.data.lte.auth_key = b'\xab\xcd'
        update.data.state.lte_auth_next_seq = 1
        update.mask.paths.append('lte.auth_key')  # only auth_key
        self._stub.UpdateSubscriber(update)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.state.lte_auth_next_seq, 0)  # no change
        self.assertEqual(sub.lte.auth_key, b'\xab\xcd')

        update.data.state.lte_auth_next_seq = 1
        update.mask.paths.append('state.lte_auth_next_seq')
        self._stub.UpdateSubscriber(update)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.state.lte_auth_next_seq, 1)

        # Delete the subscriber
        self._stub.DeleteSubscriber(sid)

        with self.assertRaises(grpc.RpcError) as err:
            self._stub.UpdateSubscriber(update)
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)
Ejemplo n.º 2
0
class SessionRpcServicer(session_manager_pb2_grpc.LocalSessionManagerServicer):
    """
    gRPC based server for LocalSessionManager service
    """
    ALLOW_ALL_PRIORITY = 100
    REDIRECT_PRIORITY = 2000
    RPC_TIMEOUT = 5

    def __init__(self, service):
        chan = ServiceRegistry.get_rpc_channel('pipelined',
                                               ServiceRegistry.LOCAL)
        self._pipelined = PipelinedStub(chan)
        chan = ServiceRegistry.get_rpc_channel('subscriberdb',
                                               ServiceRegistry.LOCAL)
        self._subscriberdb = SubscriberDBStub(chan)
        self._enabled = service.config['captive_portal_enabled']
        self._captive_portal_address = service.config['captive_portal_url']
        self._local_ip = get_ip_from_if(service.config['bridge_interface'])
        self._whitelisted_ips = service.config['whitelisted_ips']
        self._sub_profile_substr = service.config[
            'subscriber_profile_substr_match']

    def add_to_server(self, server):
        """
        Add the servicer to a gRPC server
        """
        session_manager_pb2_grpc.add_LocalSessionManagerServicer_to_server(
            self,
            server,
        )

    def CreateSession(self, request, context):
        """
        Handles create session request from MME by installing the necessary
        flows in pipelined's enforcement app.
        """
        sid = request.sid
        logging.info('Create session request for sid: %s' % sid.id)
        try:
            # Gather the set of policy rules to use
            if self._captive_portal_enabled(sid):
                rules = []
                rules.extend(self._get_whitelisted_policies())
                rules.extend(self._get_redirect_policies())
            else:
                rules = self._get_allow_all_traffic_policies()

            # Activate the flows in the enforcement app in pipelined
            act_request = ActivateFlowsRequest(sid=sid,
                                               ip_addr=request.ue_ipv4,
                                               dynamic_rules=rules)
            act_response = self._pipelined.ActivateFlows(
                act_request, timeout=self.RPC_TIMEOUT)
            for res in act_response.dynamic_rule_results:
                if res.result != res.SUCCESS:
                    # Hmm rolling back partial success is difficult
                    # Let's just log this for now
                    logging.error('Failed to activate rule: %s' % res.rule_id)
        except RpcError as err:
            self._set_rpc_error(context, err)
        return session_manager_pb2.LocalCreateSessionResponse()

    def EndSession(self, sid, context):
        """
        Handles end session request from MME by removing all the flows
        for the subscriber in pipelined's enforcement app.
        """
        logging.info('End session request for sid: %s' % sid.id)
        try:
            self._pipelined.DeactivateFlows(DeactivateFlowsRequest(sid=sid),
                                            timeout=self.RPC_TIMEOUT)
        except RpcError as err:
            self._set_rpc_error(context, err)
        return session_manager_pb2.LocalEndSessionResponse()

    def ReportRuleStats(self, request, context):
        """
        Handles stats update from the enforcement app in pipelined. We are
        ignoring this for now, since the applications can poll pipelined for
        the flow stats.
        """
        logging.debug('Ignoring ReportRuleStats rpc')
        return Void()

    def _captive_portal_enabled(self, sid):
        if not self._enabled:
            return False  # Service is disabled

        if self._sub_profile_substr == '':
            return True  # Allow all subscribers

        sub = self._subscriberdb.GetSubscriberData(sid,
                                                   timeout=self.RPC_TIMEOUT)
        return self._sub_profile_substr in sub.sub_profile

    def _get_allow_all_traffic_policies(self):
        """ Policy to allow all traffic to the internet """
        return [
            PolicyRule(
                id='allow_all_traffic',
                priority=self.ALLOW_ALL_PRIORITY,
                flow_list=[
                    FlowDescription(match=FlowMatch(
                        direction=FlowMatch.UPLINK)),
                    FlowDescription(match=FlowMatch(
                        direction=FlowMatch.DOWNLINK)),
                ],
            )
        ]

    def _get_whitelisted_policies(self):
        """ Policies to allow http traffic to the whitelisted sites """
        rules = []
        for ip, ports in self._whitelisted_ips.items():
            for port in ports:
                if ip == 'local':
                    ip = self._local_ip
                rules.append(
                    PolicyRule(
                        id='whitelist',
                        priority=self.ALLOW_ALL_PRIORITY,
                        flow_list=[
                            FlowDescription(
                                match=FlowMatch(direction=FlowMatch.UPLINK,
                                                ip_proto=FlowMatch.IPPROTO_TCP,
                                                ipv4_dst=ip,
                                                tcp_dst=port)),
                            FlowDescription(
                                match=FlowMatch(direction=FlowMatch.DOWNLINK,
                                                ip_proto=FlowMatch.IPPROTO_TCP,
                                                ipv4_src=ip,
                                                tcp_src=port)),
                        ]))
        return rules

    def _get_redirect_policies(self):
        """ Policy to redirect traffic to the captive portal """
        redirect_info = RedirectInformation(
            support=RedirectInformation.ENABLED,
            address_type=RedirectInformation.URL,
            server_address=self._captive_portal_address)
        return [
            PolicyRule(id='redirect',
                       priority=self.REDIRECT_PRIORITY,
                       redirect=redirect_info)
        ]

    def _set_rpc_error(self, context, err):
        logging.error(err.details())
        context.set_details(err.details())
        context.set_code(err.code())