Exemplo n.º 1
0
    def __init__(self, load_agent_model_from_directory: str = None):
        # Load database of movies (if you get an error unpickling movie_db.pkl then run pickle_converter.py)
        database = pickle.load(open("resources/movie_db.pkl", "rb"), encoding="latin1")

        # Create state tracker
        self.state_tracker = StateTracker(database)

        # Create user simulator with list of user goals
        self.user_simulated = RulebasedUsersim(
            json.load(open("resources/movie_user_goals.json", "r", encoding="utf-8")))

        # Create GUI for direct text interactions
        self.gui = ChatApplication()

        # Create user instance for direct text interactions
        self.user_interactive = User(nlu_path="user/regex_nlu.json", use_voice=False, gui=self.gui)

        # Create empty user (will be assigned on runtime)
        self.user = None

        # Create agent
        self.agent = DQNAgent(alpha=0.001, gamma=0.9, epsilon=0.5, epsilon_min=0.05,
                              n_actions=len(feasible_agent_actions), n_ordinals=3,
                              observation_dim=(StateTracker.state_size()),
                              batch_size=256, memory_len=80000, prioritized_memory=True,
                              replay_iter=16, replace_target_iter=200)
        if load_agent_model_from_directory:
            self.agent.load_agent_model(load_agent_model_from_directory)
Exemplo n.º 2
0
    def setUp(self):
        self.testTracker = StateTracker()
        self.branchStack = BranchStack()

        self.frame1 = {"a": 1, "b": 2, "c": 3}
        self.frame2 = {"d": 4, "b": 3}

        self.IS1 = InteractionState(self.frame1)
        self.IS2 = InteractionState(self.frame2)

        self.B1 = Branch("N.E.D.", self.IS1)
        self.B2 = Branch("U.W.S.", self.IS2)
Exemplo n.º 3
0
    def __init__(
        self,
        user_goals: List[UserGoal],
        emc_params: Dict,
        max_round_num: int,
        database: Dict,
        slot2values: Dict[str, List[Any]],
    ) -> None:

        self.user = UserSimulator(user_goals, max_round_num)
        self.emc = ErrorModelController(slot2values, emc_params)
        self.state_tracker = StateTracker(database, max_round_num)

        self.action_space = gym.spaces.Discrete(len(AGENT_ACTIONS))
        self.observation_space = gym.spaces.multi_binary.MultiBinary(
            self.state_tracker.get_state_size())
Exemplo n.º 4
0
 def __init__(self, path):
     with open(path) as f:
         self.data = json.load(f)
     self.tracker = StateTracker()
     self.four = {}  # output result
     self.index = 0  # the index of output
     self.feasible_action = {
         0: {
             'diaact': 'greeting',
             'inform_slots': {},
             'request_slots': {}
         },
         1: {
             'diaact': 'bye',
             'inform_slots': {},
             'request_slots': {}
         }
     }
     self.feasible_action_index = 2
Exemplo n.º 5
0
class DialogEnv(gym.Env):
    def __init__(
        self,
        user_goals: List[UserGoal],
        emc_params: Dict,
        max_round_num: int,
        database: Dict,
        slot2values: Dict[str, List[Any]],
    ) -> None:

        self.user = UserSimulator(user_goals, max_round_num)
        self.emc = ErrorModelController(slot2values, emc_params)
        self.state_tracker = StateTracker(database, max_round_num)

        self.action_space = gym.spaces.Discrete(len(AGENT_ACTIONS))
        self.observation_space = gym.spaces.multi_binary.MultiBinary(
            self.state_tracker.get_state_size())

    def step(self, agent_action_index: int):
        agent_action = map_index_to_action(agent_action_index)
        self.state_tracker.update_state_agent(agent_action)
        user_action, reward, done, success = self.user.step(agent_action)
        if not done:
            self.emc.infuse_error(user_action)
        self.state_tracker.update_state_user(user_action)
        next_state = self.state_tracker.get_state(done)
        return next_state, reward, done, success

    def reset(self):
        self.state_tracker.reset()
        init_user_action = self.user.reset()
        self.emc.infuse_error(init_user_action)
        self.state_tracker.update_state_user(init_user_action)
        return self.state_tracker.get_state()
Exemplo n.º 6
0
    def simulate(self,
                 num_steps,
                 measurement_speed,
                 is_graphing=False):  # TODO: implement graphing
        for i in range(num_steps):
            self.step(i, measurement_speed)
            self.spin_strategy()

        if is_graphing:
            pass

    def create_plot(self):  # TODO: implement
        pass


if __name__ == "__main__":
    init_state = (1 / math.sqrt(2)) * np.array([0, 1, -1, 0])
    init_loc = np.zeros(2)
    x_min = -2
    x_max = 2
    y_min = -2
    y_max = 2
    num_steps = 10
    measurement_speed = 10
    state_tracker = StateTracker(init_state)
    qubit1 = LocationTracker(init_loc, x_min, x_max, y_min, y_max, num_steps)
    qubit2 = LocationTracker(init_loc, x_min, x_max, y_min, y_max, num_steps)
    sim = Simulator(state_tracker, qubit1, qubit2)
    sim.simulate(num_steps, measurement_speed)
Exemplo n.º 7
0
    remove_empty_slots(database)

    # Load movie dict
    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    # Load goal File
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'),
                             encoding='latin1')

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    # sarsa_agent = SARSAgent(state_tracker.get_state_size(), constants)
    sess = K.get_session()
    ac_agent = ActorCritic(state_tracker.get_state_size(), constants, sess)
    #dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)


def run_round(state, warmup=False):
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    agent_action = ac_agent.act(state)
    # 2) Update state tracker with the agent's action
    state_tracker.update_state_agent(agent_action)
    # 3) User takes action given agent action
    user_action, reward, done, success = user.step(agent_action)
    if not done:
        # 4) Infuse error into semantic frame level of user action
Exemplo n.º 8
0
class TestStateTracker(unittest.TestCase):
    def setUp(self):
        self.testTracker = StateTracker()
        self.branchStack = BranchStack()

        self.frame1 = {"a": 1, "b": 2, "c": 3}
        self.frame2 = {"d": 4, "b": 3}

        self.IS1 = InteractionState(self.frame1)
        self.IS2 = InteractionState(self.frame2)

        self.B1 = Branch("N.E.D.", self.IS1)
        self.B2 = Branch("U.W.S.", self.IS2)

    def test_add_new_branch(self):
        self.testTracker.add_new_branch("context_test1")
        self.assertEqual((self.testTracker.branch_stack.size()), 1)

        self.testTracker.add_new_branch("context_test2")
        self.assertEqual((self.testTracker.branch_stack.size()), 2)

        self.testTracker.reset_branch_stack()
        self.testTracker.branch_stack.push(self.B1)
        self.testTracker.add_new_branch("context_test3")

        child_branch_frame_size = len(
            self.testTracker.get_lastest_state().frame)
        parent_branch_frame_size = len(
            self.testTracker.get_parent_latest_state().frame)

        self.assertEqual(child_branch_frame_size, parent_branch_frame_size)

        self.testTracker.reset_branch_stack()

    def test_merge_dicts(self):
        merged_frame1 = self.testTracker.merge_dicts(self.frame1, self.frame2,
                                                     [])
        self.assertEqual(merged_frame1, {"a": 1, "b": 3, "c": 3, 'd': 4})

        merged_frame2 = self.testTracker.merge_dicts(self.frame1, self.frame2,
                                                     ['b'])
        self.assertEqual(merged_frame2, {"a": 1, "b": 2, "c": 3, 'd': 4})

    def test_merge_current_branch_with_parent(self):
        self.testTracker.branch_stack.push(self.B1)
        self.testTracker.branch_stack.push(self.B2)

        self.testTracker.merge_current_branch_with_parent()

        self.assertEqual(self.testTracker.branch_stack.size(), 1)
        self.assertEqual(self.testTracker.get_current_branch().context,
                         "N.E.D.")
        self.assertEqual(self.testTracker.get_lastest_state().frame, {
            "a": 1,
            "b": 3,
            "c": 3,
            "d": 4
        })

        self.testTracker.reset_branch_stack()

    def test_commit_to_branch(self):
        self.testTracker.branch_stack.push(self.B1)

        self.testTracker.commit_to_branch({
            "d": 4,
            "e": 5,
            "f": 6,
            "g": 7
        }, ["f"])

        self.assertEqual(self.testTracker.get_lastest_state().frame, {
            "a": 1,
            "b": 2,
            "c": 3,
            "d": 4,
            "e": 5,
            "g": 7
        })

        self.testTracker.reset_branch_stack()
Exemplo n.º 9
0
    remove_empty_slots(database)

    # Load movie dict
    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    # Load goal File
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'),
                             encoding='latin1')

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    sarsa_agent = SARSAgent(state_tracker.get_state_size(), constants)
    #dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)


def run_round(state, warmup=False):
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    agent_action_index, agent_action = sarsa_agent.get_action(state,
                                                              use_rule=warmup)
    # 2) Update state tracker with the agent's action
    state_tracker.update_state_agent(agent_action)
    # 3) User takes action given agent action
    user_action, reward, done, success = user.step(agent_action)
    if not done:
        # 4) Infuse error into semantic frame level of user action
        emc.infuse_error(user_action)
Exemplo n.º 10
0
class GetFour:
    def __init__(self, path):
        with open(path) as f:
            self.data = json.load(f)
        self.tracker = StateTracker()
        self.four = {}  # output result
        self.index = 0  # the index of output
        self.feasible_action = {
            0: {
                'diaact': 'greeting',
                'inform_slots': {},
                'request_slots': {}
            },
            1: {
                'diaact': 'bye',
                'inform_slots': {},
                'request_slots': {}
            }
        }
        self.feasible_action_index = 2

    def init_episode(self):
        self.num_turns = 0  # the number of turns for a episode
        self.tracker.initialize_episode()
        self.episode_over = False
        self.reward = 0
        self.a_s_r_over_history = []  # action_state pairs history
        self.action = {}  # the action now
        self.state = {}
        self.episode_status = -1

    def get_a_s_r_over(self, episode_record):  # episode_record = [{},{},{}
        self.init_episode()
        self.num_turns = len(episode_record)
        for i in range(len(episode_record)):
            self.action = episode_record[i]
            a_s_r_over = {"3": False}
            if self.action["speaker"] == "agent":
                self.state = self.tracker.get_state()
                self.tracker.update(agent_action=self.action)
                self.reward += self.reward_function(self.episode_status)
                a_s_r_over["0"] = self.action
                self.action_index(self.action)
                a_s_r_over["1"] = self.state
                a_s_r_over["2"] = self.reward
                if a_s_r_over["1"]['agent_action'] == None:
                    a_s_r_over["1"]['agent_action'] = {
                        'diaact': 'greeting',
                        'inform_slots': {},
                        'request_slots': {}
                    }
                self.a_s_r_over_history.append(a_s_r_over)
            else:
                self.tracker.update(user_action=self.action)
                self.reward += self.reward_function(self.episode_status)
                if i == self.num_turns:
                    self.a_s_r["0"] = 0
                    self.a_s_r["1"] = self.state
                    self.a_s_r["2"] = self.reward
                    self.a_s_r_over_history.append(self.a_s_r)
        # when dialog over, update the latest reward
        self.episode_status = self.get_status(self.a_s_r_over_history[-1]["1"])
        self.reward += self.reward_function(self.episode_status)
        self.a_s_r_over_history[-2]["2"] = self.reward
        self.a_s_r_over_history[-2]["3"] = True
        return self.a_s_r_over_history

    # get four = [s_t, a_t, r, s_t+1, episode_over]
    def update_four(self, a_s_r_over_history):
        for i in range(len(a_s_r_over_history)):
            four = [{}, 0, 0, {}, False]
            if i != len(a_s_r_over_history) - 1:
                four[0] = a_s_r_over_history[i]["1"]
                four[1] = a_s_r_over_history[i]["0"]
                four[3] = a_s_r_over_history[i + 1]["1"]
                four[2] = a_s_r_over_history[i]["2"]
                four[4] = a_s_r_over_history[i]["3"]
                self.four[self.index] = four
                self.index += 1
            else:
                pass

    def get_four(self):
        for i in self.data.keys():
            episode = self.data[i]
            if len(episode) <= 2:
                continue
            a_s_r = self.get_a_s_r_over(episode)
            self.update_four(a_s_r)
        return self.four

    def reward_function(self, episode_status):
        if episode_status == 0:  # dialog failed
            reward = -self.num_turns
        elif episode_status == 1:  # dialog succeed
            reward = 2 * self.num_turns
        else:
            reward = -1
        return reward

    def get_status(self, state):
        for i in state["current_slots"]["inform_slots"]:
            if i == "phone_number":
                episode_status = 1  # dialog succeed
                break
            else:
                episode_status = 0  # dialog failed
        return episode_status

    # input: action   output: index of action and feasible_action
    def action_index(self, action):
        del action['speaker']
        if len(action['inform_slots']) > 0:
            for slot in action['inform_slots'].keys():
                action['inform_slots'][slot] = 'PLACEHOLDER'
        equal = False
        for i in range(self.feasible_action_index):
            if operator.eq(self.feasible_action[i], action) == True:
                equal = True
                # return i
        if equal == False:
            self.feasible_action[self.feasible_action_index] = action
            self.feasible_action_index += 1
Exemplo n.º 11
0
        for memory in memorys:
            state, agent_action_index, reward, next_state, done = memory
            state = ' '.join(map(str, state))
            next_state = ' '.join(map(str, next_state))
            memory = [state, agent_action_index, reward, next_state, done]
            memory = '||'.join(map(str, memory)) + '\n'
            f.writelines(memory)


if __name__ == "__main__":
    from keras import backend as K
    K.clear_session()

    sim = UserSimulator()
    p = Policy()
    st = StateTracker()
    ms = warmup_run()
    # for i in range(6):
    #     print(ms[i])
    # print(st.get_tuple_state(ms[4]))
    memorys = []
    print("Get experience tuple state: ...")
    for m in ms:
        memorys.append(st.get_tuple_state(m))

    save_experience(memorys)

    with open("constants.json") as f:
        constants = json.load(f)
    dqn = DQNAgent(state_size, constants)
    dqn.memory = memorys
Exemplo n.º 12
0
from state_tracker import StateTracker
from state_tracking.branch_stack import BranchStack
from state_tracking.branch import Branch
from state_tracking.interaction_state import InteractionState

testTracker = StateTracker()
branchStack = BranchStack()

frame1 = {"a": 1, "b": 2, "c": 3}
frame2 = {"d": 4, "b": 3, "e": 5}

IS1 = InteractionState(frame1)
IS2 = InteractionState(frame2)

B1 = Branch("N.E.D.", IS1)
B2 = Branch("U.W.S.", IS2)

testTracker.branch_stack.push(B1)
testTracker.branch_stack.push(B2)

testTracker.merge_current_branch_with_parent(["e"])
branchStack.push(1)
branchStack.push(2)
branchStack.push(3)

print(testTracker.get_current_branch().context)
print(testTracker.get_current_branch().latest_state.frame)
print(testTracker.get_current_branch().history[0].frame)
Exemplo n.º 13
0
from rb_policy import Policy

with open("constants.json") as f:
    constants = json.load(f)

run_dict = constants['run']
USE_USERSIM = run_dict['usersim']
NUM_EP_TEST = run_dict['num_ep_run']
MAX_ROUND_NUM = run_dict['max_round_num']

# Init. Objects
if USE_USERSIM:
    user = UserSimulator()
else:
    user = User(constants)
state_tracker = StateTracker()
dqn_agent = DQNAgent(state_tracker.get_state_size(),
                     constants)  # 已经加载已经训练的参数_load_weights()
policy = Policy()


def test_run():
    """
    Runs the loop that tests the agent.

    Tests the agent on the goal-oriented chatbot task. Only for evaluating a trained agent. Terminates when the episode
    reaches NUM_EP_TEST.

    """

    print('Testing Started...')
Exemplo n.º 14
0
    remove_empty_slots(database)

    # Load movie dict
    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    # Load goal File
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'),
                             encoding='latin1')

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    acm = AdvantageACM(state_tracker.get_state_size(), constants)


def run_round(state, Agent_Actions, User_Actions, tot_slt_len, stp, q, warmup):
    u_r = 0  ##User Repeatition
    a_r = 0  ##Agent Repeatition
    a_q = 0  ##Agent Question
    u_q = q  ##User Question
    pen = 0  ##User asked Question Agent replied Question
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    agent_action_index, agent_action = acm.act(state, warmup)
    print('Agent Action_Index:', agent_action_index)
    #print('Agent_Action:',agent_action)
    if (agent_action['intent'] == 'request'):
        a_q = 1
Exemplo n.º 15
0
    # Load movie dict
    # db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')
    db_dict = json.load(open(DICT_FILE_PATH, encoding='utf-8'))[0]

    # Load goal file
    # user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'), encoding='latin1')
    user_goals = json.load(open(USER_GOALS_FILE_PATH, encoding='utf-8'))

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)


def test_run():
    """
    Runs the loop that tests the agent.

    Tests the agent on the goal-oriented chatbot task. Only for evaluating a trained agent. Terminates when the episode
    reaches NUM_EP_TEST.

    """

    print('Testing Started...')
    episode = 0
    while episode < NUM_EP_TEST:
Exemplo n.º 16
0
    remove_empty_slots(database)

    # Load movie dict
    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    # Load goal File
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'),
                             encoding='latin1')

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    dqn_agent = DuellingQNetworkAgent(state_tracker.get_state_size(),
                                      constants)


def run_round(state, warmup=False):
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    agent_action_index, agent_action = dqn_agent.get_action(state,
                                                            use_rule=warmup)
    # 2) Update state tracker with the agent's action
    state_tracker.update_state_agent(agent_action)
    # 3) User takes action given agent action
    user_action, reward, done, success = user.step(agent_action)
    if not done:
        # 4) Infuse error into semantic frame level of user action
        emc.infuse_error(user_action)
Exemplo n.º 17
0
class Dialogue:

    def __init__(self, load_agent_model_from_directory: str = None):
        # Load database of movies (if you get an error unpickling movie_db.pkl then run pickle_converter.py)
        database = pickle.load(open("resources/movie_db.pkl", "rb"), encoding="latin1")

        # Create state tracker
        self.state_tracker = StateTracker(database)

        # Create user simulator with list of user goals
        self.user_simulated = RulebasedUsersim(
            json.load(open("resources/movie_user_goals.json", "r", encoding="utf-8")))

        # Create GUI for direct text interactions
        self.gui = ChatApplication()

        # Create user instance for direct text interactions
        self.user_interactive = User(nlu_path="user/regex_nlu.json", use_voice=False, gui=self.gui)

        # Create empty user (will be assigned on runtime)
        self.user = None

        # Create agent
        self.agent = DQNAgent(alpha=0.001, gamma=0.9, epsilon=0.5, epsilon_min=0.05,
                              n_actions=len(feasible_agent_actions), n_ordinals=3,
                              observation_dim=(StateTracker.state_size()),
                              batch_size=256, memory_len=80000, prioritized_memory=True,
                              replay_iter=16, replace_target_iter=200)
        if load_agent_model_from_directory:
            self.agent.load_agent_model(load_agent_model_from_directory)

    def run(self, n_episodes, step_size=100, warm_up=False, interactive=False, learning=True):
        """
        Runs the loop that trains the agent.

        Trains the agent on the goal-oriented dialog task (except warm_up, which fills memory with rule-based behavior)
        Training of the agent's neural network occurs every episode that step_size is a multiple of.
        Replay memory is flushed every time a best success rate is recorded, starting with success_rate_threshold.
        Terminates when the episode reaches n_episodes.

        """

        if interactive:
            self.user = self.user_interactive
            self.gui.window.update()
        else:
            self.user = self.user_simulated

        if not learning:
            self.agent.epsilon = 0.0

        batch_episode_rewards = []
        batch_successes = []
        batch_success_best = 0.0
        step_counter = 0

        for episode in range(n_episodes):

            # print("########################\n------ EPISODE {} ------\n########################".format(episode))
            self.episode_reset(interactive)
            done = False
            success = False
            episode_reward = 0

            # Initialize episode with first user and agent action
            prev_observation = self.state_tracker.get_state()
            # 1) Agent takes action given state tracker's representation of dialogue (observation)
            prev_agent_action = self.agent.choose_action(prev_observation, warm_up=warm_up)
            while not done:
                step_counter += 1
                # 2) 3) 4) 5) 6a)
                observation, reward, done, success = self.env_step(prev_agent_action, interactive)
                if learning:
                    replay = step_counter % self.agent.replay_iter == 0
                    # 6b) Add experience
                    self.agent.update(prev_observation, prev_agent_action, observation, reward, done,
                                      warm_up=warm_up, replay=replay)
                # 1) Agent takes action given state tracker's representation of dialogue (observation)
                agent_action = self.agent.choose_action(observation, warm_up=warm_up)

                episode_reward += reward
                prev_observation = observation
                prev_agent_action = agent_action

            if not warm_up and learning:
                self.agent.end_episode(n_episodes)

            # Evaluation
            # print("--- Episode: {} SUCCESS: {} REWARD: {} ---".format(episode, success, episode_reward))
            batch_episode_rewards.append(episode_reward)
            batch_successes.append(success)
            if episode % step_size == 0:
                # Check success rate
                success_rate = mean(batch_successes)
                avg_reward = mean(batch_episode_rewards)

                print('Episode: {} SUCCESS RATE: {} Avg Reward: {}'.format(episode, success_rate,
                                                                           avg_reward))
                if success_rate > batch_success_best and learning and not warm_up:
                    print('Episode: {} NEW BEST SUCCESS RATE: {} Avg Reward: {}'.format(episode, success_rate,
                                                                                        avg_reward))
                    self.agent.save_agent_model()
                    batch_success_best = success_rate
                batch_successes = []
                batch_episode_rewards = []

        if learning and not warm_up:
            # Save final model
            self.agent.save_agent_model()

    def env_step(self, agent_action, interactive=False):
        # 2) Update state tracker with the agent's action
        self.state_tracker.update_state_agent(agent_action)
        if interactive:
            self.gui.insert_message(agent_action.to_utterance(), "Shop Assistant")
        # print(agent_action)
        # 3) User takes action given agent action
        user_action, reward, done, success = self.user.get_action(agent_action)
        # print(user_action)
        # 4) Infuse error into user action (currently inactive)
        # 5) Update state tracker with user action
        self.state_tracker.update_state_user(user_action)
        # 6a) Get next state
        observation = self.state_tracker.get_state(done)
        return observation, reward, done, True if success is 1 else False

    def episode_reset(self, interactive=False):
        # Reset the state tracker
        self.state_tracker.reset()
        # Reset the user
        self.user.reset()
        # Reset the agent
        self.agent.turn = 0
        # Reset the interactive GUI
        if interactive:
            self.gui.reset_text_widget()
            self.gui.insert_message("Guten Tag! Wie kann ich Ihnen heute helfen?", "Shop Assistant")
        # User start action
        user_action, _, _, _ = self.user.get_action(None)
        # print(user_action)
        self.state_tracker.update_state_user(user_action)
Exemplo n.º 18
0
from state_tracker import StateTracker
from user_simulator import UserSimulator
from agent_dqn import AgentDQN

params = {
    'experience_replay_pool_size': 10000,
    'dqn_hidden_size': 60,
    'gamma': 0.9,
    'predict_mode': True,
    'max_turn': 40,
    'trained_model_path': 'data/saved_model.p'
}

state_tracker = StateTracker()
usersim = UserSimulator(3)
agent = AgentDQN(params)


def run_episode(count):
    for i in range(count):
        print("dialog:", i)
        episode_over = False
        turn = 0
        state_tracker.initialize_episode()
        agent_action = {
            'diaact': 'greeting',
            'inform_slots': {},
            'request_slots': {}
        }
        state_tracker.update(agent_action=agent_action)
        print("sys:", agent_action)
Exemplo n.º 19
0
    remove_empty_slots(database)

    # Load movie dict
    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    # Load goal File
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'),
                             encoding='latin1')

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    dqn_agent = DRQNAgent(
        state_tracker.get_state_size(),
        constants)  # the variable dqn agent is intialized to a DRQN agent


def run_round(states, warmup=False):
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    #print(states[0].shape)
    if len(states) > 3:
        state_1 = np.stack((states[-4], states[-3], states[-2], states[-1]),
                           axis=0)
    else:
        state_1 = np.vstack((np.zeros(
            (4 - len(states), 224)), np.array(states)))
    agent_action_index, agent_action = dqn_agent.get_action(state_1,