コード例 #1
0
    def state_update(self, obs, action):
        # update history
        if self.dst:
            self.dst.state['history'].append([str(action)])

        # NLU parsing
        input_act = self.nlu.parse(
            obs,
            sum(self.dst.state['history'], [])
            if self.dst else []) if self.nlu else obs

        # state tracking
        state = self.dst.update(input_act) if self.dst else input_act

        # update history
        if self.dst:
            self.dst.state['history'][-1].append(str(obs))

        # encode state
        encoded_state = self.state_encoder.encode(
            state) if self.state_encoder else state

        if self.nlu and self.dst:
            self.dst.state['user_action'] = input_act
        elif self.dst and not isinstance(
                self.dst, word_dst.MDBTTracker):  # for act-in act-out agent
            self.dst.state['user_action'] = obs

        logger.nl(f'User utterance: {obs}')
        logger.act(f'Inferred user action: {input_act}')
        logger.state(f'Dialog state: {state}')

        return input_act, state, encoded_state
コード例 #2
0
def gen_avg_result(agent, env, num_eval=NUM_EVAL):
    returns, lens, successes, precs, recs, f1s, book_rates = [], [], [], [], [], [], []
    for _ in range(num_eval):
        returns.append(gen_result(agent, env))
        lens.append(env.clock.t)
        if env.evaluator:
            successes.append(env.evaluator.task_success())
            _p, _r, _f1 = env.evaluator.inform_F1()
            if _f1 is not None:
                precs.append(_p)
                recs.append(_r)
                f1s.append(_f1)
            _book = env.evaluator.book_rate()
            if _book is not None:
                book_rates.append(_book)
        elif hasattr(env, 'get_task_success'):
            successes.append(env.get_task_success())
        logger.nl(f'---A dialog session is done---')
    mean_success = None if len(successes) == 0 else np.mean(successes)
    mean_p = None if len(precs) == 0 else np.mean(precs)
    mean_r = None if len(recs) == 0 else np.mean(recs)
    mean_f1 = None if len(f1s) == 0 else np.mean(f1s)
    mean_book_rate = None if len(book_rates) == 0 else np.mean(book_rates)
    return np.mean(returns), np.mean(
        lens), mean_success, mean_p, mean_r, mean_f1, mean_book_rate
コード例 #3
0
 def run_rl(self):
     '''Run the main RL loop until clock.max_frame'''
     logger.info(
         f'Running RL loop for trial {self.spec["meta"]["trial"]} session {self.index}'
     )
     clock = self.env.clock
     obs = self.env.reset()
     clock.tick('t')
     self.agent.reset(obs)
     done = False
     while True:
         if util.epi_done(done):  # before starting another episode
             logger.nl(f'A dialog session is done')
             self.try_ckpt(self.agent, self.env)
             if clock.get() < clock.max_frame:  # reset and continue
                 clock.tick('epi')
                 obs = self.env.reset()
                 self.agent.reset(obs)
                 done = False
         self.try_ckpt(self.agent, self.env)
         if clock.get() >= clock.max_frame:  # finish
             break
         clock.tick('t')
         action = self.agent.act(obs)
         next_obs, reward, done, info = self.env.step(action)
         self.agent.update(obs, action, reward, next_obs, done)
         obs = next_obs
コード例 #4
0
    def act(self, obs):
        '''Standard act method from algorithm.'''
        action = self.algorithm.act(self.body.encoded_state)
        self.body.action = action

        output_act, decoded_action = self.action_decode(action, self.body.state)

        logger.act(f'System action: {action}')
        logger.nl(f'System utterance: {decoded_action}')

        return decoded_action