コード例 #1
0
class SubscriberDbGrpc(SubscriberDbClient):
    """
    Handle subscriber actions by making calls over gRPC directly to the
    gateway.
    """
    def __init__(self):
        """ Init the gRPC stub.  """
        self._added_sids = set()
        self._subscriber_stub = SubscriberDBStub(
            get_rpc_channel("subscriberdb"))

    @staticmethod
    def _try_to_call(grpc_call):
        """ Attempt to call into SubscriberDB and retry if unavailable """
        for i in range(RETRY_COUNT):
            try:
                return grpc_call()
            except grpc.RpcError as error:
                err_code = error.exception().code()
                # If unavailable, try again
                if (err_code == grpc.StatusCode.UNAVAILABLE):
                    logging.warning("Subscriberdb unavailable, retrying...")
                    time.sleep(RETRY_INTERVAL * (2**i))
                    continue
                logging.error("Subscriberdb grpc call failed with error : %s",
                              error)
                raise

    @staticmethod
    def _get_subscriberdb_data(sid):
        """
        Get subscriber data in protobuf format.

        Args:
            sid (str): string representation of the subscriber id
        Returns:
            subscriber_data (protos.subscriberdb_pb2.SubscriberData):
                full subscriber information for :sid: in protobuf format.
        """
        sub_db_sid = SIDUtils.to_pb(sid)
        lte = LTESubscription()
        lte.state = LTESubscription.ACTIVE
        lte.auth_key = bytes.fromhex(KEY)
        state = SubscriberState()
        state.lte_auth_next_seq = 1
        return SubscriberData(sid=sub_db_sid, lte=lte, state=state)

    @staticmethod
    def _get_apn_data(sid, apn_list):
        """
        Get APN data in protobuf format.

        Args:
            apn_list : list of APN configuration
        Returns:
            update (protos.subscriberdb_pb2.SubscriberUpdate)
        """
        # APN
        update = SubscriberUpdate()
        update.data.sid.CopyFrom(sid)
        non_3gpp = update.data.non_3gpp
        for apn in apn_list:
            apn_config = non_3gpp.apn_config.add()
            apn_config.service_selection = apn["apn_name"]
            apn_config.qos_profile.class_id = apn["qci"]
            apn_config.qos_profile.priority_level = apn["priority"]
            apn_config.qos_profile.preemption_capability = apn["pre_cap"]
            apn_config.qos_profile.preemption_vulnerability = apn["pre_vul"]
            apn_config.ambr.max_bandwidth_ul = apn["mbr_ul"]
            apn_config.ambr.max_bandwidth_dl = apn["mbr_dl"]
            apn_config.pdn = apn["pdn_type"] if "pdn_type" in apn else 0
        return update

    def _check_invariants(self):
        """
        Assert preservation of invariants.

        Raises:
            AssertionError: when invariants do not hold
        """
        sids_eq_len = len(self._added_sids) == len(self.list_subscriber_sids())
        assert sids_eq_len

    def add_subscriber(self, sid):
        logging.info("Adding subscriber : %s", sid)
        self._added_sids.add(sid)
        sub_data = self._get_subscriberdb_data(sid)
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.AddSubscriber(sub_data))
        self._check_invariants()

    def delete_subscriber(self, sid):
        logging.info("Deleting subscriber : %s", sid)
        self._added_sids.discard(sid)
        sid_pb = SubscriberID(id=sid[4:])
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.DeleteSubscriber(sid_pb))

    def list_subscriber_sids(self):
        sids_pb = SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.ListSubscribers(Void()).sids)
        sids = ['IMSI' + sid.id for sid in sids_pb]
        return sids

    def config_apn_details(self, imsi, apn_list):
        sid = SIDUtils.to_pb(imsi)
        update_sub = self._get_apn_data(sid, apn_list)
        fields = update_sub.mask.paths
        fields.append('non_3gpp')
        SubscriberDbGrpc._try_to_call(
            lambda: self._subscriber_stub.UpdateSubscriber(update_sub))

    def clean_up(self):
        # Remove all sids
        for sid in self.list_subscriber_sids():
            self.delete_subscriber(sid)
        assert not self.list_subscriber_sids()
        assert not self._added_sids

    def wait_for_changes(self):
        # On gateway, changes propagate immediately
        return
コード例 #2
0
ファイル: rpc_tests.py プロジェクト: talkhasib/magma
class RpcTests(unittest.TestCase):
    """
    Tests for the SubscriberDB rpc servicer and stub
    """
    def setUp(self):
        # Create an in-memory store
        self._tmpfile = tempfile.TemporaryDirectory()
        store = SqliteStore(self._tmpfile.name + '/')

        # Bind the rpc server to a free port
        self._rpc_server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10), )
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')

        # Add the servicer
        self._servicer = SubscriberDBRpcServicer(
            store=store, lte_processor=self._create_lte_processor_mock())
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()

        # Create a rpc stub
        channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))
        self._stub = SubscriberDBStub(channel)

    def _create_lte_processor_mock(self):
        lte_processor_mock = MagicMock()
        sub_profile_mock = MagicMock()
        lte_processor_mock.get_sub_profile.return_value = sub_profile_mock
        sub_profile_mock.max_ul_bit_rate = 23
        sub_profile_mock.max_dl_bit_rate = 42
        return lte_processor_mock

    def tearDown(self):
        self._tmpfile.cleanup()
        self._rpc_server.stop(0)

    def test_get_invalid_subscriber(self):
        """
        Test if the rpc call returns NOT_FOUND
        """
        with self.assertRaises(grpc.RpcError) as err:
            self._stub.GetSubscriberData(SIDUtils.to_pb('IMSI123'))
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)

    def test_add_delete_subscriber(self):
        """
        Test if AddSubscriber and DeleteSubscriber rpc call works
        """
        sid = SIDUtils.to_pb('IMSI1')
        data = SubscriberData(sid=sid)

        # Add subscriber
        self._stub.AddSubscriber(data)

        # Add subscriber again
        with self.assertRaises(grpc.RpcError) as err:
            self._stub.AddSubscriber(data)
        self.assertEqual(err.exception.code(), grpc.StatusCode.ALREADY_EXISTS)

        # See if we can get the data for the subscriber
        self.assertEqual(self._stub.GetSubscriberData(sid).sid, data.sid)
        self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 1)
        self.assertEqual(self._stub.ListSubscribers(Void()).sids[0], sid)

        # Delete the subscriber
        self._stub.DeleteSubscriber(sid)
        self.assertEqual(len(self._stub.ListSubscribers(Void()).sids), 0)

    def test_update_subscriber(self):
        """
        Test if UpdateSubscriber rpc call works
        """
        sid = SIDUtils.to_pb('IMSI1')
        data = SubscriberData(sid=sid)

        # Add subscriber
        self._stub.AddSubscriber(data)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.lte.auth_key, b'')
        self.assertEqual(sub.state.lte_auth_next_seq, 0)

        # Update subscriber
        update = SubscriberUpdate()
        update.data.sid.CopyFrom(sid)
        update.data.lte.auth_key = b'\xab\xcd'
        update.data.state.lte_auth_next_seq = 1
        update.mask.paths.append('lte.auth_key')  # only auth_key
        self._stub.UpdateSubscriber(update)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.state.lte_auth_next_seq, 0)  # no change
        self.assertEqual(sub.lte.auth_key, b'\xab\xcd')

        update.data.state.lte_auth_next_seq = 1
        update.mask.paths.append('state.lte_auth_next_seq')
        self._stub.UpdateSubscriber(update)

        sub = self._stub.GetSubscriberData(sid)
        self.assertEqual(sub.state.lte_auth_next_seq, 1)

        # Delete the subscriber
        self._stub.DeleteSubscriber(sid)

        with self.assertRaises(grpc.RpcError) as err:
            self._stub.UpdateSubscriber(update)
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)