class HumanLearner(BaseLearner): def __init__(self, serializer): ''' Takes the serialization protocol ''' self._serializer = serializer self._input_channel = InputChannel(serializer) self._output_channel = OutputChannel(serializer) self._input_channel.message_updated.register(self.on_message) self.test_mode = False self.logger = logging.getLogger(__name__) self.speaking = False def set_view(self, view): ''' Sets the user interface to get the user input ''' self._view = view def reward(self, reward): self.logger.info("Reward received: {0}".format(reward)) self._input_channel.clear() self._output_channel.clear() def next(self, input): # If the buffer is empty, fill it with silence if self._output_channel.is_empty(): self.logger.debug("Output buffer is empty, filling with silence") # Add one silence token to the buffer self._output_channel.set_message(self._serializer.SILENCE_TOKEN) # Get the bit to return output = self._output_channel.consume_bit() # Interpret the bit from the learner self._input_channel.consume_bit(input) return output def on_message(self, message): # we ask for input on two consecutive silences if message[-2:] == self._serializer.SILENCE_TOKEN * 2 and \ self._output_channel.is_empty() and not self.speaking: self.ask_for_input() elif self._output_channel.is_empty(): # If we were speaking, we are not speaking anymore self.speaking = False def ask_for_input(self): output = self._view.get_input() self.logger.debug( "Received input from the human: '{0}'".format(output)) # by default just send a space if not output: output = ' ' if output: self.speaking = True output = re.compile('\.+').sub('.', output) self._output_channel.set_message(output)
class Environment: ''' The Environment is the one that communicates with the Learner, interpreting its output and reacting to it. The interaction is governed by an ongoing task which is picked by a TaskScheduler object. :param serializer: a Serializer object that translates text into binary and back. :param task_scheduler: a TaskScheduler object that determines which task is going to be run next. :param scramble: if True, the words outputted by the tasks are randomly scrambled. :param max_reward_per_task: maximum amount of reward that a learner can receive for a given task. ''' def __init__(self, serializer, task_scheduler, scramble=False, max_reward_per_task=10): # save parameters into member variables self._task_scheduler = task_scheduler self._serializer = serializer self._max_reward_per_task = max_reward_per_task # cumulative reward per task self._reward_per_task = defaultdict(int) # the event manager is the controller that dispatches # changes in the environment (like new inputs or state changes) # to handler functions in the tasks that tell the environment # how to react self.event_manager = EventManager() # intialize member variables self._current_task = None self._current_world = None # we hear to our own output self._output_channel_listener = InputChannel(serializer) if scramble: serializer = ScramblingSerializerWrapper(serializer) # output channel self._output_channel = OutputChannel(serializer) # input channel self._input_channel = InputChannel(serializer) # priority of ongoing message self._output_priority = 0 # reward that is to be given at the learner at the end of the task self._reward = None # Current task time self._task_time = None # Task separator issued self._task_separator_issued = False # Internal logger self.logger = logging.getLogger(__name__) # signals self.world_updated = Observable() self.task_updated = Observable() # Register channel observers self._input_channel.sequence_updated.register( self._on_input_sequence_updated) self._input_channel.message_updated.register( self._on_input_message_updated) self._output_channel_listener.sequence_updated.register( self._on_output_sequence_updated) self._output_channel_listener.message_updated.register( self._on_output_message_updated) def next(self, learner_input): '''Main loop of the Environment. Receives one bit from the learner and produces a response (also one bit)''' # Make sure we have a task if not self._current_task: self._switch_new_task() # If the task has not reached the end by either Timeout or # achieving the goal if not self._current_task.has_ended(): # Check if a Timeout occurred self._current_task.check_timeout(self._task_time) # Process the input from the learner and raise events if learner_input is not None: # record the input from the learner and deserialize it # TODO this bit is dropped otherwise on a timeout... self._input_channel.consume_bit(learner_input) # fill the output buffer if the task hasn't produced any output if self._output_channel.is_empty(): # demand for some output from the task (usually, silence) self._output_channel.set_message( self._current_task.get_default_output()) # remember if this was acting as a task separator if self._current_task.has_ended(): self._task_separator_issued = True # We are in the middle of the task, so no rewards are given reward = None else: # If the task has ended and there is nothing else to say, # issue a silence and then return reward and move to next task if self._output_channel.is_empty(): # TODO: taks separation should be implemented at the task-level if self._task_separator_issued: # I have nothing to say, I have nothing to say, I have... # reward the learner if necessary and switch to new task reward = self._reward if self._reward is not None else 0 self._switch_new_task() self._task_separator_issued = False else: self._output_channel.set_message( self._current_task.get_default_output()) self._task_separator_issued = True reward = None else: # Do Nothing until the output channel is empty reward = None # Get one bit from the output buffer and ship it output = self._output_channel.consume_bit() # we hear to ourselves (WARNING: this can still generate behavior # in the task via the OutputMessageUpdated event) self._output_channel_listener.consume_bit(output) # advance time self._task_time += 1 if reward is not None: # process the reward (clearing it if it's not allowed) reward = self._allowable_reward(reward) self._task_scheduler.reward(reward) return output, reward def get_reward_per_task(self): ''' Returns a dictonary that contains the cumulative reward for each task. ''' return self._reward_per_task def _allowable_reward(self, reward): '''Checks if the reward is allowed within the limits of the `max_reward_per_task` parameter, and resets it to 0 if not.''' task_name = self._current_task.get_name() if self._reward_per_task[task_name] < self._max_reward_per_task: self._reward_per_task[task_name] += reward return reward else: return 0 def is_silent(self): ''' Tells if the environment is sending any information through the output channel. ''' return self._output_channel.is_silent() def _on_input_sequence_updated(self, sequence): if self.event_manager.raise_event(SequenceReceived(sequence)): self.logger.debug("Sequence received by running task: '{0}'".format( sequence)) def _on_input_message_updated(self, message): # send the current received message to the task if self.event_manager.raise_event(MessageReceived( message)): self.logger.debug("Message received by running task: '{0}'".format( message)) def _on_output_sequence_updated(self, sequence): self.event_manager.raise_event(OutputSequenceUpdated(sequence)) def _on_output_message_updated(self, message): self.event_manager.raise_event(OutputMessageUpdated(message)) def set_reward(self, reward, message='', priority=0): '''Sets the reward that is going to be given to the learner once the task has sent all the remaining message''' self._reward = reward self._current_task.end() self.logger.debug('Setting reward {0} with message "{1}"' ' and priority {2}' .format(reward, message, priority)) self.set_message(message, priority) def set_message(self, message, priority=0): ''' Saves the message in the output buffer so it can be delivered bit by bit. It overwrites any previous content. ''' if self._output_channel.is_empty() or priority >= self._output_priority: self.logger.debug('Setting message "{0}" with priority {1}' .format(message, priority)) self._output_channel.set_message(message) self._output_priority = priority else: self.logger.info( 'Message "{0}" blocked because of ' 'low priority ({1}<{2}) '.format( message, priority, self._output_priority) ) def raise_event(self, event): return self.event_manager.raise_event(event) def raise_state_changed(self): ''' This rases a StateChanged Event, meaning that something in the state of the world or the tasks changed (but we don't keep track what) ''' # state changed events can only be raised if the current task is # started if self._current_task and self._current_task.has_started(): # tasks that have a world should also take the world state as # an argument if self._current_world: self.raise_event(StateChanged( self._current_world.state, self._current_task.state)) else: self.raise_event(StateChanged(self._current_task.state)) return True return False def _deregister_current_task(self): # deregister previous event managers if self._current_task: self._current_task.deinit() self._deregister_task_triggers(self._current_task) def _on_task_ended(self, task): assert (task == self._current_task) # when a task ends, it doesn't process any more events self._deregister_current_task() def _switch_new_task(self): ''' Asks the task scheduler for a new task, reset buffers and time, and registers the event handlers ''' # pick a new task self._current_task = self._task_scheduler.get_next_task() # register to the ending event self._current_task.ended_updated.register(self._on_task_ended) try: # This is to check whether the user didn't mess up in instantiating # the class self._current_task.get_world() except TypeError: raise RuntimeError("The task {0} is not correctly instantiated. " "Are you sure you are not forgetting to " "instantiate the class?".format( self._current_task)) self.logger.debug("Starting new task: {0}".format(self._current_task)) # check if it has a world: if self._current_task.get_world() != self._current_world: # if we had an ongoing world, end it. if self._current_world: self._current_world.end() self._deregister_task_triggers(self._current_world) self._current_world = self._current_task.get_world() if self._current_world: # register new event handlers for the world self._register_task_triggers(self._current_world) # initialize the new world self._current_world.start(self) self.world_updated(self._current_world) # reset state self._task_time = 0 self._reward = None self._input_channel.clear() self._output_channel.clear() self._output_channel_listener.clear() # register new event handlers self._register_task_triggers(self._current_task) # start the task, sending the current environment # so it can interact by sending back rewards and messages self._current_task.start(self) self.task_updated(self._current_task) def _deregister_task_triggers(self, task): for trigger in task.get_triggers(): try: self.event_manager.deregister(task, trigger) except ValueError: # if the trigger was not registered, we don't worry about it pass except KeyError: # if the trigger was not registered, we don't worry about it pass task.clean_dynamic_handlers() def _register_task_triggers(self, task): for trigger in task.get_triggers(): self._register_task_trigger(task, trigger) def _register_task_trigger(self, task, trigger): self.event_manager.register(task, trigger)
class ConsoleView(BaseView): def __init__(self, env, session, serializer, show_world=False): super(ConsoleView, self).__init__(env, session) # for visualization purposes, we keep an internal buffer of the # input and output stream so when they are cleared from task to # task, we can keep the history intact. self.input_buffer = '' self.output_buffer = '' self.panic = 'SKIP' # record what the learner says self._learner_channel = InputChannel(serializer) # record what the environment says self._env_channel = InputChannel(serializer) # listen to the updates in these channels self._learner_channel.sequence_updated.register( self.on_learner_sequence_updated) self._learner_channel.message_updated.register( self.on_learner_message_updated) self._env_channel.sequence_updated.register( self.on_env_sequence_updated) self._env_channel.message_updated.register(self.on_env_message_updated) if show_world: # register a handler to plot the world if show_world is active env.world_updated.register(self.on_world_updated) # connect the channels with the observed input bits session.env_token_updated.register(self.on_env_token_updated) session.learner_token_updated.register(self.on_learner_token_updated) del self.info['current_task'] def on_env_token_updated(self, token): self._env_channel.consume_bit(token) def on_learner_token_updated(self, token): self._learner_channel.consume_bit(token) def on_learner_message_updated(self, message): # we use the fact that messages arrive character by character if self._learner_channel.get_text(): self.input_buffer += self._learner_channel.get_text()[-1] self.input_buffer = self.input_buffer[-self._scroll_msg_length:] learner_input = self.channel_to_str( self.input_buffer, self._learner_channel.get_undeserialized()) self._win.addstr(self._learner_seq_y, 0, learner_input.encode(code)) self._win.refresh() def on_learner_sequence_updated(self, sequence): learner_input = self.channel_to_str( self.input_buffer, self._learner_channel.get_undeserialized()) self._win.addstr(self._learner_seq_y, 0, learner_input.encode(code)) self._win.refresh() def on_env_message_updated(self, message): if self._env_channel.get_text(): self.output_buffer += \ self._env_channel.get_text()[-1] self.output_buffer = self.output_buffer[-self._scroll_msg_length:] env_output = self.channel_to_str( self.output_buffer, self._env_channel.get_undeserialized()) self._win.addstr(self._teacher_seq_y, 0, env_output.encode(code)) self._win.refresh() def on_env_sequence_updated(self, sequence): env_output = self.channel_to_str( self.output_buffer, self._env_channel.get_undeserialized()) self._win.addstr(self._teacher_seq_y, 0, env_output.encode(code)) self._win.refresh() def on_world_updated(self, world): if world: world.state_updated.register(self.on_world_state_updated) self._worldwin.addstr(0, 0, str(world)) self._worldwin.refresh() else: self._worldwin.clear() self._worldwin.refresh() def on_world_state_updated(self, world): self._worldwin.addstr(0, 0, str(world)) self._worldwin.refresh() def initialize(self): # initialize curses self._stdscr = curses.initscr() begin_x = 0 begin_y = 0 self._teacher_seq_y = 0 self._learner_seq_y = 1 self._world_win_y = 3 self._world_win_x = 0 self._info_win_width = 20 self._info_win_height = 2 self._user_input_win_y = 2 self._user_input_win_x = 10 self.height, self.width = self._stdscr.getmaxyx() self._scroll_msg_length = self.width - self._info_win_width - 1 self._win = self._stdscr.subwin(self.height, self.width, begin_y, begin_x) self._worldwin = self._win.subwin(self.height - self._world_win_y, self.width - self._world_win_x, self._world_win_y, self._world_win_x) # create info box with reward and time self._info_win = self._win.subwin(self._info_win_height, self._info_win_width, 0, self.width - self._info_win_width) self._user_input_win = \ self._win.subwin(1, self.width - self._user_input_win_x, self._user_input_win_y, self._user_input_win_x) self._user_input_label_win = \ self._win.subwin(1, self._user_input_win_x - 1, self._user_input_win_y, 0) curses.noecho() curses.cbreak() def get_input(self): self._user_input_label_win.addstr(0, 0, 'input:') self._user_input_label_win.refresh() curses.echo() inputstr = self._user_input_win.getstr( 0, 0, self.width - self._user_input_win_x).decode(code) curses.noecho() self._user_input_win.clear() if inputstr == self.panic: inputstr = '' self._env._task_time = float('inf') return inputstr def channel_to_str(self, text, bits): length = self._scroll_msg_length - 10 return "{0:_>{length}}[{1: <8}]".format(text[-length:], bits[-7:], length=length)
class HumanLearner(BaseLearner): def __init__(self, serializer): """ Takes the serialization protocol :param serializer: """ self._serializer = serializer self._input_channel = InputChannel(serializer) self._output_channel = OutputChannel(serializer) self._input_channel.message_updated.register(self.on_message) self.logger = logging.getLogger(__name__) self.speaking = False def reward(self, reward): """ :param reward: :return: """ self.logger.info("Reward received: {0}".format(reward)) self._input_channel.clear() self._output_channel.clear() def next(self, user_input): """ If the buffer is empty, fill it with silence :param user_input: :return: """ if self._output_channel.is_empty(): self.logger.debug("Output buffer is empty, filling with silence") self._output_channel.set_message(self._serializer.SILENCE_TOKEN ) # Add 1 silence token to buffer output = self._output_channel.consume_bit() # Get the bit to return self._input_channel.consume_bit( user_input) # Interpret the bit from the learner return output def on_message(self, message): """ we ask for input on two consecutive silences :param message: :return: """ if message[ -2:] == self._serializer.SILENCE_TOKEN * 2 and self._output_channel.is_empty( ) and not self.speaking: self.ask_for_input() elif self._output_channel.is_empty(): self.speaking = False # If speaking, changes to speaking off def ask_for_input(self): """ :return: """ output = self._view.get_input() self.logger.debug( "Received input from the human: '{0}'".format(output)) if output: self.speaking = True output = re.compile('\.+').sub('.', output) self._output_channel.set_message(output)