Esempio n. 1
0
def main():
    """ main() for subscriberdb """
    service = MagmaService('policydb', mconfigs_pb2.PolicyDB())

    # Add all servicers to the server
    chan = ServiceRegistry.get_rpc_channel('subscriberdb',
                                           ServiceRegistry.LOCAL)
    subscriberdb_stub = SubscriberDBStub(chan)
    session_servicer = SessionRpcServicer(service.config, subscriberdb_stub)
    session_servicer.add_to_server(service.rpc_server)

    # Start a background thread to stream updates from the cloud
    if service.config['enable_streaming']:
        callback = PolicyDBStreamerCallback(service.loop)
        stream = StreamerClient({"policydb": callback}, service.loop)
        stream.start()
    else:
        logging.info('enable_streaming set to False. Streamer disabled!')

    # Run the service loop
    service.run()

    # Cleanup the service
    service.close()
Esempio n. 2
0
 def __init__(self):
     """ Init the gRPC stub.  """
     self._added_sids = set()
     self._subscriber_stub = SubscriberDBStub(
         get_rpc_channel("subscriberdb"))
Esempio n. 3
0
class SubscriberDbGrpc(SubscriberDbClient):
    """
    Handle subscriber actions by making calls over gRPC directly to the
    gateway.
    """
    def __init__(self):
        """ Init the gRPC stub.  """
        self._added_sids = set()
        self._subscriber_stub = SubscriberDBStub(
            get_rpc_channel("subscriberdb"))

    @staticmethod
    def _try_to_call(grpc_call):
        """ Attempt to call into SubscriberDB and retry if unavailable """
        for i in range(RETRY_COUNT):
            try:
                return grpc_call()
            except grpc.RpcError as error:
                err_code = error.exception().code()
                # If unavailable, try again
                if (err_code == grpc.StatusCode.UNAVAILABLE):
                    logging.warning("Subscriberdb unavailable, retrying...")
                    time.sleep(RETRY_INTERVAL * (2**i))
                    continue
                logging.error("Subscriberdb grpc call failed with error : %s",
                              error)
                raise

    @staticmethod
    def _get_subscriberdb_data(sid):
        """
        Get subscriber data in protobuf format.

        Args:
            sid (str): string representation of the subscriber id
        Returns:
            subscriber_data (protos.subscriberdb_pb2.SubscriberData):
                full subscriber information for :sid: in protobuf format.
        """
        sub_db_sid = SIDUtils.to_pb(sid)
        lte = LTESubscription()
        lte.state = LTESubscription.ACTIVE
        lte.auth_key = bytes.fromhex(KEY)
        state = SubscriberState()
        state.lte_auth_next_seq = 1
        return SubscriberData(sid=sub_db_sid, lte=lte, state=state)

    @staticmethod
    def _get_apn_data(sid, apn_list):
        """
        Get APN data in protobuf format.

        Args:
            apn_list : list of APN configuration
        Returns:
            update (protos.subscriberdb_pb2.SubscriberUpdate)
        """
        # APN
        update = SubscriberUpdate()
        update.data.sid.CopyFrom(sid)
        non_3gpp = update.data.non_3gpp
        for apn in apn_list:
            apn_config = non_3gpp.apn_config.add()
            apn_config.service_selection = apn["apn_name"]
            apn_config.qos_profile.class_id = apn["qci"]
            apn_config.qos_profile.priority_level = apn["priority"]
            apn_config.qos_profile.preemption_capability = apn["pre_cap"]
            apn_config.qos_profile.preemption_vulnerability = apn["pre_vul"]
            apn_config.ambr.max_bandwidth_ul = apn["mbr_ul"]
            apn_config.ambr.max_bandwidth_dl = apn["mbr_dl"]
            apn_config.pdn = apn["pdn_type"] if "pdn_type" in apn else 0
        return update

    def _check_invariants(self):
        """
        Assert preservation of invariants.

        Raises:
            AssertionError: when invariants do not hold
        """
        sids_eq_len = len(self._added_sids) == len(self.list_subscriber_sids())
        assert sids_eq_len

    def add_subscriber(self, sid):
        logging.info("Adding subscriber : %s", sid)
        self._added_sids.add(sid)
        sub_data = self._get_subscriberdb_data(sid)
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.AddSubscriber(sub_data))
        self._check_invariants()

    def delete_subscriber(self, sid):
        logging.info("Deleting subscriber : %s", sid)
        self._added_sids.discard(sid)
        sid_pb = SubscriberID(id=sid[4:])
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.DeleteSubscriber(sid_pb))

    def list_subscriber_sids(self):
        sids_pb = SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.ListSubscribers(Void()).sids)
        sids = ['IMSI' + sid.id for sid in sids_pb]
        return sids

    def config_apn_details(self, imsi, apn_list):
        sid = SIDUtils.to_pb(imsi)
        update_sub = self._get_apn_data(sid, apn_list)
        fields = update_sub.mask.paths
        fields.append('non_3gpp')
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.UpdateSubscriber(update_sub))

    def clean_up(self):
        # Remove all sids
        for sid in self.list_subscriber_sids():
            self.delete_subscriber(sid)
        assert not self.list_subscriber_sids()
        assert not self._added_sids

    def wait_for_changes(self):
        # On gateway, changes propagate immediately
        return
Esempio n. 4
0
def main():
    """ main() for MobilityD """
    service = MagmaService('mobilityd', mconfigs_pb2.MobilityD())

    # Load service configs and mconfig
    config = service.config
    mconfig = service.mconfig

    multi_apn = config.get('multi_apn', mconfig.multi_apn_ip_alloc)
    static_ip_enabled = config.get('static_ip', mconfig.static_ip_enabled)
    allocator_type = mconfig.ip_allocator_type

    dhcp_iface = config.get('dhcp_iface', 'dhcp0')
    dhcp_retry_limit = config.get('retry_limit', 300)

    # TODO: consider adding gateway mconfig to decide whether to
    # persist to Redis
    client = get_default_client()
    store = MobilityStore(client, config.get('persist_to_redis', False),
                          config.get('redis_port', 6380))

    chan = ServiceRegistry.get_rpc_channel('subscriberdb',
                                           ServiceRegistry.LOCAL)
    ipv4_allocator = _get_ipv4_allocator(store, allocator_type,
                                         static_ip_enabled, multi_apn,
                                         dhcp_iface, dhcp_retry_limit,
                                         SubscriberDBStub(chan))

    # Init IPv6 allocator, for now only IP_POOL mode is supported for IPv6
    ipv6_prefix_allocation_type = mconfig.ipv6_prefix_allocation_type or \
                                  DEFAULT_IPV6_PREFIX_ALLOC_MODE
    ipv6_allocator = IPv6AllocatorPool(
        store=store, session_prefix_alloc_mode=ipv6_prefix_allocation_type)

    # Load IPAddressManager
    ip_address_man = IPAddressManager(ipv4_allocator, ipv6_allocator, store)

    # Add all servicers to the server
    mobility_service_servicer = MobilityServiceRpcServicer(
        ip_address_man, config.get('print_grpc_payload', False))
    mobility_service_servicer.add_to_server(service.rpc_server)

    # Load IPv4 and IPv6 blocks from the configurable mconfig file
    # No dynamic reloading support for now, assume restart after updates
    logging.info('Adding IPv4 block')
    ipv4_block = _get_ip_block(mconfig.ip_block)
    if ipv4_block is not None:
        try:
            mobility_service_servicer.add_ip_block(ipv4_block)
        except OverlappedIPBlocksError:
            logging.warning("Overlapped IPv4 block: %s", ipv4_block)

    logging.info('Adding IPv6 block')
    ipv6_block = _get_ip_block(mconfig.ipv6_block)
    if ipv6_block is not None:
        try:
            mobility_service_servicer.add_ip_block(ipv6_block)
        except OverlappedIPBlocksError:
            logging.warning("Overlapped IPv6 block: %s", ipv6_block)

    # Run the service loop
    service.run()

    # Cleanup the service
    service.close()
Esempio n. 5
0
class RpcTests(unittest.TestCase):
    """
    Tests for the SubscriberDB rpc servicer and stub
    """
    def setUp(self):
        # Create an in-memory store
        self._tmpfile = tempfile.TemporaryDirectory()
        store = SqliteStore(self._tmpfile.name + '/')

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10), )
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')

        # Add the servicer
        self._servicer = SubscriberDBRpcServicer(
            store=store, lte_processor=self._create_lte_processor_mock())
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()

        # Create a rpc stub
        channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))
        self._stub = SubscriberDBStub(channel)

    def _create_lte_processor_mock(self):
        lte_processor_mock = MagicMock()
        sub_profile_mock = MagicMock()
        lte_processor_mock.get_sub_profile.return_value = sub_profile_mock
        sub_profile_mock.max_ul_bit_rate = 23
        sub_profile_mock.max_dl_bit_rate = 42
        return lte_processor_mock

    def tearDown(self):
        self._tmpfile.cleanup()
        self._rpc_server.stop(0)

    def test_get_invalid_subscriber(self):
        """
        Test if the rpc call returns NOT_FOUND
        """
        with self.assertRaises(grpc.RpcError) as err:
            self._stub.GetSubscriberData(SIDUtils.to_pb('IMSI123'))
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)

    def test_add_delete_subscriber(self):
        """
        Test if AddSubscriber and DeleteSubscriber rpc call works
        """
        sid = SIDUtils.to_pb('IMSI1')
        data = SubscriberData(sid=sid)

        # Add subscriber
        self._stub.AddSubscriber(data)

        # Add subscriber again
        with self.assertRaises(grpc.RpcError) as err:
            self._stub.AddSubscriber(data)
        self.assertEqual(err.exception.code(), grpc.StatusCode.ALREADY_EXISTS)

        # See if we can get the data for the subscriber
        self.assertEqual(self._stub.GetSubscriberData(sid).sid, data.sid)
        self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 1)
        self.assertEqual(self._stub.ListSubscribers(Void()).sids[0], sid)

        # Delete the subscriber
        self._stub.DeleteSubscriber(sid)
        self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 0)

    def test_update_subscriber(self):
        """
        Test if UpdateSubscriber rpc call works
        """
        sid = SIDUtils.to_pb('IMSI1')
        data = SubscriberData(sid=sid)

        # Add subscriber
        self._stub.AddSubscriber(data)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.lte.auth_key, b'')
        self.assertEqual(sub.state.lte_auth_next_seq, 0)

        # Update subscriber
        update = SubscriberUpdate()
        update.data.sid.CopyFrom(sid)
        update.data.lte.auth_key = b'\xab\xcd'
        update.data.state.lte_auth_next_seq = 1
        update.mask.paths.append('lte.auth_key')  # only auth_key
        self._stub.UpdateSubscriber(update)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.state.lte_auth_next_seq, 0)  # no change
        self.assertEqual(sub.lte.auth_key, b'\xab\xcd')

        update.data.state.lte_auth_next_seq = 1
        update.mask.paths.append('state.lte_auth_next_seq')
        self._stub.UpdateSubscriber(update)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.state.lte_auth_next_seq, 1)

        # Delete the subscriber
        self._stub.DeleteSubscriber(sid)

        with self.assertRaises(grpc.RpcError) as err:
            self._stub.UpdateSubscriber(update)
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)
Esempio n. 6
0
def main():
    """Start mobilityd"""
    service = MagmaService('mobilityd', mconfigs_pb2.MobilityD())

    # Optionally pipe errors to Sentry
    sentry_init(service_name=service.name)

    # Load service configs and mconfig
    config = service.config
    mconfig = service.mconfig

    multi_apn = config.get('multi_apn', mconfig.multi_apn_ip_alloc)
    static_ip_enabled = config.get('static_ip', mconfig.static_ip_enabled)
    allocator_type = mconfig.ip_allocator_type

    dhcp_iface = config.get('dhcp_iface', 'dhcp0')
    dhcp_retry_limit = config.get('retry_limit', RETRY_LIMIT)

    # TODO: consider adding gateway mconfig to decide whether to
    # persist to Redis
    client = get_default_client()
    store = MobilityStore(
        client, config.get('persist_to_redis', False),
        config.get('redis_port', DEFAULT_REDIS_PORT),
    )

    chan = ServiceRegistry.get_rpc_channel(
        'subscriberdb',
        ServiceRegistry.LOCAL,
    )
    ipv4_allocator = _get_ipv4_allocator(
        store, allocator_type,
        static_ip_enabled, multi_apn,
        dhcp_iface, dhcp_retry_limit,
        SubscriberDBStub(chan),
    )

    # Init IPv6 allocator, for now only IP_POOL mode is supported for IPv6
    ipv6_allocator = IPv6AllocatorPool(
        store=store,
        session_prefix_alloc_mode=_get_value_or_default(
            mconfig.ipv6_prefix_allocation_type,
            DEFAULT_IPV6_PREFIX_ALLOC_MODE,
        ),
    )

    # Load IPAddressManager
    ip_address_man = IPAddressManager(ipv4_allocator, ipv6_allocator, store)

    # Load IPv4 and IPv6 blocks from the configurable mconfig file
    # No dynamic reloading support for now, assume restart after updates
    ipv4_block = _get_ip_block(mconfig.ip_block, "IPv4")
    if ipv4_block is not None:
        logging.info('Adding IPv4 block')
        try:
            allocated_ip_blocks = ip_address_man.list_added_ip_blocks()
            if ipv4_block not in allocated_ip_blocks:
                # Cleanup previously allocated IP blocks
                ip_address_man.remove_ip_blocks(*allocated_ip_blocks, force=True)
                ip_address_man.add_ip_block(ipv4_block)
        except OverlappedIPBlocksError:
            logging.warning("Overlapped IPv4 block: %s", ipv4_block)

    ipv6_block = _get_ip_block(mconfig.ipv6_block, "IPv6")
    if ipv6_block is not None:
        logging.info('Adding IPv6 block')
        try:
            allocated_ipv6_block = ip_address_man.get_assigned_ipv6_block()
            if ipv6_block != allocated_ipv6_block:
                ip_address_man.add_ip_block(ipv6_block)
        except OverlappedIPBlocksError:
            logging.warning("Overlapped IPv6 block: %s", ipv6_block)

    # Add all servicers to the server
    mobility_service_servicer = MobilityServiceRpcServicer(
        ip_address_man, config.get('print_grpc_payload', False),
    )
    mobility_service_servicer.add_to_server(service.rpc_server)
    service.run()

    # Cleanup the service
    service.close()
Esempio n. 7
0
class SessionRpcServicer(session_manager_pb2_grpc.LocalSessionManagerServicer):
    """
    gRPC based server for LocalSessionManager service
    """
    ALLOW_ALL_PRIORITY = 100
    REDIRECT_PRIORITY = 2000
    RPC_TIMEOUT = 5

    def __init__(self, service):
        chan = ServiceRegistry.get_rpc_channel('pipelined',
                                               ServiceRegistry.LOCAL)
        self._pipelined = PipelinedStub(chan)
        chan = ServiceRegistry.get_rpc_channel('subscriberdb',
                                               ServiceRegistry.LOCAL)
        self._subscriberdb = SubscriberDBStub(chan)
        self._enabled = service.config['captive_portal_enabled']
        self._captive_portal_address = service.config['captive_portal_url']
        self._local_ip = get_ip_from_if(service.config['bridge_interface'])
        self._whitelisted_ips = service.config['whitelisted_ips']
        self._sub_profile_substr = service.config[
            'subscriber_profile_substr_match']

    def add_to_server(self, server):
        """
        Add the servicer to a gRPC server
        """
        session_manager_pb2_grpc.add_LocalSessionManagerServicer_to_server(
            self,
            server,
        )

    def CreateSession(self, request, context):
        """
        Handles create session request from MME by installing the necessary
        flows in pipelined's enforcement app.
        """
        sid = request.sid
        logging.info('Create session request for sid: %s' % sid.id)
        try:
            # Gather the set of policy rules to use
            if self._captive_portal_enabled(sid):
                rules = []
                rules.extend(self._get_whitelisted_policies())
                rules.extend(self._get_redirect_policies())
            else:
                rules = self._get_allow_all_traffic_policies()

            # Activate the flows in the enforcement app in pipelined
            act_request = ActivateFlowsRequest(sid=sid,
                                               ip_addr=request.ue_ipv4,
                                               dynamic_rules=rules)
            act_response = self._pipelined.ActivateFlows(
                act_request, timeout=self.RPC_TIMEOUT)
            for res in act_response.dynamic_rule_results:
                if res.result != res.SUCCESS:
                    # Hmm rolling back partial success is difficult
                    # Let's just log this for now
                    logging.error('Failed to activate rule: %s' % res.rule_id)
        except RpcError as err:
            self._set_rpc_error(context, err)
        return session_manager_pb2.LocalCreateSessionResponse()

    def EndSession(self, sid, context):
        """
        Handles end session request from MME by removing all the flows
        for the subscriber in pipelined's enforcement app.
        """
        logging.info('End session request for sid: %s' % sid.id)
        try:
            self._pipelined.DeactivateFlows(DeactivateFlowsRequest(sid=sid),
                                            timeout=self.RPC_TIMEOUT)
        except RpcError as err:
            self._set_rpc_error(context, err)
        return session_manager_pb2.LocalEndSessionResponse()

    def ReportRuleStats(self, request, context):
        """
        Handles stats update from the enforcement app in pipelined. We are
        ignoring this for now, since the applications can poll pipelined for
        the flow stats.
        """
        logging.debug('Ignoring ReportRuleStats rpc')
        return Void()

    def _captive_portal_enabled(self, sid):
        if not self._enabled:
            return False  # Service is disabled

        if self._sub_profile_substr == '':
            return True  # Allow all subscribers

        sub = self._subscriberdb.GetSubscriberData(sid,
                                                   timeout=self.RPC_TIMEOUT)
        return self._sub_profile_substr in sub.sub_profile

    def _get_allow_all_traffic_policies(self):
        """ Policy to allow all traffic to the internet """
        return [
            PolicyRule(
                id='allow_all_traffic',
                priority=self.ALLOW_ALL_PRIORITY,
                flow_list=[
                    FlowDescription(match=FlowMatch(
                        direction=FlowMatch.UPLINK)),
                    FlowDescription(match=FlowMatch(
                        direction=FlowMatch.DOWNLINK)),
                ],
            )
        ]

    def _get_whitelisted_policies(self):
        """ Policies to allow http traffic to the whitelisted sites """
        rules = []
        for ip, ports in self._whitelisted_ips.items():
            for port in ports:
                if ip == 'local':
                    ip = self._local_ip
                rules.append(
                    PolicyRule(
                        id='whitelist',
                        priority=self.ALLOW_ALL_PRIORITY,
                        flow_list=[
                            FlowDescription(
                                match=FlowMatch(direction=FlowMatch.UPLINK,
                                                ip_proto=FlowMatch.IPPROTO_TCP,
                                                ipv4_dst=ip,
                                                tcp_dst=port)),
                            FlowDescription(
                                match=FlowMatch(direction=FlowMatch.DOWNLINK,
                                                ip_proto=FlowMatch.IPPROTO_TCP,
                                                ipv4_src=ip,
                                                tcp_src=port)),
                        ]))
        return rules

    def _get_redirect_policies(self):
        """ Policy to redirect traffic to the captive portal """
        redirect_info = RedirectInformation(
            support=RedirectInformation.ENABLED,
            address_type=RedirectInformation.URL,
            server_address=self._captive_portal_address)
        return [
            PolicyRule(id='redirect',
                       priority=self.REDIRECT_PRIORITY,
                       redirect=redirect_info)
        ]

    def _set_rpc_error(self, context, err):
        logging.error(err.details())
        context.set_details(err.details())
        context.set_code(err.code())
Esempio n. 8
0
def _cleanup_subs():
    client = SubscriberDBStub(
        ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL))

    for sid in client.ListSubscribers(Void()).sids:
        client.DeleteSubscriber(SIDUtils.to_pb('IMSI%s' % sid.id))
Esempio n. 9
0
class SubscriberDbGrpc(SubscriberDbClient):
    """
    Handle subscriber actions by making calls over gRPC directly to the
    gateway.
    """

    def __init__(self):
        """ Init the gRPC stub.  """
        self._added_sids = set()
        self._subscriber_stub = SubscriberDBStub(
            get_rpc_channel("subscriberdb"))

    @staticmethod
    def _try_to_call(grpc_call):
        """ Attempt to call into SubscriberDB and retry if unavailable """
        for i in range(RETRY_COUNT):
            try:
                return grpc_call()
            except grpc.RpcError as error:
                err_code = error.exception().code()
                # If unavailable, try again
                if (err_code == grpc.StatusCode.UNAVAILABLE):
                    logging.warning("Subscriberdb unavailable, retrying...")
                    time.sleep(RETRY_INTERVAL * (2 ** i))
                    continue
                logging.error("Subscriberdb grpc call failed with error : %s",
                              error)
                raise

    @staticmethod
    def _get_subscriberdb_data(sid):
        """
        Get subscriber data in protobuf format.

        Args:
            sid (str): string representation of the subscriber id
        Returns:
            subscriber_data (protos.subscriberdb_pb2.SubscriberData):
                full subscriber information for :sid: in protobuf format.
        """
        sub_db_sid = SIDUtils.to_pb(sid)
        lte = LTESubscription()
        lte.state = LTESubscription.ACTIVE
        lte.auth_key = bytes.fromhex(KEY)
        state = SubscriberState()
        state.lte_auth_next_seq = 1
        return SubscriberData(sid=sub_db_sid, lte=lte, state=state)

    def _check_invariants(self):
        """
        Assert preservation of invariants.

        Raises:
            AssertionError: when invariants do not hold
        """
        sids_eq_len = len(self._added_sids) == len(self.list_subscriber_sids())
        assert sids_eq_len

    def add_subscriber(self, sid):
        logging.info("Adding subscriber : %s", sid)
        self._added_sids.add(sid)
        sub_data = self._get_subscriberdb_data(sid)
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.AddSubscriber(sub_data))
        self._check_invariants()

    def delete_subscriber(self, sid):
        logging.info("Deleting subscriber : %s", sid)
        self._added_sids.discard(sid)
        sid_pb = SubscriberID(id=sid[4:])
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.DeleteSubscriber(sid_pb))

    def list_subscriber_sids(self):
        sids_pb = SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.ListSubscribers(Void()).sids)
        sids = ['IMSI' + sid.id for sid in sids_pb]
        return sids

    def clean_up(self):
        # Remove all sids
        for sid in self.list_subscriber_sids():
            self.delete_subscriber(sid)
        assert not self.list_subscriber_sids()
        assert not self._added_sids

    def wait_for_changes(self):
        # On gateway, changes propagate immediately
        return
Esempio n. 10
0
 def __init__(self):
     """ Init the gRPC stub to connect to subscriberDb. """
     super().__init__(SubscriberDBStub(get_rpc_channel("subscriberdb")))