Example #1
0
    def __init__(self, config, manager, goal_gen):

        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = goal_gen
        self.manager = manager

        self.print_per_batch = config['print_per_batch']
        self.save_dir = config['save_dir']
        self.save_per_epoch = config['save_per_epoch']
        seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
        train_goals, train_usrdas, train_sysdas, \
        test_goals, test_usrdas, test_sysdas, \
        val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
            seq_goals, seq_usr_dass, seq_sys_dass)
        self.data_train = (train_goals, train_usrdas, train_sysdas,
                           config['batchsz'])
        self.data_valid = (val_goals, val_usrdas, val_sysdas,
                           config['batchsz'])
        self.data_test = (test_goals, test_usrdas, test_sysdas,
                          config['batchsz'])
        self.alpha = config['alpha']
        self.optim = torch.optim.Adam(self.user.parameters(), lr=config['lr'])
        self.nll_loss = nn.NLLLoss(ignore_index=0)  # PAD=0
        self.bce_loss = nn.BCEWithLogitsLoss()
Example #2
0
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_camrest.zip'
    ):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for VHUS Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_simulator.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(config['load'])
Example #3
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'):
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(archive_file, model_file, config['load'])
Example #4
0
    def __init__(self):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(config['load'])
Example #5
0
class UserPolicyVHUS(UserPolicyVHUSAbstract):
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'
    ):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(archive_file, model_file, config['load'])

    def predict(self, state):
        usr_action, terminal = super().predict(state)
        return capital(usr_action), terminal
Example #6
0
class UserPolicyVHUS(Policy):
    def __init__(self):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(config['load'])

    def init_session(self):
        self.time_step = -1
        self.topic = 'NONE'
        self.goal = self.goal_gen.get_user_goal()
        self.goal_input = torch.LongTensor(
            self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
        self.goal_len_input = torch.LongTensor([len(self.goal_input)
                                                ]).squeeze()
        self.sys_da_id_stack = []  # to save sys da history

    def predict(self, state):
        """Predict an user act based on state and preorder system action.

        Args:
            state (tuple):
                Dialog state.
        Returns:
            usr_action (tuple):
                User act.
            session_over (boolean):
                True to terminate session, otherwise session continues.
        """
        sys_action = state['system_action']
        sys_seq_turn = self.manager.sysda2seq(
            self.manager.ref_data2stand(sys_action), self.goal)
        self.sys_da_id_stack += self.manager.get_sysda_id([sys_seq_turn])
        sys_seq_len = torch.LongTensor(
            [max(len(sen), 1) for sen in self.sys_da_id_stack])
        max_sen_len = sys_seq_len.max().item()
        sys_seq = torch.LongTensor(padding(self.sys_da_id_stack, max_sen_len))
        usr_a, terminal = self.user.select_action(self.goal_input,
                                                  self.goal_len_input, sys_seq,
                                                  sys_seq_len)
        usr_action = self.manager.usrseq2da(self.manager.id2sentence(usr_a),
                                            self.goal)

        return capital(usr_action), terminal

    def load(self, filename):
        user_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                filename + '_simulator.mdl')
        if os.path.exists(user_mdl):
            self.user.load_state_dict(torch.load(user_mdl))
            print('<<user simulator>> loaded checkpoint from file: {}'.format(
                user_mdl))
Example #7
0
class VHUS_Trainer():
    def __init__(self, config, manager, goal_gen):
        
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE)
        self.goal_gen = goal_gen
        self.manager = manager
        
        self.print_per_batch = config['print_per_batch']
        self.save_dir = config['save_dir']
        self.save_per_epoch = config['save_per_epoch']
        seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
        train_goals, train_usrdas, train_sysdas, \
        test_goals, test_usrdas, test_sysdas, \
        val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
            seq_goals, seq_usr_dass, seq_sys_dass)
        self.data_train = (train_goals, train_usrdas, train_sysdas, config['batchsz'])
        self.data_valid = (val_goals, val_usrdas, val_sysdas, config['batchsz'])
        self.data_test = (test_goals, test_usrdas, test_sysdas, config['batchsz'])
        self.alpha = config['alpha']
        self.optim = torch.optim.Adam(self.user.parameters(), lr=config['lr'])
        self.nll_loss = nn.NLLLoss(ignore_index=0) # PAD=0
        self.bce_loss = nn.BCEWithLogitsLoss()        
            
    def user_loop(self, data):
        batch_input = to_device(padding_data(data))
        a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])
        
        loss_a, targets_a = 0, batch_input['origin_responses'][:, 1:] # remove sos_id
        for i, a_weight in enumerate(a_weights):
            loss_a += self.nll_loss(a_weight, targets_a[:, i])
        loss_a /= i
        loss_t = self.bce_loss(t_weights, batch_input['terminal'])
        loss_a += self.alpha * kl_gaussian(argu)
        return loss_a, loss_t
        
    def imitating(self, epoch):
        """
        train the user simulator by simple imitation learning (behavioral cloning)
        """
        self.user.train()
        a_loss, t_loss = 0., 0.
        data_train_iter = batch_iter(self.data_train[0], self.data_train[1], self.data_train[2], self.data_train[3])
        for i, data in enumerate(data_train_iter):
            self.optim.zero_grad()
            loss_a, loss_t = self.user_loop(data)
            a_loss += loss_a.item()
            t_loss += loss_t.item()
            loss = loss_a + loss_t
            loss.backward()
            self.optim.step()
            
            if (i+1) % self.print_per_batch == 0:
                a_loss /= self.print_per_batch
                t_loss /= self.print_per_batch
                logging.debug('<<user simulator>> epoch {}, iter {}, loss_a:{}, loss_t:{}'.format(epoch, i, a_loss, t_loss))
                a_loss, t_loss = 0., 0.
        
        if (epoch+1) % self.save_per_epoch == 0:
            self.save(self.save_dir, epoch)
        self.user.eval()
        
    def imit_test(self, epoch, best):
        """
        provide an unbiased evaluation of the user simulator fit on the training dataset
        """        
        a_loss, t_loss = 0., 0.
        data_valid_iter = batch_iter(self.data_valid[0], self.data_valid[1], self.data_valid[2], self.data_valid[3])
        for i, data in enumerate(data_valid_iter):
            loss_a, loss_t = self.user_loop(data)
            a_loss += loss_a.item()
            t_loss += loss_t.item()
            
        a_loss /= i+1
        t_loss /= i+1
        logging.debug('<<user simulator>> validation, epoch {}, loss_a:{}, loss_t:{}'.format(epoch, a_loss, t_loss))
        loss = a_loss + t_loss
        if loss < best:
            logging.info('<<user simulator>> best model saved')
            best = loss
            self.save(self.save_dir, 'best')
            
        a_loss, t_loss = 0., 0.
        data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3])
        for i, data in enumerate(data_test_iter):
            loss_a, loss_t = self.user_loop(data)
            a_loss += loss_a.item()
            t_loss += loss_t.item()
            
        a_loss /= i+1
        t_loss /= i+1
        logging.debug('<<user simulator>> test, epoch {}, loss_a:{}, loss_t:{}'.format(epoch, a_loss, t_loss))
        return best
		
    def test(self):
        def sequential(da_seq):
            da = []
            cur_act = None
            for word in da_seq:
                if word in ['<PAD>', '<UNK>', '<SOS>', '<EOS>', '(', ')']:
                    continue
                if '-' in word:
                    cur_act = word
                else:
                    if cur_act is None:
                        continue
                    da.append(cur_act+'-'+word)
            return da
            
        def f1(pred, real):
            if not real:
                return 0, 0, 0
            TP, FP, FN = 0, 0, 0
            for item in real:
                if item in pred:
                    TP += 1
                else:
                    FN += 1
            for item in pred:
                if item not in real:
                    FP += 1
            return TP, FP, FN
    
        data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3])
        a_TP, a_FP, a_FN, t_corr, t_tot = 0, 0, 0, 0, 0
        eos_id = self.user.usr_decoder.eos_id
        for i, data in enumerate(data_test_iter):
            batch_input = to_device(padding_data(data))
            a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])
            usr_a = []
            for a_weight in a_weights:
                usr_a.append(a_weight.argmax(1).cpu().numpy())
            usr_a = np.array(usr_a).T.tolist()
            a = []
            for ua in usr_a:
                if eos_id in ua:
                    ua = ua[:ua.index(eos_id)]
                a.append(sequential(self.manager.id2sentence(ua)))
            targets_a = []
            for ua_sess in data[1]:
                for ua in ua_sess:
                    targets_a.append(sequential(self.manager.id2sentence(ua[1:-1])))
            TP, FP, FN = f1(a, targets_a)
            a_TP += TP
            a_FP += FP
            a_FN += FN
                    
            t = t_weights.ge(0).cpu().tolist()
            targets_t = batch_input['terminal'].cpu().long().tolist()
            judge = np.array(t) == np.array(targets_t)
            t_corr += judge.sum()
            t_tot += judge.size

        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
        print(t_corr, t_tot, t_corr/t_tot)
        
    def save(self, directory, epoch):
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        torch.save(self.user.state_dict(), directory + '/' + str(epoch) + '_simulator.mdl')
        logging.info('<<user simulator>> epoch {}: saved network to mdl'.format(epoch))
Example #8
0
class UserPolicyVHUS(Policy):
    def __init__(
        self,
        archive_file=DEFAULT_ARCHIVE_FILE,
        model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_camrest.zip'
    ):
        with open(
                os.path.join(os.path.dirname(os.path.abspath(__file__)),
                             'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for VHUS Policy is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'save')
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        if not os.path.exists(os.path.join(model_dir, 'best_simulator.mdl')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)
        self.load(config['load'])

    def init_session(self):
        self.time_step = -1
        self.topic = 'NONE'
        self.goal = self.goal_gen.get_user_goal()
        self.goal_input = torch.LongTensor(
            self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
        self.goal_len_input = torch.LongTensor([len(self.goal_input)
                                                ]).squeeze()
        self.sys_da_id_stack = []  # to save sys da history

    def predict(self, state):
        """Predict an user act based on state and preorder system action.

        Args:
            state (tuple):
                Dialog state.
        Returns:
            usr_action (tuple):
                User act.
            session_over (boolean):
                True to terminate session, otherwise session continues.
        """
        sys_action = state

        sys_seq_turn = self.manager.sysda2seq(
            self.manager.ref_data2stand(sys_action), self.goal)
        self.sys_da_id_stack += self.manager.get_sysda_id([sys_seq_turn])
        sys_seq_len = torch.LongTensor(
            [max(len(sen), 1) for sen in self.sys_da_id_stack])
        max_sen_len = sys_seq_len.max().item()
        sys_seq = torch.LongTensor(padding(self.sys_da_id_stack, max_sen_len))
        usr_a, terminal = self.user.select_action(self.goal_input,
                                                  self.goal_len_input, sys_seq,
                                                  sys_seq_len)
        usr_action = self.manager.usrseq2da(self.manager.id2sentence(usr_a),
                                            self.goal)

        return usr_action, terminal

    def load(self, filename):
        user_mdl = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                filename + '_simulator.mdl')
        if os.path.exists(user_mdl):
            self.user.load_state_dict(torch.load(user_mdl))
            print('<<user simulator>> loaded checkpoint from file: {}'.format(
                user_mdl))