예제 #1
0
 def imit_test(self, epoch, best):
     """
     provide an unbiased evaluation of the user simulator fit on the training dataset
     """        
     a_loss, t_loss = 0., 0.
     data_valid_iter = batch_iter(self.data_valid[0], self.data_valid[1], self.data_valid[2], self.data_valid[3])
     for i, data in enumerate(data_valid_iter):
         loss_a, loss_t = self.user_loop(data)
         a_loss += loss_a.item()
         t_loss += loss_t.item()
         
     a_loss /= i
     t_loss /= i
     logging.debug('<<user simulator>> validation, epoch {}, loss_a:{}, loss_t:{}'.format(epoch, a_loss, t_loss))
     loss = a_loss + t_loss
     if loss < best:
         logging.info('<<user simulator>> best model saved')
         best = loss
         self.save(self.save_dir, 'best')
         
     a_loss, t_loss = 0., 0.
     data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3])
     for i, data in enumerate(data_test_iter):
         loss_a, loss_t = self.user_loop(data)
         a_loss += loss_a.item()
         t_loss += loss_t.item()
         
     a_loss /= i
     t_loss /= i
     logging.debug('<<user simulator>> test, epoch {}, loss_a:{}, loss_t:{}'.format(epoch, a_loss, t_loss))
     return best
예제 #2
0
 def imitating(self, epoch):
     """
     train the user simulator by simple imitation learning (behavioral cloning)
     """
     self.user.train()
     a_loss, t_loss = 0., 0.
     data_train_iter = batch_iter(self.data_train[0], self.data_train[1], self.data_train[2], self.data_train[3])
     for i, data in enumerate(data_train_iter):
         self.optim.zero_grad()
         loss_a, loss_t = self.user_loop(data)
         a_loss += loss_a.item()
         t_loss += loss_t.item()
         loss = loss_a + loss_t
         loss.backward()
         self.optim.step()
         
         if (i+1) % self.print_per_batch == 0:
             a_loss /= self.print_per_batch
             t_loss /= self.print_per_batch
             logging.debug('<<user simulator>> epoch {}, iter {}, loss_a:{}, loss_t:{}'.format(epoch, i, a_loss, t_loss))
             a_loss, t_loss = 0., 0.
     
     if (epoch+1) % self.save_per_epoch == 0:
         self.save(self.save_dir, epoch)
     self.user.eval()
예제 #3
0
    def test(self):
        def sequential(da_seq):
            da = []
            cur_act = None
            for word in da_seq:
                if word in ['<PAD>', '<UNK>', '<SOS>', '<EOS>', '(', ')']:
                    continue
                if '-' in word:
                    cur_act = word
                else:
                    if cur_act is None:
                        continue
                    da.append(cur_act+'-'+word)
            return da
            
        def f1(pred, real):
            if not real:
                return 0, 0, 0
            TP, FP, FN = 0, 0, 0
            for item in real:
                if item in pred:
                    TP += 1
                else:
                    FN += 1
            for item in pred:
                if item not in real:
                    FP += 1
            return TP, FP, FN
    
        data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3])
        a_TP, a_FP, a_FN, t_corr, t_tot = 0, 0, 0, 0, 0
        eos_id = self.user.usr_decoder.eos_id
        for i, data in enumerate(data_test_iter):
            batch_input = to_device(padding_data(data))
            a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])
            usr_a = []
            for a_weight in a_weights:
                usr_a.append(a_weight.argmax(1).cpu().numpy())
            usr_a = np.array(usr_a).T.tolist()
            a = []
            for ua in usr_a:
                if eos_id in ua:
                    ua = ua[:ua.index(eos_id)]
                a.append(sequential(self.manager.id2sentence(ua)))
            targets_a = []
            for ua_sess in data[1]:
                for ua in ua_sess:
                    targets_a.append(sequential(self.manager.id2sentence(ua[1:-1])))
            TP, FP, FN = f1(a, targets_a)
            a_TP += TP
            a_FP += FP
            a_FN += FN
                    
            t = t_weights.ge(0).cpu().tolist()
            targets_t = batch_input['terminal'].cpu().long().tolist()
            judge = np.array(t) == np.array(targets_t)
            t_corr += judge.sum()
            t_tot += judge.size

        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
        print(t_corr, t_tot, t_corr/t_tot)
예제 #4
0
        sys_seq_len = torch.LongTensor([max(len(sen), 1) for sen in self.sys_da_id_stack])
        max_sen_len = sys_seq_len.max().item()
        sys_seq = torch.LongTensor(padding(self.sys_da_id_stack, max_sen_len))
        usr_a, terminal = self.user.select_action(self.goal_input, self.goal_len_input, sys_seq, sys_seq_len)
        usr_action = self.manager.usrseq2da(self.manager.id2sentence(usr_a), self.goal)
        
        return capital(usr_action), terminal

if __name__ == '__main__':
    manager = UserDataManager('../../../../data/multiwoz', 'annotated_user_da_with_span_full.json')
    seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
    train_goals, train_usrdas, train_sysdas, \
    test_goals, test_usrdas, test_sysdas, \
    val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
        seq_goals, seq_usr_dass, seq_sys_dass)
    data_train = batch_iter(train_goals, train_usrdas, train_sysdas, 32)
    data_valid = batch_iter(val_goals, val_usrdas, val_sysdas, 32)
    data_test = batch_iter(test_goals, test_usrdas, test_sysdas, 32)
    for data in data_train:
        batch_input = to_device(padding_data(data))
        break
    for k, v in batch_input.items():
        print(k, v.shape)
    voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
    cfg = MultiWozConfig()
    user = VHUS(cfg, voc_goal_size, voc_usr_size, voc_sys_size)
    
    a_weights, t_weights, _ = user(batch_input['goals'], batch_input['goals_length'], batch_input['posts'], batch_input['posts_length'])#, batch_input['origin_responses'])
    print(len(a_weights)) #[L, B, V]
    print(t_weights.shape) #[B]