예제 #1
0
    def test_add_three_subscribers(self):
        """
        Add flows for two subscribers
        """
        imsi_1 = 'IMSI010000000088888'
        imsi_2 = 'IMSI010000111111118'
        imsi_3 = 'IMSI010002222222222'
        mac_1 = '5e:cc:cc:b1:49:4b'
        mac_2 = '5e:a:cc:af:aa:fe'
        mac_3 = '5e:bb:cc:aa:aa:fe'

        # Add subscriber with UE MAC address
        self.check_quota_controller.update_subscriber_quota_state(
            [
                SubscriberQuotaUpdate(
                    sid=SubscriberID(id=imsi_1), mac_addr=mac_1,
                    update_type=SubscriberQuotaUpdate.NO_QUOTA),
                SubscriberQuotaUpdate(
                    sid=SubscriberID(id=imsi_2), mac_addr=mac_2,
                    update_type=SubscriberQuotaUpdate.NO_QUOTA),
                SubscriberQuotaUpdate(
                    sid=SubscriberID(id=imsi_3), mac_addr=mac_3,
                    update_type=SubscriberQuotaUpdate.VALID_QUOTA),
            ]
        )

        snapshot_verifier = SnapshotVerifier(self, self.BRIDGE,
                                             self.service_manager,
                                             include_stats=False)

        with snapshot_verifier:
            wait_after_send(self.testing_controller)
예제 #2
0
    def test_ocs_failure(self):
        """
        Test that when the OCS fails to respond to an update request, the service
        is cut off until the update can be completed
        """
        sub1 = SubContextConfig('IMSI001010000088888', '192.168.128.74', default_ambr_config, 4)
        quota = 1024

        self.test_util.controller.mock_create_session = Mock(
            return_value=session_manager_pb2.CreateSessionResponse(
                credits=[create_update_response(sub1.imsi, 1, quota)],
                static_rules=[session_manager_pb2.StaticRuleInstall(
                    rule_id="simple_match"
                )],
            ),
        )

        update_complete = hub.Queue()
        self.test_util.controller.mock_update_session = Mock(
            side_effect=get_standard_update_response(
                update_complete, None, quota, success=False),
        )

        self.test_util.controller.mock_terminate_session = Mock(
            return_value=session_manager_pb2.SessionTerminateResponse(),
        )

        self.test_util.sessiond.CreateSession(
            session_manager_pb2.LocalCreateSessionRequest(
                sid=SubscriberID(id=sub1.imsi),
                ue_ipv4=sub1.ip,
            ),
        )
        self.assertEqual(self.test_util.controller.mock_create_session.call_count, 1)

        packets = get_packets_for_flows(
            sub1, self.test_util.static_rules["simple_match"].flow_list)
        packet_count = int(quota / len(packets[0])) + 1
        sender = self.test_util.get_packet_sender([sub1], packets, packet_count)

        # assert after session init, data can flow
        self.assertGreater(self.test_util.thread.run_in_greenthread(sender), 0)

        # wait for failed update
        self.assertIsNotNone(get_from_queue(update_complete))
        hub.sleep(2)

        # assert that no data can be sent anymore
        self.assertEqual(self.test_util.thread.run_in_greenthread(sender), 0)

        self.test_util.controller.mock_update_session = Mock(
            side_effect=get_standard_update_response(
                update_complete, None, quota, success=True),
        )
        # wait for second update cycle to reactivate
        hub.sleep(4)
        self.assertGreater(self.test_util.thread.run_in_greenthread(sender), 0)

        self.test_util.sessiond.EndSession(SubscriberID(id=sub1.imsi))
        self.assertEqual(self.test_util.controller.mock_terminate_session.call_count, 1)
예제 #3
0
def send_create_session(client, args):
    sub1 = SubContextConfig('IMSI' + args.imsi, '192.168.128.74', 4)

    try:
        create_account_in_PCRF(args.imsi)
    except grpc.RpcError as e:
        print("gRPC failed with %s: %s" % (e.code(), e.details()))

    try:
        create_account_in_OCS(args.imsi)
    except grpc.RpcError as e:
        print("gRPC failed with %s: %s" % (e.code(), e.details()))

    req = LocalCreateSessionRequest(
        sid=SubscriberID(id=sub1.imsi),
        ue_ipv4=sub1.ip,
    )
    print("Sending LocalCreateSessionRequest with following fields:\n %s" %
          req)
    try:
        client.CreateSession(req)
    except grpc.RpcError as e:
        print("gRPC failed with %s: %s" % (e.code(), e.details()))

    req = SubscriberID(id=sub1.imsi)
    print("Sending EndSession with following fields:\n %s" % req)
    try:
        client.EndSession(req)
    except grpc.RpcError as e:
        print("gRPC failed with %s: %s" % (e.code(), e.details()))
예제 #4
0
    def test_basic_init(self):
        """
        Initiate subscriber, return 1 static policy with monitoring key, send
        traffic to match the policy, verify monitoring update is sent, terminate
        subscriber
        """
        sub1 = SubContextConfig('IMSI001010000088888', '192.168.128.74', default_ambr_config, 4)
        quota = 1024  # bytes

        self.test_util.controller.mock_create_session = Mock(
            return_value=session_manager_pb2.CreateSessionResponse(
                credits=[],
                static_rules=[
                    session_manager_pb2.StaticRuleInstall(
                        rule_id="monitor_rule",
                    ),
                ],
                dynamic_rules=[],
                usage_monitors=[
                    create_monitor_response(
                        sub1.imsi, "mkey1", quota, session_manager_pb2.PCC_RULE_LEVEL,
                    ),
                ],
            ),
        )

        self.test_util.controller.mock_terminate_session = Mock(
            return_value=session_manager_pb2.SessionTerminateResponse(),
        )

        monitor_complete = hub.Queue()
        self.test_util.controller.mock_update_session = Mock(
            side_effect=get_standard_update_response(
                None, monitor_complete, quota,
            ),
        )

        self.test_util.sessiond.CreateSession(
            session_manager_pb2.LocalCreateSessionRequest(
                sid=SubscriberID(id=sub1.imsi),
                ue_ipv4=sub1.ip,
            ),
        )

        self.assertEqual(self.test_util.controller.mock_create_session.call_count, 1)

        packets = get_packets_for_flows(
            sub1, self.test_util.static_rules["monitor_rule"].flow_list,
        )
        packet_count = int(quota / len(packets[0])) + 1

        self.test_util.thread.run_in_greenthread(
            self.test_util.get_packet_sender([sub1], packets, packet_count),
        )
        self.assertIsNotNone(get_from_queue(monitor_complete))
        self.assertEqual(self.test_util.controller.mock_update_session.call_count, 1)

        self.test_util.sessiond.EndSession(SubscriberID(id=sub1.imsi))
        self.assertEqual(self.test_util.controller.mock_terminate_session.call_count, 1)
예제 #5
0
    def test_rules_with_failed_credit(self):
        """
        Test that when a session is initialized but the OCS either errored out or
        returned 0 GSUs, data is not allowed to flow
        """
        sub1 = SubContextConfig('IMSI001010000088888', '192.168.128.74', default_ambr_config, 4)

        rule2 = create_uplink_rule("rule2", 2, '46.10.0.1')
        rule3 = create_uplink_rule("rule3", 3, '47.10.0.1')
        self.test_util.controller.mock_create_session = Mock(
            return_value=session_manager_pb2.CreateSessionResponse(
                credits=[
                    # failed update
                    create_update_response(sub1.imsi, 1, 0, success=False),
                    # successful update, no credit
                    create_update_response(sub1.imsi, 1, 0, success=True),
                ],
                static_rules=[session_manager_pb2.StaticRuleInstall(
                    rule_id="simple_match"
                )],  # no credit for RG 1
                dynamic_rules=[
                    session_manager_pb2.DynamicRuleInstall(
                        policy_rule=rule2
                    ),
                    session_manager_pb2.DynamicRuleInstall(
                        policy_rule=rule3
                    )
                ],
            ),
        )

        self.test_util.controller.mock_terminate_session = Mock(
            return_value=session_manager_pb2.SessionTerminateResponse(),
        )

        self.test_util.sessiond.CreateSession(
            session_manager_pb2.LocalCreateSessionRequest(
                sid=SubscriberID(id=sub1.imsi),
                ue_ipv4=sub1.ip,
            ),
        )
        self.assertEqual(self.test_util.controller.mock_create_session.call_count, 1)

        flows = [rule.flow_list[0] for rule in [rule2, rule3]]
        packets = get_packets_for_flows(sub1, flows)
        pkt_diff = self.test_util.thread.run_in_greenthread(
            self.test_util.get_packet_sender([sub1], packets, 1),
        )
        self.assertEqual(pkt_diff, 0)

        self.test_util.sessiond.EndSession(SubscriberID(id=sub1.imsi))
        self.assertEqual(self.test_util.controller.mock_terminate_session.call_count, 1)
예제 #6
0
 def __init__(self):
     self._set_session =\
         SetSMSessionContext(
             common_context=CommonSessionContext(
                 sid=SubscriberID(id="IMSI12345"),
                 apn=bytes("BLR", 'utf-8'),
                 rat_type=RATType.Name(2),
                 sm_session_state=SMSessionFSMState.Name(0),
                 sm_session_version=0,
             ),
             rat_specific_context=RatSpecificContext(
                 m5gsm_session_context=M5GSMSessionContext(
                     pdu_session_id=2,
                     request_type=RequestType.Name(
                         0,
                     ),
                     gnode_endpoint=TeidSet(
                         teid=10001,
                         end_ipv4_addr="192.168.60.141",
                     ),
                     pdu_address=RedirectServer(
                         redirect_address_type=RedirectServer.IPV4,
                         redirect_server_address="192.168.128.12",
                     ),
                     pdu_session_type=PduSessionType.Name(0),
                     ssc_mode=SscMode.Name(2),
                 ),
             ),
         )
예제 #7
0
    def generate_m5g_auth_vector(self, imsi: str, snni: bytes):
        """
        Returns the m5g auth vector for the subscriber by querying the store
        for the crypto algo and secret keys.
        """
        sid = SIDUtils.to_str(SubscriberID(id=imsi, type=SubscriberID.IMSI))
        subs = self._store.get_subscriber_data(sid)

        if subs.lte.state != LTESubscription.ACTIVE:
            raise ServiceNotActive("5G service not active for %s" % sid)

        if CoreNetworkType.NT_5GC in subs.sub_network.forbidden_network_types:
            raise ServiceNotActive("5G services not allowed for %s" % sid)

        if subs.lte.auth_algo != LTESubscription.MILENAGE:
            raise CryptoError("Unknown crypto (%s) for %s" %
                              (subs.lte.auth_algo, sid))

        if len(subs.lte.auth_key) != 16:
            raise CryptoError("Subscriber key not valid for %s" % sid)

        if len(subs.lte.auth_opc) == 0:
            opc = Milenage.generate_opc(subs.lte.auth_key, self._op)
        elif len(subs.lte.auth_opc) != 16:
            raise CryptoError("Subscriber OPc is invalid length for %s" % sid)
        else:
            opc = subs.lte.auth_opc

        sqn = self.seq_to_sqn(self.get_next_lte_auth_seq(imsi))
        milenage = Milenage(self._amf)
        return milenage.generate_m5gran_vector(subs.lte.auth_key, opc, sqn,
                                               snni)
예제 #8
0
def _ip_desc_to_proto(desc):
    """
    Convert an IP descriptor to protobuf.

    Args:
        desc (magma.mobilityd.IPDesc): IP descriptor
    Returns:
        proto (protos.keyval_pb2.IPDesc): protobuf of :desc:
    """
    ip = IPAddress(
        version=_ip_version_int_to_proto(desc.ip_block.version),
        address=desc.ip.packed,
    )
    ip_block = IPBlock(
        version=_ip_version_int_to_proto(desc.ip_block.version),
        net_address=desc.ip_block.network_address.packed,
        prefix_len=desc.ip_block.prefixlen,
    )
    state = _desc_state_str_to_proto(desc.state)
    sid = SubscriberID(
        id=desc.sid,
        type=SubscriberID.IMSI,
    )
    proto = IPDesc(ip=ip, ip_block=ip_block, state=state, sid=sid)
    return proto
예제 #9
0
 def __init__(self):
     self._set_session = SetSMSessionContext(
         common_context=CommonSessionContext(
             sid=SubscriberID(id="IMSI12345"),
             ue_ipv4="192.168.128.11",
             apn=bytes("BLR", 'utf-8'),
             rat_type=RATType.Name(2),
             sm_session_state=SMSessionFSMState.Name(0),
             sm_session_version=0,
         ),
         rat_specific_context=RatSpecificContext(
             m5gsm_session_context=M5GSMSessionContext(
                 pdu_session_id=1,
                 request_type=RequestType.Name(0, ),
                 gnode_endpoint=TeidSet(
                     teid=10000,
                     end_ipv4_addr="192.168.60.141",
                 ),
                 pdu_session_type=PduSessionType.Name(0),
                 ssc_mode=SscMode.Name(2),
                 subscribed_qos=M5GQosInformationRequest(
                     apn_ambr_ul=750000,
                     apn_ambr_dl=1000000,
                     priority_level=1,
                     preemption_capability=1,
                     preemption_vulnerability=1,
                     qos_class_id=9,
                     br_unit=M5GQosInformationRequest.BitrateUnitsAMBR.Name(
                         1),
                 ),
             ), ),
     )
예제 #10
0
 def __init__(self):
     self._set_session =\
         SetSMSessionContext(
             common_context=CommonSessionContext(
                 sid=SubscriberID(id="IMSI987654"),
                 ue_ipv4="192.168.128.111",
                 apn=bytes("BLR", 'utf-8'), rat_type=RATType.Name(2),
                 sm_session_state=SMSessionFSMState.Name(0),
                 sm_session_version=0,
             ),
             rat_specific_context=RatSpecificContext(
                 m5gsm_session_context=M5GSMSessionContext(
                     pdu_session_id=2,
                     request_type=RequestType.Name(
                         0,
                     ),
                     gnode_endpoint=TeidSet(
                         teid=300,
                         end_ipv4_addr="192.168.60.141",
                     ),
                     pdu_session_type=PduSessionType.Name(0),
                     ssc_mode=SscMode.Name(2),
                 ),
             ),
         )
예제 #11
0
 def __init__(self):
     self._set_session = SetSMSessionContext(
         common_context=CommonSessionContext(
             sid=SubscriberID(id="IMSI12345"),
             ue_ipv4="192.168.128.11",
             apn=bytes("BLR", 'utf-8'),
             rat_type=RATType.Name(2),
             sm_session_state=SMSessionFSMState.Name(0),
             sm_session_version=0,
         ),
         rat_specific_context=RatSpecificContext(
             m5gsm_session_context=M5GSMSessionContext(
                 pdu_session_id=1,
                 request_type=RequestType.Name(
                     0,
                 ),
                 gnode_endpoint=TeidSet(
                     teid=10000,
                     end_ipv4_addr="192.168.60.141",
                 ),
                 pdu_session_type=PduSessionType.Name(0),
                 ssc_mode=SscMode.Name(2),
                 default_ambr=AggregatedMaximumBitrate(
                     max_bandwidth_ul=750000,
                     max_bandwidth_dl=1000000,
                 ),
             ),
         ),
     )
예제 #12
0
파일: processor.py 프로젝트: zhluo94/magma
    def generate_lte_auth_vector(self, imsi, plmn):
        """
        Returns the lte auth vector for the subscriber by querying the store
        for the crypto algo and secret keys.
        """
        sid = SIDUtils.to_str(SubscriberID(id=imsi, type=SubscriberID.IMSI))
        subs = self._store.get_subscriber_data(sid)

        if subs.lte.state != LTESubscription.ACTIVE:
            raise CryptoError("LTE service not active for %s" % sid)

        if subs.lte.auth_algo != LTESubscription.MILENAGE:
            raise CryptoError("Unknown crypto (%s) for %s" %
                              (subs.lte.auth_algo, sid))

        if len(subs.lte.auth_key) != 16:
            raise CryptoError("Subscriber key not valid for %s" % sid)

        if len(subs.lte.auth_opc) == 0:
            if imsi == SRSUE_IMSI:
                opc = bytes.fromhex(SRSUE_OPC)
            else:
                opc = Milenage.generate_opc(subs.lte.auth_key, self._op)
        elif len(subs.lte.auth_opc) != 16:
            raise CryptoError("Subscriber OPc is invalid length for %s" % sid)
        else:
            opc = subs.lte.auth_opc

        sqn = self.seq_to_sqn(self.get_next_lte_auth_seq(imsi))
        milenage = Milenage(self._amf)
        return milenage.generate_eutran_vector(subs.lte.auth_key, opc, sqn,
                                               plmn)
예제 #13
0
 def delete_subscriber(self, sid):
     logging.info("Deleting subscriber : %s", sid)
     self._added_sids.discard(sid)
     sid_pb = SubscriberID(id=sid[4:])
     SubscriberDbGrpc._try_to_call(
         lambda: self._subscriber_stub.DeleteSubscriber(sid_pb),
     )
예제 #14
0
 def GetSubscriberIDFromIP(self, ip_addr, context):
     sent_ip = ipaddress.ip_address(ip_addr.address)
     sid = self._ipv4_allocator.get_sid_for_ip(sent_ip)
     if sid is None:
         context.set_details('IP address %s not found' % str(sent_ip))
         context.set_code(grpc.StatusCode.NOT_FOUND)
         return SubscriberID()
     return SIDUtils.to_pb(sid)
예제 #15
0
파일: processor.py 프로젝트: zhluo94/magma
    def set_next_lte_auth_seq(self, imsi, seq):
        """
        Updates the LTE auth sequence number.
        """
        sid = SIDUtils.to_str(SubscriberID(id=imsi, type=SubscriberID.IMSI))

        with self._store.edit_subscriber(sid) as subs:
            subs.state.lte_auth_next_seq = seq
예제 #16
0
def create_account_in_PCRF(imsi):
    pcrf_chan = ServiceRegistry.get_rpc_channel('pcrf', ServiceRegistry.CLOUD)
    pcrf_client = MockPCRFStub(pcrf_chan)

    print("Clearing accounts in PCRF")
    pcrf_client.ClearSubscribers(Void())

    print("Creating account in PCRF")
    pcrf_client.CreateAccount(SubscriberID(id=imsi))
예제 #17
0
def create_account_in_OCS(imsi):
    ocs_chan = ServiceRegistry.get_rpc_channel('ocs', ServiceRegistry.CLOUD)
    ocs_client = MockOCSStub(ocs_chan)

    print("Clearing accounts in OCS")
    ocs_client.ClearSubscribers(Void())

    print("Creating account in OCS")
    ocs_client.CreateAccount(SubscriberID(id=imsi))
예제 #18
0
파일: processor.py 프로젝트: zhluo94/magma
 def get_sub_data(self, imsi):
     """
     Returns the complete subscriber profile for subscriber.
     Args:
         imsi: IMSI string
     Returns:
         SubscriberData proto struct
     """
     sid = SIDUtils.to_str(SubscriberID(id=imsi, type=SubscriberID.IMSI))
     sub_data = self._store.get_subscriber_data(sid)
     return sub_data
예제 #19
0
    def test_activate_flows_req(self):
        rule = PolicyRule(id="rule1", priority=100, flow_list=[])
        policies = [VersionedPolicy(rule=rule, version=1)]
        req = ActivateFlowsRequest(
            sid=SubscriberID(id="imsi12345"),
            ip_addr="1.2.3.4",
            msisdn=b'magma',
            uplink_tunnel=0x1,
            downlink_tunnel=0x2,
            policies=policies,
        )
        ip_addr = IPAddress(
            version=IPAddress.IPV4,
            address=req.ip_addr.encode('utf-8'),
        )

        self.pipelined_srv.ActivateFlows(req, MagicMock())
        # Not using assert_called_with because protos comparison

        assert self._enforcement_stats.activate_rules.call_args.args[
            0] == req.sid.id
        assert self._enforcement_stats.activate_rules.call_args.args[
            1] == req.msisdn
        assert self._enforcement_stats.activate_rules.call_args.args[
            2] == req.uplink_tunnel
        assert self._enforcement_stats.activate_rules.call_args.args[
            3].version == ip_addr.version
        assert self._enforcement_stats.activate_rules.call_args.args[
            3].address == ip_addr.address
        assert self._enforcement_stats.activate_rules.call_args.args[
            4] == req.apn_ambr
        assert self._enforcement_stats.activate_rules.call_args.args[5][
            0].version == policies[0].version
        assert self._enforcement_stats.activate_rules.call_args.args[
            6] == req.shard_id
        assert self._enforcement_stats.activate_rules.call_args.args[7] == 0

        assert self._enforcer_app.activate_rules.call_args.args[
            0] == req.sid.id
        assert self._enforcer_app.activate_rules.call_args.args[
            1] == req.msisdn
        assert self._enforcer_app.activate_rules.call_args.args[
            2] == req.uplink_tunnel
        assert self._enforcer_app.activate_rules.call_args.args[
            3].version == ip_addr.version
        assert self._enforcer_app.activate_rules.call_args.args[
            3].address == ip_addr.address
        assert self._enforcer_app.activate_rules.call_args.args[
            4] == req.apn_ambr
        assert self._enforcer_app.activate_rules.call_args.args[5][
            0].version == policies[0].version
        assert self._enforcer_app.activate_rules.call_args.args[
            6] == req.shard_id
        assert self._enforcer_app.activate_rules.call_args.args[7] == 0
예제 #20
0
    def GetSubscriberIDFromIP(self, ip_addr, context):
        sent_ip = ipaddress.ip_address(ip_addr.address)
        sid = self._ipv4_allocator.get_sid_for_ip(sent_ip)

        if sid is None:
            context.set_details('IP address %s not found' % str(sent_ip))
            context.set_code(grpc.StatusCode.NOT_FOUND)
            return SubscriberID()
        else:
            #handle composite key case
            sid, *rest = sid.partition('.')
            return SIDUtils.to_pb(sid)
예제 #21
0
def _build_add_subs_data(num_subs: int, input_file: str):
    add_subs_reqs = []
    for i in range(1, num_subs):
        sid = SubscriberID(id=str(i).zfill(15))
        config = Non3GPPUserProfile(
            apn_config=[APNConfiguration(service_selection=TEST_APN)], )
        data = SubscriberData(sid=sid, non_3gpp=config)
        add_sub_req_dict = json_format.MessageToDict(data)
        add_subs_reqs.append(add_sub_req_dict)

    with open(input_file, 'w') as file:
        json.dump(add_subs_reqs, file, separators=(',', ':'))
예제 #22
0
 def get_all_subscribers(self):
     return [
         SubscriberData(
             sid=SubscriberID(
                 id="IMSI111",
             ),
         ), SubscriberData(
             sid=SubscriberID(
                 id="IMSI222",
             ),
         ), SubscriberData(
             sid=SubscriberID(
                 id="IMSI333",
             ),
         ), SubscriberData(
             sid=SubscriberID(
                 id="IMSI444",
             ),
         ), SubscriberData(
             sid=SubscriberID(
                 id="IMSI555",
             ),
         ), SubscriberData(
             sid=SubscriberID(
                 id="IMSI666",
             ),
         ),
     ]
예제 #23
0
    def test_str_conversion(self):
        """
        Tests the string conversion utils
        """
        sid = SubscriberID(id='12345', type=SubscriberID.IMSI)
        self.assertEqual(SIDUtils.to_str(sid), 'IMSI12345')
        self.assertEqual(SIDUtils.to_pb('IMSI12345'), sid)

        # By default the id type is IMSI
        sid = SubscriberID(id='12345')
        self.assertEqual(SIDUtils.to_str(sid), 'IMSI12345')
        self.assertEqual(SIDUtils.to_pb('IMSI12345'), sid)

        # Raise ValueError if invalid strings are given
        with self.assertRaises(ValueError):
            SIDUtils.to_pb('IMS')

        with self.assertRaises(ValueError):
            SIDUtils.to_pb('IMSI12345a')

        with self.assertRaises(ValueError):
            SIDUtils.to_pb('')
예제 #24
0
    def test_add_three_subscribers(self):
        """
        Add flows for two subscribers
        """
        imsi_1 = 'IMSI010000000088888'
        imsi_2 = 'IMSI010000111111118'
        imsi_3 = 'IMSI010002222222222'
        mac_1 = '5e:cc:cc:b1:49:4b'
        mac_2 = '5e:a:cc:af:aa:fe'
        mac_3 = '5e:bb:cc:aa:aa:fe'

        # Add subscriber with UE MAC address """

        self.check_quota_controller.update_subscriber_quota_state(
            SubscriberQuotaUpdate(sid=SubscriberID(id=imsi_1),
                                  mac_addr=mac_1,
                                  update_type=SubscriberQuotaUpdate.NO_QUOTA))
        self.check_quota_controller.update_subscriber_quota_state(
            SubscriberQuotaUpdate(sid=SubscriberID(id=imsi_2),
                                  mac_addr=mac_2,
                                  update_type=SubscriberQuotaUpdate.NO_QUOTA))
        self.check_quota_controller.update_subscriber_quota_state(
            SubscriberQuotaUpdate(
                sid=SubscriberID(id=imsi_3),
                mac_addr=mac_3,
                update_type=SubscriberQuotaUpdate.VALID_QUOTA))

        wait_after_send(self.testing_controller)

        assert_bridge_snapshot_match(self, self.BRIDGE, self.service_manager)

        self.check_quota_controller.update_subscriber_quota_state(
            SubscriberQuotaUpdate(sid=SubscriberID(id=imsi_2),
                                  mac_addr=mac_2,
                                  update_type=SubscriberQuotaUpdate.TERMINATE))
        self.check_quota_controller.update_subscriber_quota_state(
            SubscriberQuotaUpdate(sid=SubscriberID(id=imsi_3),
                                  mac_addr=mac_3,
                                  update_type=SubscriberQuotaUpdate.TERMINATE))
예제 #25
0
    def test_rule_with_no_credit(self):
        """
        Test that when a rule is returned that requires OCS tracking but has
        no credit, data is not allowed to pass
        """
        sub1 = SubContextConfig('IMSI001010000088888', '192.168.128.74',
                                default_ambr_config, 4)

        self.test_util.controller.mock_create_session = Mock(
            return_value=session_manager_pb2.CreateSessionResponse(
                static_rules=[
                    session_manager_pb2.StaticRuleInstall(
                        rule_id="simple_match", ),
                ],  # no credit for RG 1
            ), )

        self.test_util.controller.mock_terminate_session = Mock(
            return_value=session_manager_pb2.SessionTerminateResponse(), )

        self.test_util.sessiond.CreateSession(
            session_manager_pb2.LocalCreateSessionRequest(
                sid=SubscriberID(id=sub1.imsi),
                ue_ipv4=sub1.ip,
            ), )
        self.assertEqual(
            self.test_util.controller.mock_create_session.call_count, 1)

        packets = get_packets_for_flows(
            sub1,
            self.test_util.static_rules["simple_match"].flow_list,
        )

        pkt_diff = self.test_util.thread.run_in_greenthread(
            self.test_util.get_packet_sender([sub1], packets, 1), )
        self.assertEqual(pkt_diff, 0)

        self.test_util.sessiond.EndSession(SubscriberID(id=sub1.imsi))
        self.assertEqual(
            self.test_util.controller.mock_terminate_session.call_count, 1)
예제 #26
0
def _load_subs(num_subs: int) -> List[SubscriberID]:
    client = SubscriberDBStub(
        ServiceRegistry.get_rpc_channel('subscriberdb', ServiceRegistry.LOCAL))
    sids = []

    for i in range(1, num_subs):
        sid = SubscriberID(id=str(i).zfill(15))
        config = Non3GPPUserProfile(
            apn_config=[APNConfiguration(service_selection="magma.ipv4")])
        data = SubscriberData(sid=sid, non_3gpp=config)
        client.AddSubscriber(data)
        sids.append(sid)
    return sids
예제 #27
0
 def __init__(self):
     self._set_session =\
         SetSmNotificationContext(
             common_context=CommonSessionContext(
                 sid=SubscriberID(id="IMSI12345"),
                 apn=bytes("BLR", 'utf-8'), rat_type=RATType.Name(2),
                 sm_session_state=SMSessionFSMState.Name(2),
                 sm_session_version=6,
             ),
             rat_specific_notification=RatSpecificNotification(
                 pdu_session_id=1,
                 request_type=RequestType.Name(1), notify_ue_event=5,
             ),
         )
예제 #28
0
    def get_next_lte_auth_seq(self, imsi):
        """
        Returns the sequence number for the next auth operation.
        """
        sid = SIDUtils.to_str(SubscriberID(id=imsi, type=SubscriberID.IMSI))

        # Increment the sequence number.
        # The 3GPP TS 33.102 spec allows wrapping around the maximum value.
        # The re-synchronization mechanism would be used to sync the counter
        # between USIM and HSS when it happens.
        with self._store.edit_subscriber(sid) as subs:
            seq = subs.state.lte_auth_next_seq
            subs.state.lte_auth_next_seq += 1
        return seq
예제 #29
0
    def resync_lte_auth_seq(self, imsi, rand, auts):
        """
        Validates a re-synchronization request and computes the SEQ from
        the AUTS sent by U-SIM
        """
        sid = SIDUtils.to_str(SubscriberID(id=imsi, type=SubscriberID.IMSI))
        subs = self._store.get_subscriber_data(sid)

        if subs.lte.state != LTESubscription.ACTIVE:
            raise CryptoError("LTE service not active for %s" % sid)

        if subs.lte.auth_algo != LTESubscription.MILENAGE:
            raise CryptoError(
                "Unknown crypto (%s) for %s" % (subs.lte.auth_algo, sid), )

        if len(subs.lte.auth_key) != 16:
            raise CryptoError("Subscriber key not valid for %s" % sid)

        if len(subs.lte.auth_opc) == 0:
            opc = Milenage.generate_opc(subs.lte.auth_key, self._op)
        elif len(subs.lte.auth_opc) != 16:
            raise CryptoError("Subscriber OPc is invalid length for %s" % sid)
        else:
            opc = subs.lte.auth_opc

        dummy_amf = b'\x00\x00'  # Use dummy AMF for re-synchronization
        milenage = Milenage(dummy_amf)
        sqn_ms, mac_s = \
            milenage.generate_resync(auts, subs.lte.auth_key, opc, rand)

        if mac_s != auts[6:]:
            raise CryptoError("Invalid resync authentication code")

        seq_ms = self.sqn_to_seq(sqn_ms)

        # current_seq_number was the seq number the network sent
        # to the mobile station as part of the original auth request.
        current_seq_number = subs.state.lte_auth_next_seq - 1
        if seq_ms >= current_seq_number:
            self.set_next_lte_auth_seq(imsi, seq_ms + 1)
        else:
            seq_delta = current_seq_number - seq_ms
            if seq_delta > (2**28):
                self.set_next_lte_auth_seq(imsi, seq_ms + 1)
            else:
                # This shouldn't have happened
                raise CryptoError(
                    "Re-sync delta in range but UE rejected "
                    "auth: %d" % seq_delta, )
예제 #30
0
def generate_subs(num_subs: int) -> List[SubscriberID]:
    """Return a list of num_subs many SubscriberIDs

    Args:
        num_subs (int): number of SubscriberIDs to generate

    Returns:
        List[SubscriberID]: Created list of SubscriberIDs
    """
    subs = []
    digit_num = 15
    for index in range(1, num_subs):
        sid = SubscriberID(id=str(index).zfill(digit_num))
        subs.append(sid)
    return subs