def main(): """ main() for subscriberdb """ service = MagmaService('policydb', mconfigs_pb2.PolicyDB()) # Add all servicers to the server chan = ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL) subscriberdb_stub = SubscriberDBStub(chan) session_servicer = SessionRpcServicer(service.config, subscriberdb_stub) session_servicer.add_to_server(service.rpc_server) # Start a background thread to stream updates from the cloud if service.config['enable_streaming']: callback = PolicyDBStreamerCallback(service.loop) stream = StreamerClient({"policydb": callback}, service.loop) stream.start() else: logging.info('enable_streaming set to False. Streamer disabled!') # Run the service loop service.run() # Cleanup the service service.close()
def __init__(self): """ Init the gRPC stub. """ self._added_sids = set() self._subscriber_stub = SubscriberDBStub( get_rpc_channel("subscriberdb"))
class SubscriberDbGrpc(SubscriberDbClient): """ Handle subscriber actions by making calls over gRPC directly to the gateway. """ def __init__(self): """ Init the gRPC stub. """ self._added_sids = set() self._subscriber_stub = SubscriberDBStub( get_rpc_channel("subscriberdb")) @staticmethod def _try_to_call(grpc_call): """ Attempt to call into SubscriberDB and retry if unavailable """ for i in range(RETRY_COUNT): try: return grpc_call() except grpc.RpcError as error: err_code = error.exception().code() # If unavailable, try again if (err_code == grpc.StatusCode.UNAVAILABLE): logging.warning("Subscriberdb unavailable, retrying...") time.sleep(RETRY_INTERVAL * (2**i)) continue logging.error("Subscriberdb grpc call failed with error : %s", error) raise @staticmethod def _get_subscriberdb_data(sid): """ Get subscriber data in protobuf format. Args: sid (str): string representation of the subscriber id Returns: subscriber_data (protos.subscriberdb_pb2.SubscriberData): full subscriber information for :sid: in protobuf format. """ sub_db_sid = SIDUtils.to_pb(sid) lte = LTESubscription() lte.state = LTESubscription.ACTIVE lte.auth_key = bytes.fromhex(KEY) state = SubscriberState() state.lte_auth_next_seq = 1 return SubscriberData(sid=sub_db_sid, lte=lte, state=state) @staticmethod def _get_apn_data(sid, apn_list): """ Get APN data in protobuf format. Args: apn_list : list of APN configuration Returns: update (protos.subscriberdb_pb2.SubscriberUpdate) """ # APN update = SubscriberUpdate() update.data.sid.CopyFrom(sid) non_3gpp = update.data.non_3gpp for apn in apn_list: apn_config = non_3gpp.apn_config.add() apn_config.service_selection = apn["apn_name"] apn_config.qos_profile.class_id = apn["qci"] apn_config.qos_profile.priority_level = apn["priority"] apn_config.qos_profile.preemption_capability = apn["pre_cap"] apn_config.qos_profile.preemption_vulnerability = apn["pre_vul"] apn_config.ambr.max_bandwidth_ul = apn["mbr_ul"] apn_config.ambr.max_bandwidth_dl = apn["mbr_dl"] apn_config.pdn = apn["pdn_type"] if "pdn_type" in apn else 0 return update def _check_invariants(self): """ Assert preservation of invariants. Raises: AssertionError: when invariants do not hold """ sids_eq_len = len(self._added_sids) == len(self.list_subscriber_sids()) assert sids_eq_len def add_subscriber(self, sid): logging.info("Adding subscriber : %s", sid) self._added_sids.add(sid) sub_data = self._get_subscriberdb_data(sid) SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.AddSubscriber(sub_data)) self._check_invariants() def delete_subscriber(self, sid): logging.info("Deleting subscriber : %s", sid) self._added_sids.discard(sid) sid_pb = SubscriberID(id=sid[4:]) SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.DeleteSubscriber(sid_pb)) def list_subscriber_sids(self): sids_pb = SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.ListSubscribers(Void()).sids) sids = ['IMSI' + sid.id for sid in sids_pb] return sids def config_apn_details(self, imsi, apn_list): sid = SIDUtils.to_pb(imsi) update_sub = self._get_apn_data(sid, apn_list) fields = update_sub.mask.paths fields.append('non_3gpp') SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.UpdateSubscriber(update_sub)) def clean_up(self): # Remove all sids for sid in self.list_subscriber_sids(): self.delete_subscriber(sid) assert not self.list_subscriber_sids() assert not self._added_sids def wait_for_changes(self): # On gateway, changes propagate immediately return
def main(): """ main() for MobilityD """ service = MagmaService('mobilityd', mconfigs_pb2.MobilityD()) # Load service configs and mconfig config = service.config mconfig = service.mconfig multi_apn = config.get('multi_apn', mconfig.multi_apn_ip_alloc) static_ip_enabled = config.get('static_ip', mconfig.static_ip_enabled) allocator_type = mconfig.ip_allocator_type dhcp_iface = config.get('dhcp_iface', 'dhcp0') dhcp_retry_limit = config.get('retry_limit', 300) # TODO: consider adding gateway mconfig to decide whether to # persist to Redis client = get_default_client() store = MobilityStore(client, config.get('persist_to_redis', False), config.get('redis_port', 6380)) chan = ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL) ipv4_allocator = _get_ipv4_allocator(store, allocator_type, static_ip_enabled, multi_apn, dhcp_iface, dhcp_retry_limit, SubscriberDBStub(chan)) # Init IPv6 allocator, for now only IP_POOL mode is supported for IPv6 ipv6_prefix_allocation_type = mconfig.ipv6_prefix_allocation_type or \ DEFAULT_IPV6_PREFIX_ALLOC_MODE ipv6_allocator = IPv6AllocatorPool( store=store, session_prefix_alloc_mode=ipv6_prefix_allocation_type) # Load IPAddressManager ip_address_man = IPAddressManager(ipv4_allocator, ipv6_allocator, store) # Add all servicers to the server mobility_service_servicer = MobilityServiceRpcServicer( ip_address_man, config.get('print_grpc_payload', False)) mobility_service_servicer.add_to_server(service.rpc_server) # Load IPv4 and IPv6 blocks from the configurable mconfig file # No dynamic reloading support for now, assume restart after updates logging.info('Adding IPv4 block') ipv4_block = _get_ip_block(mconfig.ip_block) if ipv4_block is not None: try: mobility_service_servicer.add_ip_block(ipv4_block) except OverlappedIPBlocksError: logging.warning("Overlapped IPv4 block: %s", ipv4_block) logging.info('Adding IPv6 block') ipv6_block = _get_ip_block(mconfig.ipv6_block) if ipv6_block is not None: try: mobility_service_servicer.add_ip_block(ipv6_block) except OverlappedIPBlocksError: logging.warning("Overlapped IPv6 block: %s", ipv6_block) # Run the service loop service.run() # Cleanup the service service.close()
class RpcTests(unittest.TestCase): """ Tests for the SubscriberDB rpc servicer and stub """ def setUp(self): # Create an in-memory store self._tmpfile = tempfile.TemporaryDirectory() store = SqliteStore(self._tmpfile.name + '/') # Bind the rpc server to a free port self._rpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), ) port = self._rpc_server.add_insecure_port('0.0.0.0:0') # Add the servicer self._servicer = SubscriberDBRpcServicer( store=store, lte_processor=self._create_lte_processor_mock()) self._servicer.add_to_server(self._rpc_server) self._rpc_server.start() # Create a rpc stub channel = grpc.insecure_channel('0.0.0.0:{}'.format(port)) self._stub = SubscriberDBStub(channel) def _create_lte_processor_mock(self): lte_processor_mock = MagicMock() sub_profile_mock = MagicMock() lte_processor_mock.get_sub_profile.return_value = sub_profile_mock sub_profile_mock.max_ul_bit_rate = 23 sub_profile_mock.max_dl_bit_rate = 42 return lte_processor_mock def tearDown(self): self._tmpfile.cleanup() self._rpc_server.stop(0) def test_get_invalid_subscriber(self): """ Test if the rpc call returns NOT_FOUND """ with self.assertRaises(grpc.RpcError) as err: self._stub.GetSubscriberData(SIDUtils.to_pb('IMSI123')) self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND) def test_add_delete_subscriber(self): """ Test if AddSubscriber and DeleteSubscriber rpc call works """ sid = SIDUtils.to_pb('IMSI1') data = SubscriberData(sid=sid) # Add subscriber self._stub.AddSubscriber(data) # Add subscriber again with self.assertRaises(grpc.RpcError) as err: self._stub.AddSubscriber(data) self.assertEqual(err.exception.code(), grpc.StatusCode.ALREADY_EXISTS) # See if we can get the data for the subscriber self.assertEqual(self._stub.GetSubscriberData(sid).sid, data.sid) self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 1) self.assertEqual(self._stub.ListSubscribers(Void()).sids[0], sid) # Delete the subscriber self._stub.DeleteSubscriber(sid) self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 0) def test_update_subscriber(self): """ Test if UpdateSubscriber rpc call works """ sid = SIDUtils.to_pb('IMSI1') data = SubscriberData(sid=sid) # Add subscriber self._stub.AddSubscriber(data) sub = self._stub.GetSubscriberData(sid) self.assertEqual(sub.lte.auth_key, b'') self.assertEqual(sub.state.lte_auth_next_seq, 0) # Update subscriber update = SubscriberUpdate() update.data.sid.CopyFrom(sid) update.data.lte.auth_key = b'\xab\xcd' update.data.state.lte_auth_next_seq = 1 update.mask.paths.append('lte.auth_key') # only auth_key self._stub.UpdateSubscriber(update) sub = self._stub.GetSubscriberData(sid) self.assertEqual(sub.state.lte_auth_next_seq, 0) # no change self.assertEqual(sub.lte.auth_key, b'\xab\xcd') update.data.state.lte_auth_next_seq = 1 update.mask.paths.append('state.lte_auth_next_seq') self._stub.UpdateSubscriber(update) sub = self._stub.GetSubscriberData(sid) self.assertEqual(sub.state.lte_auth_next_seq, 1) # Delete the subscriber self._stub.DeleteSubscriber(sid) with self.assertRaises(grpc.RpcError) as err: self._stub.UpdateSubscriber(update) self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)
def main(): """Start mobilityd""" service = MagmaService('mobilityd', mconfigs_pb2.MobilityD()) # Optionally pipe errors to Sentry sentry_init(service_name=service.name) # Load service configs and mconfig config = service.config mconfig = service.mconfig multi_apn = config.get('multi_apn', mconfig.multi_apn_ip_alloc) static_ip_enabled = config.get('static_ip', mconfig.static_ip_enabled) allocator_type = mconfig.ip_allocator_type dhcp_iface = config.get('dhcp_iface', 'dhcp0') dhcp_retry_limit = config.get('retry_limit', RETRY_LIMIT) # TODO: consider adding gateway mconfig to decide whether to # persist to Redis client = get_default_client() store = MobilityStore( client, config.get('persist_to_redis', False), config.get('redis_port', DEFAULT_REDIS_PORT), ) chan = ServiceRegistry.get_rpc_channel( 'subscriberdb', ServiceRegistry.LOCAL, ) ipv4_allocator = _get_ipv4_allocator( store, allocator_type, static_ip_enabled, multi_apn, dhcp_iface, dhcp_retry_limit, SubscriberDBStub(chan), ) # Init IPv6 allocator, for now only IP_POOL mode is supported for IPv6 ipv6_allocator = IPv6AllocatorPool( store=store, session_prefix_alloc_mode=_get_value_or_default( mconfig.ipv6_prefix_allocation_type, DEFAULT_IPV6_PREFIX_ALLOC_MODE, ), ) # Load IPAddressManager ip_address_man = IPAddressManager(ipv4_allocator, ipv6_allocator, store) # Load IPv4 and IPv6 blocks from the configurable mconfig file # No dynamic reloading support for now, assume restart after updates ipv4_block = _get_ip_block(mconfig.ip_block, "IPv4") if ipv4_block is not None: logging.info('Adding IPv4 block') try: allocated_ip_blocks = ip_address_man.list_added_ip_blocks() if ipv4_block not in allocated_ip_blocks: # Cleanup previously allocated IP blocks ip_address_man.remove_ip_blocks(*allocated_ip_blocks, force=True) ip_address_man.add_ip_block(ipv4_block) except OverlappedIPBlocksError: logging.warning("Overlapped IPv4 block: %s", ipv4_block) ipv6_block = _get_ip_block(mconfig.ipv6_block, "IPv6") if ipv6_block is not None: logging.info('Adding IPv6 block') try: allocated_ipv6_block = ip_address_man.get_assigned_ipv6_block() if ipv6_block != allocated_ipv6_block: ip_address_man.add_ip_block(ipv6_block) except OverlappedIPBlocksError: logging.warning("Overlapped IPv6 block: %s", ipv6_block) # Add all servicers to the server mobility_service_servicer = MobilityServiceRpcServicer( ip_address_man, config.get('print_grpc_payload', False), ) mobility_service_servicer.add_to_server(service.rpc_server) service.run() # Cleanup the service service.close()
class SessionRpcServicer(session_manager_pb2_grpc.LocalSessionManagerServicer): """ gRPC based server for LocalSessionManager service """ ALLOW_ALL_PRIORITY = 100 REDIRECT_PRIORITY = 2000 RPC_TIMEOUT = 5 def __init__(self, service): chan = ServiceRegistry.get_rpc_channel('pipelined', ServiceRegistry.LOCAL) self._pipelined = PipelinedStub(chan) chan = ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL) self._subscriberdb = SubscriberDBStub(chan) self._enabled = service.config['captive_portal_enabled'] self._captive_portal_address = service.config['captive_portal_url'] self._local_ip = get_ip_from_if(service.config['bridge_interface']) self._whitelisted_ips = service.config['whitelisted_ips'] self._sub_profile_substr = service.config[ 'subscriber_profile_substr_match'] def add_to_server(self, server): """ Add the servicer to a gRPC server """ session_manager_pb2_grpc.add_LocalSessionManagerServicer_to_server( self, server, ) def CreateSession(self, request, context): """ Handles create session request from MME by installing the necessary flows in pipelined's enforcement app. """ sid = request.sid logging.info('Create session request for sid: %s' % sid.id) try: # Gather the set of policy rules to use if self._captive_portal_enabled(sid): rules = [] rules.extend(self._get_whitelisted_policies()) rules.extend(self._get_redirect_policies()) else: rules = self._get_allow_all_traffic_policies() # Activate the flows in the enforcement app in pipelined act_request = ActivateFlowsRequest(sid=sid, ip_addr=request.ue_ipv4, dynamic_rules=rules) act_response = self._pipelined.ActivateFlows( act_request, timeout=self.RPC_TIMEOUT) for res in act_response.dynamic_rule_results: if res.result != res.SUCCESS: # Hmm rolling back partial success is difficult # Let's just log this for now logging.error('Failed to activate rule: %s' % res.rule_id) except RpcError as err: self._set_rpc_error(context, err) return session_manager_pb2.LocalCreateSessionResponse() def EndSession(self, sid, context): """ Handles end session request from MME by removing all the flows for the subscriber in pipelined's enforcement app. """ logging.info('End session request for sid: %s' % sid.id) try: self._pipelined.DeactivateFlows(DeactivateFlowsRequest(sid=sid), timeout=self.RPC_TIMEOUT) except RpcError as err: self._set_rpc_error(context, err) return session_manager_pb2.LocalEndSessionResponse() def ReportRuleStats(self, request, context): """ Handles stats update from the enforcement app in pipelined. We are ignoring this for now, since the applications can poll pipelined for the flow stats. """ logging.debug('Ignoring ReportRuleStats rpc') return Void() def _captive_portal_enabled(self, sid): if not self._enabled: return False # Service is disabled if self._sub_profile_substr == '': return True # Allow all subscribers sub = self._subscriberdb.GetSubscriberData(sid, timeout=self.RPC_TIMEOUT) return self._sub_profile_substr in sub.sub_profile def _get_allow_all_traffic_policies(self): """ Policy to allow all traffic to the internet """ return [ PolicyRule( id='allow_all_traffic', priority=self.ALLOW_ALL_PRIORITY, flow_list=[ FlowDescription(match=FlowMatch( direction=FlowMatch.UPLINK)), FlowDescription(match=FlowMatch( direction=FlowMatch.DOWNLINK)), ], ) ] def _get_whitelisted_policies(self): """ Policies to allow http traffic to the whitelisted sites """ rules = [] for ip, ports in self._whitelisted_ips.items(): for port in ports: if ip == 'local': ip = self._local_ip rules.append( PolicyRule( id='whitelist', priority=self.ALLOW_ALL_PRIORITY, flow_list=[ FlowDescription( match=FlowMatch(direction=FlowMatch.UPLINK, ip_proto=FlowMatch.IPPROTO_TCP, ipv4_dst=ip, tcp_dst=port)), FlowDescription( match=FlowMatch(direction=FlowMatch.DOWNLINK, ip_proto=FlowMatch.IPPROTO_TCP, ipv4_src=ip, tcp_src=port)), ])) return rules def _get_redirect_policies(self): """ Policy to redirect traffic to the captive portal """ redirect_info = RedirectInformation( support=RedirectInformation.ENABLED, address_type=RedirectInformation.URL, server_address=self._captive_portal_address) return [ PolicyRule(id='redirect', priority=self.REDIRECT_PRIORITY, redirect=redirect_info) ] def _set_rpc_error(self, context, err): logging.error(err.details()) context.set_details(err.details()) context.set_code(err.code())
def _cleanup_subs(): client = SubscriberDBStub( ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL)) for sid in client.ListSubscribers(Void()).sids: client.DeleteSubscriber(SIDUtils.to_pb('IMSI%s' % sid.id))
class SubscriberDbGrpc(SubscriberDbClient): """ Handle subscriber actions by making calls over gRPC directly to the gateway. """ def __init__(self): """ Init the gRPC stub. """ self._added_sids = set() self._subscriber_stub = SubscriberDBStub( get_rpc_channel("subscriberdb")) @staticmethod def _try_to_call(grpc_call): """ Attempt to call into SubscriberDB and retry if unavailable """ for i in range(RETRY_COUNT): try: return grpc_call() except grpc.RpcError as error: err_code = error.exception().code() # If unavailable, try again if (err_code == grpc.StatusCode.UNAVAILABLE): logging.warning("Subscriberdb unavailable, retrying...") time.sleep(RETRY_INTERVAL * (2 ** i)) continue logging.error("Subscriberdb grpc call failed with error : %s", error) raise @staticmethod def _get_subscriberdb_data(sid): """ Get subscriber data in protobuf format. Args: sid (str): string representation of the subscriber id Returns: subscriber_data (protos.subscriberdb_pb2.SubscriberData): full subscriber information for :sid: in protobuf format. """ sub_db_sid = SIDUtils.to_pb(sid) lte = LTESubscription() lte.state = LTESubscription.ACTIVE lte.auth_key = bytes.fromhex(KEY) state = SubscriberState() state.lte_auth_next_seq = 1 return SubscriberData(sid=sub_db_sid, lte=lte, state=state) def _check_invariants(self): """ Assert preservation of invariants. Raises: AssertionError: when invariants do not hold """ sids_eq_len = len(self._added_sids) == len(self.list_subscriber_sids()) assert sids_eq_len def add_subscriber(self, sid): logging.info("Adding subscriber : %s", sid) self._added_sids.add(sid) sub_data = self._get_subscriberdb_data(sid) SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.AddSubscriber(sub_data)) self._check_invariants() def delete_subscriber(self, sid): logging.info("Deleting subscriber : %s", sid) self._added_sids.discard(sid) sid_pb = SubscriberID(id=sid[4:]) SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.DeleteSubscriber(sid_pb)) def list_subscriber_sids(self): sids_pb = SubscriberDbGrpc._try_to_call( lambda: self._subscriber_stub.ListSubscribers(Void()).sids) sids = ['IMSI' + sid.id for sid in sids_pb] return sids def clean_up(self): # Remove all sids for sid in self.list_subscriber_sids(): self.delete_subscriber(sid) assert not self.list_subscriber_sids() assert not self._added_sids def wait_for_changes(self): # On gateway, changes propagate immediately return
def __init__(self): """ Init the gRPC stub to connect to subscriberDb. """ super().__init__(SubscriberDBStub(get_rpc_channel("subscriberdb")))