Beispiel #1
0
    def __init__(self):
        self.reader = MultiWozReader()
        if len(cfg.cuda_device) == 1:
            self.m = DAMD(self.reader)
        else:
            m = DAMD(self.reader)
            self.m = torch.nn.DataParallel(m, device_ids=cfg.cuda_device)
            # print(self.m.module)
        self.evaluator = MultiWozEvaluator(self.reader)  # evaluator class
        if cfg.cuda: self.m = self.m.cuda()  #cfg.cuda_device[0]
        self.optim = Adam(lr=cfg.lr,
                          params=filter(lambda x: x.requires_grad,
                                        self.m.parameters()),
                          weight_decay=5e-5)
        self.base_epoch = -1

        if cfg.limit_bspn_vocab:
            self.reader.bspn_masks_tensor = {}
            for key, values in self.reader.bspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.bspn_masks_tensor[key] = v_
        if cfg.limit_aspn_vocab:
            self.reader.aspn_masks_tensor = {}
            for key, values in self.reader.aspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.aspn_masks_tensor[key] = v_
Beispiel #2
0
    def _convert_batch_para(self, py_batch, mode, prev_a_py=None):
        combined_input = []
        combined_length = []
        for prev_response, delex_user in zip(py_batch['pv_resp'],
                                             py_batch['usdx']):
            combined_input.append(prev_response + delex_user)
            combined_length.append(len(prev_response + delex_user))
        u_input_np = pad_sequences(combined_input,
                                   cfg.max_nl_length,
                                   padding='post',
                                   truncating='pre').transpose((1, 0))
        u_len = np.array(combined_length)
        delex_para_input_np = pad_sequences(py_batch['padx'],
                                            cfg.max_nl_length,
                                            padding='post',
                                            truncating='pre').transpose((1, 0))
        u_input = cuda_(torch.from_numpy(u_input_np).long())
        delex_para_input = cuda_(torch.from_numpy(delex_para_input_np).long())
        if mode == 'test':
            if prev_a_py:
                for i in range(len(prev_a_py)):
                    eob = self.reader.vocab.encode('<eos_a>')
                    if eob in prev_a_py[i] and prev_a_py[i].index(eob) != len(
                            prev_a_py[i]) - 1:
                        idx = prev_a_py[i].index(eob)
                        prev_a_py[i] = prev_a_py[i][:idx + 1]
                    else:
                        prev_a_py[i] = [eob]
                    '''
                    for j, word in enumerate(prev_a_py[i]):
                        if word >= cfg.vocab_size:
                            prev_a_py[i][j] = 2 #unk
                    '''
            else:
                prev_a_py = py_batch['pv_aspn']
            prev_dial_act_input_np = pad_sequences(prev_a_py,
                                                   cfg.max_nl_length,
                                                   padding='post',
                                                   truncating='pre').transpose(
                                                       (1, 0))
            prev_dial_act_input = cuda_(
                torch.from_numpy(prev_dial_act_input_np).long())
        else:
            prev_dial_act_input_np = pad_sequences(py_batch['pv_aspn'],
                                                   cfg.max_nl_length,
                                                   padding='post',
                                                   truncating='pre').transpose(
                                                       (1, 0))
            prev_dial_act_input = cuda_(
                torch.from_numpy(prev_dial_act_input_np).long())

        return u_input, u_input_np, delex_para_input, delex_para_input_np, u_len, prev_dial_act_input
Beispiel #3
0
    def add_torch_input(self, inputs, mode='train', first_turn=False):
        need_onehot = [
            'user', 'usdx', 'bspn', 'aspn', 'pv_resp', 'pv_bspn', 'pv_aspn',
            'dspn', 'pv_dspn', 'bsdx', 'pv_bsdx'
        ]
        inputs['db'] = cuda_(torch.from_numpy(inputs['db_np']).float())
        for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']:
            if not cfg.enable_aspn and item == 'aspn':
                continue
            if not cfg.enable_bspn and item == 'bspn':
                continue
            if not cfg.enable_dspn and item == 'dspn':
                continue
            inputs[item] = cuda_(
                torch.from_numpy(
                    inputs[item + '_unk_np']).long())  # replace oov to <unk>
            if item in ['user', 'usdx', 'resp', 'bspn']:
                inputs[item + '_nounk'] = cuda_(
                    torch.from_numpy(
                        inputs[item +
                               '_np']).long())  # don't replace oov to <unk>
            else:
                inputs[item + '_nounk'] = inputs[item]
            # print(item, inputs[item].size())
            if item in ['resp', 'bspn', 'aspn', 'bsdx', 'dspn']:
                if 'pv_' + item + '_unk_np' not in inputs:
                    continue
                inputs['pv_' + item] = cuda_(
                    torch.from_numpy(inputs['pv_' + item + '_unk_np']).long())
                if item in ['user', 'usdx', 'bspn']:
                    inputs['pv_' + item + '_nounk'] = cuda_(
                        torch.from_numpy(inputs['pv_' + item + '_np']).long())
                    inputs[item + '_4loss'] = self.index_for_loss(item, inputs)
                else:
                    inputs['pv_' + item + '_nounk'] = inputs['pv_' + item]
                    inputs[item + '_4loss'] = inputs[item]
                if 'pv_' + item in need_onehot:
                    inputs['pv_' + item + '_onehot'] = get_one_hot_input(
                        inputs['pv_' + item + '_unk_np'])
            if item in need_onehot:
                inputs[item + '_onehot'] = get_one_hot_input(inputs[item +
                                                                    '_unk_np'])

        if cfg.multi_acts_training and 'aspn_aug_unk_np' in inputs:
            inputs['aspn_aug'] = cuda_(
                torch.from_numpy(inputs['aspn_aug_unk_np']).long())
            inputs['aspn_aug_4loss'] = inputs['aspn_aug']

        return inputs
Beispiel #4
0
 def index_for_loss(self, item, inputs):
     raw_labels = inputs[item + '_np']
     if item == 'bspn':
         copy_sources = [
             inputs['user_np'], inputs['pv_resp_np'], inputs['pv_bspn_np']
         ]
     elif item == 'bsdx':
         copy_sources = [
             inputs['usdx_np'], inputs['pv_resp_np'], inputs['pv_bsdx_np']
         ]
     elif item == 'aspn':
         copy_sources = []
         if cfg.use_pvaspn:
             copy_sources.append(inputs['pv_aspn_np'])
         if cfg.enable_bspn:
             copy_sources.append(inputs[cfg.bspn_mode + '_np'])
     elif item == 'dspn':
         copy_sources = [inputs['pv_dspn_np']]
     elif item == 'resp':
         copy_sources = [inputs['usdx_np']]
         if cfg.enable_bspn:
             copy_sources.append(inputs[cfg.bspn_mode + '_np'])
         if cfg.enable_aspn:
             copy_sources.append(inputs['aspn_np'])
     else:
         return
     new_labels = np.copy(raw_labels)
     if copy_sources:
         bidx, tidx = np.where(raw_labels >= self.reader.vocab_size)
         copy_sources = np.concatenate(copy_sources, axis=1)
         for b in bidx:
             for t in tidx:
                 oov_idx = raw_labels[b, t]
                 if len(np.where(copy_sources[b, :] == oov_idx)[0]) == 0:
                     new_labels[b, t] = 2
     return cuda_(torch.from_numpy(new_labels).long())