Ejemplo n.º 1
0
    async def broadcaster_queue_loop(self):
        responses = []

        while True:
            try:
                addr, commands = await self.command_bundle_queue.get()
                self.broadcaster.process_commands(commands)
            except curio.TaskCancelled:
                break
            except Exception as ex:
                logger.error('Broadcaster command queue evaluation failed',
                             exc_info=ex)
                continue

            if addr in self.ignore_addresses:
                continue

            responses.clear()
            for command in commands:
                if isinstance(command, ca.VersionRequest):
                    responses.append(ca.VersionResponse(13))
                if isinstance(command, ca.SearchRequest):
                    pv_name = command.name.decode(STR_ENC)
                    known_pv = pv_name in self.pvdb
                    if (not known_pv) and command.reply == ca.NO_REPLY:
                        responses.clear()
                        break  # Do not send any repsonse to this datagram.

                    # responding with an IP of `None` tells client to get IP
                    # address from packet
                    responses.append(ca.SearchResponse(self.port, None,
                                                       command.cid, 13))
            if responses:
                bytes_to_send = self.broadcaster.send(*responses)
                await self.udp_sock.sendto(bytes_to_send, addr)
Ejemplo n.º 2
0
    async def _broadcaster_evaluate(self, addr, commands):
        search_replies = []
        version_requested = False
        for command in commands:
            if isinstance(command, ca.VersionRequest):
                version_requested = True
            elif isinstance(command, ca.SearchRequest):
                pv_name = command.name
                try:
                    known_pv = self[pv_name] is not None
                except KeyError:
                    known_pv = False

                if known_pv:
                    # responding with an IP of `None` tells client to get IP
                    # address from the datagram.
                    search_replies.append(
                        ca.SearchResponse(self.port, None, command.cid,
                                          ca.DEFAULT_PROTOCOL_VERSION)
                    )

        if search_replies:
            if version_requested:
                bytes_to_send = self.broadcaster.send(ca.VersionResponse(13),
                                                      *search_replies)
            else:
                bytes_to_send = self.broadcaster.send(*search_replies)

            for udp_sock in self.udp_socks.values():
                try:
                    await udp_sock.sendto(bytes_to_send, addr)
                except OSError as exc:
                    host, port = addr
                    raise CaprotoNetworkError(f"Failed to send to {host}:{port}") from exc
Ejemplo n.º 3
0
def circuit_pair(request):
    host = '127.0.0.1'
    port = 5555
    priority = 1
    version = 13
    cli_circuit = ca.VirtualCircuit(ca.CLIENT, (host, port), priority)
    buffers_to_send = cli_circuit.send(
        ca.VersionRequest(version=version, priority=priority))

    srv_circuit = ca.VirtualCircuit(ca.SERVER, (host, port), None)
    commands, _ = srv_circuit.recv(*buffers_to_send)
    for command in commands:
        srv_circuit.process_command(command)
    buffers_to_send = srv_circuit.send(ca.VersionResponse(version=version))
    commands, _ = cli_circuit.recv(*buffers_to_send)
    for command in commands:
        cli_circuit.process_command(command)
    return cli_circuit, srv_circuit
Ejemplo n.º 4
0
    async def _broadcaster_evaluate(self, addr, commands):
        search_replies = []
        version_requested = False
        for command in commands:
            if isinstance(command, ca.VersionRequest):
                version_requested = True
            if isinstance(command, ca.SearchRequest):
                pv_name = command.name
                try:
                    known_pv = self[pv_name] is not None
                except KeyError:
                    known_pv = False

                if known_pv:
                    # responding with an IP of `None` tells client to get IP
                    # address from the datagram.
                    search_replies.append(
                        ca.SearchResponse(self.port, None, command.cid,
                                          ca.DEFAULT_PROTOCOL_VERSION))
                else:
                    if command.reply == ca.DO_REPLY:
                        search_replies.append(
                            ca.NotFoundResponse(
                                version=ca.DEFAULT_PROTOCOL_VERSION,
                                cid=command.cid))
                    else:
                        # Not a known PV and no reply required
                        ...

        if search_replies:
            if version_requested:
                bytes_to_send = self.broadcaster.send(ca.VersionResponse(13),
                                                      *search_replies)
            else:
                bytes_to_send = self.broadcaster.send(*search_replies)

            for udp_sock in self.udp_socks.values():
                await udp_sock.sendto(bytes_to_send, addr)
Ejemplo n.º 5
0
def test_broadcaster_checks():
    b = ca.Broadcaster(ca.CLIENT)
    with pytest.raises(ca.LocalProtocolError):
        b.send(ca.SearchRequest(name='LIRR', cid=0, version=13))

    b.send(ca.RepeaterRegisterRequest('1.2.3.4'))
    res = ca.RepeaterConfirmResponse('5.6.7.8')
    commands = b.recv(bytes(res), ('5.6.7.8', 6666))
    assert commands[0] == res
    b.process_commands(commands)

    req = ca.SearchRequest(name='LIRR', cid=0, version=13)
    with pytest.raises(ca.LocalProtocolError):
        b.send(req)
    b.send(ca.VersionRequest(priority=0, version=13), req)

    res = ca.SearchResponse(port=6666, ip='1.2.3.4', cid=0, version=13)
    addr = ('1.2.3.4', 6666)
    commands = b.recv(bytes(res), addr)
    with pytest.raises(ca.RemoteProtocolError):
        b.process_commands(commands)
    commands = b.recv(bytes(ca.VersionResponse(version=13)) + bytes(res), addr)
    b.process_commands(commands)  # this gets both
Ejemplo n.º 6
0
def test_nonet():
    # Register with the repeater.
    assert not cli_b._registered
    bytes_to_send = cli_b.send(ca.RepeaterRegisterRequest('0.0.0.0'))
    assert not cli_b._registered

    # Receive response
    data = bytes(ca.RepeaterConfirmResponse('127.0.0.1'))
    commands = cli_b.recv(data, cli_addr)
    cli_b.process_commands(commands)
    assert cli_b._registered

    # Search for pv1.
    # CA requires us to send a VersionRequest and a SearchRequest bundled into
    # one datagram.
    bytes_to_send = cli_b.send(ca.VersionRequest(0, ca.DEFAULT_PROTOCOL_VERSION),
                               ca.SearchRequest(pv1, 0,
                                                ca.DEFAULT_PROTOCOL_VERSION))

    commands = srv_b.recv(bytes_to_send, cli_addr)
    srv_b.process_commands(commands)
    ver_req, search_req = commands
    bytes_to_send = srv_b.send(
        ca.VersionResponse(ca.DEFAULT_PROTOCOL_VERSION),
        ca.SearchResponse(5064, None, search_req.cid, ca.DEFAULT_PROTOCOL_VERSION))

    # Receive a VersionResponse and SearchResponse.
    commands = iter(cli_b.recv(bytes_to_send, cli_addr))
    command = next(commands)
    assert type(command) is ca.VersionResponse
    command = next(commands)
    assert type(command) is ca.SearchResponse
    address = ca.extract_address(command)

    circuit = ca.VirtualCircuit(our_role=ca.CLIENT,
                                address=address,
                                priority=0)
    circuit.log.setLevel('DEBUG')
    chan1 = ca.ClientChannel(pv1, circuit)
    assert chan1.states[ca.CLIENT] is ca.SEND_CREATE_CHAN_REQUEST
    assert chan1.states[ca.SERVER] is ca.IDLE

    srv_circuit = ca.VirtualCircuit(our_role=ca.SERVER,
                                    address=address, priority=None)

    cli_send(chan1.circuit, ca.VersionRequest(priority=0,
                                              version=ca.DEFAULT_PROTOCOL_VERSION))

    srv_recv(srv_circuit)

    srv_send(srv_circuit, ca.VersionResponse(version=ca.DEFAULT_PROTOCOL_VERSION))
    cli_recv(chan1.circuit)
    cli_send(chan1.circuit, ca.HostNameRequest('localhost'))
    cli_send(chan1.circuit, ca.ClientNameRequest('username'))
    cli_send(chan1.circuit, ca.CreateChanRequest(name=pv1, cid=chan1.cid,
                                                 version=ca.DEFAULT_PROTOCOL_VERSION))
    assert chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE

    srv_recv(srv_circuit)
    assert chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE
    srv_chan1, = srv_circuit.channels.values()
    assert srv_chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert srv_chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE

    srv_send(srv_circuit, ca.CreateChanResponse(cid=chan1.cid, sid=1,
                                                data_type=5, data_count=1))
    assert srv_chan1.states[ca.CLIENT] is ca.CONNECTED
    assert srv_chan1.states[ca.SERVER] is ca.CONNECTED

    # At this point the CLIENT is not aware that we are CONNECTED because it
    # has not yet received the CreateChanResponse. It should not be allowed to
    # read or write.
    assert chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE

    # Try sending a premature read request.
    read_req = ca.ReadNotifyRequest(sid=srv_chan1.sid,
                                    data_type=srv_chan1.native_data_type,
                                    data_count=srv_chan1.native_data_count,
                                    ioid=0)
    with pytest.raises(ca.LocalProtocolError):
        cli_send(chan1.circuit, read_req)

    # The above failed because the sid is not recognized. Remove that failure
    # by editing the sid cache, and check that it *still* fails, this time
    # because of the state machine prohibiting this command before the channel
    # is in a CONNECTED state.
    chan1.circuit.channels_sid[1] = chan1
    with pytest.raises(ca.LocalProtocolError):
        cli_send(chan1.circuit, read_req)

    cli_recv(chan1.circuit)
    assert chan1.states[ca.CLIENT] is ca.CONNECTED
    assert chan1.states[ca.SERVER] is ca.CONNECTED

    # Test subscriptions.
    assert chan1.native_data_type and chan1.native_data_count
    add_req = ca.EventAddRequest(data_type=chan1.native_data_type,
                                 data_count=chan1.native_data_count,
                                 sid=chan1.sid,
                                 subscriptionid=0,
                                 low=0, high=0, to=0, mask=1)
    cli_send(chan1.circuit, add_req)
    srv_recv(srv_circuit)
    add_res = ca.EventAddResponse(data=(3,),
                                  data_type=chan1.native_data_type,
                                  data_count=chan1.native_data_count,
                                  subscriptionid=0,
                                  status=1)

    srv_send(srv_circuit, add_res)
    cli_recv(chan1.circuit)

    cancel_req = ca.EventCancelRequest(data_type=add_req.data_type,
                                       sid=add_req.sid,
                                       subscriptionid=add_req.subscriptionid)

    cli_send(chan1.circuit, cancel_req)
    srv_recv(srv_circuit)

    # Test reading.
    cli_send(chan1.circuit, ca.ReadNotifyRequest(data_type=5, data_count=1,
                                                 sid=chan1.sid,
                                                 ioid=12))
    srv_recv(srv_circuit)
    srv_send(srv_circuit, ca.ReadNotifyResponse(data=(3,),
                                                data_type=5, data_count=1,
                                                ioid=12, status=1))
    cli_recv(chan1.circuit)

    # Test writing.
    request = ca.WriteNotifyRequest(data_type=2, data_count=1,
                                    sid=chan1.sid,
                                    ioid=13, data=(4,))

    cli_send(chan1.circuit, request)
    srv_recv(srv_circuit)
    srv_send(srv_circuit, ca.WriteNotifyResponse(data_type=5, data_count=1,
                                                 ioid=13, status=1))
    cli_recv(chan1.circuit)

    # Test "clearing" (closing) the channel.
    cli_send(chan1.circuit, ca.ClearChannelRequest(sid=chan1.sid, cid=chan1.cid))
    assert chan1.states[ca.CLIENT] is ca.MUST_CLOSE
    assert chan1.states[ca.SERVER] is ca.MUST_CLOSE

    srv_recv(srv_circuit)
    assert srv_chan1.states[ca.CLIENT] is ca.MUST_CLOSE
    assert srv_chan1.states[ca.SERVER] is ca.MUST_CLOSE

    srv_send(srv_circuit, ca.ClearChannelResponse(sid=chan1.sid, cid=chan1.cid))
    assert srv_chan1.states[ca.CLIENT] is ca.CLOSED
    assert srv_chan1.states[ca.SERVER] is ca.CLOSED
Ejemplo n.º 7
0
    async def _process_command(self, command):
        '''Process a command from a client, and return the server response'''
        tags = self._tags
        if command is ca.DISCONNECTED:
            raise DisconnectedCircuit()
        elif isinstance(command, ca.VersionRequest):
            to_send = [ca.VersionResponse(ca.DEFAULT_PROTOCOL_VERSION)]
        elif isinstance(command, ca.SearchRequest):
            pv_name = command.name
            try:
                self.context[pv_name]
            except KeyError:
                if command.reply == ca.DO_REPLY:
                    to_send = [
                        ca.NotFoundResponse(
                            version=ca.DEFAULT_PROTOCOL_VERSION,
                            cid=command.cid)
                    ]
                else:
                    to_send = []
            else:
                to_send = [
                    ca.SearchResponse(self.context.port, None, command.cid,
                                      ca.DEFAULT_PROTOCOL_VERSION)
                ]
        elif isinstance(command, ca.CreateChanRequest):
            pvname = command.name
            try:
                db_entry = self.context[pvname]
            except KeyError:
                self.log.debug('Client requested invalid channel name: %s',
                               pvname)
                to_send = [ca.CreateChFailResponse(cid=command.cid)]
            else:

                access = db_entry.check_access(self.client_hostname,
                                               self.client_username)

                modifiers = ca.parse_record_field(pvname).modifiers
                data_type = db_entry.data_type
                data_count = db_entry.max_length
                if ca.RecordModifiers.long_string in (modifiers or {}):
                    if data_type in (ChannelType.STRING, ):
                        data_type = ChannelType.CHAR
                        data_count = db_entry.long_string_max_length

                to_send = [
                    ca.AccessRightsResponse(cid=command.cid,
                                            access_rights=access),
                    ca.CreateChanResponse(data_type=data_type,
                                          data_count=data_count,
                                          cid=command.cid,
                                          sid=self.circuit.new_channel_id()),
                ]
        elif isinstance(command, ca.HostNameRequest):
            self.client_hostname = command.name
            to_send = []
        elif isinstance(command, ca.ClientNameRequest):
            self.client_username = command.name
            to_send = []
        elif isinstance(command, (ca.ReadNotifyRequest, ca.ReadRequest)):
            chan, db_entry = self._get_db_entry_from_command(command)
            try:
                data_type = command.data_type
            except ValueError:
                raise ca.RemoteProtocolError('Invalid data type')

            # If we are in the middle of processing a Write[Notify]Request,
            # allow a bit of time for that to (maybe) finish. Some requests
            # may take a long time, so give up rather quickly to avoid
            # introducing too much latency.
            await self.write_event.wait(timeout=WRITE_LOCK_TIMEOUT)

            read_data_type = data_type
            if chan.name.endswith('$'):
                try:
                    read_data_type = _LongStringChannelType(read_data_type)
                except ValueError:
                    # Not requesting a LONG_STRING type
                    ...

            metadata, data = await db_entry.auth_read(
                self.client_hostname,
                self.client_username,
                read_data_type,
                user_address=self.circuit.address,
            )

            old_version = self.circuit.protocol_version < 13
            if command.data_count > 0 or old_version:
                data = data[:command.data_count]

            # This is a pass-through if arr is None.
            data = apply_arr_filter(chan.channel_filter.arr, data)
            # If the timestamp feature is active swap the timestamp.
            # Information must copied because not all clients will have the
            # timestamp filter
            if chan.channel_filter.ts and command.data_type in ca.time_types:
                time_type = type(metadata)
                now = ca.TimeStamp.from_unix_timestamp(time.time())
                metadata = time_type(
                    **ChainMap({'stamp': now},
                               dict((field, getattr(metadata, field))
                                    for field, _ in time_type._fields_)))
            notify = isinstance(command, ca.ReadNotifyRequest)
            data_count = db_entry.calculate_length(data)
            to_send = [
                chan.read(data=data,
                          data_type=command.data_type,
                          data_count=data_count,
                          status=1,
                          ioid=command.ioid,
                          metadata=metadata,
                          notify=notify)
            ]
        elif isinstance(command, (ca.WriteRequest, ca.WriteNotifyRequest)):
            chan, db_entry = self._get_db_entry_from_command(command)
            client_waiting = isinstance(command, ca.WriteNotifyRequest)

            async def handle_write():
                '''Wait for an asynchronous caput to finish'''
                try:
                    write_status = await db_entry.auth_write(
                        self.client_hostname,
                        self.client_username,
                        command.data,
                        command.data_type,
                        command.metadata,
                        user_address=self.circuit.address)
                except Exception as ex:
                    self.log.exception('Invalid write request by %s (%s): %r',
                                       self.client_username,
                                       self.client_hostname, command)
                    cid = self.circuit.channels_sid[command.sid].cid
                    response_command = ca.ErrorResponse(
                        command,
                        cid,
                        status=ca.CAStatus.ECA_PUTFAIL,
                        error_message=('Python exception: {} {}'
                                       ''.format(type(ex).__name__, ex)))
                    await self.send(response_command)
                else:
                    if client_waiting:
                        if write_status is None:
                            # errors can be passed back by exceptions, and
                            # returning none for write_status can just be
                            # considered laziness
                            write_status = True

                        response_command = chan.write(
                            ioid=command.ioid,
                            status=write_status,
                            data_count=db_entry.length)
                        await self.send(response_command)
                finally:
                    maybe_awaitable = self.write_event.set()
                    # The curio backend makes this an awaitable thing.
                    if maybe_awaitable is not None:
                        await maybe_awaitable

            self.write_event.clear()
            await self._start_write_task(handle_write)
            to_send = []
        elif isinstance(command, ca.EventAddRequest):
            chan, db_entry = self._get_db_entry_from_command(command)
            # TODO no support for deprecated low/high/to

            read_data_type = command.data_type
            if chan.name.endswith('$'):
                try:
                    read_data_type = _LongStringChannelType(read_data_type)
                except ValueError:
                    # Not requesting a LONG_STRING type
                    ...

            sub = Subscription(mask=command.mask,
                               channel_filter=chan.channel_filter,
                               channel=chan,
                               circuit=self,
                               data_type=read_data_type,
                               data_count=command.data_count,
                               subscriptionid=command.subscriptionid,
                               db_entry=db_entry)
            sub_spec = SubscriptionSpec(db_entry=db_entry,
                                        data_type_name=read_data_type.name,
                                        mask=command.mask,
                                        channel_filter=chan.channel_filter)
            self.subscriptions[sub_spec].append(sub)
            self.context.subscriptions[sub_spec].append(sub)

            # If we are in the middle of processing a Write[Notify]Request,
            # allow a bit of time for that to (maybe) finish. Some requests
            # may take a long time, so give up rather quickly to avoid
            # introducing too much latency.
            if not self.write_event.is_set():
                await self.write_event.wait(timeout=WRITE_LOCK_TIMEOUT)

            await db_entry.subscribe(self.context.subscription_queue, sub_spec,
                                     sub)
            to_send = []
        elif isinstance(command, ca.EventCancelRequest):
            chan, db_entry = self._get_db_entry_from_command(command)
            removed = await self._cull_subscriptions(
                db_entry,
                lambda sub: sub.subscriptionid == command.subscriptionid)
            if removed:
                _, removed_sub = removed[0]
                data_count = removed_sub.data_count
            else:
                data_count = db_entry.length
            to_send = [
                chan.unsubscribe(command.subscriptionid,
                                 data_type=command.data_type,
                                 data_count=data_count)
            ]
        elif isinstance(command, ca.EventsOnRequest):
            # Immediately send most recent updates for all subscriptions.
            most_recent_updates = list(self.most_recent_updates.values())
            self.most_recent_updates.clear()
            if most_recent_updates:
                await self.send(*most_recent_updates)
            maybe_awaitable = self.events_on.set()
            # The curio backend makes this an awaitable thing.
            if maybe_awaitable is not None:
                await maybe_awaitable
            self.circuit.log.info("Client at %s:%d has turned events on.",
                                  *self.circuit.address)
            to_send = []
        elif isinstance(command, ca.EventsOffRequest):
            # The client has signaled that it does not think it will be able to
            # catch up to the backlog. Clear all updates queued to be sent...
            self.unexpired_updates.clear()
            # ...and tell the Context that any future updates from ChannelData
            # should not be added to this circuit's queue until further notice.
            self.events_on.clear()
            self.circuit.log.info("Client at %s:%d has turned events off.",
                                  *self.circuit.address)
            to_send = []
        elif isinstance(command, ca.ClearChannelRequest):
            chan, db_entry = self._get_db_entry_from_command(command)
            await self._cull_subscriptions(
                db_entry, lambda sub: sub.channel == command.sid)
            to_send = [chan.clear()]
        elif isinstance(command, ca.EchoRequest):
            to_send = [ca.EchoResponse()]
        if isinstance(command, ca.Message):
            tags['bytesize'] = len(command)
            self.log.debug("%r", command, extra=tags)
        return to_send
Ejemplo n.º 8
0
    async def _process_command(self, command):
        '''Process a command from a client, and return the server response'''
        def get_db_entry():
            chan = self.circuit.channels_sid[command.sid]
            db_entry = self.context.pvdb[chan.name.decode(STR_ENC)]
            return chan, db_entry

        if command is ca.DISCONNECTED:
            raise DisconnectedCircuit()
        elif isinstance(command, ca.VersionRequest):
            return [ca.VersionResponse(13)]
        elif isinstance(command, ca.CreateChanRequest):
            db_entry = self.context.pvdb[command.name.decode(STR_ENC)]
            access = db_entry.check_access(self.client_hostname,
                                           self.client_username)

            return [ca.AccessRightsResponse(cid=command.cid,
                                            access_rights=access),
                    ca.CreateChanResponse(data_type=db_entry.data_type,
                                          data_count=len(db_entry),
                                          cid=command.cid,
                                          sid=self.circuit.new_channel_id()),
                    ]
        elif isinstance(command, ca.HostNameRequest):
            self.client_hostname = command.name.decode(STR_ENC)
        elif isinstance(command, ca.ClientNameRequest):
            self.client_username = command.name.decode(STR_ENC)
        elif isinstance(command, ca.ReadNotifyRequest):
            chan, db_entry = get_db_entry()
            metadata, data = await db_entry.auth_read(
                self.client_hostname, self.client_username,
                command.data_type)
            return [chan.read(data=data, data_type=command.data_type,
                              data_count=len(data), status=1,
                              ioid=command.ioid, metadata=metadata)
                    ]
        elif isinstance(command, (ca.WriteRequest, ca.WriteNotifyRequest)):
            chan, db_entry = get_db_entry()
            client_waiting = isinstance(command, ca.WriteNotifyRequest)

            async def handle_write():
                '''Wait for an asynchronous caput to finish'''
                try:
                    write_status = await db_entry.auth_write(
                        self.client_hostname, self.client_username,
                        command.data, command.data_type, command.metadata)
                except Exception as ex:
                    cid = self.circuit.channels_sid[command.sid].cid
                    response_command = ca.ErrorResponse(
                        command, cid,
                        status_code=ca.ECA_INTERNAL.code_with_severity,
                        error_message=('Python exception: {} {}'
                                       ''.format(type(ex).__name__, ex))
                    )
                else:
                    if write_status is None:
                        # errors can be passed back by exceptions, and
                        # returning none for write_status can just be
                        # considered laziness
                        write_status = True
                    response_command = chan.write(ioid=command.ioid,
                                                  status=write_status)

                if client_waiting:
                    await self.send(response_command)

            await self.pending_tasks.spawn(handle_write, ignore_result=True)
            # TODO pretty sure using the taskgroup will bog things down,
            # but it suppresses an annoying warning message, so... there
        elif isinstance(command, ca.EventAddRequest):
            chan, db_entry = get_db_entry()
            # TODO no support for deprecated low/high/to
            sub = Subscription(mask=command.mask,
                               channel=chan,
                               circuit=self,
                               data_type=command.data_type,
                               data_count=command.data_count,
                               subscriptionid=command.subscriptionid)
            sub_spec = SubscriptionSpec(db_entry=db_entry,
                                        data_type=command.data_type)
            self.subscriptions[sub_spec].append(sub)
            self.context.subscriptions[sub_spec].append(sub)
            await db_entry.subscribe(self.context.subscription_queue, sub_spec)
        elif isinstance(command, ca.EventCancelRequest):
            chan, db_entry = get_db_entry()
            # Search self.subscriptions for a Subscription with a matching id.
            for _sub_spec, _subs in self.subscriptions.items():
                for _sub in _subs:
                    if _sub.subscriptionid == command.subscriptionid:
                        sub_spec = _sub_spec
                        sub = _sub

            unsub_response = chan.unsubscribe(command.subscriptionid)

            if sub:
                self.subscriptions[sub_spec].remove(sub)
                self.context.subscriptions[sub_spec].remove(sub)
                # Does anything else on the Context still care about sub_spec?
                # If not unsubscribe the Context's queue from the db_entry.
                if not self.context.subscriptions[sub_spec]:
                    queue = self.context.subscription_queue
                    await sub_spec.db_entry.unsubscribe(queue, sub_spec)
                return [unsub_response]
        elif isinstance(command, ca.ClearChannelRequest):
            chan, db_entry = get_db_entry()
            return [chan.disconnect()]
        elif isinstance(command, ca.EchoRequest):
            return [ca.EchoResponse()]
Ejemplo n.º 9
0
    async def _process_command(self, command):
        '''Process a command from a client, and return the server response'''
        def get_db_entry():
            chan = self.circuit.channels_sid[command.sid]
            db_entry = self.context[chan.name]
            return chan, db_entry

        if command is ca.DISCONNECTED:
            raise DisconnectedCircuit()
        elif isinstance(command, ca.VersionRequest):
            return [ca.VersionResponse(ca.DEFAULT_PROTOCOL_VERSION)]
        elif isinstance(command, ca.CreateChanRequest):
            db_entry = self.context[command.name]
            access = db_entry.check_access(self.client_hostname,
                                           self.client_username)

            return [
                ca.AccessRightsResponse(cid=command.cid, access_rights=access),
                ca.CreateChanResponse(data_type=db_entry.data_type,
                                      data_count=len(db_entry),
                                      cid=command.cid,
                                      sid=self.circuit.new_channel_id()),
            ]
        elif isinstance(command, ca.HostNameRequest):
            self.client_hostname = command.name
        elif isinstance(command, ca.ClientNameRequest):
            self.client_username = command.name
        elif isinstance(command, (ca.ReadNotifyRequest, ca.ReadRequest)):
            chan, db_entry = get_db_entry()
            metadata, data = await db_entry.auth_read(
                self.client_hostname,
                self.client_username,
                command.data_type,
                user_address=self.circuit.address)
            # This is a pass-through if arr is None.
            data = apply_arr_filter(chan.channel_filter.arr, data)
            # If the timestamp feature is active swap the timestamp.
            # Information must copied because not all clients will have the
            # timestamp filter
            if chan.channel_filter.ts and command.data_type in ca.time_types:
                time_type = type(metadata)
                now = ca.TimeStamp.from_unix_timestamp(time.time())
                metadata = time_type(
                    **ChainMap({'stamp': now},
                               dict((field, getattr(metadata, field))
                                    for field, _ in time_type._fields_)))
            notify = isinstance(command, ca.ReadNotifyRequest)
            return [
                chan.read(data=data,
                          data_type=command.data_type,
                          data_count=len(data),
                          status=1,
                          ioid=command.ioid,
                          metadata=metadata,
                          notify=notify)
            ]
        elif isinstance(command, (ca.WriteRequest, ca.WriteNotifyRequest)):
            chan, db_entry = get_db_entry()
            client_waiting = isinstance(command, ca.WriteNotifyRequest)

            async def handle_write():
                '''Wait for an asynchronous caput to finish'''
                try:
                    write_status = await db_entry.auth_write(
                        self.client_hostname,
                        self.client_username,
                        command.data,
                        command.data_type,
                        command.metadata,
                        user_address=self.circuit.address)
                except Exception as ex:
                    self.log.exception('Invalid write request by %s (%s): %r',
                                       self.client_username,
                                       self.client_hostname, command)
                    cid = self.circuit.channels_sid[command.sid].cid
                    response_command = ca.ErrorResponse(
                        command,
                        cid,
                        status=ca.CAStatus.ECA_INTERNAL,
                        error_message=('Python exception: {} {}'
                                       ''.format(type(ex).__name__, ex)))
                else:
                    if write_status is None:
                        # errors can be passed back by exceptions, and
                        # returning none for write_status can just be
                        # considered laziness
                        write_status = True
                    response_command = chan.write(ioid=command.ioid,
                                                  status=write_status)

                if client_waiting:
                    await self.send(response_command)

            await self._start_write_task(handle_write)
            # TODO pretty sure using the taskgroup will bog things down,
            # but it suppresses an annoying warning message, so... there
        elif isinstance(command, ca.EventAddRequest):
            chan, db_entry = get_db_entry()
            # TODO no support for deprecated low/high/to
            sub = Subscription(mask=command.mask,
                               channel_filter=chan.channel_filter,
                               channel=chan,
                               circuit=self,
                               data_type=command.data_type,
                               data_count=command.data_count,
                               subscriptionid=command.subscriptionid,
                               db_entry=db_entry)
            sub_spec = SubscriptionSpec(db_entry=db_entry,
                                        data_type=command.data_type,
                                        mask=command.mask,
                                        channel_filter=chan.channel_filter)
            self.subscriptions[sub_spec].append(sub)
            self.context.subscriptions[sub_spec].append(sub)
            await db_entry.subscribe(self.context.subscription_queue, sub_spec)
        elif isinstance(command, ca.EventCancelRequest):
            chan, db_entry = get_db_entry()
            await self._cull_subscriptions(
                db_entry,
                lambda sub: sub.subscriptionid == command.subscriptionid)
            return [
                chan.unsubscribe(command.subscriptionid,
                                 data_type=command.data_type)
            ]
        elif isinstance(command, ca.ClearChannelRequest):
            chan, db_entry = get_db_entry()
            await self._cull_subscriptions(
                db_entry, lambda sub: sub.channel == command.sid)
            return [chan.disconnect()]
        elif isinstance(command, ca.EchoRequest):
            return [ca.EchoResponse()]
Ejemplo n.º 10
0
    async def _process_command(self, command):
        '''Process a command from a client, and return the server response'''
        def get_db_entry():
            chan = self.circuit.channels_sid[command.sid]
            db_entry = self.context.pvdb[chan.name.decode(SERVER_ENCODING)]
            return chan, db_entry

        if command is ca.DISCONNECTED:
            raise DisconnectedCircuit()
        elif isinstance(command, ca.CreateChanRequest):
            db_entry = self.context.pvdb[command.name.decode(SERVER_ENCODING)]
            access = db_entry.check_access(self.client_hostname,
                                           self.client_username)

            return [
                ca.VersionResponse(13),
                ca.AccessRightsResponse(cid=command.cid, access_rights=access),
                ca.CreateChanResponse(data_type=db_entry.data_type,
                                      data_count=len(db_entry),
                                      cid=command.cid,
                                      sid=self.circuit.new_channel_id()),
            ]
        elif isinstance(command, ca.HostNameRequest):
            self.client_hostname = command.name.decode(SERVER_ENCODING)
        elif isinstance(command, ca.ClientNameRequest):
            self.client_username = command.name.decode(SERVER_ENCODING)
        elif isinstance(command, ca.ReadNotifyRequest):
            chan, db_entry = get_db_entry()
            metadata, data = await db_entry.get_dbr_data(command.data_type)
            return [
                chan.read(data=data,
                          data_type=command.data_type,
                          data_count=len(data),
                          status=1,
                          ioid=command.ioid,
                          metadata=metadata)
            ]
        elif isinstance(command, (ca.WriteRequest, ca.WriteNotifyRequest)):
            chan, db_entry = get_db_entry()
            client_waiting = isinstance(command, ca.WriteNotifyRequest)

            async def handle_write():
                '''Wait for an asynchronous caput to finish'''
                try:
                    write_status = await db_entry.set_dbr_data(
                        command.data, command.data_type, command.metadata)
                except Exception as ex:
                    cid = self.circuit.channels_sid[command.sid].cid
                    response_command = ca.ErrorResponse(
                        command,
                        cid,
                        status_code=ca.ECA_INTERNAL.code_with_severity,
                        error_message=('Python exception: {} {}'
                                       ''.format(type(ex).__name__, ex)))
                else:
                    if write_status is None:
                        # errors can be passed back by exceptions, and
                        # returning none for write_status can just be
                        # considered laziness
                        write_status = True
                    response_command = chan.write(ioid=command.ioid,
                                                  status=write_status)

                if client_waiting:
                    await self.send(response_command)

            await self.pending_tasks.spawn(handle_write, ignore_result=True)
            # TODO pretty sure using the taskgroup will bog things down,
            # but it suppresses an annoying warning message, so... there
        elif isinstance(command, ca.EventAddRequest):
            chan, db_entry = get_db_entry()
            # TODO no support for deprecated low/high/to
            sub = Subscription(mask=command.mask,
                               circuit=self,
                               data_type=command.data_type,
                               subscription_id=command.subscriptionid)
            if db_entry not in self.context.subscriptions:
                self.context.subscriptions[db_entry] = []
                db_entry.subscribe(self.context.subscription_queue, chan)
            self.context.subscriptions[db_entry].append(sub)
            if db_entry not in self.subscriptions:
                self.subscriptions[db_entry] = []
            self.subscriptions[db_entry].append(sub)

            # send back a first monitor always
            metadata, data = await db_entry.get_dbr_data(command.data_type)
            return [
                chan.subscribe(data=data,
                               data_type=command.data_type,
                               data_count=len(data),
                               subscriptionid=command.subscriptionid,
                               metadata=metadata,
                               status_code=1)
            ]
        elif isinstance(command, ca.EventCancelRequest):
            chan, db_entry = get_db_entry()
            sub = [
                sub for sub in self.subscriptions[db_entry]
                if sub.subscription_id == command.subscriptionid
            ]
            if sub:
                sub = sub[0]
                unsub_response = chan.unsubscribe(command.subscriptionid)
                self.context.subscriptions[db_entry].remove(sub)
                if not self.context.subscriptions[db_entry]:
                    db_entry.subscribe(None)
                    del self.context.subscriptions[db_entry]
                self.subscriptions[db_entry].remove(sub)
                return [unsub_response]
        elif isinstance(command, ca.ClearChannelRequest):
            chan, db_entry = get_db_entry()
            return [chan.disconnect()]
        elif isinstance(command, ca.EchoRequest):
            return [ca.EchoResponse()]