示例#1
0
def update_record(imsi: str, ip_addr: str) -> None:
    """
    Make RPC call to 'UpdateRecord' method of local directoryD service
    """
    try:
        chan = ServiceRegistry.get_rpc_channel(DIRECTORYD_SERVICE_NAME,
                                               ServiceRegistry.LOCAL)
    except ValueError:
        logging.error('Cant get RPC channel to %s', DIRECTORYD_SERVICE_NAME)
        return
    client = GatewayDirectoryServiceStub(chan)
    if not imsi.startswith("IMSI"):
        imsi = "IMSI" + imsi
    try:
        # Location will be filled in by directory service
        req = UpdateRecordRequest(id=imsi, location="hwid")
        req.fields[IPV4_ADDR_KEY] = ip_addr
        client.UpdateRecord(req, DEFAULT_GRPC_TIMEOUT)
    except grpc.RpcError as err:
        logging.error(
            "UpdateRecordRequest error for id: %s, ipv4_addr: %s! [%s] %s",
            imsi,
            ip_addr,
            err.code(),
            err.details())
示例#2
0
def get_record(imsi: str, field: str) -> str:
    """
    Make RPC call to 'GetDirectoryField' method of local directoryD service
    """
    try:
        chan = ServiceRegistry.get_rpc_channel(DIRECTORYD_SERVICE_NAME,
                                               ServiceRegistry.LOCAL)
    except ValueError:
        logging.error('Cant get RPC channel to %s', DIRECTORYD_SERVICE_NAME)
        return
    client = GatewayDirectoryServiceStub(chan)
    if not imsi.startswith("IMSI"):
        imsi = "IMSI" + imsi
    try:
        # Location will be filled in by directory service
        req = GetDirectoryFieldRequest(id=imsi, field_key=field)
        res = client.GetDirectoryField(req, DEFAULT_GRPC_TIMEOUT)
        if res.value is not None:
            return res.value
    except grpc.RpcError as err:
        logging.error(
            "GetDirectoryFieldRequest error for id: %s! [%s] %s",
            imsi,
            err.code(),
            err.details())
    return None
示例#3
0
def get_all_records(retries: int = 3, sleep_time: float = 0.1) -> [dict]:
    """
    Make RPC call to 'GetAllDirectoryRecords' method of local directoryD service
    """
    try:
        chan = ServiceRegistry.get_rpc_channel(
            DIRECTORYD_SERVICE_NAME,
            ServiceRegistry.LOCAL,
        )
    except ValueError:
        logging.error('Cant get RPC channel to %s', DIRECTORYD_SERVICE_NAME)
        return
    client = GatewayDirectoryServiceStub(chan)
    for _ in range(0, retries):
        try:
            res = client.GetAllDirectoryRecords(Void(), DEFAULT_GRPC_TIMEOUT)
            if res.records is not None:
                return res.records
            hub.sleep(sleep_time)
        except grpc.RpcError as err:
            logging.error(
                "GetAllDirectoryRecords error! [%s] %s",
                err.code(),
                err.details(),
            )
    return []
示例#4
0
 def __init__(self):
     """
     Initialize sessionManager util.
     """
     self._session_stub = SessionProxyResponderStub(
         get_rpc_channel("sessiond"))
     self._directorydstub = GatewayDirectoryServiceStub(
         get_rpc_channel("directoryd"))
示例#5
0
def _cleanup_subs():
    """Clear directory records"""
    client = GatewayDirectoryServiceStub(
        ServiceRegistry.get_rpc_channel(
            DIRECTORYD_SERVICE_NAME,
            ServiceRegistry.LOCAL,
        ), )
    for record in client.GetAllDirectoryRecords(Void()).records:
        sid = DeleteRecordRequest(id=record.id, )
        client.DeleteRecord(sid)
示例#6
0
    def setUp(self):
        # Bind the rpc server to a free port
        thread_pool = futures.ThreadPoolExecutor(max_workers=10)
        self._rpc_server = grpc.server(thread_pool)
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')

        # Add the servicer
        self._servicer = GatewayDirectoryServiceRpcServicer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()

        # Create a rpc stub
        channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))
        self._stub = GatewayDirectoryServiceStub(channel)
示例#7
0
文件: main.py 项目: sdechi/magma
def main():
    """ main() for smsd """
    service = MagmaService('smsd', None)

    # Optionally pipe errors to Sentry
    sentry_init(service_name=service.name)

    directoryd_chan = ServiceRegistry.get_rpc_channel(
        'directoryd',
        ServiceRegistry.LOCAL,
    )
    mme_chan = ServiceRegistry.get_rpc_channel(
        'sms_mme_service',
        ServiceRegistry.LOCAL,
    )
    smsd_chan = ServiceRegistry.get_rpc_channel('smsd', ServiceRegistry.CLOUD)

    # Add all servicers to the server
    smsd_relay = SmsRelay(
        service.loop,
        GatewayDirectoryServiceStub(directoryd_chan),
        SMSOrc8rGatewayServiceStub(mme_chan),
        SmsDStub(smsd_chan),
    )
    smsd_relay.add_to_server(service.rpc_server)
    smsd_relay.start()

    # Run the service loop
    service.run()
    # Cleanup the service
    service.close()
示例#8
0
    def setUp(self):
        # Bind the rpc server to a free port
        thread_pool = futures.ThreadPoolExecutor(max_workers=10)
        self._rpc_server = grpc.server(thread_pool)
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')

        # mock the get_default_client function used to return the same
        # fakeredis object
        func_mock = \
            mock.MagicMock(return_value=fakeredis.FakeStrictRedis())
        with mock.patch('magma.directoryd.rpc_servicer.get_default_client',
                        func_mock):
            # Add the servicer
            self._servicer = GatewayDirectoryServiceRpcServicer(False)
            self._servicer.add_to_server(self._rpc_server)
            self._rpc_server.start()

        # Create a rpc stub
        channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))
        self._stub = GatewayDirectoryServiceStub(channel)
示例#9
0
def _load_subs(num_subs: int) -> List[DirectoryRecord]:
    """Load directory records"""
    client = GatewayDirectoryServiceStub(
        ServiceRegistry.get_rpc_channel(
            DIRECTORYD_SERVICE_NAME,
            ServiceRegistry.LOCAL,
        ), )
    sids = []
    for i in range(num_subs):
        mac_addr = (str(i) * 2 + ":") * 5 + (str(i) * 2)
        ipv4_addr = str(i) * 3 + "." + str(i) * 3 + "." + str(
            i) * 3 + "." + str(i) * 3
        fields = {"mac-addr": mac_addr, "ipv4_addr": ipv4_addr}
        sid = UpdateRecordRequest(
            fields=fields,
            id=str(i).zfill(15),
            location=str(i).zfill(15),
        )
        client.UpdateRecord(sid)
        sids.append(sid)
    return sids
示例#10
0
class SessionManagerUtil(object):
    """
    Helper class to communicate with session manager for the tests.
    """
    def __init__(self):
        """
        Initialize sessionManager util.
        """
        self._session_stub = SessionProxyResponderStub(
            get_rpc_channel("sessiond"))
        self._abort_session_stub = AbortSessionResponderStub(
            get_rpc_channel("abort_session_service"))
        self._directorydstub = GatewayDirectoryServiceStub(
            get_rpc_channel("directoryd"))

    def get_flow_match(self, flow_list, flow_match_list):
        """
        Populates flow match list
        """
        for flow in flow_list:
            flow_direction = flow["direction"]
            ip_protocol = flow["ip_proto"]
            if ip_protocol == FlowMatch.IPPROTO_TCP:
                udp_src_port = 0
                udp_dst_port = 0
                tcp_src_port = (int(flow["tcp_src_port"])
                                if "tcp_src_port" in flow else 0)
                tcp_dst_port = (int(flow["tcp_dst_port"])
                                if "tcp_dst_port" in flow else 0)
            elif ip_protocol == FlowMatch.IPPROTO_UDP:
                tcp_src_port = 0
                tcp_dst_port = 0
                udp_src_port = (int(flow["udp_src_port"])
                                if "udp_src_port" in flow else 0)
                udp_dst_port = (int(flow["udp_dst_port"])
                                if "udp_dst_port" in flow else 0)
            else:
                udp_src_port = 0
                udp_dst_port = 0
                tcp_src_port = 0
                tcp_dst_port = 0

            src_addr = None
            if flow.get("ipv4_src", None):
                src_addr = IPAddress(
                    version=IPAddress.IPV4,
                    address=flow.get("ipv4_src").encode('utf-8'))
            elif flow.get("ipv6_src", None):
                src_addr = IPAddress(
                    version=IPAddress.IPV6,
                    address=flow.get("ipv6_src").encode('utf-8'))

            dst_addr = None
            if flow.get("ipv4_dst", None):
                dst_addr = IPAddress(
                    version=IPAddress.IPV4,
                    address=flow.get("ipv4_dst").encode('utf-8'))
            elif flow.get("ipv6_dst", None):
                dst_addr = IPAddress(
                    version=IPAddress.IPV6,
                    address=flow.get("ipv6_dst").encode('utf-8'))

            flow_match_list.append(
                FlowDescription(
                    match=FlowMatch(
                        ip_dst=dst_addr,
                        ip_src=src_addr,
                        tcp_src=tcp_src_port,
                        tcp_dst=tcp_dst_port,
                        udp_src=udp_src_port,
                        udp_dst=udp_dst_port,
                        ip_proto=ip_protocol,
                        direction=flow_direction,
                    ),
                    action=FlowDescription.PERMIT,
                ))

    def create_ReAuthRequest(self, imsi, policy_id, flow_list, qos):
        """
        Sends Policy RAR message to session manager
        """
        print("Sending Policy RAR message to session manager")
        flow_match_list = []
        res = None
        self.get_flow_match(flow_list, flow_match_list)

        policy_qos = FlowQos(
            qci=qos["qci"],
            max_req_bw_ul=qos["max_req_bw_ul"],
            max_req_bw_dl=qos["max_req_bw_dl"],
            gbr_ul=qos["gbr_ul"],
            gbr_dl=qos["gbr_dl"],
            arp=QosArp(
                priority_level=qos["arp_prio"],
                pre_capability=qos["pre_cap"],
                pre_vulnerability=qos["pre_vul"],
            ),
        )

        policy_rule = PolicyRule(
            id=policy_id,
            priority=qos["priority"],
            flow_list=flow_match_list,
            tracking_type=PolicyRule.NO_TRACKING,
            rating_group=1,
            monitoring_key=None,
            qos=policy_qos,
        )

        qos = QoSInformation(qci=qos["qci"])

        # Get sessionid
        req = GetDirectoryFieldRequest(id=imsi, field_key="session_id")
        try:
            res = self._directorydstub.GetDirectoryField(
                req, DEFAULT_GRPC_TIMEOUT)
        except grpc.RpcError as err:
            logging.error(
                "GetDirectoryFieldRequest error for id: %s! [%s] %s",
                imsi,
                err.code(),
                err.details(),
            )

        self._session_stub.PolicyReAuth(
            PolicyReAuthRequest(
                session_id=res.value,
                imsi=imsi,
                rules_to_remove=[],
                rules_to_install=[],
                dynamic_rules_to_install=[
                    DynamicRuleInstall(policy_rule=policy_rule)
                ],
                event_triggers=[],
                revalidation_time=None,
                usage_monitoring_credits=[],
                qos_info=qos,
            ))

    def create_AbortSessionRequest(self, imsi: str) -> AbortSessionResult:
        # Get SessionID
        req = GetDirectoryFieldRequest(id=imsi, field_key="session_id")
        try:
            res = self._directorydstub.GetDirectoryField(
                req, DEFAULT_GRPC_TIMEOUT)
        except grpc.RpcError as err:
            logging.error(
                "GetDirectoryFieldRequest error for id: %s! [%s] %s",
                imsi,
                err.code(),
                err.details(),
            )
        return self._abort_session_stub.AbortSession(
            AbortSessionRequest(
                session_id=res.value,
                user_name=imsi,
            ))
示例#11
0
class DirectorydRpcServiceTests(TestCase):
    @mock.patch("redis.Redis", MockRedis)
    def setUp(self):
        # Bind the rpc server to a free port
        thread_pool = futures.ThreadPoolExecutor(max_workers=10)
        self._rpc_server = grpc.server(thread_pool)
        port = self._rpc_server.add_insecure_port('0.0.0.0:0')

        # Add the servicer
        self._servicer = GatewayDirectoryServiceRpcServicer()
        self._servicer.add_to_server(self._rpc_server)
        self._rpc_server.start()

        # Create a rpc stub
        channel = grpc.insecure_channel('0.0.0.0:{}'.format(port))
        self._stub = GatewayDirectoryServiceStub(channel)

    @mock.patch("redis.Redis", MockRedis)
    def tearDown(self):
        self._rpc_server.stop(None)

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_update_record(self):
        self._servicer._redis_dict.clear()

        req = UpdateRecordRequest()
        req.id = "IMSI555"
        self._stub.UpdateRecord(req)
        actual_record = self._servicer._redis_dict[req.id]
        self.assertEqual(actual_record.location_history, ['aaa-bbb'])
        self.assertEqual(actual_record.identifiers, {})

        req.fields["mac_addr"] = "aa:aa:bb:bb:cc:cc"
        req.fields["ipv4_addr"] = "192.168.172.12"

        self._stub.UpdateRecord(req)
        actual_record2 = self._servicer._redis_dict[req.id]
        self.assertEqual(actual_record2.location_history, ["aaa-bbb"])
        self.assertEqual(actual_record2.identifiers['mac_addr'],
                         "aa:aa:bb:bb:cc:cc")
        self.assertEqual(actual_record2.identifiers['ipv4_addr'],
                         "192.168.172.12")

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_update_record_bad_location(self):
        self._servicer._redis_dict.clear()

        req = UpdateRecordRequest()
        req.id = "IMSI556"
        req.location = "bbb-ccc"

        self._stub.UpdateRecord(req)
        actual_record = self._servicer._redis_dict[req.id]
        self.assertEqual(actual_record.location_history, ['aaa-bbb'])
        self.assertEqual(actual_record.identifiers, {})

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_delete_record(self):
        self._servicer._redis_dict.clear()

        req = UpdateRecordRequest()
        req.id = "IMSI557"
        self._stub.UpdateRecord(req)
        self.assertTrue(req.id in self._servicer._redis_dict)

        del_req = DeleteRecordRequest()
        del_req.id = "IMSI557"
        self._stub.DeleteRecord(del_req)
        self.assertFalse(req.id in self._servicer._redis_dict)

        with self.assertRaises(grpc.RpcError) as err:
            self._stub.DeleteRecord(del_req)
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_get_field(self):
        self._servicer._redis_dict.clear()

        req = UpdateRecordRequest()
        req.id = "IMSI557"
        req.fields["mac_addr"] = "aa:bb:aa:bb:aa:bb"
        self._stub.UpdateRecord(req)
        self.assertTrue(req.id in self._servicer._redis_dict)

        get_req = GetDirectoryFieldRequest()
        get_req.id = "IMSI557"
        get_req.field_key = "mac_addr"
        ret = self._stub.GetDirectoryField(get_req)
        self.assertEqual("aa:bb:aa:bb:aa:bb", ret.value)

        with self.assertRaises(grpc.RpcError) as err:
            get_req.field_key = "ipv4_addr"
            self._stub.GetDirectoryField(get_req)
        self.assertEqual(err.exception.code(), grpc.StatusCode.NOT_FOUND)

    @mock.patch("redis.Redis", MockRedis)
    @mock.patch('snowflake.snowflake', get_mock_snowflake)
    def test_get_all(self):
        self._servicer._redis_dict.clear()

        req = UpdateRecordRequest()
        req.id = "IMSI557"
        req.fields["mac_addr"] = "aa:bb:aa:bb:aa:bb"
        self._stub.UpdateRecord(req)
        self.assertTrue(req.id in self._servicer._redis_dict)

        req2 = UpdateRecordRequest()
        req2.id = "IMSI556"
        req2.fields["ipv4_addr"] = "192.168.127.11"
        self._stub.UpdateRecord(req2)
        self.assertTrue(req2.id in self._servicer._redis_dict)

        void_req = Void()
        ret = self._stub.GetAllDirectoryRecords(void_req)
        self.assertEqual(2, len(ret.records))
        for record in ret.records:
            if record.id == "IMSI556":
                self.assertEqual(record.fields["ipv4_addr"], "192.168.127.11")
            elif record.id == "IMSI557":
                self.assertEqual(record.fields["mac_addr"],
                                 "aa:bb:aa:bb:aa:bb")
            else:
                raise AssertionError()