def run_generator_for_training(self, websocket):
        if not self.is_generator:
            raise RuntimeError(
                "Method run_generator_for_training should only be called "
                "when a generator is being used.")

        message_count = 0
        while True:

            # Generators should just always send next data messages
            to_server = self.get_next_data_message()
            yield from websocket.send(to_server.SerializeToString())

            # Get a message from the server
            from_server_bytes = yield from websocket.recv()
            from_server = ServerToSimulator()
            from_server.ParseFromString(from_server_bytes)

            # Handle FINISHED and SET_PROPERTIES messages, otherwise
            # ignore the message.
            if (from_server.message_type == ServerToSimulator.SET_PROPERTIES):
                if not from_server.HasField("set_properties_data"):
                    raise RuntimeError(
                        "Received a SET_PROPERTIES message that did "
                        "not contain set_properties_data.")
                self.handle_set_properties(from_server.set_properties_data)
            elif from_server.message_type == ServerToSimulator.FINISHED:
                log.info("Training is finished!")
                return

            message_count += 1
            if message_count % 250 == 0:
                log.info("Handled %i messages from the server so far",
                         message_count)
예제 #2
0
    def run(self):
        if not self.access_key:
            raise RuntimeError("Access Key was not set.")

        log.info("About to connect to %s", self.brain_api_url)

        req = HTTPRequest(self.brain_api_url,
                          connect_timeout=_INITIAL_CONNECT_TIMEOUT_SECS,
                          request_timeout=_INITIAL_CONNECT_TIMEOUT_SECS)
        req.headers['Authorization'] = self.access_key

        websocket = yield websockets.websocket_connect(req)
        wrapped = _WrapSocket(websocket)

        input_message = None

        try:
            # The driver starts out in an unregistered... the first "next" will
            # perform the registration and all subsequent "next"s will continue
            # the operation.
            while self.driver.state != DriverState.FINISHED:
                if self.recording_file:
                    yield self._record('RECV', input_message)

                output_message = yield self._sim_executor.submit(
                    self.driver.next, input_message)

                if self.recording_file:
                    yield self._record('SEND', output_message)

                # If the driver is FINSIHED, don't bother sending and
                # receiving again before exiting the loop.
                if self.driver.state != DriverState.FINISHED:
                    if not output_message:
                        raise RuntimeError(
                            "Driver did not return a message to send.")

                    output_bytes = output_message.SerializeToString()
                    yield wrapped.send(output_bytes)

                    # Only do this part if the last message wasn't a FINISH
                    input_bytes = yield wrapped.recv()
                    if input_bytes:
                        input_message = ServerToSimulator()
                        input_message.ParseFromString(input_bytes)
                    else:
                        input_message = None

        except (websockets.WebSocketClosedError, ManualClosedException):
            code = websocket.close_code
            reason = websocket.close_reason
            log.error("Connection to '%s' is closed, code='%s', reason='%s'",
                      self.brain_api_url, code, reason)
        finally:
            if self.recording_file:
                yield self.recording_queue.put(None)
            websocket.close()
예제 #3
0
    async def run(self):
        if not self.access_key:
            raise RuntimeError("Access Key was not set.")

        log.info("About to connect to %s", self.brain_api_url)
        async with websockets.connect(
                uri=self.brain_api_url,
                extra_headers={'Authorization': self.access_key}) as websocket:
            log.debug('Connection to %s established.', self.brain_api_url)
            input_message = None
            last_check = datetime.utcnow()
            try:
                log.debug('Starting execution loop...')
                # The driver starts out in an unregistered... the first "next"
                # will perform the registration and all subsequent "next"s will
                # continue the operation.
                while self.driver.state != DriverState.FINISHED:
                    now = datetime.utcnow()
                    if (now - last_check).total_seconds() > 10:
                        last_check = now
                        log.debug('Driver for %s currently in state %s...',
                                  self.brain_api_url, str(self.driver.state))

                    if self.recording_file:
                        await self._record('RECV', input_message)

                    # This is where the state-machine magic happens
                    output_message = self.driver.next(input_message)

                    if self.recording_file:
                        await self._record('SEND', output_message)

                    if output_message:
                        output_bytes = output_message.SerializeToString()
                        await websocket.send(output_bytes)

                    if self.driver.state != DriverState.FINISHED:
                        # Only do this part if the driver isn't in a FINISHED
                        # state.
                        input_bytes = await websocket.recv()
                        if input_bytes:
                            input_message = ServerToSimulator()
                            input_message.ParseFromString(input_bytes)
                        else:
                            input_message = None

            except websockets.exceptions.ConnectionClosed as e:
                log.error(
                    "Connection to '%s' is closed, code='%s', reason='%s'",
                    self.brain_api_url, e.code, e.reason)
            finally:
                log.debug('Execution loop complete for %s!',
                          self.brain_api_url)
                if self.recording_file:
                    await self.recording_queue.put(None)
예제 #4
0
    def _on_message(self, ws, message):
        log.debug("ON_MESSAGE: %s", message)

        input_bytes = message
        if input_bytes:
            input_message = ServerToSimulator()
            input_message.ParseFromString(input_bytes)
        else:
            input_message = None

        try:
            self._handle_message(ws, input_message)
        except Exception as e:
            self._on_error(ws, e)
            ws.close()
    def run_simulator_for_prediction(self, websocket):
        num_predictions = 0
        while True:

            # Send state to the server
            to_server = self.get_state_message()
            yield from websocket.send(to_server.SerializeToString())

            # Get a prediction back from the server
            from_server_bytes = yield from websocket.recv()
            from_server = ServerToSimulator()
            from_server.ParseFromString(from_server_bytes)
            self.handle_prediction(from_server.prediction_data)

            num_predictions += 1
            if num_predictions % 250 == 0:
                log.info("Recieved %i predictions", num_predictions)
예제 #6
0
def load_test_message_stream(path):
    """
    Loads a message stream. The message stream is a simple text file of
    protobuf messages sent and received during the course of a training or
    prediction run.

    The file is very simply formatted as a pairs of lines for each message:
    Line N  : SEND|RECV
    Line N+1: ServerToSimulator or SimulutorToServer protobuf object serialized
              to text form with google.protobuf.text_format.MessageToString(),
              or None indicating an empty message.

    This output file can be easily re-created with any of the gym sample
    simulators by simply adding a "--messages-out <PATH>" parameter to the
    command line.

    :param path: Path to test file to load.
    :type path: string
    :return: Array of TestMessage instances representing the back-and-forth
             communications between the simulator and its BRAIN.
    """
    with open(path, 'r') as infile:

        line_number = 0
        direction = None
        message_as_text = None
        message = None
        messages = []
        for line in infile:
            line_number += 1
            line = line.strip()
            if line_number % 2 == 1:
                direction = line
            else:
                message_as_text = line
                if message_as_text == 'None':
                    message = None
                elif direction == 'RECV':
                    message = ServerToSimulator()
                elif direction == 'SEND':
                    message = SimulatorToServer()
                else:
                    raise RuntimeError('Error loading file '
                                       'on line {}'.format(line_number))
                if message:
                    Merge(message_as_text, message)
                tst_msg = TestMessage(direction=direction,
                                      message_as_text=message_as_text,
                                      message=message)
                messages.append(tst_msg)

    return messages
    def recv_acknowledge_register(self, websocket):
        from_server_bytes = yield from websocket.recv()
        from_server = ServerToSimulator()
        from_server.ParseFromString(from_server_bytes)

        if from_server.message_type != ServerToSimulator.ACKNOWLEDGE_REGISTER:
            raise RuntimeError(
                "Expected to receive an ACKNOWLEDGE_REGISTER message, but "
                "instead received message of type {}".format(
                    from_server.message_type))

        if not from_server.HasField("acknowledge_register_data"):
            raise RuntimeError(
                "Received an ACKNOWLEDGE_REGISTER message that did "
                "not contain acknowledge_register_data.")

        # Reconstitute the simulator schemas.
        self.properties_schema = MessageBuilder().reconstitute(
            from_server.acknowledge_register_data.properties_schema)
        self.output_schema = MessageBuilder().reconstitute(
            from_server.acknowledge_register_data.output_schema)
        self.prediction_schema = MessageBuilder().reconstitute(
            from_server.acknowledge_register_data.prediction_schema)
    def run_simulator_for_training(self, websocket):
        # Start by sending a ready message to the server
        # TODO: T365: Exchange should start with register first
        yield from self.send_ready(websocket)

        message_count = 0
        while True:
            # Get a message from the server
            from_server_bytes = yield from websocket.recv()
            from_server = ServerToSimulator()
            from_server.ParseFromString(from_server_bytes)

            # Exit if it is a FINISHED message
            if from_server.message_type == ServerToSimulator.FINISHED:
                log.info("Training is finished!")
                return

            # Otherwise handle the message
            yield from self.handle_from_server(websocket, from_server)

            message_count += 1
            if message_count % 250 == 0:
                log.info("Handled %i messages from the server so far",
                         message_count)