Ejemplo n.º 1
0
    def __init__(self, load_agent_model_from_directory: str = None):
        # Load database of movies (if you get an error unpickling movie_db.pkl then run pickle_converter.py)
        database = pickle.load(open("resources/movie_db.pkl", "rb"), encoding="latin1")

        # Create state tracker
        self.state_tracker = StateTracker(database)

        # Create user simulator with list of user goals
        self.user_simulated = RulebasedUsersim(
            json.load(open("resources/movie_user_goals.json", "r", encoding="utf-8")))

        # Create GUI for direct text interactions
        self.gui = ChatApplication()

        # Create user instance for direct text interactions
        self.user_interactive = User(nlu_path="user/regex_nlu.json", use_voice=False, gui=self.gui)

        # Create empty user (will be assigned on runtime)
        self.user = None

        # Create agent
        self.agent = DQNAgent(alpha=0.001, gamma=0.9, epsilon=0.5, epsilon_min=0.05,
                              n_actions=len(feasible_agent_actions), n_ordinals=3,
                              observation_dim=(StateTracker.state_size()),
                              batch_size=256, memory_len=80000, prioritized_memory=True,
                              replay_iter=16, replace_target_iter=200)
        if load_agent_model_from_directory:
            self.agent.load_agent_model(load_agent_model_from_directory)
Ejemplo n.º 2
0
    def setUp(self):
        self.testTracker = StateTracker()
        self.branchStack = BranchStack()

        self.frame1 = {"a": 1, "b": 2, "c": 3}
        self.frame2 = {"d": 4, "b": 3}

        self.IS1 = InteractionState(self.frame1)
        self.IS2 = InteractionState(self.frame2)

        self.B1 = Branch("N.E.D.", self.IS1)
        self.B2 = Branch("U.W.S.", self.IS2)
Ejemplo n.º 3
0
    def __init__(
        self,
        user_goals: List[UserGoal],
        emc_params: Dict,
        max_round_num: int,
        database: Dict,
        slot2values: Dict[str, List[Any]],
    ) -> None:

        self.user = UserSimulator(user_goals, max_round_num)
        self.emc = ErrorModelController(slot2values, emc_params)
        self.state_tracker = StateTracker(database, max_round_num)

        self.action_space = gym.spaces.Discrete(len(AGENT_ACTIONS))
        self.observation_space = gym.spaces.multi_binary.MultiBinary(
            self.state_tracker.get_state_size())
Ejemplo n.º 4
0
 def __init__(self, path):
     with open(path) as f:
         self.data = json.load(f)
     self.tracker = StateTracker()
     self.four = {}  # output result
     self.index = 0  # the index of output
     self.feasible_action = {
         0: {
             'diaact': 'greeting',
             'inform_slots': {},
             'request_slots': {}
         },
         1: {
             'diaact': 'bye',
             'inform_slots': {},
             'request_slots': {}
         }
     }
     self.feasible_action_index = 2
Ejemplo n.º 5
0
from state_tracker import StateTracker
from user_simulator import UserSimulator
from agent_dqn import AgentDQN

params = {
    'experience_replay_pool_size': 10000,
    'dqn_hidden_size': 60,
    'gamma': 0.9,
    'predict_mode': True,
    'max_turn': 40,
    'trained_model_path': 'data/saved_model.p'
}

state_tracker = StateTracker()
usersim = UserSimulator(3)
agent = AgentDQN(params)


def run_episode(count):
    for i in range(count):
        print("dialog:", i)
        episode_over = False
        turn = 0
        state_tracker.initialize_episode()
        agent_action = {
            'diaact': 'greeting',
            'inform_slots': {},
            'request_slots': {}
        }
        state_tracker.update(agent_action=agent_action)
        print("sys:", agent_action)
Ejemplo n.º 6
0
    remove_empty_slots(database)

    # Load movie dict
    db_dict = pickle.load(open(DICT_FILE_PATH, 'rb'), encoding='latin1')

    # Load goal File
    user_goals = pickle.load(open(USER_GOALS_FILE_PATH, 'rb'),
                             encoding='latin1')

    # Init. Objects
    if USE_USERSIM:
        user = UserSimulator(user_goals, constants, database)
    else:
        user = User(constants)
    emc = ErrorModelController(db_dict, constants)
    state_tracker = StateTracker(database, constants)
    sarsa_agent = SARSAgent(state_tracker.get_state_size(), constants)
    #dqn_agent = DQNAgent(state_tracker.get_state_size(), constants)


def run_round(state, warmup=False):
    # 1) Agent takes action given state tracker's representation of dialogue (state)
    agent_action_index, agent_action = sarsa_agent.get_action(state,
                                                              use_rule=warmup)
    # 2) Update state tracker with the agent's action
    state_tracker.update_state_agent(agent_action)
    # 3) User takes action given agent action
    user_action, reward, done, success = user.step(agent_action)
    if not done:
        # 4) Infuse error into semantic frame level of user action
        emc.infuse_error(user_action)
Ejemplo n.º 7
0
    def simulate(self,
                 num_steps,
                 measurement_speed,
                 is_graphing=False):  # TODO: implement graphing
        for i in range(num_steps):
            self.step(i, measurement_speed)
            self.spin_strategy()

        if is_graphing:
            pass

    def create_plot(self):  # TODO: implement
        pass


if __name__ == "__main__":
    init_state = (1 / math.sqrt(2)) * np.array([0, 1, -1, 0])
    init_loc = np.zeros(2)
    x_min = -2
    x_max = 2
    y_min = -2
    y_max = 2
    num_steps = 10
    measurement_speed = 10
    state_tracker = StateTracker(init_state)
    qubit1 = LocationTracker(init_loc, x_min, x_max, y_min, y_max, num_steps)
    qubit2 = LocationTracker(init_loc, x_min, x_max, y_min, y_max, num_steps)
    sim = Simulator(state_tracker, qubit1, qubit2)
    sim.simulate(num_steps, measurement_speed)
Ejemplo n.º 8
0
        for memory in memorys:
            state, agent_action_index, reward, next_state, done = memory
            state = ' '.join(map(str, state))
            next_state = ' '.join(map(str, next_state))
            memory = [state, agent_action_index, reward, next_state, done]
            memory = '||'.join(map(str, memory)) + '\n'
            f.writelines(memory)


if __name__ == "__main__":
    from keras import backend as K
    K.clear_session()

    sim = UserSimulator()
    p = Policy()
    st = StateTracker()
    ms = warmup_run()
    # for i in range(6):
    #     print(ms[i])
    # print(st.get_tuple_state(ms[4]))
    memorys = []
    print("Get experience tuple state: ...")
    for m in ms:
        memorys.append(st.get_tuple_state(m))

    save_experience(memorys)

    with open("constants.json") as f:
        constants = json.load(f)
    dqn = DQNAgent(state_size, constants)
    dqn.memory = memorys
Ejemplo n.º 9
0
from state_tracker import StateTracker
from state_tracking.branch_stack import BranchStack
from state_tracking.branch import Branch
from state_tracking.interaction_state import InteractionState

testTracker = StateTracker()
branchStack = BranchStack()

frame1 = {"a": 1, "b": 2, "c": 3}
frame2 = {"d": 4, "b": 3, "e": 5}

IS1 = InteractionState(frame1)
IS2 = InteractionState(frame2)

B1 = Branch("N.E.D.", IS1)
B2 = Branch("U.W.S.", IS2)

testTracker.branch_stack.push(B1)
testTracker.branch_stack.push(B2)

testTracker.merge_current_branch_with_parent(["e"])
branchStack.push(1)
branchStack.push(2)
branchStack.push(3)

print(testTracker.get_current_branch().context)
print(testTracker.get_current_branch().latest_state.frame)
print(testTracker.get_current_branch().history[0].frame)