Ejemplo n.º 1
0
def test_circuit_properties():
    host = '127.0.0.1'
    port = 5555
    prio = 1

    circuit = ca.VirtualCircuit(ca.CLIENT, (host, port), prio)
    circuit.host == host
    circuit.port == port
    circuit.key == ((host, port), prio)

    # CLIENT circuit needs to know its prio at init time
    with pytest.raises(ca.CaprotoRuntimeError):
        ca.VirtualCircuit(ca.CLIENT, host, None)

    # SERVER circuit does not
    srv_circuit = ca.VirtualCircuit(ca.SERVER, (host, port), None)

    # 'key' is not defined until prio is defined
    with pytest.raises(ca.CaprotoRuntimeError):
        srv_circuit.key
    srv_circuit.priority = prio
    srv_circuit.key

    # VersionRequest priority must match prio set above.
    with pytest.raises(ca.LocalProtocolError):
        circuit.send(
            ca.VersionRequest(version=ca.DEFAULT_PROTOCOL_VERSION, priority=2))
Ejemplo n.º 2
0
def make_channel(pv_name, udp_sock, priority, timeout):
    log = logging.LoggerAdapter(logging.getLogger('caproto.ch'),
                                {'pv': pv_name})
    address = search(pv_name, udp_sock, timeout)
    try:
        circuit = global_circuits[(address, priority)]
    except KeyError:

        circuit = global_circuits[(address, priority)] = ca.VirtualCircuit(
            our_role=ca.CLIENT, address=address, priority=priority)

    chan = ca.ClientChannel(pv_name, circuit)
    new = False
    if chan.circuit not in sockets:
        new = True
        sockets[chan.circuit] = socket.create_connection(
            chan.circuit.address, timeout)
        circuit.our_address = sockets[chan.circuit].getsockname()
    try:
        if new:
            # Initialize our new TCP-based CA connection with a VersionRequest.
            send(
                chan.circuit,
                ca.VersionRequest(priority=priority,
                                  version=ca.DEFAULT_PROTOCOL_VERSION),
                pv_name)
            send(chan.circuit, chan.host_name(socket.gethostname()))
            send(chan.circuit, chan.client_name(getpass.getuser()))
        send(chan.circuit, chan.create(), pv_name)
        t = time.monotonic()
        while True:
            try:
                commands = recv(chan.circuit)
                if time.monotonic() - t > timeout:
                    raise socket.timeout
            except socket.timeout:
                raise CaprotoTimeoutError("Timeout while awaiting channel "
                                          "creation.")
            tags = {
                'direction': '<<<---',
                'our_address': chan.circuit.our_address,
                'their_address': chan.circuit.address
            }
            for command in commands:
                if isinstance(command, ca.Message):
                    tags['bytesize'] = len(command)
                    logger.debug("%r", command, extra=tags)
                elif command is ca.DISCONNECTED:
                    raise CaprotoError('Disconnected during initialization')
            if chan.states[ca.CLIENT] is ca.CONNECTED:
                log.info("Channel connected.")
                break

    except BaseException:
        sockets[chan.circuit].close()
        del sockets[chan.circuit]
        del global_circuits[(chan.circuit.address, chan.circuit.priority)]
        raise
    return chan
Ejemplo n.º 3
0
def circuit_pair(request):
    host = '127.0.0.1'
    port = 5555
    priority = 1
    version = 13
    cli_circuit = ca.VirtualCircuit(ca.CLIENT, (host, port), priority)
    buffers_to_send = cli_circuit.send(
        ca.VersionRequest(version=version, priority=priority))

    srv_circuit = ca.VirtualCircuit(ca.SERVER, (host, port), None)
    commands, _ = srv_circuit.recv(*buffers_to_send)
    for command in commands:
        srv_circuit.process_command(command)
    buffers_to_send = srv_circuit.send(ca.VersionResponse(version=version))
    commands, _ = cli_circuit.recv(*buffers_to_send)
    for command in commands:
        cli_circuit.process_command(command)
    return cli_circuit, srv_circuit
Ejemplo n.º 4
0
def make_channel(pv_name, udp_sock, priority, timeout):
    log = logging.getLogger(f'caproto.ch.{pv_name}.{priority}')
    address = search(pv_name, udp_sock, timeout)
    try:
        circuit = global_circuits[(address, priority)]
    except KeyError:

        circuit = global_circuits[(address, priority)] = ca.VirtualCircuit(
            our_role=ca.CLIENT, address=address, priority=priority)

    chan = ca.ClientChannel(pv_name, circuit)
    new = False
    if chan.circuit not in sockets:
        new = True
        sockets[chan.circuit] = socket.create_connection(
            chan.circuit.address, timeout)

    try:
        if new:
            # Initialize our new TCP-based CA connection with a VersionRequest.
            send(
                chan.circuit,
                ca.VersionRequest(priority=priority,
                                  version=ca.DEFAULT_PROTOCOL_VERSION))
            send(chan.circuit, chan.host_name(socket.gethostname()))
            send(chan.circuit, chan.client_name(getpass.getuser()))
        send(chan.circuit, chan.create())
        t = time.monotonic()
        while True:
            try:
                commands = recv(chan.circuit)
                if time.monotonic() - t > timeout:
                    raise socket.timeout
            except socket.timeout:
                raise CaprotoTimeoutError("Timeout while awaiting channel "
                                          "creation.")
            if chan.states[ca.CLIENT] is ca.CONNECTED:
                log.info('%s connected' % pv_name)
                break
            for command in commands:
                if command is ca.DISCONNECTED:
                    raise CaprotoError('Disconnected during initialization')

    except BaseException:
        sockets[chan.circuit].close()
        del sockets[chan.circuit]
        del global_circuits[(chan.circuit.address, chan.circuit.priority)]
        raise
    return chan
Ejemplo n.º 5
0
    async def tcp_handler(self, client, addr):
        '''Handler for each new TCP client to the server'''
        cavc = ca.VirtualCircuit(ca.SERVER, addr, None)
        circuit = CurioVirtualCircuit(cavc, client, self)
        self.circuits.add(circuit)

        circuit.circuit.log.setLevel(self.log_level)

        await circuit.run()

        while True:
            try:
                await circuit.recv()
            except DisconnectedCircuit:
                break
Ejemplo n.º 6
0
    def get_circuit(self, address, priority):
        """
        Return a VirtualCircuit with this address, priority.

        Make a new one if necessary.
        """
        circuit = self.circuits.get((address, priority), None)
        if circuit is None or not circuit.connected:
            circuit = VirtualCircuit(
                ca.VirtualCircuit(
                    our_role=ca.CLIENT,
                    address=address,
                    priority=priority,
                ))
            circuit.circuit.log.setLevel(self.log_level)
            self.circuits[(address, priority)] = circuit
        return circuit
Ejemplo n.º 7
0
    async def get_circuit(self, address, priority):
        """
        Return a VirtualCircuit with this address, priority.

        Make a new one if necessary.
        """
        for circuit in self.circuits:
            if (circuit.circuit.address == address and
                    circuit.circuit.priority == priority):
                return circuit

        ca_circuit = ca.VirtualCircuit(our_role=ca.CLIENT, address=address,
                                       priority=priority)
        circuit = VirtualCircuit(ca_circuit, nursery=self.nursery)
        self.circuits.append(circuit)
        self.nursery.start_soon(circuit.connect)
        return circuit
Ejemplo n.º 8
0
    async def get_circuit(self, address, priority):
        """
        Return a VirtualCircuit with this address, priority.

        Make a new one if necessary.
        """
        for circuit in self.circuits:
            if (circuit.circuit.address == address
                    and circuit.circuit.priority == priority):
                return circuit

        ca_circuit = ca.VirtualCircuit(our_role=ca.CLIENT,
                                       address=address,
                                       priority=priority)
        circuit = VirtualCircuit(ca_circuit)
        self.circuits.append(circuit)
        await curio.spawn(circuit.connect, daemon=True)
        return circuit
Ejemplo n.º 9
0
    def get_circuit(self, address, priority):
        """
        Return a VirtualCircuit with this address, priority.

        Make a new one if necessary.
        """
        for circuit in self.circuits:
            if (circuit.circuit.address == address
                    and circuit.circuit.priority == priority):
                return circuit

        ca_circuit = ca.VirtualCircuit(our_role=ca.CLIENT,
                                       address=address,
                                       priority=priority)
        circuit = VirtualCircuit(ca_circuit)
        circuit.circuit.log.setLevel(self.log_level)
        self.circuits.append(circuit)
        return circuit
Ejemplo n.º 10
0
    async def tcp_handler(self, client, addr):
        '''Handler for each new TCP client to the server'''
        cavc = ca.VirtualCircuit(ca.SERVER, addr, None)
        circuit = self.CircuitClass(cavc, client, self)
        self.circuits.add(circuit)
        self.log.info('Connected to new client at %s:%d (total: %d).', *addr,
                      len(self.circuits))

        await circuit.run()

        try:
            while True:
                try:
                    await circuit.recv()
                except DisconnectedCircuit:
                    await self.circuit_disconnected(circuit)
                    break
        except KeyboardInterrupt as ex:
            self.log.debug('TCP handler received KeyboardInterrupt')
            raise self.ServerExit() from ex
        self.log.info('Disconnected from client at %s:%d (total: %d).', *addr,
                      len(self.circuits))
Ejemplo n.º 11
0
def test_nonet():
    # Register with the repeater.
    assert not cli_b._registered
    bytes_to_send = cli_b.send(ca.RepeaterRegisterRequest('0.0.0.0'))
    assert not cli_b._registered

    # Receive response
    data = bytes(ca.RepeaterConfirmResponse('127.0.0.1'))
    commands = cli_b.recv(data, cli_addr)
    cli_b.process_commands(commands)
    assert cli_b._registered

    # Search for pv1.
    # CA requires us to send a VersionRequest and a SearchRequest bundled into
    # one datagram.
    bytes_to_send = cli_b.send(ca.VersionRequest(0, ca.DEFAULT_PROTOCOL_VERSION),
                               ca.SearchRequest(pv1, 0,
                                                ca.DEFAULT_PROTOCOL_VERSION))

    commands = srv_b.recv(bytes_to_send, cli_addr)
    srv_b.process_commands(commands)
    ver_req, search_req = commands
    bytes_to_send = srv_b.send(
        ca.VersionResponse(ca.DEFAULT_PROTOCOL_VERSION),
        ca.SearchResponse(5064, None, search_req.cid, ca.DEFAULT_PROTOCOL_VERSION))

    # Receive a VersionResponse and SearchResponse.
    commands = iter(cli_b.recv(bytes_to_send, cli_addr))
    command = next(commands)
    assert type(command) is ca.VersionResponse
    command = next(commands)
    assert type(command) is ca.SearchResponse
    address = ca.extract_address(command)

    circuit = ca.VirtualCircuit(our_role=ca.CLIENT,
                                address=address,
                                priority=0)
    circuit.log.setLevel('DEBUG')
    chan1 = ca.ClientChannel(pv1, circuit)
    assert chan1.states[ca.CLIENT] is ca.SEND_CREATE_CHAN_REQUEST
    assert chan1.states[ca.SERVER] is ca.IDLE

    srv_circuit = ca.VirtualCircuit(our_role=ca.SERVER,
                                    address=address, priority=None)

    cli_send(chan1.circuit, ca.VersionRequest(priority=0,
                                              version=ca.DEFAULT_PROTOCOL_VERSION))

    srv_recv(srv_circuit)

    srv_send(srv_circuit, ca.VersionResponse(version=ca.DEFAULT_PROTOCOL_VERSION))
    cli_recv(chan1.circuit)
    cli_send(chan1.circuit, ca.HostNameRequest('localhost'))
    cli_send(chan1.circuit, ca.ClientNameRequest('username'))
    cli_send(chan1.circuit, ca.CreateChanRequest(name=pv1, cid=chan1.cid,
                                                 version=ca.DEFAULT_PROTOCOL_VERSION))
    assert chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE

    srv_recv(srv_circuit)
    assert chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE
    srv_chan1, = srv_circuit.channels.values()
    assert srv_chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert srv_chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE

    srv_send(srv_circuit, ca.CreateChanResponse(cid=chan1.cid, sid=1,
                                                data_type=5, data_count=1))
    assert srv_chan1.states[ca.CLIENT] is ca.CONNECTED
    assert srv_chan1.states[ca.SERVER] is ca.CONNECTED

    # At this point the CLIENT is not aware that we are CONNECTED because it
    # has not yet received the CreateChanResponse. It should not be allowed to
    # read or write.
    assert chan1.states[ca.CLIENT] is ca.AWAIT_CREATE_CHAN_RESPONSE
    assert chan1.states[ca.SERVER] is ca.SEND_CREATE_CHAN_RESPONSE

    # Try sending a premature read request.
    read_req = ca.ReadNotifyRequest(sid=srv_chan1.sid,
                                    data_type=srv_chan1.native_data_type,
                                    data_count=srv_chan1.native_data_count,
                                    ioid=0)
    with pytest.raises(ca.LocalProtocolError):
        cli_send(chan1.circuit, read_req)

    # The above failed because the sid is not recognized. Remove that failure
    # by editing the sid cache, and check that it *still* fails, this time
    # because of the state machine prohibiting this command before the channel
    # is in a CONNECTED state.
    chan1.circuit.channels_sid[1] = chan1
    with pytest.raises(ca.LocalProtocolError):
        cli_send(chan1.circuit, read_req)

    cli_recv(chan1.circuit)
    assert chan1.states[ca.CLIENT] is ca.CONNECTED
    assert chan1.states[ca.SERVER] is ca.CONNECTED

    # Test subscriptions.
    assert chan1.native_data_type and chan1.native_data_count
    add_req = ca.EventAddRequest(data_type=chan1.native_data_type,
                                 data_count=chan1.native_data_count,
                                 sid=chan1.sid,
                                 subscriptionid=0,
                                 low=0, high=0, to=0, mask=1)
    cli_send(chan1.circuit, add_req)
    srv_recv(srv_circuit)
    add_res = ca.EventAddResponse(data=(3,),
                                  data_type=chan1.native_data_type,
                                  data_count=chan1.native_data_count,
                                  subscriptionid=0,
                                  status=1)

    srv_send(srv_circuit, add_res)
    cli_recv(chan1.circuit)

    cancel_req = ca.EventCancelRequest(data_type=add_req.data_type,
                                       sid=add_req.sid,
                                       subscriptionid=add_req.subscriptionid)

    cli_send(chan1.circuit, cancel_req)
    srv_recv(srv_circuit)

    # Test reading.
    cli_send(chan1.circuit, ca.ReadNotifyRequest(data_type=5, data_count=1,
                                                 sid=chan1.sid,
                                                 ioid=12))
    srv_recv(srv_circuit)
    srv_send(srv_circuit, ca.ReadNotifyResponse(data=(3,),
                                                data_type=5, data_count=1,
                                                ioid=12, status=1))
    cli_recv(chan1.circuit)

    # Test writing.
    request = ca.WriteNotifyRequest(data_type=2, data_count=1,
                                    sid=chan1.sid,
                                    ioid=13, data=(4,))

    cli_send(chan1.circuit, request)
    srv_recv(srv_circuit)
    srv_send(srv_circuit, ca.WriteNotifyResponse(data_type=5, data_count=1,
                                                 ioid=13, status=1))
    cli_recv(chan1.circuit)

    # Test "clearing" (closing) the channel.
    cli_send(chan1.circuit, ca.ClearChannelRequest(sid=chan1.sid, cid=chan1.cid))
    assert chan1.states[ca.CLIENT] is ca.MUST_CLOSE
    assert chan1.states[ca.SERVER] is ca.MUST_CLOSE

    srv_recv(srv_circuit)
    assert srv_chan1.states[ca.CLIENT] is ca.MUST_CLOSE
    assert srv_chan1.states[ca.SERVER] is ca.MUST_CLOSE

    srv_send(srv_circuit, ca.ClearChannelResponse(sid=chan1.sid, cid=chan1.cid))
    assert srv_chan1.states[ca.CLIENT] is ca.CLOSED
    assert srv_chan1.states[ca.SERVER] is ca.CLOSED
Ejemplo n.º 12
0
def test_circuit_equality():
    a = ca.VirtualCircuit(ca.CLIENT, ('asdf', 1234), 1)
    b = ca.VirtualCircuit(ca.CLIENT, ('asdf', 1234), 1)
    c = ca.VirtualCircuit(ca.CLIENT, ('asdf', 1234), 2)
    assert a == b != c
Ejemplo n.º 13
0
def main(*, skip_monitor_section=False):
    # A broadcast socket
    udp_sock = ca.bcast_socket()

    # Register with the repeater.
    bytes_to_send = b.send(ca.RepeaterRegisterRequest('0.0.0.0'))

    # TODO: for test environment with specific hosts listed in
    # EPICS_CA_ADDR_LIST
    if False:
        fake_reg = (('127.0.0.1', ca.EPICS_CA1_PORT),
                    [ca.RepeaterConfirmResponse(repeater_address='127.0.0.1')])
        b.command_queue.put(fake_reg)
    else:
        udp_sock.sendto(bytes_to_send, ('', CA_REPEATER_PORT))

        # Receive response
        data, address = udp_sock.recvfrom(1024)
        commands = b.recv(data, address)

    b.process_commands(commands)

    # Search for pv1.
    # CA requires us to send a VersionRequest and a SearchRequest bundled into
    # one datagram.
    bytes_to_send = b.send(ca.VersionRequest(0, 13),
                           ca.SearchRequest(pv1, 0, 13))
    for host in ca.get_address_list():
        if ':' in host:
            host, _, specified_port = host.partition(':')
            udp_sock.sendto(bytes_to_send, (host, int(specified_port)))
        else:
            udp_sock.sendto(bytes_to_send, (host, CA_SERVER_PORT))
    print('searching for %s' % pv1)
    # Receive a VersionResponse and SearchResponse.
    bytes_received, address = udp_sock.recvfrom(1024)
    commands = b.recv(bytes_received, address)
    b.process_commands(commands)
    c1, c2 = commands
    assert type(c1) is ca.VersionResponse
    assert type(c2) is ca.SearchResponse
    address = ca.extract_address(c2)

    circuit = ca.VirtualCircuit(our_role=ca.CLIENT,
                                address=address,
                                priority=0)
    circuit.log.setLevel('DEBUG')
    chan1 = ca.ClientChannel(pv1, circuit)
    sockets[chan1.circuit] = socket.create_connection(chan1.circuit.address)

    # Initialize our new TCP-based CA connection with a VersionRequest.
    send(chan1.circuit, ca.VersionRequest(priority=0, version=13))
    recv(chan1.circuit)
    # Send info about us.
    send(chan1.circuit, ca.HostNameRequest('localhost'))
    send(chan1.circuit, ca.ClientNameRequest('username'))
    send(chan1.circuit,
         ca.CreateChanRequest(name=pv1, cid=chan1.cid, version=13))
    commands = recv(chan1.circuit)

    # Test subscriptions.
    assert chan1.native_data_type and chan1.native_data_count
    add_req = ca.EventAddRequest(data_type=chan1.native_data_type,
                                 data_count=chan1.native_data_count,
                                 sid=chan1.sid,
                                 subscriptionid=0,
                                 low=0,
                                 high=0,
                                 to=0,
                                 mask=1)
    send(chan1.circuit, add_req)

    commands = recv(chan1.circuit)

    if not skip_monitor_section:
        try:
            print('Monitoring until Ctrl-C is hit. Meanwhile, use caput to '
                  'change the value and watch for commands to arrive here.')
            while True:
                commands = recv(chan1.circuit)
                if commands:
                    print(commands)
        except KeyboardInterrupt:
            pass

    cancel_req = ca.EventCancelRequest(data_type=add_req.data_type,
                                       sid=add_req.sid,
                                       subscriptionid=add_req.subscriptionid)

    send(chan1.circuit, cancel_req)
    commands, = recv(chan1.circuit)

    # Test reading.
    send(
        chan1.circuit,
        ca.ReadNotifyRequest(data_type=2, data_count=1, sid=chan1.sid,
                             ioid=12))
    commands, = recv(chan1.circuit)

    # Test writing.
    request = ca.WriteNotifyRequest(data_type=2,
                                    data_count=1,
                                    sid=chan1.sid,
                                    ioid=13,
                                    data=(4, ))

    send(chan1.circuit, request)
    recv(chan1.circuit)
    time.sleep(2)
    send(
        chan1.circuit,
        ca.ReadNotifyRequest(data_type=2, data_count=1, sid=chan1.sid,
                             ioid=14))
    recv(chan1.circuit)

    # Test "clearing" (closing) the channel.
    send(chan1.circuit, ca.ClearChannelRequest(chan1.sid, chan1.cid))
    recv(chan1.circuit)

    sockets.pop(chan1.circuit).close()
    udp_sock.close()