コード例 #1
0
ファイル: vhus.py プロジェクト: aaa123git/tatk
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'
    ):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for VHUS Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_simulator.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(config['load'])
コード例 #2
0
    def __init__(self):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(config['load'])
コード例 #3
0
    def __init__(self, config):

        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager

        self.print_per_batch = config['print_per_batch']
        self.save_dir = config['save_dir']
        self.save_per_epoch = config['save_per_epoch']
        seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
        train_goals, train_usrdas, train_sysdas, \
        test_goals, test_usrdas, test_sysdas, \
        val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
            seq_goals, seq_usr_dass, seq_sys_dass)
        self.data_train = (train_goals, train_usrdas, train_sysdas,
                           config['batchsz'])
        self.data_valid = (val_goals, val_usrdas, val_sysdas,
                           config['batchsz'])
        self.data_test = (test_goals, test_usrdas, test_sysdas,
                          config['batchsz'])
        self.alpha = config['alpha']
        self.optim = torch.optim.Adam(self.user.parameters(), lr=config['lr'])
        self.nll_loss = nn.NLLLoss(ignore_index=0)  # PAD=0
        self.bce_loss = nn.BCEWithLogitsLoss()
コード例 #4
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'):
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(archive_file, model_file, config['load'])
コード例 #5
0
    def __init__(self):
        """
        Constructor for User_Policy_Agenda class.
        """
        self.max_turn = 40
        self.max_initiative = 4

        self.goal_generator = GoalGenerator()

        self.__turn = 0
        self.goal = None
        self.agenda = None

        Policy.__init__(self)
コード例 #6
0
    def __init__(self):
        """
        Constructor for User_Policy_Agenda class.
        """
        self.max_turn = 40
        self.max_initiative = 4

        self.goal_generator = GoalGenerator(corpus_path=os.path.join(os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
            'data/multiwoz/annotated_user_da_with_span_full.json'))

        self.__turn = 0
        self.goal = None
        self.agenda = None

        Policy.__init__(self)
コード例 #7
0
    def __init__(self, goal_generator: GoalGenerator, seed=None):
        """
        create new Goal by random
        Args:
            goal_generator (GoalGenerator): Goal Gernerator.
        """
        self.domain_goals = goal_generator.get_user_goal(seed)

        self.domains = list(self.domain_goals['domain_ordering'])
        del self.domain_goals['domain_ordering']

        for domain in self.domains:
            if 'reqt' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['reqt'] = {slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']}

            if 'book' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['booked'] = DEF_VAL_UNK
コード例 #8
0
    def __init__(self, max_goal_num=100, seed=2019):
        """
        Constructor for User_Policy_Agenda class.
        """
        self.max_turn = 40
        self.max_initiative = 4

        self.goal_generator = GoalGenerator(corpus_path=os.path.join(os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
            'data/multiwoz/annotated_user_da_with_span_full.json'))

        self.__turn = 0
        self.goal = None
        self.agenda = None

        random.seed(seed)
        self.goal_seeds = [random.randint(1, 1e7) for i in range(max_goal_num)]

        Policy.__init__(self)
コード例 #9
0
class UserPolicyVHUS(Policy):
    def __init__(self):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(config['load'])

    def init_session(self):
        self.time_step = -1
        self.topic = 'NONE'
        self.goal = self.goal_gen.get_user_goal()
        self.goal_input = torch.LongTensor(
            self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
        self.goal_len_input = torch.LongTensor([len(self.goal_input)
                                                ]).squeeze()
        self.sys_da_id_stack = []  # to save sys da history

    def predict(self, state):
        """Predict an user act based on state and preorder system action.

        Args:
            state (tuple):
                Dialog state.
        Returns:
            usr_action (tuple):
                User act.
            session_over (boolean):
                True to terminate session, otherwise session continues.
        """
        sys_action = state['system_action']
        sys_seq_turn = self.manager.sysda2seq(
            self.manager.ref_data2stand(sys_action), self.goal)
        self.sys_da_id_stack += self.manager.get_sysda_id([sys_seq_turn])
        sys_seq_len = torch.LongTensor(
            [max(len(sen), 1) for sen in self.sys_da_id_stack])
        max_sen_len = sys_seq_len.max().item()
        sys_seq = torch.LongTensor(padding(self.sys_da_id_stack, max_sen_len))
        usr_a, terminal = self.user.select_action(self.goal_input,
                                                  self.goal_len_input, sys_seq,
                                                  sys_seq_len)
        usr_action = self.manager.usrseq2da(self.manager.id2sentence(usr_a),
                                            self.goal)

        return capital(usr_action), terminal

    def load(self, filename):
        user_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                filename + '_simulator.mdl')
        if os.path.exists(user_mdl):
            self.user.load_state_dict(torch.load(user_mdl))
            print('<<user simulator>> loaded checkpoint from file: {}'.format(
                user_mdl))
コード例 #10
0
ファイル: train.py プロジェクト: youngornever/tatk
import os
import json
import logging
import sys
root_dir = os.path.dirname(
    os.path.dirname(
        os.path.dirname(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)

from tatk.util.train_util import init_logging_handler
from tatk.task.multiwoz.goal_generator import GoalGenerator
from tatk.policy.vhus.multiwoz.usermanager import UserDataManager
from tatk.policy.vhus.train import VHUS_Trainer

if __name__ == '__main__':
    with open('config.json', 'r') as f:
        cfg = json.load(f)

    init_logging_handler(cfg['log_dir'])
    manager = UserDataManager()
    goal_gen = GoalGenerator()
    env = VHUS_Trainer(cfg, manager, goal_gen)

    logging.debug('start training')

    best = float('inf')
    for e in range(cfg['epoch']):
        env.imitating(e)
        best = env.imit_test(e, best)
コード例 #11
0
ファイル: vhus.py プロジェクト: aaa123git/tatk
class UserPolicyVHUS(Policy):
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'
    ):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for VHUS Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_simulator.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(config['load'])

    def init_session(self):
        self.time_step = -1
        self.topic = 'NONE'
        self.goal = self.goal_gen.get_user_goal()
        self.goal_input = torch.LongTensor(
            self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
        self.goal_len_input = torch.LongTensor([len(self.goal_input)
                                                ]).squeeze()
        self.sys_da_id_stack = []  # to save sys da history

    def predict(self, state):
        """Predict an user act based on state and preorder system action.

        Args:
            state (tuple):
                Dialog state.
        Returns:
            usr_action (tuple):
                User act.
            session_over (boolean):
                True to terminate session, otherwise session continues.
        """
        sys_action = state

        sys_seq_turn = self.manager.sysda2seq(
            self.manager.ref_data2stand(sys_action), self.goal)
        self.sys_da_id_stack += self.manager.get_sysda_id([sys_seq_turn])
        sys_seq_len = torch.LongTensor(
            [max(len(sen), 1) for sen in self.sys_da_id_stack])
        max_sen_len = sys_seq_len.max().item()
        sys_seq = torch.LongTensor(padding(self.sys_da_id_stack, max_sen_len))
        usr_a, terminal = self.user.select_action(self.goal_input,
                                                  self.goal_len_input, sys_seq,
                                                  sys_seq_len)
        usr_action = self.manager.usrseq2da(self.manager.id2sentence(usr_a),
                                            self.goal)

        return capital(usr_action), terminal

    def load(self, filename):
        user_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                filename + '_simulator.mdl')
        if os.path.exists(user_mdl):
            self.user.load_state_dict(torch.load(user_mdl))
            print('<<user simulator>> loaded checkpoint from file: {}'.format(
                user_mdl))

    def get_goal(self):
        return self.goal