class Simulator_WS(object): class SimStep(object): """ Internal class used for keeping track of batch-processed round trips through the simulator. Packed into a protobuf message at the end and sent over the wire. """ def __init__(self): self.prediction = None self.state = None self.reward = 0 self.terminal = False def __init__(self, brain, sim, simulator_name): self.brain = brain self.name = simulator_name self.objective_name = None self._sim = sim self._ws = None self._prev_message_type = ServerToSimulator.UNKNOWN # acknowledge_register # schemas are of type DescriptorProto self._properties_schema = None self._output_schema = None self._prediction_schema = None self._sim_id = -1 # set_properties self._init_properties = {} # TODO(oren.leiman): Pretty sure this is vestigial. # self._initial_prediction_schema = None # current batch of simulation steps self._sim_steps = [] # Caching actions for predictor self._predictor_action = None # protobuf discriptor cache self._inkling = InklingMessageFactory() self._dispatch_send = { ServerToSimulator.UNKNOWN: '_send_registration', ServerToSimulator.ACKNOWLEDGE_REGISTER: '_send_initial_state' if self._sim.predict else '_send_ready', ServerToSimulator.SET_PROPERTIES: '_send_initial_state' if self._sim.predict else '_send_ready', ServerToSimulator.START: '_send_initial_state', ServerToSimulator.PREDICTION: '_send_state', ServerToSimulator.RESET: '_send_initial_state' if self._sim.predict else '_send_ready', ServerToSimulator.STOP: '_send_initial_state' if self._sim.predict else '_send_ready', } self._dispatch_recv = { ServerToSimulator.ACKNOWLEDGE_REGISTER: '_on_acknowledge_register', ServerToSimulator.SET_PROPERTIES: '_on_set_properties', ServerToSimulator.START: '_on_start', ServerToSimulator.PREDICTION: '_on_prediction', ServerToSimulator.RESET: '_on_reset', ServerToSimulator.STOP: '_on_stop', ServerToSimulator.FINISHED: '_on_finished' } def _new_state_message(self): """ Generate an InklingMessage for holding simulator state :return: state message """ return self._inkling.new_message_from_proto(self._output_schema) def _dict_for_message(self, message): """ Unpack a protobuf message into a Python dictionary :return: dictionary """ result = {} # If the message is bogus, return an empty dictionary rather # than crashing. if message is not None: for field in message.DESCRIPTOR.fields: result[field.name] = getattr(message, field.name) return result def _send_registration(self, to_server): log.simulator_ws('Sending Registration') to_server.message_type = SimulatorToServer.REGISTER to_server.register_data.simulator_name = self.name def _send_ready(self, to_server): log.simulator_ws('Sending Ready') to_server.message_type = SimulatorToServer.READY to_server.sim_id = self._sim_id def _send_initial_state(self, to_server): log.simulator_ws('Sending initial State') to_server.message_type = SimulatorToServer.STATE to_server.sim_id = self._sim_id state = to_server.state_data.add() try: initial_state = self._sim._on_episode_start(self._init_properties) except Exception as e: raise EpisodeStartError(e) state_message = self._new_state_message() convert_state_to_proto(state_message, initial_state) state.state = state_message.SerializeToString() state.reward = 0.0 state.terminal = False # state.action_taken = ... # no-op for init state def _send_state(self, to_server): log.simulator_ws('Sending State') to_server.message_type = SimulatorToServer.STATE to_server.sim_id = self._sim_id for step in self._sim_steps: if step.state: state = to_server.state_data.add() state.state = step.state.SerializeToString() state.reward = step.reward state.terminal = step.terminal state.action_taken = step.prediction else: log.simulator("WARNING: Missing step in send_state") self._sim_steps = [] def _on_acknowledge_register(self, from_server): log.simulator_ws('Acknowledging Registration') data = from_server.acknowledge_register_data self._properties_schema = data.properties_schema self._output_schema = data.output_schema self._prediction_schema = data.prediction_schema self._sim_id = data.sim_id def _on_set_properties(self, from_server): log.simulator_ws('Setting properties') data = from_server.set_properties_data self._prediction_schema = data.prediction_schema self.objective_name = data.reward_name dynamic_properties = data.dynamic_properties properties_message = self._inkling.message_for_dynamic_message( dynamic_properties, self._properties_schema) self._init_properties = self._dict_for_message(properties_message) def _on_start(self, from_server): pass def _on_prediction(self, from_server): log.simulator_ws('On Prediction') for p_data in from_server.prediction_data: step = self.SimStep() step.prediction = p_data.dynamic_prediction self._sim_steps.append(step) # Convert server msg to action dict and saves it for predictor self._cache_action_for_predictor(step.prediction) def _on_reset(self, from_server): pass def _on_stop(self, from_server): self._sim._on_episode_finish() def _on_finished(self, from_server): pass def _dump_message(self, message, fname): '''Helper function for dumping protobuf message contents''' with open(fname, 'wb') as f: f.write(message.SerializeToString()) def _on_send(self, to_server): ''' message handler for sending messages to the server ''' method_name = self._dispatch_send.get( self._prev_message_type, 'default') method = getattr(self, method_name, lambda: log.simulator("Finished")) method(to_server) def _on_recv(self, from_server): ''' message handler for server messages ''' def _raise(msg): raise BonsaiServerError( "Received unknown message ({}) from server".format( msg.message_type)) method_name = self._dispatch_recv.get( from_server.message_type, 'default') method = getattr(self, method_name, _raise) method(from_server) self._prev_message_type = from_server.message_type def _cache_action_for_predictor(self, prediction): """ Converts a server prediction into an action dictionary and saves it for the predictor class """ action_message = self._inkling.message_for_dynamic_message( prediction, self._prediction_schema) self._predictor_action = self._dict_for_message(action_message) @gen.coroutine def _connect(self): """ Fire up a websocket connection. """ try: if self._sim.predict is True: url = self.brain._prediction_url() else: url = self.brain._simulation_url() log.info("trying to connect: {}".format(url)) req = HTTPRequest( url, connect_timeout=_CONNECT_TIMEOUT_SECS, request_timeout=_CONNECT_TIMEOUT_SECS) req.headers['Authorization'] = self.brain.config.accesskey req.headers['User-Agent'] = self.brain._user_info self._ws = yield websocket_connect(req) except Exception as e: raise gen.Return(repr(e)) else: raise gen.Return(None) def _advance(self, step): """ Helper function to advance the simulator and process the resulting state for transmission. """ log.simulator_ws('Advancing') action_message = self._inkling.message_for_dynamic_message( step.prediction, self._prediction_schema) action = self._dict_for_message(action_message) try: state, reward, terminal = self._sim._on_simulate(action) except Exception as e: raise SimulateError(e) state_message = self._new_state_message() convert_state_to_proto(state_message, state) log.simulator("{}".format(MessageToJson(state_message))) step.state = state_message step.reward = reward step.terminal = terminal if terminal: try: self._sim._on_episode_finish() self._sim._on_episode_start(self._init_properties) except Exception as e: raise EpisodeStartError(e) @gen.coroutine def close_connection(self): """ Close websocket connection """ yield self._ws.close() @gen.coroutine def run(self): """ Run loop called from Simulator. Encapsulates one round trip to the backend, which might include a simulation loop. """ # Grab a web socket connection if needed if self._ws is None: message = yield self._connect() # If the connection failed, report if message is not None: raise BonsaiServerError( "Error while connecting to websocket: {}".format(message)) # If there is a batch of predictions cued up, step through it if self._prev_message_type == ServerToSimulator.PREDICTION: for step in self._sim_steps: self._advance(step) # send message to server to_server = SimulatorToServer() self._on_send(to_server) if (to_server.message_type): out_bytes = to_server.SerializeToString() try: yield self._ws.write_message(out_bytes, binary=True) except (StreamClosedError, WebSocketClosedError) as e: raise BonsaiServerError( "Websocket connection closed. Code: {}, Reason: {}".format( self._ws.close_code, self._ws.close_reason)) # read response from server in_bytes = yield self._ws.read_message() if in_bytes is None: raise BonsaiServerError( "Websocket connection closed. Code: {}, Reason: {}".format( self._ws.close_code, self._ws.close_reason)) from_server = ServerToSimulator() from_server.ParseFromString(in_bytes) self._on_recv(from_server) if self._prev_message_type == ServerToSimulator.FINISHED: yield self._ws.close() raise gen.Return(False) # You've come this far, celebrate! raise gen.Return(True)
class Simulator_WS(object): class SimStep(object): """ Internal class used for keeping track of batch-processed round trips through the simulator. Packed into a protobuf message at the end and sent over the wire. """ def __init__(self): self.prediction = None self.state = None self.reward = 0.0 self.terminal = False def __init__(self, brain, sim, simulator_name): self.brain = brain self.name = simulator_name self._sim = sim self._reset_simulator_ws() self._sim_connection = SimulatorConnection(brain, sim.predict) # protobuf discriptor cache self._inkling = InklingMessageFactory() self._dispatch_send = { ServerToSimulator.UNKNOWN: '_send_registration', ServerToSimulator.ACKNOWLEDGE_REGISTER: '_send_initial_state' if self._sim.predict else '_send_ready', ServerToSimulator.SET_PROPERTIES: '_unsupported' if self._sim.predict else '_send_ready', ServerToSimulator.START: '_unsupported' if self._sim.predict else '_send_initial_state', ServerToSimulator.PREDICTION: '_send_state', ServerToSimulator.RESET: '_unsupported' if self._sim.predict else '_send_ready', ServerToSimulator.STOP: '_unsupported' if self._sim.predict else '_send_ready', } self._dispatch_recv = { ServerToSimulator.ACKNOWLEDGE_REGISTER: '_on_acknowledge_register', ServerToSimulator.SET_PROPERTIES: '_on_set_properties', ServerToSimulator.START: '_on_start', ServerToSimulator.PREDICTION: '_on_prediction', ServerToSimulator.RESET: '_on_reset', ServerToSimulator.STOP: '_on_stop', ServerToSimulator.FINISHED: '_on_finished' } def _reset_simulator_ws(self): """ Reset state of simulator_ws""" log.simulator_ws('Resetting simulator_ws') self.objective_name = None self._prev_message_type = ServerToSimulator.UNKNOWN # acknowledge_register # schemas are of type DescriptorProto self._properties_schema = None self._output_schema = None self._prediction_schema = None self._sim_id = 0 # set_properties self._init_properties = {} self._initial_state = None # current batch of simulation steps self._sim_steps = [] self._step_iter = None self._prev_step_terminal = [False] self._prev_step_finish = False # Caching actions for predictor self._predictor_action = None self._sim._reset_rate_counter = True # Handle on WS receive for cleanup self._receive_handle = None def _new_state_message(self): """ Generate an InklingMessage for holding simulator state :return: state message """ return self._inkling.new_message_from_proto(self._output_schema) def _send_registration(self, to_server): log.simulator_ws('Sending Registration') to_server.message_type = SimulatorToServer.REGISTER to_server.register_data.simulator_name = self.name def _send_ready(self, to_server): log.simulator_ws('Sending Ready') to_server.message_type = SimulatorToServer.READY to_server.sim_id = self._sim_id def _send_initial_state(self, to_server): log.simulator_ws('Sending initial State') to_server.message_type = SimulatorToServer.STATE to_server.sim_id = self._sim_id state = to_server.state_data.add() state.state = self._initial_state.SerializeToString() state.reward = 0.0 state.terminal = False self._prev_step_terminal[0] = state.terminal # state.action_taken = #... # no-op for init state def _send_state(self, to_server): log.simulator_ws('Sending State') to_server.message_type = SimulatorToServer.STATE to_server.sim_id = self._sim_id for step in self._sim_steps: if step.state: state = to_server.state_data.add() state.state = step.state.SerializeToString() state.reward = step.reward state.terminal = step.terminal log.action( self._inkling.message_for_dynamic_message( step.prediction, self._prediction_schema)) state.action_taken = step.prediction else: log.simulator("WARNING: Missing step in send_state") self._sim_steps = [] self._step_iter = None def _unsupported(self, to_server): descriptor = ServerToSimulator.MessageType.DESCRIPTOR raise BonsaiServerError("Unexpected Message during {}: {}".format( "prediction" if self._sim.predict else "training", descriptor.values_by_number[self._prev_message_type].name)) def _on_acknowledge_register(self, from_server): log.simulator_ws('Acknowledging Registration') data = from_server.acknowledge_register_data self._properties_schema = data.properties_schema self._output_schema = data.output_schema self._prediction_schema = data.prediction_schema if self._sim.writer is not None: self._configure_writer() self._sim_id = data.sim_id log.info("Starting {} ID: <{}>".format( "Prediction" if self._sim.predict else "Training", self._sim_id)) def _on_set_properties(self, from_server): log.simulator_ws('Setting properties') data = from_server.set_properties_data self._prediction_schema = data.prediction_schema self.objective_name = data.reward_name dynamic_properties = data.dynamic_properties properties_message = self._inkling.message_for_dynamic_message( dynamic_properties, self._properties_schema) self._init_properties = dict_for_message(properties_message) def _on_start(self, from_server): log.simulator_ws('On Start') def _on_prediction(self, from_server): log.simulator_ws('On Prediction') for p_data in from_server.prediction_data: step = self.SimStep() step.prediction = p_data.dynamic_prediction self._sim_steps.append(step) # Convert server msg to action dict and save it for predictor self._cache_action_for_predictor(step.prediction) self._step_iter = iter(self._sim_steps) def _on_reset(self, from_server): log.simulator_ws('On Reset') def _on_stop(self, from_server): log.simulator_ws('On Stop') # fire the finished message if the previous step wasn't terminal # as it will already have been called # if not self._prev_step_terminal: # self._sim._on_episode_finish() def _on_finished(self, from_server): log.simulator_ws('On Finished') def _dump_message(self, message, fname): '''Helper function for dumping protobuf message contents''' with open(fname, 'wb') as f: f.write(message.SerializeToString()) def _on_send(self, to_server): ''' message handler for sending messages to the server ''' method_name = self._dispatch_send.get(self._prev_message_type, 'default') method = getattr(self, method_name, lambda x: log.simulator("Finished")) method(to_server) def _on_recv(self, from_server): ''' message handler for server messages ''' def _raise(msg): raise BonsaiServerError( "Received unknown message ({}) from server".format( msg.message_type)) method_name = self._dispatch_recv.get(from_server.message_type, 'default') method = getattr(self, method_name, _raise) method(from_server) self._prev_message_type = from_server.message_type def _cache_action_for_predictor(self, prediction): """ Converts a server prediction into an action dictionary and saves it for the predictor class """ action_message = self._inkling.message_for_dynamic_message( prediction, self._prediction_schema) self._predictor_action = dict_for_message(action_message) def _configure_writer(self): self._sim.writer.enable_keys( self._fields_for_schema(self._properties_schema), 'config') self._sim.writer.enable_keys( self._fields_for_schema(self._prediction_schema), 'action') self._sim.writer.enable_keys( self._fields_for_schema(self._output_schema), 'state') self._sim.writer.enable_keys( ['reward', 'terminal', 'time', 'simulator', 'predict', 'sim_id']) self._sim.writer.enable_keys([ 'episode_reward', 'episode_count', 'episode_rate', 'iteration_count', 'iteration_rate' ], 'statistics') def _fields_for_schema(self, schema): msg = self._inkling.new_message_from_proto(schema) return [f.name for f in msg.DESCRIPTOR.fields] async def _ws_send_recv(self): to_server = SimulatorToServer() self._on_send(to_server) log.pb("to_server: {}".format(MessageToJson(to_server))) if to_server.message_type: out_bytes = to_server.SerializeToString() try: with self._sim_connection.lock: if self._sim_connection.client.closed: await self._handle_disconnect( "Attempted write to closed web socket") return await self._sim_connection.client.send_bytes(out_bytes) except ClientError as e: await self._handle_disconnect(e) return with self._sim_connection.lock: log.network('Reading response from server') self._receive_handle = ensure_future( self._sim_connection.client.receive()) msg = await self._receive_handle log.network('Received response from server') if msg.type == WSMsgType.CLOSE or msg.type == WSMsgType.CLOSED \ or msg.type == WSMsgType.ERROR or isinstance(msg.data, EofStream): await self._handle_disconnect(msg.extra) return from_server = ServerToSimulator() from_server.ParseFromString(msg.data) log.pb("from_server: {}".format(MessageToJson(from_server))) self._on_recv(from_server) async def _handle_disconnect(self, message=None): await self._sim_connection.handle_disconnect(message) self._reset_simulator_ws() def _process_sim_step(self): try: if not self._step_iter: return None event = None step = next(self._step_iter) step.state = self._new_state_message() if self._prev_step_finish: event = EpisodeStartEvent(self._init_properties, step.state) self._prev_step_finish = False else: action_message = self._inkling.message_for_dynamic_message( step.prediction, self._prediction_schema) action = dict_for_message(action_message) event = SimulateEvent(action, step, self._prev_step_terminal) return event except StopIteration: return None async def get_next_event(self): """ Update the internal event machine and return the next event for processing""" # Grab a web socket connection if needed if self._sim_connection.client is None: message = await self._sim_connection.connect() # If the connection failed, report if message is not None: await self._handle_disconnect(message) return UnknownEvent() if self._prev_message_type == ServerToSimulator.PREDICTION: if self._prev_step_terminal[0]: self._prev_step_terminal[0] = False self._prev_step_finish = True event = EpisodeFinishEvent() else: event = self._process_sim_step() if event is not None: return event await self._ws_send_recv() pmt = self._prev_message_type if pmt == ServerToSimulator.ACKNOWLEDGE_REGISTER: if self._sim.predict: self._initial_state = self._new_state_message() event = EpisodeStartEvent(self._init_properties, self._initial_state) self._prev_step_finish = False else: event = UnknownEvent() elif (pmt == ServerToSimulator.SET_PROPERTIES or pmt == ServerToSimulator.RESET): event = UnknownEvent() elif pmt == ServerToSimulator.STOP: if self._prev_step_finish: event = UnknownEvent() self._prev_step_finish = False else: event = EpisodeFinishEvent() elif pmt == ServerToSimulator.START: self._initial_state = self._new_state_message() event = EpisodeStartEvent(self._init_properties, self._initial_state) self._prev_step_finish = False elif pmt == ServerToSimulator.PREDICTION: event = self._process_sim_step() elif pmt == ServerToSimulator.FINISHED: event = FinishedEvent() else: event = UnknownEvent() return event async def run(self): """ Run loop called from Simulator. Encapsulates one round trip to the backend, which might include a simulation loop. """ event = await self.get_next_event() if isinstance(event, EpisodeStartEvent): log.event("Episode Start") try: state = self._sim._on_episode_start(event.initial_properties) except Exception as e: raise EpisodeStartError(e) event.initial_state = state log.simulator("initial state: {}".format(event.initial_state)) log.simulator_ws('\tES') elif isinstance(event, SimulateEvent): log.event("Simulate") try: log.simulator("action: {}".format(event.action)) event.state, event.reward, event.terminal = \ self._sim._on_simulate(event.action) except Exception as e: raise SimulateError(e) log.simulator_ws('\tT' if event.terminal else '\tS') log.simulator("state: {}".format(event.state)) elif isinstance(event, EpisodeFinishEvent): log.event("Episode Finish") try: self._sim._on_episode_finish() except Exception as e: raise EpisodeFinishError(e) log.simulator_ws('\tF') elif isinstance(event, FinishedEvent): log.event("Finished") await self._sim_connection.close() return False elif isinstance(event, UnknownEvent): log.event("No Operation") if isinstance(event, EpisodeStartEvent) or \ isinstance(event, SimulateEvent): self._sim.flush_record() return True
class Simulator_WS(object): class SimStep(object): """ Internal class used for keeping track of batch-processed round trips through the simulator. Packed into a protobuf message at the end and sent over the wire. """ def __init__(self): self.prediction = None self.state = None self.reward = 0.0 self.terminal = False def __init__(self, brain, sim, simulator_name): self.brain = brain self.name = simulator_name self.objective_name = None self._sim = sim self._ws = None self._prev_message_type = ServerToSimulator.UNKNOWN # acknowledge_register # schemas are of type DescriptorProto self._properties_schema = None self._output_schema = None self._prediction_schema = None self._sim_id = -1 # set_properties self._init_properties = {} self._initial_state = None # current batch of simulation steps self._sim_steps = [] self._step_iter = None self._prev_step_terminal = [False] self._prev_step_finish = False # Caching actions for predictor self._predictor_action = None # protobuf discriptor cache self._inkling = InklingMessageFactory() self._dispatch_send = { ServerToSimulator.UNKNOWN: '_send_registration', ServerToSimulator.ACKNOWLEDGE_REGISTER: '_send_initial_state' if self._sim.predict else '_send_ready', ServerToSimulator.SET_PROPERTIES: '_unsupported' if self._sim.predict else '_send_ready', ServerToSimulator.START: '_unsupported' if self._sim.predict else '_send_initial_state', ServerToSimulator.PREDICTION: '_send_state', ServerToSimulator.RESET: '_unsupported' if self._sim.predict else '_send_ready', ServerToSimulator.STOP: '_unsupported' if self._sim.predict else '_send_ready', } self._dispatch_recv = { ServerToSimulator.ACKNOWLEDGE_REGISTER: '_on_acknowledge_register', ServerToSimulator.SET_PROPERTIES: '_on_set_properties', ServerToSimulator.START: '_on_start', ServerToSimulator.PREDICTION: '_on_prediction', ServerToSimulator.RESET: '_on_reset', ServerToSimulator.STOP: '_on_stop', ServerToSimulator.FINISHED: '_on_finished' } def _new_state_message(self): """ Generate an InklingMessage for holding simulator state :return: state message """ return self._inkling.new_message_from_proto(self._output_schema) def _send_registration(self, to_server): log.simulator_ws('Sending Registration') to_server.message_type = SimulatorToServer.REGISTER to_server.register_data.simulator_name = self.name def _send_ready(self, to_server): log.simulator_ws('Sending Ready') to_server.message_type = SimulatorToServer.READY to_server.sim_id = self._sim_id def _send_initial_state(self, to_server): log.simulator_ws('Sending initial State') to_server.message_type = SimulatorToServer.STATE to_server.sim_id = self._sim_id state = to_server.state_data.add() state.state = self._initial_state.SerializeToString() state.reward = 0.0 state.terminal = False self._prev_step_terminal[0] = state.terminal # state.action_taken = ... # no-op for init state def _send_state(self, to_server): log.simulator_ws('Sending State') to_server.message_type = SimulatorToServer.STATE to_server.sim_id = self._sim_id for step in self._sim_steps: if step.state: state = to_server.state_data.add() state.state = step.state.SerializeToString() state.reward = step.reward state.terminal = step.terminal state.action_taken = step.prediction else: log.simulator("WARNING: Missing step in send_state") self._sim_steps = [] self._step_iter = None def _unsupported(self, to_server): descriptor = ServerToSimulator.MessageType.DESCRIPTOR raise BonsaiServerError("Unexpected Message during {}: {}".format( "prediction" if self._sim.predict else "training", descriptor.values_by_number[self._prev_message_type].name)) def _on_acknowledge_register(self, from_server): log.simulator_ws('Acknowledging Registration') data = from_server.acknowledge_register_data self._properties_schema = data.properties_schema self._output_schema = data.output_schema self._prediction_schema = data.prediction_schema if self._sim.writer is not None: self._configure_writer() self._sim_id = data.sim_id def _on_set_properties(self, from_server): log.simulator_ws('Setting properties') data = from_server.set_properties_data self._prediction_schema = data.prediction_schema self.objective_name = data.reward_name dynamic_properties = data.dynamic_properties properties_message = self._inkling.message_for_dynamic_message( dynamic_properties, self._properties_schema) self._init_properties = dict_for_message(properties_message) def _on_start(self, from_server): pass def _on_prediction(self, from_server): log.simulator_ws('On Prediction') for p_data in from_server.prediction_data: step = self.SimStep() step.prediction = p_data.dynamic_prediction self._sim_steps.append(step) # Convert server msg to action dict and saves it for predictor self._cache_action_for_predictor(step.prediction) self._step_iter = iter(self._sim_steps) def _on_reset(self, from_server): pass def _on_stop(self, from_server): # fire the finished message if the previous step wasn't terminal # as it will already have been called # if not self._prev_step_terminal: # self._sim._on_episode_finish() pass def _on_finished(self, from_server): pass def _dump_message(self, message, fname): '''Helper function for dumping protobuf message contents''' with open(fname, 'wb') as f: f.write(message.SerializeToString()) def _on_send(self, to_server): ''' message handler for sending messages to the server ''' method_name = self._dispatch_send.get(self._prev_message_type, 'default') method = getattr(self, method_name, lambda x: log.simulator("Finished")) method(to_server) def _on_recv(self, from_server): ''' message handler for server messages ''' def _raise(msg): raise BonsaiServerError( "Received unknown message ({}) from server".format( msg.message_type)) method_name = self._dispatch_recv.get(from_server.message_type, 'default') method = getattr(self, method_name, _raise) method(from_server) self._prev_message_type = from_server.message_type def _cache_action_for_predictor(self, prediction): """ Converts a server prediction into an action dictionary and saves it for the predictor class """ action_message = self._inkling.message_for_dynamic_message( prediction, self._prediction_schema) self._predictor_action = dict_for_message(action_message) @gen.coroutine def _connect(self): """ Fire up a websocket connection. """ try: if self._sim.predict is True: url = self.brain._prediction_url() else: url = self.brain._simulation_url() log.info("trying to connect: {}".format(url)) req = HTTPRequest(url, connect_timeout=_CONNECT_TIMEOUT_SECS, request_timeout=_CONNECT_TIMEOUT_SECS) req.headers['Authorization'] = self.brain.config.accesskey req.headers['User-Agent'] = self.brain._user_info self._ws = yield websocket_connect(req) except Exception as e: raise gen.Return(repr(e)) else: raise gen.Return(None) def _configure_writer(self): self._sim.writer.enable_keys( self._fields_for_schema(self._properties_schema), 'config') self._sim.writer.enable_keys( self._fields_for_schema(self._prediction_schema), 'action') self._sim.writer.enable_keys( self._fields_for_schema(self._output_schema), 'state') self._sim.writer.enable_keys( ['reward', 'terminal', 'time', 'simulator', 'predict', 'sim_id']) self._sim.writer.enable_keys([ 'episode_reward', 'episode_count', 'episode_rate', 'iteration_count', 'iteration_rate' ], 'statistics') def _fields_for_schema(self, schema): msg = self._inkling.new_message_from_proto(schema) return [f.name for f in msg.DESCRIPTOR.fields] @gen.coroutine def _ws_send_recv(self): to_server = SimulatorToServer() self._on_send(to_server) log.pb("to_server: {}".format(MessageToJson(to_server))) if (to_server.message_type): out_bytes = to_server.SerializeToString() try: yield self._ws.write_message(out_bytes, binary=True) except (StreamClosedError, WebSocketClosedError) as e: raise BonsaiServerError( "Websocket connection closed. Code: {}, Reason: {}".format( self._ws.close_code, self._ws.close_reason)) # read response from server in_bytes = yield self._ws.read_message() if in_bytes is None: raise BonsaiServerError( "Websocket connection closed. Code: {}, Reason: {}".format( self._ws.close_code, self._ws.close_reason)) from_server = ServerToSimulator() from_server.ParseFromString(in_bytes) log.pb("from_server: {}".format(MessageToJson(from_server))) self._on_recv(from_server) @gen.coroutine def close_connection(self): """ Close websocket connection """ if self._ws is not None: yield self._ws.close() def _process_sim_step(self): try: event = None step = next(self._step_iter) step.state = self._new_state_message() if self._prev_step_finish: event = EpisodeStartEvent(self._init_properties, step.state) self._prev_step_finish = False else: action_message = self._inkling.message_for_dynamic_message( step.prediction, self._prediction_schema) action = dict_for_message(action_message) event = SimulateEvent(action, step, self._prev_step_terminal) return event except StopIteration: return None @gen.coroutine def get_next_event(self): """ Update the internal event machine and return the next event for processing""" # Grab a web socket connection if needed if self._ws is None: message = yield self._connect() # If the connection failed, report if message is not None: raise BonsaiServerError( "Error while connecting to websocket: {}".format(message)) if self._prev_message_type == ServerToSimulator.PREDICTION: if self._prev_step_terminal[0]: self._prev_step_terminal[0] = False self._prev_step_finish = True event = EpisodeFinishEvent() else: event = self._process_sim_step() if event is not None: raise gen.Return(event) yield self._ws_send_recv() pmt = self._prev_message_type if pmt == ServerToSimulator.ACKNOWLEDGE_REGISTER: if self._sim.predict: self._initial_state = self._new_state_message() event = EpisodeStartEvent(self._init_properties, self._initial_state) self._prev_step_finish = False else: event = UnknownEvent() if pmt == ServerToSimulator.SET_PROPERTIES or \ pmt == ServerToSimulator.RESET: event = UnknownEvent() elif pmt == ServerToSimulator.STOP: if self._prev_step_finish: event = UnknownEvent() self._prev_step_finish = False else: event = EpisodeFinishEvent() elif pmt == ServerToSimulator.START: self._initial_state = self._new_state_message() event = EpisodeStartEvent(self._init_properties, self._initial_state) self._prev_step_finish = False elif pmt == ServerToSimulator.PREDICTION: event = self._process_sim_step() elif pmt == ServerToSimulator.FINISHED: event = FinishedEvent() else: event = UnknownEvent() raise gen.Return(event) @gen.coroutine def run(self): """ Run loop called from Simulator. Encapsulates one round trip to the backend, which might include a simulation loop. """ event = yield self.get_next_event() if isinstance(event, EpisodeStartEvent): log.event("Episode Start") try: state = self._sim._on_episode_start(event.initial_properties) except Exception as e: raise EpisodeStartError(e) event.initial_state = state log.simulator("initial state: {}".format(event.initial_state)) log.simulator_ws('\tES') elif isinstance(event, SimulateEvent): log.event("Simulate") try: event.state, event.reward, event.terminal = \ self._sim._on_simulate(event.action) except Exception as e: raise SimulateError(e) log.simulator_ws('\tT' if event.terminal else '\tS') log.simulator("state: {}".format(event.state)) elif isinstance(event, EpisodeFinishEvent): log.event("Episode Finish") try: self._sim._on_episode_finish() except Exception as e: raise EpisodeFinishError(e) log.simulator_ws('\tF') elif isinstance(event, FinishedEvent): log.event("Finished") self.close_connection() raise gen.Return(False) elif isinstance(event, UnknownEvent): log.event("No Operation") raise gen.Return(True)