示例#1
0
 def _init_sender(self, db, vertex_sizes):
     self._sender_connection = EIEIOConnection()
     for label in self._send_labels:
         self._send_address_details[label] = self.__get_live_input_details(
             db, label)
         if self._machine_vertices:
             key, _ = db.get_machine_live_input_key(label)
             self._atom_id_to_key[label] = {0: key}
             vertex_sizes[label] = 1
         else:
             self._atom_id_to_key[label] = db.get_atom_id_to_key_mapping(
                 label)
             vertex_sizes[label] = len(self._atom_id_to_key[label])
 def test_send_start_resume_notification(self):
     """ Test the sending of the start/resume message of the notification\
         protocol
     """
     listener = EIEIOConnection()
     socket_addresses = [SocketAddress(
         "127.0.0.1", listener.local_port, None)]
     protocol = NotificationProtocol(socket_addresses, False)
     protocol.send_start_resume_notification()
     message = listener.receive_eieio_message(timeout=10)
     self.assertIsInstance(message, EIEIOCommandMessage)
     self.assertEqual(
         message.eieio_header.command,
         EIEIO_COMMAND_IDS.START_RESUME_NOTIFICATION.value)
示例#3
0
 def test_send_start_resume_notification(self):
     """ Test the sending of the start/resume message of the notification\
         protocol
     """
     listener = EIEIOConnection()
     socket_addresses = [
         SocketAddress("127.0.0.1", listener.local_port, None)
     ]
     protocol = NotificationProtocol(socket_addresses, False)
     protocol.send_start_resume_notification()
     message = listener.receive_eieio_message(timeout=10)
     self.assertIsInstance(message, EIEIOCommandMessage)
     self.assertEqual(message.eieio_header.command,
                      EIEIO_COMMAND_IDS.START_RESUME_NOTIFICATION.value)
示例#4
0
    def _init_receivers(self, db, vertex_sizes):
        for label_id, label in enumerate(self._receive_labels):
            host, port, board_address = self.__get_live_output_details(
                db, label)
            if port not in self._receivers:
                receiver = EIEIOConnection(local_port=port)
                listener = ConnectionListener(receiver)
                listener.add_callback(self._receive_packet_callback)
                listener.start()
                self._receivers[port] = receiver
                self._listeners[port] = listener

            send_port_trigger_message(receiver, board_address)
            logger.info("Listening for traffic from {} on {}:{}", label, host,
                        port)

            if self._machine_vertices:
                key, _ = db.get_machine_live_output_key(
                    label, self._live_packet_gather_label)
                self._key_to_atom_id_and_label[key] = (0, label_id)
                vertex_sizes[label] = 1
            else:
                key_to_atom_id = db.get_key_to_atom_id_mapping(label)
                for key, atom_id in iteritems(key_to_atom_id):
                    self._key_to_atom_id_and_label[key] = (atom_id, label_id)
                vertex_sizes[label] = len(key_to_atom_id)
示例#5
0
    def test_listener_creation(self):
        # Tests the creation of listening sockets

        # Create board connections
        connections = []
        connections.append(SCAMPConnection(remote_host=None))
        orig_connection = EIEIOConnection()
        connections.append(orig_connection)

        # Create transceiver
        with Transceiver(version=5, connections=connections) as trnx:
            # Register a UDP listeners
            connection_1 = trnx.register_udp_listener(
                callback=None, connection_class=EIEIOConnection)
            connection_2 = trnx.register_udp_listener(
                callback=None, connection_class=EIEIOConnection)
            connection_3 = trnx.register_udp_listener(
                callback=None,
                connection_class=EIEIOConnection,
                local_port=orig_connection.local_port)
            connection_4 = trnx.register_udp_listener(
                callback=None,
                connection_class=EIEIOConnection,
                local_port=orig_connection.local_port + 1)

            assert connection_1 == orig_connection
            assert connection_2 == orig_connection
            assert connection_3 == orig_connection
            assert connection_4 != orig_connection
    def test_listener_creation(self):
        # Test of buffer manager listener creation problem, where multiple
        # listeners were being created for the buffer manager traffic from
        # individual boards, where it's preferred all traffic is received by
        # a single listener

        # Create two vertices
        v1 = _TestVertex(10, "v1", 256)
        v2 = _TestVertex(10, "v2", 256)

        # Create two tags - important thing is port=None
        t1 = IPTag(board_address='127.0.0.1', destination_x=0,
                   destination_y=1, tag=1, port=None, ip_address=None,
                   strip_sdp=True, traffic_identifier='BufferTraffic')
        t2 = IPTag(board_address='127.0.0.1', destination_x=0,
                   destination_y=2, tag=1, port=None, ip_address=None,
                   strip_sdp=True, traffic_identifier='BufferTraffic')

        # Create 'Tags' object and add tags
        t = Tags()
        t.add_ip_tag(t1, v1)
        t.add_ip_tag(t2, v2)

        # Create board connections
        connections = []
        connections.append(SCAMPConnection(
            remote_host=None))
        connections.append(EIEIOConnection())

        # Create two placements and 'Placements' object
        pl1 = Placement(v1, 0, 1, 1)
        pl2 = Placement(v2, 0, 2, 1)
        pl = Placements([pl1, pl2])

        # Create transceiver
        trnx = Transceiver(version=5, connections=connections)
        # Alternatively, one can register a udp listener for testing via:
        # trnx.register_udp_listener(callback=None,
        #        connection_class=EIEIOConnection)

        # Create buffer manager
        bm = BufferManager(pl, t, trnx)

        # Register two listeners, and check the second listener uses the
        # first rather than creating a new one
        bm._add_buffer_listeners(vertex=v1)
        bm._add_buffer_listeners(vertex=v2)

        number_of_listeners = 0
        for i in bm._transceiver._udp_listenable_connections_by_class[
                EIEIOConnection]:
            # Check if listener is registered on connection - we only expect
            # one listener to be registered, as all connections can use the
            # same listener for the buffer manager
            if not i[1] is None:
                number_of_listeners += 1
            print i
        self.assertEqual(number_of_listeners, 1)
示例#7
0
    def __init_receivers(self, db, vertex_sizes):
        # Set up a single connection for receive
        if self.__receiver_connection is None:
            self.__receiver_connection = EIEIOConnection()
        receivers = set()
        for label_id, label in enumerate(self.__receive_labels):
            _, port, board_address, tag = self.__get_live_output_details(
                db, label)

            # Update the tag if not already done
            if (board_address, port, tag) not in receivers:
                self.__update_tag(self.__receiver_connection, board_address,
                                  tag)
                receivers.add((board_address, port, tag))
                send_port_trigger_message(self.__receiver_connection,
                                          board_address)

            logger.info("Listening for traffic from {} on {}:{}", label,
                        self.__receiver_connection.local_ip_address,
                        self.__receiver_connection.local_port)

            if self.__machine_vertices:
                key, _ = db.get_machine_live_output_key(
                    label, self.__live_packet_gather_label)
                self.__key_to_atom_id_and_label[key] = (0, label_id)
                vertex_sizes[label] = 1
            else:
                key_to_atom_id = db.get_key_to_atom_id_mapping(label)
                for key, atom_id in iteritems(key_to_atom_id):
                    self.__key_to_atom_id_and_label[key] = (atom_id, label_id)
                vertex_sizes[label] = len(key_to_atom_id)

        # Last of all, set up the listener for packets
        # NOTE: Has to be done last as otherwise will receive SCP messages
        # sent above!
        if self.__receiver_listener is None:
            self.__receiver_listener = ConnectionListener(
                self.__receiver_connection)
            self.__receiver_listener.add_callback(self.__do_receive_packet)
            self.__receiver_listener.start()
    def __init__(self, socket_addresses, wait_for_read_confirmation):
        self._socket_addresses = socket_addresses

        # Determines whether to wait for confirmation that the database
        # has been read before starting the simulation
        self._wait_for_read_confirmation = wait_for_read_confirmation
        self._wait_pool = ThreadPool(processes=1)
        self._data_base_message_connections = list()
        for socket_address in socket_addresses:
            self._data_base_message_connections.append(
                EIEIOConnection(local_port=socket_address.listen_port,
                                remote_host=socket_address.notify_host_name,
                                remote_port=socket_address.notify_port_no))
 def _init_sender(self, db, vertex_sizes):
     if self._sender_connection is None:
         self._sender_connection = EIEIOConnection()
     for label in self._send_labels:
         self._send_address_details[label] = self.__get_live_input_details(
             db, label)
         if self._machine_vertices:
             key, _ = db.get_machine_live_input_key(label)
             self._atom_id_to_key[label] = {0: key}
             vertex_sizes[label] = 1
         else:
             self._atom_id_to_key[label] = db.get_atom_id_to_key_mapping(
                 label)
             vertex_sizes[label] = len(self._atom_id_to_key[label])
    def _init_receivers(self, db, vertex_sizes):
        # Set up a single connection for receive
        if self._receiver_connection is None:
            self._receiver_connection = EIEIOConnection()
        receivers = set()
        for label_id, label in enumerate(self._receive_labels):
            _, port, board_address, tag = self.__get_live_output_details(
                db, label)

            # Update the tag if not already done
            if (board_address, port, tag) not in receivers:
                self.__update_tag(
                    self._receiver_connection, board_address, tag)
                receivers.add((board_address, port, tag))
                send_port_trigger_message(
                    self._receiver_connection, board_address)

            logger.info(
                "Listening for traffic from {} on {}:{}",
                label, self._receiver_connection.local_ip_address,
                self._receiver_connection.local_port)

            if self._machine_vertices:
                key, _ = db.get_machine_live_output_key(
                    label, self._live_packet_gather_label)
                self._key_to_atom_id_and_label[key] = (0, label_id)
                vertex_sizes[label] = 1
            else:
                key_to_atom_id = db.get_key_to_atom_id_mapping(label)
                for key, atom_id in iteritems(key_to_atom_id):
                    self._key_to_atom_id_and_label[key] = (atom_id, label_id)
                vertex_sizes[label] = len(key_to_atom_id)

        # Last of all, set up the listener for packets
        # NOTE: Has to be done last as otherwise will receive SCP messages
        # sent above!
        if self._receiver_listener is None:
            self._receiver_listener = ConnectionListener(
                self._receiver_connection)
            self._receiver_listener.add_callback(self._receive_packet_callback)
            self._receiver_listener.start()
    def __init__(self, socket_addresses, wait_for_read_confirmation):
        """
        :param list(~spinn_utilities.socket_address.SocketAddress) \
                socket_addresses: Where to notify.
        :param bool wait_for_read_confirmation:
            Whether to wait for the other side to acknowledge
        """
        self.__socket_addresses = socket_addresses

        # Determines whether to wait for confirmation that the database
        # has been read before starting the simulation
        self.__wait_for_read_confirmation = wait_for_read_confirmation
        self.__wait_pool = ThreadPoolExecutor(max_workers=1)
        self.__wait_futures = list()
        self.__sent_visualisation_confirmation = False
        self.__database_message_connections = list()
        for socket_address in socket_addresses:
            self.__database_message_connections.append(
                EIEIOConnection(local_port=socket_address.listen_port,
                                remote_host=socket_address.notify_host_name,
                                remote_port=socket_address.notify_port_no))
class LiveEventConnection(DatabaseConnection):
    """ A connection for receiving and sending live events from and to\
        SpiNNaker
    """
    __slots__ = [
        "_atom_id_to_key",
        "_init_callbacks",
        "_key_to_atom_id_and_label",
        "_live_event_callbacks",
        "_live_packet_gather_label",
        "_machine_vertices",
        "_pause_stop_callbacks",
        "_receive_labels",
        "_receiver_connection",
        "_receiver_listener",
        "_send_address_details",
        "_send_labels",
        "_sender_connection",
        "_start_resume_callbacks"]

    def __init__(self, live_packet_gather_label, receive_labels=None,
                 send_labels=None, local_host=None, local_port=NOTIFY_PORT,
                 machine_vertices=False):
        """
        :param live_packet_gather_label: The label of the LivePacketGather\
            vertex to which received events are being sent
        :param receive_labels: \
            Labels of vertices from which live events will be received.
        :type receive_labels: iterable of str
        :param send_labels: \
            Labels of vertices to which live events will be sent
        :type send_labels: iterable of str
        :param local_host: Optional specification of the local hostname or\
            IP address of the interface to listen on
        :type local_host: str
        :param local_port: Optional specification of the local port to listen\
            on. Must match the port that the toolchain will send the\
            notification on (19999 by default)
        :type local_port: int
        """
        # pylint: disable=too-many-arguments
        super(LiveEventConnection, self).__init__(
            self._start_resume_callback, self._stop_pause_callback,
            local_host=local_host, local_port=local_port)

        self.add_database_callback(self._read_database_callback)

        self._live_packet_gather_label = live_packet_gather_label
        self._receive_labels = receive_labels
        self._send_labels = send_labels
        self._machine_vertices = machine_vertices
        self._sender_connection = None
        self._send_address_details = dict()
        self._atom_id_to_key = dict()
        self._key_to_atom_id_and_label = dict()
        self._live_event_callbacks = list()
        self._start_resume_callbacks = dict()
        self._pause_stop_callbacks = dict()
        self._init_callbacks = dict()
        if receive_labels is not None:
            for label in receive_labels:
                self._live_event_callbacks.append(list())
                self._start_resume_callbacks[label] = list()
                self._pause_stop_callbacks[label] = list()
                self._init_callbacks[label] = list()
        if send_labels is not None:
            for label in send_labels:
                self._start_resume_callbacks[label] = list()
                self._pause_stop_callbacks[label] = list()
                self._init_callbacks[label] = list()
        self._receiver_listener = None
        self._receiver_connection = None

    def add_init_callback(self, label, init_callback):
        """ Add a callback to be called to initialise a vertex

        :param label: The label of the vertex to be notified about. Must be\
            one of the vertices listed in the constructor
        :type label: str
        :param init_callback: A function to be called to initialise the\
            vertex. This should take as parameters the label of the vertex,\
            the number of neurons in the population, the run time of the\
            simulation in milliseconds, and the simulation timestep in\
            milliseconds
        :type init_callback: function(str, int, float, float) -> None
        """
        self._init_callbacks[label].append(init_callback)

    def add_receive_callback(self, label, live_event_callback):
        """ Add a callback for the reception of live events from a vertex

        :param label: The label of the vertex to be notified about. Must be\
            one of the vertices listed in the constructor
        :type label: str
        :param live_event_callback: A function to be called when events are\
            received. This should take as parameters the label of the vertex,\
            the simulation timestep when the event occurred, and an\
            array-like of atom IDs.
        :type live_event_callback: function(str, int, [int]) -> None
        """
        label_id = self._receive_labels.index(label)
        self._live_event_callbacks[label_id].append(live_event_callback)

    def add_start_callback(self, label, start_callback):
        """ Add a callback for the start of the simulation

        :param start_callback: A function to be called when the start\
            message has been received. This function should take the label of\
            the referenced vertex, and an instance of this class, which can\
            be used to send events
        :type start_callback: function(str, \
            :py:class:`SpynnakerLiveEventConnection`) -> None
        :param label: the label of the function to be sent
        :type label: str
        """
        logger.warning(
            "the method 'add_start_callback(label, start_callback)' is in "
            "deprecation, and will be replaced with the method "
            "'add_start_resume_callback(label, start_resume_callback)' in a "
            "future release.")
        self.add_start_resume_callback(label, start_callback)

    def add_start_resume_callback(self, label, start_resume_callback):
        self._start_resume_callbacks[label].append(start_resume_callback)

    def add_pause_stop_callback(self, label, pause_stop_callback):
        """ Add a callback for the pause and stop state of the simulation

        :param label: the label of the function to be sent
        :type label: str
        :param pause_stop_callback: A function to be called when the pause\
            or stop message has been received. This function should take the\
            label of the referenced  vertex, and an instance of this class,\
            which can be used to send events.
        :type pause_stop_callback: function(str, \
            :py:class:`SpynnakerLiveEventConnection`) -> None
        :rtype: None
        """
        self._pause_stop_callbacks[label].append(pause_stop_callback)

    def _read_database_callback(self, db_reader):
        self._handle_possible_rerun_state()

        vertex_sizes = OrderedDict()
        run_time_ms = db_reader.get_configuration_parameter_value(
            "runtime")
        machine_timestep_ms = db_reader.get_configuration_parameter_value(
            "machine_time_step") / 1000.0

        if self._send_labels is not None:
            self._init_sender(db_reader, vertex_sizes)

        if self._receive_labels is not None:
            self._init_receivers(db_reader, vertex_sizes)

        for label, vertex_size in iteritems(vertex_sizes):
            for init_callback in self._init_callbacks[label]:
                init_callback(
                    label, vertex_size, run_time_ms, machine_timestep_ms)

    def _init_sender(self, db, vertex_sizes):
        if self._sender_connection is None:
            self._sender_connection = EIEIOConnection()
        for label in self._send_labels:
            self._send_address_details[label] = self.__get_live_input_details(
                db, label)
            if self._machine_vertices:
                key, _ = db.get_machine_live_input_key(label)
                self._atom_id_to_key[label] = {0: key}
                vertex_sizes[label] = 1
            else:
                self._atom_id_to_key[label] = db.get_atom_id_to_key_mapping(
                    label)
                vertex_sizes[label] = len(self._atom_id_to_key[label])

    def _init_receivers(self, db, vertex_sizes):
        # Set up a single connection for receive
        if self._receiver_connection is None:
            self._receiver_connection = EIEIOConnection()
        receivers = set()
        for label_id, label in enumerate(self._receive_labels):
            _, port, board_address, tag = self.__get_live_output_details(
                db, label)

            # Update the tag if not already done
            if (board_address, port, tag) not in receivers:
                self.__update_tag(
                    self._receiver_connection, board_address, tag)
                receivers.add((board_address, port, tag))
                send_port_trigger_message(
                    self._receiver_connection, board_address)

            logger.info(
                "Listening for traffic from {} on {}:{}",
                label, self._receiver_connection.local_ip_address,
                self._receiver_connection.local_port)

            if self._machine_vertices:
                key, _ = db.get_machine_live_output_key(
                    label, self._live_packet_gather_label)
                self._key_to_atom_id_and_label[key] = (0, label_id)
                vertex_sizes[label] = 1
            else:
                key_to_atom_id = db.get_key_to_atom_id_mapping(label)
                for key, atom_id in iteritems(key_to_atom_id):
                    self._key_to_atom_id_and_label[key] = (atom_id, label_id)
                vertex_sizes[label] = len(key_to_atom_id)

        # Last of all, set up the listener for packets
        # NOTE: Has to be done last as otherwise will receive SCP messages
        # sent above!
        if self._receiver_listener is None:
            self._receiver_listener = ConnectionListener(
                self._receiver_connection)
            self._receiver_listener.add_callback(self._receive_packet_callback)
            self._receiver_listener.start()

    def __get_live_input_details(self, db_reader, send_label):
        if self._machine_vertices:
            return db_reader.get_machine_live_input_details(send_label)
        return db_reader.get_live_input_details(send_label)

    def __get_live_output_details(self, db_reader, receive_label):
        if self._machine_vertices:
            host, port, strip_sdp, board_address, tag = \
                db_reader.get_machine_live_output_details(
                    receive_label, self._live_packet_gather_label)
        else:
            host, port, strip_sdp, board_address, tag = \
                db_reader.get_live_output_details(
                    receive_label, self._live_packet_gather_label)
        if not strip_sdp:
            raise Exception("Currently, only IP tags which strip the SDP "
                            "headers are supported")
        return host, port, board_address, tag

    def __update_tag(self, connection, board_address, tag):
        # Update an IP Tag with the sender's address and port
        # This avoids issues with NAT firewalls
        logger.info("Updating tag for {}".format(board_address))
        request = IPTagSet(
            0, 0, [0, 0, 0, 0], 0, tag, strip=True, use_sender=True)
        request.sdp_header.flags = SDPFlag.REPLY_EXPECTED_NO_P2P
        update_sdp_header_for_udp_send(request.sdp_header, 0, 0)
        data = _TWO_SKIP.pack() + request.bytestring
        sent = False
        tries_to_go = 3
        while not sent:
            try:
                connection.send_to(data, (board_address, SCP_SCAMP_PORT))
                response_data = connection.receive(1.0)
                request.get_scp_response().read_bytestring(response_data, 2)
                sent = True
            except SpinnmanTimeoutException:
                if not tries_to_go:
                    logger.info("No more tries - Error!")
                    reraise(*sys.exc_info())

                logger.info("Timeout, retrying")
                tries_to_go -= 1
        logger.info("Done updating tag for {}".format(board_address))

    def _handle_possible_rerun_state(self):
        # reset from possible previous calls
        if self._sender_connection is not None:
            self._sender_connection.close()
            self._sender_connection = None
        if self._receiver_connection is not None:
            self._receiver_connection.close()
            self._receiver_connection = None
        if self._receiver_listener is not None:
            self._receiver_listener.close()
            self._receiver_listener = None

    def __launch_thread(self, kind, label, callback):
        thread = Thread(
            target=callback, args=(label, self),
            name="{} callback thread for live_event_connection {}:{}".format(
                kind, self._local_port, self._local_ip_address))
        thread.start()

    def _start_resume_callback(self):
        for label, callbacks in iteritems(self._start_resume_callbacks):
            for callback in callbacks:
                self.__launch_thread("start_resume", label, callback)

    def _stop_pause_callback(self):
        for label, callbacks in iteritems(self._pause_stop_callbacks):
            for callback in callbacks:
                self.__launch_thread("pause_stop", label, callback)

    def _receive_packet_callback(self, packet):
        try:
            if packet.eieio_header.is_time:
                self.__handle_time_packet(packet)
            else:
                self.__handle_no_time_packet(packet)
        except Exception:
            logger.warning("problem handling received packet", exc_info=True)

    def __handle_time_packet(self, packet):
        key_times_labels = OrderedDict()
        while packet.is_next_element:
            element = packet.next_element
            time = element.payload
            key = element.key
            if key in self._key_to_atom_id_and_label:
                atom_id, label_id = self._key_to_atom_id_and_label[key]
                if time not in key_times_labels:
                    key_times_labels[time] = dict()
                if label_id not in key_times_labels[time]:
                    key_times_labels[time][label_id] = list()
                key_times_labels[time][label_id].append(atom_id)

        for time in iterkeys(key_times_labels):
            for label_id in iterkeys(key_times_labels[time]):
                label = self._receive_labels[label_id]
                for callback in self._live_event_callbacks[label_id]:
                    callback(label, time, key_times_labels[time][label_id])

    def __handle_no_time_packet(self, packet):
        while packet.is_next_element:
            element = packet.next_element
            key = element.key
            if key in self._key_to_atom_id_and_label:
                atom_id, label_id = self._key_to_atom_id_and_label[key]
                for callback in self._live_event_callbacks[label_id]:
                    if isinstance(element, KeyPayloadDataElement):
                        callback(self._receive_labels[label_id], atom_id,
                                 element.payload)
                    else:
                        callback(self._receive_labels[label_id], atom_id)

    def send_event(self, label, atom_id, send_full_keys=False):
        """ Send an event from a single atom

        :param label: \
            The label of the vertex from which the event will originate
        :type label: str
        :param atom_id: The ID of the atom sending the event
        :type atom_id: int
        :param send_full_keys: Determines whether to send full 32-bit keys,\
            getting the key for each atom from the database, or whether to\
            send 16-bit atom IDs directly
        :type send_full_keys: bool
        """
        self.send_events(label, [atom_id], send_full_keys)

    def send_events(self, label, atom_ids, send_full_keys=False):
        """ Send a number of events

        :param label: \
            The label of the vertex from which the events will originate
        :type label: str
        :param atom_ids: array-like of atom IDs sending events
        :type atom_ids: [int]
        :param send_full_keys: Determines whether to send full 32-bit keys,\
            getting the key for each atom from the database, or whether to\
            send 16-bit atom IDs directly
        :type send_full_keys: bool
        """
        max_keys = _MAX_HALF_KEYS_PER_PACKET
        msg_type = EIEIOType.KEY_16_BIT
        if send_full_keys:
            max_keys = _MAX_FULL_KEYS_PER_PACKET
            msg_type = EIEIOType.KEY_32_BIT

        pos = 0
        while pos < len(atom_ids):
            message = EIEIODataMessage.create(msg_type)
            events_in_packet = 0
            while pos < len(atom_ids) and events_in_packet < max_keys:
                key = atom_ids[pos]
                if send_full_keys:
                    key = self._atom_id_to_key[label][key]
                message.add_key(key)
                pos += 1
                events_in_packet += 1
            ip_address, port = self._send_address_details[label]
            self._sender_connection.send_eieio_message_to(
                message, ip_address, port)

    def close(self):
        DatabaseConnection.close(self)
示例#13
0
class LiveEventConnection(DatabaseConnection):
    """ A connection for receiving and sending live events from and to\
        SpiNNaker
    """
    __slots__ = [
        "_atom_id_to_key", "__error_keys", "__init_callbacks",
        "__key_to_atom_id_and_label", "__live_event_callbacks",
        "__live_packet_gather_label", "__machine_vertices",
        "__pause_stop_callbacks", "__receive_labels", "__receiver_connection",
        "__receiver_listener", "__send_address_details", "__send_labels",
        "__sender_connection", "__start_resume_callbacks"
    ]

    def __init__(self,
                 live_packet_gather_label,
                 receive_labels=None,
                 send_labels=None,
                 local_host=None,
                 local_port=NOTIFY_PORT,
                 machine_vertices=False):
        """
        :param str live_packet_gather_label:
            The label of the :py:class:`LivePacketGather` vertex to which
            received events are being sent
        :param iterable(str) receive_labels:
            Labels of vertices from which live events will be received.
        :param iterable(str) send_labels:
            Labels of vertices to which live events will be sent
        :param str local_host:
            Optional specification of the local hostname or IP address of the
            interface to listen on
        :param int local_port:
            Optional specification of the local port to listen on. Must match
            the port that the toolchain will send the notification on (19999
            by default)
        """
        # pylint: disable=too-many-arguments
        super(LiveEventConnection, self).__init__(self.__do_start_resume,
                                                  self.__do_stop_pause,
                                                  local_host=local_host,
                                                  local_port=local_port)

        self.add_database_callback(self.__read_database_callback)

        self.__live_packet_gather_label = live_packet_gather_label
        self.__receive_labels = (list(receive_labels)
                                 if receive_labels is not None else None)
        self.__send_labels = (list(send_labels)
                              if send_labels is not None else None)
        self.__machine_vertices = machine_vertices
        self.__sender_connection = None
        self.__send_address_details = dict()
        # Also used by SpynnakerPoissonControlConnection
        self._atom_id_to_key = dict()
        self.__key_to_atom_id_and_label = dict()
        self.__live_event_callbacks = list()
        self.__start_resume_callbacks = dict()
        self.__pause_stop_callbacks = dict()
        self.__init_callbacks = dict()
        if receive_labels is not None:
            for label in receive_labels:
                self.__live_event_callbacks.append(list())
                self.__start_resume_callbacks[label] = list()
                self.__pause_stop_callbacks[label] = list()
                self.__init_callbacks[label] = list()
        if send_labels is not None:
            for label in send_labels:
                self.__start_resume_callbacks[label] = list()
                self.__pause_stop_callbacks[label] = list()
                self.__init_callbacks[label] = list()
        self.__receiver_listener = None
        self.__receiver_connection = None
        self.__error_keys = set()

    def add_send_label(self, label):
        if self.__send_labels is None:
            self.__send_labels = list()
        if label not in self.__send_labels:
            self.__send_labels.append(label)
        if label not in self.__start_resume_callbacks:
            self.__start_resume_callbacks[label] = list()
            self.__pause_stop_callbacks[label] = list()
            self.__init_callbacks[label] = list()

    def add_receive_label(self, label):
        if self.__receive_labels is None:
            self.__receive_labels = list()
        if label not in self.__receive_labels:
            self.__receive_labels.append(label)
            self.__live_event_callbacks.append(list())
        if label not in self.__start_resume_callbacks:
            self.__start_resume_callbacks[label] = list()
            self.__pause_stop_callbacks[label] = list()
            self.__init_callbacks[label] = list()

    def add_init_callback(self, label, init_callback):
        """ Add a callback to be called to initialise a vertex

        :param str label:
            The label of the vertex to be notified about. Must be one of the
            vertices listed in the constructor
        :param init_callback: A function to be called to initialise the\
            vertex. This should take as parameters the label of the vertex,\
            the number of neurons in the population, the run time of the\
            simulation in milliseconds, and the simulation timestep in\
            milliseconds
        :type init_callback: callable(str, int, float, float) -> None
        """
        self.__init_callbacks[label].append(init_callback)

    def add_receive_callback(self,
                             label,
                             live_event_callback,
                             translate_key=True):
        """ Add a callback for the reception of live events from a vertex

        :param str label: The label of the vertex to be notified about.
            Must be one of the vertices listed in the constructor
        :param live_event_callback: A function to be called when events are\
            received. This should take as parameters the label of the vertex,\
            the simulation timestep when the event occurred, and an\
            array-like of atom IDs.
        :type live_event_callback: callable(str, int, list(int)) -> None
        :param bool translate_key:
            True if the key is to be converted to an atom ID, False if the
            key should stay a key
        """
        label_id = self.__receive_labels.index(label)
        logger.info("Receive callback {} registered to label {}".format(
            live_event_callback, label))
        self.__live_event_callbacks[label_id].append(
            (live_event_callback, translate_key))

    def add_start_callback(self, label, start_callback):
        """ Add a callback for the start of the simulation

        :param start_callback: A function to be called when the start\
            message has been received. This function should take the label of\
            the referenced vertex, and an instance of this class, which can\
            be used to send events
        :type start_callback: callable(str, LiveEventConnection) -> None
        :param str label: the label of the function to be sent
        """
        logger.warning(
            "the method 'add_start_callback(label, start_callback)' is in "
            "deprecation, and will be replaced with the method "
            "'add_start_resume_callback(label, start_resume_callback)' in a "
            "future release.")
        self.add_start_resume_callback(label, start_callback)

    def add_start_resume_callback(self, label, start_resume_callback):
        """ Add a callback for the start and resume state of the simulation

        :param str label: the label of the function to be sent
        :param start_resume_callback: A function to be called when the start\
            or resume message has been received. This function should take \
            the label of the referenced vertex, and an instance of this \
            class, which can be used to send events.
        :type start_resume_callback: callable(str, LiveEventConnection) -> None
        :rtype: None
        """
        self.__start_resume_callbacks[label].append(start_resume_callback)

    def add_pause_stop_callback(self, label, pause_stop_callback):
        """ Add a callback for the pause and stop state of the simulation

        :param str label: the label of the function to be sent
        :param pause_stop_callback: A function to be called when the pause\
            or stop message has been received. This function should take the\
            label of the referenced  vertex, and an instance of this class,\
            which can be used to send events.
        :type pause_stop_callback: callable(str, LiveEventConnection) -> None
        :rtype: None
        """
        self.__pause_stop_callbacks[label].append(pause_stop_callback)

    def __read_database_callback(self, db_reader):
        self.__handle_possible_rerun_state()

        vertex_sizes = OrderedDict()
        run_time_ms = db_reader.get_configuration_parameter_value("runtime")
        machine_timestep_ms = db_reader.get_configuration_parameter_value(
            "machine_time_step") / 1000.0

        if self.__send_labels is not None:
            self.__init_sender(db_reader, vertex_sizes)

        if self.__receive_labels is not None:
            self.__init_receivers(db_reader, vertex_sizes)

        for label, vertex_size in iteritems(vertex_sizes):
            for init_callback in self.__init_callbacks[label]:
                init_callback(label, vertex_size, run_time_ms,
                              machine_timestep_ms)

    def __init_sender(self, db, vertex_sizes):
        if self.__sender_connection is None:
            self.__sender_connection = UDPConnection()
        for label in self.__send_labels:
            self.__send_address_details[label] = self.__get_live_input_details(
                db, label)
            if self.__machine_vertices:
                key, _ = db.get_machine_live_input_key(label)
                self._atom_id_to_key[label] = {0: key}
                vertex_sizes[label] = 1
            else:
                self._atom_id_to_key[label] = db.get_atom_id_to_key_mapping(
                    label)
                vertex_sizes[label] = len(self._atom_id_to_key[label])

    def __init_receivers(self, db, vertex_sizes):
        # Set up a single connection for receive
        if self.__receiver_connection is None:
            self.__receiver_connection = EIEIOConnection()
        receivers = set()
        for label_id, label in enumerate(self.__receive_labels):
            _, port, board_address, tag = self.__get_live_output_details(
                db, label)

            # Update the tag if not already done
            if (board_address, port, tag) not in receivers:
                self.__update_tag(self.__receiver_connection, board_address,
                                  tag)
                receivers.add((board_address, port, tag))
                send_port_trigger_message(self.__receiver_connection,
                                          board_address)

            logger.info("Listening for traffic from {} on {}:{}", label,
                        self.__receiver_connection.local_ip_address,
                        self.__receiver_connection.local_port)

            if self.__machine_vertices:
                key, _ = db.get_machine_live_output_key(
                    label, self.__live_packet_gather_label)
                self.__key_to_atom_id_and_label[key] = (0, label_id)
                vertex_sizes[label] = 1
            else:
                key_to_atom_id = db.get_key_to_atom_id_mapping(label)
                for key, atom_id in iteritems(key_to_atom_id):
                    self.__key_to_atom_id_and_label[key] = (atom_id, label_id)
                vertex_sizes[label] = len(key_to_atom_id)

        # Last of all, set up the listener for packets
        # NOTE: Has to be done last as otherwise will receive SCP messages
        # sent above!
        if self.__receiver_listener is None:
            self.__receiver_listener = ConnectionListener(
                self.__receiver_connection)
            self.__receiver_listener.add_callback(self.__do_receive_packet)
            self.__receiver_listener.start()

    def __get_live_input_details(self, db_reader, send_label):
        if self.__machine_vertices:
            x, y, p = db_reader.get_placement(send_label)
        else:
            x, y, p = db_reader.get_placements(send_label)[0]

        ip_address = db_reader.get_ip_address(x, y)
        return x, y, p, ip_address

    def __get_live_output_details(self, db_reader, receive_label):
        if self.__machine_vertices:
            host, port, strip_sdp, board_address, tag = \
                db_reader.get_machine_live_output_details(
                    receive_label, self.__live_packet_gather_label)
            if host is None:
                raise Exception(
                    "no live output tag found for {} in machine graph".format(
                        receive_label))
        else:
            host, port, strip_sdp, board_address, tag = \
                db_reader.get_live_output_details(
                    receive_label, self.__live_packet_gather_label)
            if host is None:
                raise Exception(
                    "no live output tag found for {} in app graph".format(
                        receive_label))
        if not strip_sdp:
            raise Exception("Currently, only IP tags which strip the SDP "
                            "headers are supported")
        return host, port, board_address, tag

    def __update_tag(self, connection, board_address, tag):
        # Update an IP Tag with the sender's address and port
        # This avoids issues with NAT firewalls
        logger.debug("Updating tag for {}".format(board_address))
        request = IPTagSet(0,
                           0, [0, 0, 0, 0],
                           0,
                           tag,
                           strip=True,
                           use_sender=True)
        request.sdp_header.flags = SDPFlag.REPLY_EXPECTED_NO_P2P
        update_sdp_header_for_udp_send(request.sdp_header, 0, 0)
        data = _TWO_SKIP.pack() + request.bytestring
        sent = False
        tries_to_go = 3
        while not sent:
            try:
                connection.send_to(data, (board_address, SCP_SCAMP_PORT))
                response_data = connection.receive(1.0)
                request.get_scp_response().read_bytestring(
                    response_data, _TWO_SKIP.size)
                sent = True
            except SpinnmanTimeoutException:
                if not tries_to_go:
                    logger.info("No more tries - Error!")
                    reraise(*sys.exc_info())

                logger.info("Timeout, retrying")
                tries_to_go -= 1
        logger.debug("Done updating tag for {}".format(board_address))

    def __handle_possible_rerun_state(self):
        # reset from possible previous calls
        if self.__sender_connection is not None:
            self.__sender_connection.close()
            self.__sender_connection = None
        if self.__receiver_listener is not None:
            self.__receiver_listener.close()
            self.__receiver_listener = None
        if self.__receiver_connection is not None:
            self.__receiver_connection.close()
            self.__receiver_connection = None

    def __launch_thread(self, kind, label, callback):
        thread = Thread(
            target=callback,
            args=(label, self),
            name="{} callback thread for live_event_connection {}:{}".format(
                kind, self._local_port, self._local_ip_address))
        thread.start()

    def __do_start_resume(self):
        for label, callbacks in iteritems(self.__start_resume_callbacks):
            for callback in callbacks:
                self.__launch_thread("start_resume", label, callback)

    def __do_stop_pause(self):
        for label, callbacks in iteritems(self.__pause_stop_callbacks):
            for callback in callbacks:
                self.__launch_thread("pause_stop", label, callback)

    def __do_receive_packet(self, packet):
        # pylint: disable=broad-except
        logger.debug("Received packet")
        try:
            if packet.eieio_header.is_time:
                self.__handle_time_packet(packet)
            else:
                self.__handle_no_time_packet(packet)
        except Exception:
            logger.warning("problem handling received packet", exc_info=True)

    def __handle_time_packet(self, packet):
        key_times_labels = OrderedDict()
        atoms_times_labels = OrderedDict()
        while packet.is_next_element:
            element = packet.next_element
            time = element.payload
            key = element.key
            if key in self.__key_to_atom_id_and_label:
                atom_id, label_id = self.__key_to_atom_id_and_label[key]
                if time not in key_times_labels:
                    key_times_labels[time] = dict()
                    atoms_times_labels[time] = dict()
                if label_id not in key_times_labels[time]:
                    key_times_labels[time][label_id] = list()
                    atoms_times_labels[time][label_id] = list()
                key_times_labels[time][label_id].append(key)
                atoms_times_labels[time][label_id].append(atom_id)
            else:
                self.__handle_unknown_key(key)

        for time in iterkeys(key_times_labels):
            for label_id in iterkeys(key_times_labels[time]):
                label = self.__receive_labels[label_id]
                for c_back, use_atom in self.__live_event_callbacks[label_id]:
                    if use_atom:
                        c_back(label, time, atoms_times_labels[time][label_id])
                    else:
                        c_back(label, time, key_times_labels[time][label_id])

    def __handle_no_time_packet(self, packet):
        while packet.is_next_element:
            element = packet.next_element
            key = element.key
            if key in self.__key_to_atom_id_and_label:
                atom_id, label_id = self.__key_to_atom_id_and_label[key]
                label = self.__receive_labels[label_id]
                for c_back, use_atom in self.__live_event_callbacks[label_id]:
                    if isinstance(element, KeyPayloadDataElement):
                        if use_atom:
                            c_back(label, atom_id, element.payload)
                        else:
                            c_back(label, key, element.payload)
                    else:
                        if use_atom:
                            c_back(label, atom_id)
                        else:
                            c_back(label, key)
            else:
                self.__handle_unknown_key(key)

    def __handle_unknown_key(self, key):
        if key not in self.__error_keys:
            self.__error_keys.add(key)
            logger.warning("Received unexpected key {}".format(key))

    def send_event(self, label, atom_id, send_full_keys=False):
        """ Send an event from a single atom

        :param str label:
            The label of the vertex from which the event will originate
        :param int atom_id: The ID of the atom sending the event
        :param bool send_full_keys:
            Determines whether to send full 32-bit keys, getting the key for
            each atom from the database, or whether to send 16-bit atom IDs
            directly
        """
        self.send_events(label, [atom_id], send_full_keys)

    def send_events(self, label, atom_ids, send_full_keys=False):
        """ Send a number of events

        :param str label:
            The label of the vertex from which the events will originate
        :param list(int) atom_ids: array-like of atom IDs sending events
        :param bool send_full_keys:
            Determines whether to send full 32-bit keys, getting the key for
            each atom from the database, or whether to send 16-bit atom IDs
            directly
        """
        max_keys = _MAX_HALF_KEYS_PER_PACKET
        msg_type = EIEIOType.KEY_16_BIT
        if send_full_keys:
            max_keys = _MAX_FULL_KEYS_PER_PACKET
            msg_type = EIEIOType.KEY_32_BIT

        pos = 0
        x, y, p, ip_address = self.__send_address_details[label]
        while pos < len(atom_ids):
            message = EIEIODataMessage.create(msg_type)
            events_in_packet = 0
            while pos < len(atom_ids) and events_in_packet < max_keys:
                key = atom_ids[pos]
                if send_full_keys:
                    key = self._atom_id_to_key[label][key]
                message.add_key(key)
                pos += 1
                events_in_packet += 1

            self.__sender_connection.send_to(
                self.__get_sdp_data(message, x, y, p),
                (ip_address, SCP_SCAMP_PORT))

    def send_event_with_payload(self, label, atom_id, payload):
        """ Send an event with a payload from a single atom

        :param str label:
            The label of the vertex from which the event will originate
        :param int atom_id: The ID of the atom sending the event
        :param int payload: The payload to send
        """
        self.send_events_with_payloads(label, [(atom_id, payload)])

    def send_events_with_payloads(self, label, atom_ids_and_payloads):
        """ Send a number of events with payloads

        :param str label:
            The label of the vertex from which the events will originate
        :param list(tuple(int,int)) atom_ids_and_payloads:
            array-like of tuples of atom IDs sending events with their payloads
        """
        msg_type = EIEIOType.KEY_PAYLOAD_32_BIT
        max_keys = _MAX_FULL_KEYS_PAYLOADS_PER_PACKET
        pos = 0
        x, y, p, ip_address = self.__send_address_details[label]
        while pos < len(atom_ids_and_payloads):
            message = EIEIODataMessage.create(msg_type)
            events = 0
            while pos < len(atom_ids_and_payloads) and events < max_keys:
                key, payload = atom_ids_and_payloads[pos]
                key = self._atom_id_to_key[label][key]
                message.add_key_and_payload(key, payload)
                pos += 1
                events += 1

            self.__sender_connection.send_to(
                self.__get_sdp_data(message, x, y, p),
                (ip_address, SCP_SCAMP_PORT))

    def send_eieio_message(self, message, label):
        """ Send an EIEIO message (using one-way the live input) to the \
            vertex with the given label.

        :param ~spinnman.messages.eieio.AbstractEIEIOMessage message:
            The EIEIO message to send
        :param str label: The label of the receiver machine vertex
        """
        target = self.__send_address_details[label]
        if target is None:
            return
        x, y, p, ip_address = target
        self.__sender_connection.send_to(self.__get_sdp_data(message, x, y, p),
                                         (ip_address, SCP_SCAMP_PORT))

    def close(self):
        self.__handle_possible_rerun_state()
        super(LiveEventConnection, self).close()

    @staticmethod
    def __get_sdp_data(message, x, y, p):
        # Create an SDP message - no reply so source is unimportant
        # SDP port can be anything except 0 as the target doesn't care
        sdp_message = SDPMessage(SDPHeader(flags=SDPFlag.REPLY_NOT_EXPECTED,
                                           tag=0,
                                           destination_port=1,
                                           destination_cpu=p,
                                           destination_chip_x=x,
                                           destination_chip_y=y,
                                           source_port=0,
                                           source_cpu=0,
                                           source_chip_x=0,
                                           source_chip_y=0),
                                 data=message.bytestring)
        return _TWO_SKIP.pack() + sdp_message.bytestring
示例#14
0
class LiveEventConnection(DatabaseConnection):
    """ A connection for receiving and sending live events from and to\
        SpiNNaker
    """
    __slots__ = [
        "_atom_id_to_key", "_init_callbacks", "_key_to_atom_id_and_label",
        "_listeners", "_live_event_callbacks", "_live_packet_gather_label",
        "_machine_vertices", "_pause_stop_callbacks", "_receive_labels",
        "_receivers", "_send_address_details", "_send_labels",
        "_sender_connection", "_start_resume_callbacks"
    ]

    def __init__(self,
                 live_packet_gather_label,
                 receive_labels=None,
                 send_labels=None,
                 local_host=None,
                 local_port=NOTIFY_PORT,
                 machine_vertices=False):
        """
        :param live_packet_gather_label: The label of the LivePacketGather\
            vertex to which received events are being sent
        :param receive_labels: \
            Labels of vertices from which live events will be received.
        :type receive_labels: iterable of str
        :param send_labels: \
            Labels of vertices to which live events will be sent
        :type send_labels: iterable of str
        :param local_host: Optional specification of the local hostname or\
            IP address of the interface to listen on
        :type local_host: str
        :param local_port: Optional specification of the local port to listen\
            on. Must match the port that the toolchain will send the\
            notification on (19999 by default)
        :type local_port: int
        """
        # pylint: disable=too-many-arguments
        super(LiveEventConnection, self).__init__(self._start_resume_callback,
                                                  self._stop_pause_callback,
                                                  local_host=local_host,
                                                  local_port=local_port)

        self.add_database_callback(self._read_database_callback)

        self._live_packet_gather_label = live_packet_gather_label
        self._receive_labels = receive_labels
        self._send_labels = send_labels
        self._machine_vertices = machine_vertices
        self._sender_connection = None
        self._send_address_details = dict()
        self._atom_id_to_key = dict()
        self._key_to_atom_id_and_label = dict()
        self._live_event_callbacks = list()
        self._start_resume_callbacks = dict()
        self._pause_stop_callbacks = dict()
        self._init_callbacks = dict()
        if receive_labels is not None:
            for label in receive_labels:
                self._live_event_callbacks.append(list())
                self._start_resume_callbacks[label] = list()
                self._pause_stop_callbacks[label] = list()
                self._init_callbacks[label] = list()
        if send_labels is not None:
            for label in send_labels:
                self._start_resume_callbacks[label] = list()
                self._pause_stop_callbacks[label] = list()
                self._init_callbacks[label] = list()
        self._receivers = dict()
        self._listeners = dict()

    def add_init_callback(self, label, init_callback):
        """ Add a callback to be called to initialise a vertex

        :param label: The label of the vertex to be notified about. Must be\
            one of the vertices listed in the constructor
        :type label: str
        :param init_callback: A function to be called to initialise the\
            vertex. This should take as parameters the label of the vertex,\
            the number of neurons in the population, the run time of the\
            simulation in milliseconds, and the simulation timestep in\
            milliseconds
        :type init_callback: function(str, int, float, float) -> None
        """
        self._init_callbacks[label].append(init_callback)

    def add_receive_callback(self, label, live_event_callback):
        """ Add a callback for the reception of live events from a vertex

        :param label: The label of the vertex to be notified about. Must be\
            one of the vertices listed in the constructor
        :type label: str
        :param live_event_callback: A function to be called when events are\
            received. This should take as parameters the label of the vertex,\
            the simulation timestep when the event occurred, and an\
            array-like of atom IDs.
        :type live_event_callback: function(str, int, [int]) -> None
        """
        label_id = self._receive_labels.index(label)
        self._live_event_callbacks[label_id].append(live_event_callback)

    def add_start_callback(self, label, start_callback):
        """ Add a callback for the start of the simulation

        :param start_callback: A function to be called when the start\
            message has been received. This function should take the label of\
            the referenced vertex, and an instance of this class, which can\
            be used to send events
        :type start_callback: function(str, \
            :py:class:`SpynnakerLiveEventConnection`) -> None
        :param label: the label of the function to be sent
        :type label: str
        """
        logger.warning(
            "the method 'add_start_callback(label, start_callback)' is in "
            "deprecation, and will be replaced with the method "
            "'add_start_resume_callback(label, start_resume_callback)' in a "
            "future release.")
        self.add_start_resume_callback(label, start_callback)

    def add_start_resume_callback(self, label, start_resume_callback):
        self._start_resume_callbacks[label].append(start_resume_callback)

    def add_pause_stop_callback(self, label, pause_stop_callback):
        """ Add a callback for the pause and stop state of the simulation

        :param label: the label of the function to be sent
        :type label: str
        :param pause_stop_callback: A function to be called when the pause\
            or stop message has been received. This function should take the\
            label of the referenced  vertex, and an instance of this class,\
            which can be used to send events.
        :type pause_stop_callback: function(str, \
            :py:class:`SpynnakerLiveEventConnection`) -> None
        :rtype: None
        """
        self._pause_stop_callbacks[label].append(pause_stop_callback)

    def _read_database_callback(self, db_reader):
        self._handle_possible_rerun_state()

        vertex_sizes = OrderedDict()
        run_time_ms = db_reader.get_configuration_parameter_value("runtime")
        machine_timestep_ms = db_reader.get_configuration_parameter_value(
            "machine_time_step") / 1000.0

        if self._send_labels is not None:
            self._init_sender(db_reader, vertex_sizes)

        if self._receive_labels is not None:
            self._init_receivers(db_reader, vertex_sizes)

        for label, vertex_size in iteritems(vertex_sizes):
            for init_callback in self._init_callbacks[label]:
                init_callback(label, vertex_size, run_time_ms,
                              machine_timestep_ms)

    def _init_sender(self, db, vertex_sizes):
        self._sender_connection = EIEIOConnection()
        for label in self._send_labels:
            self._send_address_details[label] = self.__get_live_input_details(
                db, label)
            if self._machine_vertices:
                key, _ = db.get_machine_live_input_key(label)
                self._atom_id_to_key[label] = {0: key}
                vertex_sizes[label] = 1
            else:
                self._atom_id_to_key[label] = db.get_atom_id_to_key_mapping(
                    label)
                vertex_sizes[label] = len(self._atom_id_to_key[label])

    def _init_receivers(self, db, vertex_sizes):
        for label_id, label in enumerate(self._receive_labels):
            host, port, board_address = self.__get_live_output_details(
                db, label)
            if port not in self._receivers:
                receiver = EIEIOConnection(local_port=port)
                listener = ConnectionListener(receiver)
                listener.add_callback(self._receive_packet_callback)
                listener.start()
                self._receivers[port] = receiver
                self._listeners[port] = listener

            send_port_trigger_message(receiver, board_address)
            logger.info("Listening for traffic from {} on {}:{}", label, host,
                        port)

            if self._machine_vertices:
                key, _ = db.get_machine_live_output_key(
                    label, self._live_packet_gather_label)
                self._key_to_atom_id_and_label[key] = (0, label_id)
                vertex_sizes[label] = 1
            else:
                key_to_atom_id = db.get_key_to_atom_id_mapping(label)
                for key, atom_id in iteritems(key_to_atom_id):
                    self._key_to_atom_id_and_label[key] = (atom_id, label_id)
                vertex_sizes[label] = len(key_to_atom_id)

    def __get_live_input_details(self, db_reader, send_label):
        if self._machine_vertices:
            return db_reader.get_machine_live_input_details(send_label)
        return db_reader.get_live_input_details(send_label)

    def __get_live_output_details(self, db_reader, receive_label):
        if self._machine_vertices:
            host, port, strip_sdp, board_address = \
                db_reader.get_machine_live_output_details(
                    receive_label, self._live_packet_gather_label)
        else:
            host, port, strip_sdp, board_address = \
                db_reader.get_live_output_details(
                    receive_label, self._live_packet_gather_label)
        if not strip_sdp:
            raise Exception("Currently, only IP tags which strip the SDP "
                            "headers are supported")
        return host, port, board_address

    def _handle_possible_rerun_state(self):
        # reset from possible previous calls
        if self._sender_connection is not None:
            self._sender_connection.close()
            self._sender_connection = None
        for port in self._receivers:
            self._receivers[port].close()
        self._receivers = dict()
        for port in self._listeners:
            self._listeners[port].close()
        self._listeners = dict()

    def __launch_thread(self, kind, label, callback):
        thread = Thread(
            target=callback,
            args=(label, self),
            name="{} callback thread for live_event_connection {}:{}".format(
                kind, self._local_port, self._local_ip_address))
        thread.start()

    def _start_resume_callback(self):
        for label, callbacks in iteritems(self._start_resume_callbacks):
            for callback in callbacks:
                self.__launch_thread("start_resume", label, callback)

    def _stop_pause_callback(self):
        for label, callbacks in iteritems(self._pause_stop_callbacks):
            for callback in callbacks:
                self.__launch_thread("pause_stop", label, callback)

    def _receive_packet_callback(self, packet):
        try:
            if packet.eieio_header.is_time:
                self.__handle_time_packet(packet)
            else:
                self.__handle_no_time_packet(packet)
        except Exception:
            logger.warning("problem handling received packet", exc_info=True)

    def __handle_time_packet(self, packet):
        key_times_labels = OrderedDict()
        while packet.is_next_element:
            element = packet.next_element
            time = element.payload
            key = element.key
            if key in self._key_to_atom_id_and_label:
                atom_id, label_id = self._key_to_atom_id_and_label[key]
                if time not in key_times_labels:
                    key_times_labels[time] = dict()
                if label_id not in key_times_labels[time]:
                    key_times_labels[time][label_id] = list()
                key_times_labels[time][label_id].append(atom_id)

        for time in iterkeys(key_times_labels):
            for label_id in iterkeys(key_times_labels[time]):
                label = self._receive_labels[label_id]
                for callback in self._live_event_callbacks[label_id]:
                    callback(label, time, key_times_labels[time][label_id])

    def __handle_no_time_packet(self, packet):
        while packet.is_next_element:
            element = packet.next_element
            key = element.key
            if key in self._key_to_atom_id_and_label:
                atom_id, label_id = self._key_to_atom_id_and_label[key]
                for callback in self._live_event_callbacks[label_id]:
                    if isinstance(element, KeyPayloadDataElement):
                        callback(self._receive_labels[label_id], atom_id,
                                 element.payload)
                    else:
                        callback(self._receive_labels[label_id], atom_id)

    def send_event(self, label, atom_id, send_full_keys=False):
        """ Send an event from a single atom

        :param label: \
            The label of the vertex from which the event will originate
        :type label: str
        :param atom_id: The ID of the atom sending the event
        :type atom_id: int
        :param send_full_keys: Determines whether to send full 32-bit keys,\
            getting the key for each atom from the database, or whether to\
            send 16-bit atom IDs directly
        :type send_full_keys: bool
        """
        self.send_events(label, [atom_id], send_full_keys)

    def send_events(self, label, atom_ids, send_full_keys=False):
        """ Send a number of events

        :param label: \
            The label of the vertex from which the events will originate
        :type label: str
        :param atom_ids: array-like of atom IDs sending events
        :type atom_ids: [int]
        :param send_full_keys: Determines whether to send full 32-bit keys,\
            getting the key for each atom from the database, or whether to\
            send 16-bit atom IDs directly
        :type send_full_keys: bool
        """
        max_keys = _MAX_HALF_KEYS_PER_PACKET
        msg_type = EIEIOType.KEY_16_BIT
        if send_full_keys:
            max_keys = _MAX_FULL_KEYS_PER_PACKET
            msg_type = EIEIOType.KEY_32_BIT

        pos = 0
        while pos < len(atom_ids):
            message = EIEIODataMessage.create(msg_type)
            events_in_packet = 0
            while pos < len(atom_ids) and events_in_packet < max_keys:
                key = atom_ids[pos]
                if send_full_keys:
                    key = self._atom_id_to_key[label][key]
                message.add_key(key)
                pos += 1
                events_in_packet += 1
            ip_address, port = self._send_address_details[label]
            self._sender_connection.send_eieio_message_to(
                message, ip_address, port)

    def close(self):
        DatabaseConnection.close(self)
示例#15
0
    def _read_database_callback(self, database_reader):
        self._handle_possible_rerun_state()

        vertex_sizes = OrderedDict()
        run_time_ms = database_reader.get_configuration_parameter_value(
            "runtime")
        machine_timestep_ms = \
            database_reader.get_configuration_parameter_value(
                "machine_time_step") / 1000.0

        if self._send_labels is not None:
            self._sender_connection = EIEIOConnection()
            for send_label in self._send_labels:
                ip_address, port = None, None
                if self._machine_vertices:
                    ip_address, port = \
                        database_reader.get_machine_live_input_details(
                            send_label)
                else:
                    ip_address, port = database_reader.get_live_input_details(
                        send_label)
                self._send_address_details[send_label] = (ip_address, port)
                if self._machine_vertices:
                    key, _ = database_reader.get_machine_live_input_key(
                        send_label)
                    self._atom_id_to_key[send_label] = {0: key}
                    vertex_sizes[send_label] = 1
                else:
                    self._atom_id_to_key[send_label] = \
                        database_reader.get_atom_id_to_key_mapping(send_label)
                    vertex_sizes[send_label] = len(
                        self._atom_id_to_key[send_label])

        if self._receive_labels is not None:

            label_id = 0
            for receive_label in self._receive_labels:
                host, port, strip_sdp = None, None, None
                if self._machine_vertices:
                    host, port, strip_sdp, board_address = \
                        database_reader.get_machine_live_output_details(
                            receive_label, self._live_packet_gather_label)
                else:
                    host, port, strip_sdp, board_address = \
                        database_reader.get_live_output_details(
                            receive_label, self._live_packet_gather_label)
                if not strip_sdp:
                    raise Exception("Currently, only ip tags which strip the"
                                    " SDP headers are supported")
                if port not in self._receivers:
                    receiver = EIEIOConnection(local_port=port)
                    send_port_trigger_message(receiver, board_address)
                    listener = ConnectionListener(receiver)
                    listener.add_callback(self._receive_packet_callback)
                    listener.start()
                    self._receivers[port] = receiver
                    self._listeners[port] = listener
                logger.info("Listening for traffic from {} on {}:{}".format(
                    receive_label, host, port))

                if self._machine_vertices:
                    key, _ = database_reader.get_machine_live_output_key(
                        receive_label, self._live_packet_gather_label)
                    self._key_to_atom_id_and_label[key] = (0, label_id)
                    vertex_sizes[receive_label] = 1
                else:
                    key_to_atom_id = \
                        database_reader.get_key_to_atom_id_mapping(
                            receive_label)
                    for (key, atom_id) in key_to_atom_id.iteritems():
                        self._key_to_atom_id_and_label[key] = (atom_id,
                                                               label_id)
                    vertex_sizes[receive_label] = len(key_to_atom_id)

                label_id += 1

        for (label, vertex_size) in vertex_sizes.iteritems():
            for init_callback in self._init_callbacks[label]:
                init_callback(label, vertex_size, run_time_ms,
                              machine_timestep_ms)