class RpcTests(unittest.TestCase): """ Tests for the SubscriberDB rpc servicer and stub """ def setUp(self): # Create an in-memory store self._tmpfile = tempfile.TemporaryDirectory() store = SqliteStore(self._tmpfile.name + '/') # Bind the rpc server to a free port self._rpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), ) port = self._rpc_server.add_insecure_port('0.0.0.0:0') # Add the servicer self._servicer = SubscriberDBRpcServicer( store=store, lte_processor=self._create_lte_processor_mock()) self._servicer.add_to_server(self._rpc_server) self._rpc_server.start() # Create a rpc stub channel = grpc.insecure_channel('0.0.0.0:{}'.format(port)) self._stub = SubscriberDBStub(channel) def _create_lte_processor_mock(self): lte_processor_mock = MagicMock() sub_profile_mock = MagicMock() lte_processor_mock.get_sub_profile.return_value = sub_profile_mock sub_profile_mock.max_ul_bit_rate = 23 sub_profile_mock.max_dl_bit_rate = 42 return lte_processor_mock def tearDown(self): self._tmpfile.cleanup() self._rpc_server.stop(0) def test_get_invalid_subscriber(self): """ Test if the rpc call returns NOT_FOUND """ with self.assertRaises(grpc.RpcError) as err: self._stub.GetSubscriberData(SIDUtils.to_pb('IMSI123')) self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND) def test_add_delete_subscriber(self): """ Test if AddSubscriber and DeleteSubscriber rpc call works """ sid = SIDUtils.to_pb('IMSI1') data = SubscriberData(sid=sid) # Add subscriber self._stub.AddSubscriber(data) # Add subscriber again with self.assertRaises(grpc.RpcError) as err: self._stub.AddSubscriber(data) self.assertEqual(err.exception.code(), grpc.StatusCode.ALREADY_EXISTS) # See if we can get the data for the subscriber self.assertEqual(self._stub.GetSubscriberData(sid).sid, data.sid) self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 1) self.assertEqual(self._stub.ListSubscribers(Void()).sids[0], sid) # Delete the subscriber self._stub.DeleteSubscriber(sid) self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 0) def test_update_subscriber(self): """ Test if UpdateSubscriber rpc call works """ sid = SIDUtils.to_pb('IMSI1') data = SubscriberData(sid=sid) # Add subscriber self._stub.AddSubscriber(data) sub = self._stub.GetSubscriberData(sid) self.assertEqual(sub.lte.auth_key, b'') self.assertEqual(sub.state.lte_auth_next_seq, 0) # Update subscriber update = SubscriberUpdate() update.data.sid.CopyFrom(sid) update.data.lte.auth_key = b'\xab\xcd' update.data.state.lte_auth_next_seq = 1 update.mask.paths.append('lte.auth_key') # only auth_key self._stub.UpdateSubscriber(update) sub = self._stub.GetSubscriberData(sid) self.assertEqual(sub.state.lte_auth_next_seq, 0) # no change self.assertEqual(sub.lte.auth_key, b'\xab\xcd') update.data.state.lte_auth_next_seq = 1 update.mask.paths.append('state.lte_auth_next_seq') self._stub.UpdateSubscriber(update) sub = self._stub.GetSubscriberData(sid) self.assertEqual(sub.state.lte_auth_next_seq, 1) # Delete the subscriber self._stub.DeleteSubscriber(sid) with self.assertRaises(grpc.RpcError) as err: self._stub.UpdateSubscriber(update) self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)
class SessionRpcServicer(session_manager_pb2_grpc.LocalSessionManagerServicer): """ gRPC based server for LocalSessionManager service """ ALLOW_ALL_PRIORITY = 100 REDIRECT_PRIORITY = 2000 RPC_TIMEOUT = 5 def __init__(self, service): chan = ServiceRegistry.get_rpc_channel('pipelined', ServiceRegistry.LOCAL) self._pipelined = PipelinedStub(chan) chan = ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL) self._subscriberdb = SubscriberDBStub(chan) self._enabled = service.config['captive_portal_enabled'] self._captive_portal_address = service.config['captive_portal_url'] self._local_ip = get_ip_from_if(service.config['bridge_interface']) self._whitelisted_ips = service.config['whitelisted_ips'] self._sub_profile_substr = service.config[ 'subscriber_profile_substr_match'] def add_to_server(self, server): """ Add the servicer to a gRPC server """ session_manager_pb2_grpc.add_LocalSessionManagerServicer_to_server( self, server, ) def CreateSession(self, request, context): """ Handles create session request from MME by installing the necessary flows in pipelined's enforcement app. """ sid = request.sid logging.info('Create session request for sid: %s' % sid.id) try: # Gather the set of policy rules to use if self._captive_portal_enabled(sid): rules = [] rules.extend(self._get_whitelisted_policies()) rules.extend(self._get_redirect_policies()) else: rules = self._get_allow_all_traffic_policies() # Activate the flows in the enforcement app in pipelined act_request = ActivateFlowsRequest(sid=sid, ip_addr=request.ue_ipv4, dynamic_rules=rules) act_response = self._pipelined.ActivateFlows( act_request, timeout=self.RPC_TIMEOUT) for res in act_response.dynamic_rule_results: if res.result != res.SUCCESS: # Hmm rolling back partial success is difficult # Let's just log this for now logging.error('Failed to activate rule: %s' % res.rule_id) except RpcError as err: self._set_rpc_error(context, err) return session_manager_pb2.LocalCreateSessionResponse() def EndSession(self, sid, context): """ Handles end session request from MME by removing all the flows for the subscriber in pipelined's enforcement app. """ logging.info('End session request for sid: %s' % sid.id) try: self._pipelined.DeactivateFlows(DeactivateFlowsRequest(sid=sid), timeout=self.RPC_TIMEOUT) except RpcError as err: self._set_rpc_error(context, err) return session_manager_pb2.LocalEndSessionResponse() def ReportRuleStats(self, request, context): """ Handles stats update from the enforcement app in pipelined. We are ignoring this for now, since the applications can poll pipelined for the flow stats. """ logging.debug('Ignoring ReportRuleStats rpc') return Void() def _captive_portal_enabled(self, sid): if not self._enabled: return False # Service is disabled if self._sub_profile_substr == '': return True # Allow all subscribers sub = self._subscriberdb.GetSubscriberData(sid, timeout=self.RPC_TIMEOUT) return self._sub_profile_substr in sub.sub_profile def _get_allow_all_traffic_policies(self): """ Policy to allow all traffic to the internet """ return [ PolicyRule( id='allow_all_traffic', priority=self.ALLOW_ALL_PRIORITY, flow_list=[ FlowDescription(match=FlowMatch( direction=FlowMatch.UPLINK)), FlowDescription(match=FlowMatch( direction=FlowMatch.DOWNLINK)), ], ) ] def _get_whitelisted_policies(self): """ Policies to allow http traffic to the whitelisted sites """ rules = [] for ip, ports in self._whitelisted_ips.items(): for port in ports: if ip == 'local': ip = self._local_ip rules.append( PolicyRule( id='whitelist', priority=self.ALLOW_ALL_PRIORITY, flow_list=[ FlowDescription( match=FlowMatch(direction=FlowMatch.UPLINK, ip_proto=FlowMatch.IPPROTO_TCP, ipv4_dst=ip, tcp_dst=port)), FlowDescription( match=FlowMatch(direction=FlowMatch.DOWNLINK, ip_proto=FlowMatch.IPPROTO_TCP, ipv4_src=ip, tcp_src=port)), ])) return rules def _get_redirect_policies(self): """ Policy to redirect traffic to the captive portal """ redirect_info = RedirectInformation( support=RedirectInformation.ENABLED, address_type=RedirectInformation.URL, server_address=self._captive_portal_address) return [ PolicyRule(id='redirect', priority=self.REDIRECT_PRIORITY, redirect=redirect_info) ] def _set_rpc_error(self, context, err): logging.error(err.details()) context.set_details(err.details()) context.set_code(err.code())