示例#1
0
class DialogEnv(gym.Env):
    def __init__(
        self,
        user_goals: List[UserGoal],
        emc_params: Dict,
        max_round_num: int,
        database: Dict,
        slot2values: Dict[str, List[Any]],
    ) -> None:

        self.user = UserSimulator(user_goals, max_round_num)
        self.emc = ErrorModelController(slot2values, emc_params)
        self.state_tracker = StateTracker(database, max_round_num)

        self.action_space = gym.spaces.Discrete(len(AGENT_ACTIONS))
        self.observation_space = gym.spaces.multi_binary.MultiBinary(
            self.state_tracker.get_state_size())

    def step(self, agent_action_index: int):
        agent_action = map_index_to_action(agent_action_index)
        self.state_tracker.update_state_agent(agent_action)
        user_action, reward, done, success = self.user.step(agent_action)
        if not done:
            self.emc.infuse_error(user_action)
        self.state_tracker.update_state_user(user_action)
        next_state = self.state_tracker.get_state(done)
        return next_state, reward, done, success

    def reset(self):
        self.state_tracker.reset()
        init_user_action = self.user.reset()
        self.emc.infuse_error(init_user_action)
        self.state_tracker.update_state_user(init_user_action)
        return self.state_tracker.get_state()
示例#2
0
class Dialogue:

    def __init__(self, load_agent_model_from_directory: str = None):
        # Load database of movies (if you get an error unpickling movie_db.pkl then run pickle_converter.py)
        database = pickle.load(open("resources/movie_db.pkl", "rb"), encoding="latin1")

        # Create state tracker
        self.state_tracker = StateTracker(database)

        # Create user simulator with list of user goals
        self.user_simulated = RulebasedUsersim(
            json.load(open("resources/movie_user_goals.json", "r", encoding="utf-8")))

        # Create GUI for direct text interactions
        self.gui = ChatApplication()

        # Create user instance for direct text interactions
        self.user_interactive = User(nlu_path="user/regex_nlu.json", use_voice=False, gui=self.gui)

        # Create empty user (will be assigned on runtime)
        self.user = None

        # Create agent
        self.agent = DQNAgent(alpha=0.001, gamma=0.9, epsilon=0.5, epsilon_min=0.05,
                              n_actions=len(feasible_agent_actions), n_ordinals=3,
                              observation_dim=(StateTracker.state_size()),
                              batch_size=256, memory_len=80000, prioritized_memory=True,
                              replay_iter=16, replace_target_iter=200)
        if load_agent_model_from_directory:
            self.agent.load_agent_model(load_agent_model_from_directory)

    def run(self, n_episodes, step_size=100, warm_up=False, interactive=False, learning=True):
        """
        Runs the loop that trains the agent.

        Trains the agent on the goal-oriented dialog task (except warm_up, which fills memory with rule-based behavior)
        Training of the agent's neural network occurs every episode that step_size is a multiple of.
        Replay memory is flushed every time a best success rate is recorded, starting with success_rate_threshold.
        Terminates when the episode reaches n_episodes.

        """

        if interactive:
            self.user = self.user_interactive
            self.gui.window.update()
        else:
            self.user = self.user_simulated

        if not learning:
            self.agent.epsilon = 0.0

        batch_episode_rewards = []
        batch_successes = []
        batch_success_best = 0.0
        step_counter = 0

        for episode in range(n_episodes):

            # print("########################\n------ EPISODE {} ------\n########################".format(episode))
            self.episode_reset(interactive)
            done = False
            success = False
            episode_reward = 0

            # Initialize episode with first user and agent action
            prev_observation = self.state_tracker.get_state()
            # 1) Agent takes action given state tracker's representation of dialogue (observation)
            prev_agent_action = self.agent.choose_action(prev_observation, warm_up=warm_up)
            while not done:
                step_counter += 1
                # 2) 3) 4) 5) 6a)
                observation, reward, done, success = self.env_step(prev_agent_action, interactive)
                if learning:
                    replay = step_counter % self.agent.replay_iter == 0
                    # 6b) Add experience
                    self.agent.update(prev_observation, prev_agent_action, observation, reward, done,
                                      warm_up=warm_up, replay=replay)
                # 1) Agent takes action given state tracker's representation of dialogue (observation)
                agent_action = self.agent.choose_action(observation, warm_up=warm_up)

                episode_reward += reward
                prev_observation = observation
                prev_agent_action = agent_action

            if not warm_up and learning:
                self.agent.end_episode(n_episodes)

            # Evaluation
            # print("--- Episode: {} SUCCESS: {} REWARD: {} ---".format(episode, success, episode_reward))
            batch_episode_rewards.append(episode_reward)
            batch_successes.append(success)
            if episode % step_size == 0:
                # Check success rate
                success_rate = mean(batch_successes)
                avg_reward = mean(batch_episode_rewards)

                print('Episode: {} SUCCESS RATE: {} Avg Reward: {}'.format(episode, success_rate,
                                                                           avg_reward))
                if success_rate > batch_success_best and learning and not warm_up:
                    print('Episode: {} NEW BEST SUCCESS RATE: {} Avg Reward: {}'.format(episode, success_rate,
                                                                                        avg_reward))
                    self.agent.save_agent_model()
                    batch_success_best = success_rate
                batch_successes = []
                batch_episode_rewards = []

        if learning and not warm_up:
            # Save final model
            self.agent.save_agent_model()

    def env_step(self, agent_action, interactive=False):
        # 2) Update state tracker with the agent's action
        self.state_tracker.update_state_agent(agent_action)
        if interactive:
            self.gui.insert_message(agent_action.to_utterance(), "Shop Assistant")
        # print(agent_action)
        # 3) User takes action given agent action
        user_action, reward, done, success = self.user.get_action(agent_action)
        # print(user_action)
        # 4) Infuse error into user action (currently inactive)
        # 5) Update state tracker with user action
        self.state_tracker.update_state_user(user_action)
        # 6a) Get next state
        observation = self.state_tracker.get_state(done)
        return observation, reward, done, True if success is 1 else False

    def episode_reset(self, interactive=False):
        # Reset the state tracker
        self.state_tracker.reset()
        # Reset the user
        self.user.reset()
        # Reset the agent
        self.agent.turn = 0
        # Reset the interactive GUI
        if interactive:
            self.gui.reset_text_widget()
            self.gui.insert_message("Guten Tag! Wie kann ich Ihnen heute helfen?", "Shop Assistant")
        # User start action
        user_action, _, _, _ = self.user.get_action(None)
        # print(user_action)
        self.state_tracker.update_state_user(user_action)
示例#3
0
文件: get_four.py 项目: Helicqin/DQN
class GetFour:
    def __init__(self, path):
        with open(path) as f:
            self.data = json.load(f)
        self.tracker = StateTracker()
        self.four = {}  # output result
        self.index = 0  # the index of output
        self.feasible_action = {
            0: {
                'diaact': 'greeting',
                'inform_slots': {},
                'request_slots': {}
            },
            1: {
                'diaact': 'bye',
                'inform_slots': {},
                'request_slots': {}
            }
        }
        self.feasible_action_index = 2

    def init_episode(self):
        self.num_turns = 0  # the number of turns for a episode
        self.tracker.initialize_episode()
        self.episode_over = False
        self.reward = 0
        self.a_s_r_over_history = []  # action_state pairs history
        self.action = {}  # the action now
        self.state = {}
        self.episode_status = -1

    def get_a_s_r_over(self, episode_record):  # episode_record = [{},{},{}
        self.init_episode()
        self.num_turns = len(episode_record)
        for i in range(len(episode_record)):
            self.action = episode_record[i]
            a_s_r_over = {"3": False}
            if self.action["speaker"] == "agent":
                self.state = self.tracker.get_state()
                self.tracker.update(agent_action=self.action)
                self.reward += self.reward_function(self.episode_status)
                a_s_r_over["0"] = self.action
                self.action_index(self.action)
                a_s_r_over["1"] = self.state
                a_s_r_over["2"] = self.reward
                if a_s_r_over["1"]['agent_action'] == None:
                    a_s_r_over["1"]['agent_action'] = {
                        'diaact': 'greeting',
                        'inform_slots': {},
                        'request_slots': {}
                    }
                self.a_s_r_over_history.append(a_s_r_over)
            else:
                self.tracker.update(user_action=self.action)
                self.reward += self.reward_function(self.episode_status)
                if i == self.num_turns:
                    self.a_s_r["0"] = 0
                    self.a_s_r["1"] = self.state
                    self.a_s_r["2"] = self.reward
                    self.a_s_r_over_history.append(self.a_s_r)
        # when dialog over, update the latest reward
        self.episode_status = self.get_status(self.a_s_r_over_history[-1]["1"])
        self.reward += self.reward_function(self.episode_status)
        self.a_s_r_over_history[-2]["2"] = self.reward
        self.a_s_r_over_history[-2]["3"] = True
        return self.a_s_r_over_history

    # get four = [s_t, a_t, r, s_t+1, episode_over]
    def update_four(self, a_s_r_over_history):
        for i in range(len(a_s_r_over_history)):
            four = [{}, 0, 0, {}, False]
            if i != len(a_s_r_over_history) - 1:
                four[0] = a_s_r_over_history[i]["1"]
                four[1] = a_s_r_over_history[i]["0"]
                four[3] = a_s_r_over_history[i + 1]["1"]
                four[2] = a_s_r_over_history[i]["2"]
                four[4] = a_s_r_over_history[i]["3"]
                self.four[self.index] = four
                self.index += 1
            else:
                pass

    def get_four(self):
        for i in self.data.keys():
            episode = self.data[i]
            if len(episode) <= 2:
                continue
            a_s_r = self.get_a_s_r_over(episode)
            self.update_four(a_s_r)
        return self.four

    def reward_function(self, episode_status):
        if episode_status == 0:  # dialog failed
            reward = -self.num_turns
        elif episode_status == 1:  # dialog succeed
            reward = 2 * self.num_turns
        else:
            reward = -1
        return reward

    def get_status(self, state):
        for i in state["current_slots"]["inform_slots"]:
            if i == "phone_number":
                episode_status = 1  # dialog succeed
                break
            else:
                episode_status = 0  # dialog failed
        return episode_status

    # input: action   output: index of action and feasible_action
    def action_index(self, action):
        del action['speaker']
        if len(action['inform_slots']) > 0:
            for slot in action['inform_slots'].keys():
                action['inform_slots'][slot] = 'PLACEHOLDER'
        equal = False
        for i in range(self.feasible_action_index):
            if operator.eq(self.feasible_action[i], action) == True:
                equal = True
                # return i
        if equal == False:
            self.feasible_action[self.feasible_action_index] = action
            self.feasible_action_index += 1