Пример #1
0
 def __init__(self, pretrain=False):
 
     config = MultiWozConfig()
     manager = UserDataManager(config.data_dir, config.data_file)
     voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
     self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE)
     self.optim = optim.Adam(self.user.parameters(), lr=config.lr_simu)
     self.goal_gen = GoalGenerator(config.data_dir+'/goal/goal_model.pkl')
     self.cfg = config
     self.manager = manager
     self.user.eval()
     
     if pretrain:
         self.print_per_batch = config.print_per_batch
         self.save_dir = config.save_dir
         self.save_per_epoch = config.save_per_epoch
         seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
         train_goals, train_usrdas, train_sysdas, \
         test_goals, test_usrdas, test_sysdas, \
         val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
             seq_goals, seq_usr_dass, seq_sys_dass)
         self.data_train = (train_goals, train_usrdas, train_sysdas, config.batchsz)
         self.data_valid = (val_goals, val_usrdas, val_sysdas, config.batchsz)
         self.data_test = (test_goals, test_usrdas, test_sysdas, config.batchsz)
         self.nll_loss = nn.NLLLoss(ignore_index=0) # PAD=0
         self.bce_loss = nn.BCEWithLogitsLoss()
     else:
         self.load(config.load)
Пример #2
0
            usr_action (tuple): User act.
            session_over (boolean): True to terminate session, otherwise session continues.
            reward (float): Reward given by user.
        """
        sys_seq_turn = self.manager.sysda2seq(self.manager.ref_data2stand(sys_action), self.goal)
        self.sys_da_id_stack += self.manager.get_sysda_id([sys_seq_turn])
        sys_seq_len = torch.LongTensor([max(len(sen), 1) for sen in self.sys_da_id_stack])
        max_sen_len = sys_seq_len.max().item()
        sys_seq = torch.LongTensor(padding(self.sys_da_id_stack, max_sen_len))
        usr_a, terminal = self.user.select_action(self.goal_input, self.goal_len_input, sys_seq, sys_seq_len)
        usr_action = self.manager.usrseq2da(self.manager.id2sentence(usr_a), self.goal)
        
        return capital(usr_action), terminal

if __name__ == '__main__':
    manager = UserDataManager('../../../../data/multiwoz', 'annotated_user_da_with_span_full.json')
    seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
    train_goals, train_usrdas, train_sysdas, \
    test_goals, test_usrdas, test_sysdas, \
    val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
        seq_goals, seq_usr_dass, seq_sys_dass)
    data_train = batch_iter(train_goals, train_usrdas, train_sysdas, 32)
    data_valid = batch_iter(val_goals, val_usrdas, val_sysdas, 32)
    data_test = batch_iter(test_goals, test_usrdas, test_sysdas, 32)
    for data in data_train:
        batch_input = to_device(padding_data(data))
        break
    for k, v in batch_input.items():
        print(k, v.shape)
    voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
    cfg = MultiWozConfig()