def __init__(self): self.reader = MultiWozReader() if len(cfg.cuda_device) == 1: self.m = DAMD(self.reader) else: m = DAMD(self.reader) self.m = torch.nn.DataParallel(m, device_ids=cfg.cuda_device) # print(self.m.module) self.evaluator = MultiWozEvaluator(self.reader) # evaluator class if cfg.cuda: self.m = self.m.cuda() #cfg.cuda_device[0] self.optim = Adam(lr=cfg.lr, params=filter(lambda x: x.requires_grad, self.m.parameters()), weight_decay=5e-5) self.base_epoch = -1 if cfg.limit_bspn_vocab: self.reader.bspn_masks_tensor = {} for key, values in self.reader.bspn_masks.items(): v_ = cuda_(torch.Tensor(values).long()) self.reader.bspn_masks_tensor[key] = v_ if cfg.limit_aspn_vocab: self.reader.aspn_masks_tensor = {} for key, values in self.reader.aspn_masks.items(): v_ = cuda_(torch.Tensor(values).long()) self.reader.aspn_masks_tensor[key] = v_
def _convert_batch_para(self, py_batch, mode, prev_a_py=None): combined_input = [] combined_length = [] for prev_response, delex_user in zip(py_batch['pv_resp'], py_batch['usdx']): combined_input.append(prev_response + delex_user) combined_length.append(len(prev_response + delex_user)) u_input_np = pad_sequences(combined_input, cfg.max_nl_length, padding='post', truncating='pre').transpose((1, 0)) u_len = np.array(combined_length) delex_para_input_np = pad_sequences(py_batch['padx'], cfg.max_nl_length, padding='post', truncating='pre').transpose((1, 0)) u_input = cuda_(torch.from_numpy(u_input_np).long()) delex_para_input = cuda_(torch.from_numpy(delex_para_input_np).long()) if mode == 'test': if prev_a_py: for i in range(len(prev_a_py)): eob = self.reader.vocab.encode('<eos_a>') if eob in prev_a_py[i] and prev_a_py[i].index(eob) != len( prev_a_py[i]) - 1: idx = prev_a_py[i].index(eob) prev_a_py[i] = prev_a_py[i][:idx + 1] else: prev_a_py[i] = [eob] ''' for j, word in enumerate(prev_a_py[i]): if word >= cfg.vocab_size: prev_a_py[i][j] = 2 #unk ''' else: prev_a_py = py_batch['pv_aspn'] prev_dial_act_input_np = pad_sequences(prev_a_py, cfg.max_nl_length, padding='post', truncating='pre').transpose( (1, 0)) prev_dial_act_input = cuda_( torch.from_numpy(prev_dial_act_input_np).long()) else: prev_dial_act_input_np = pad_sequences(py_batch['pv_aspn'], cfg.max_nl_length, padding='post', truncating='pre').transpose( (1, 0)) prev_dial_act_input = cuda_( torch.from_numpy(prev_dial_act_input_np).long()) return u_input, u_input_np, delex_para_input, delex_para_input_np, u_len, prev_dial_act_input
def add_torch_input(self, inputs, mode='train', first_turn=False): need_onehot = [ 'user', 'usdx', 'bspn', 'aspn', 'pv_resp', 'pv_bspn', 'pv_aspn', 'dspn', 'pv_dspn', 'bsdx', 'pv_bsdx' ] inputs['db'] = cuda_(torch.from_numpy(inputs['db_np']).float()) for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']: if not cfg.enable_aspn and item == 'aspn': continue if not cfg.enable_bspn and item == 'bspn': continue if not cfg.enable_dspn and item == 'dspn': continue inputs[item] = cuda_( torch.from_numpy( inputs[item + '_unk_np']).long()) # replace oov to <unk> if item in ['user', 'usdx', 'resp', 'bspn']: inputs[item + '_nounk'] = cuda_( torch.from_numpy( inputs[item + '_np']).long()) # don't replace oov to <unk> else: inputs[item + '_nounk'] = inputs[item] # print(item, inputs[item].size()) if item in ['resp', 'bspn', 'aspn', 'bsdx', 'dspn']: if 'pv_' + item + '_unk_np' not in inputs: continue inputs['pv_' + item] = cuda_( torch.from_numpy(inputs['pv_' + item + '_unk_np']).long()) if item in ['user', 'usdx', 'bspn']: inputs['pv_' + item + '_nounk'] = cuda_( torch.from_numpy(inputs['pv_' + item + '_np']).long()) inputs[item + '_4loss'] = self.index_for_loss(item, inputs) else: inputs['pv_' + item + '_nounk'] = inputs['pv_' + item] inputs[item + '_4loss'] = inputs[item] if 'pv_' + item in need_onehot: inputs['pv_' + item + '_onehot'] = get_one_hot_input( inputs['pv_' + item + '_unk_np']) if item in need_onehot: inputs[item + '_onehot'] = get_one_hot_input(inputs[item + '_unk_np']) if cfg.multi_acts_training and 'aspn_aug_unk_np' in inputs: inputs['aspn_aug'] = cuda_( torch.from_numpy(inputs['aspn_aug_unk_np']).long()) inputs['aspn_aug_4loss'] = inputs['aspn_aug'] return inputs
def index_for_loss(self, item, inputs): raw_labels = inputs[item + '_np'] if item == 'bspn': copy_sources = [ inputs['user_np'], inputs['pv_resp_np'], inputs['pv_bspn_np'] ] elif item == 'bsdx': copy_sources = [ inputs['usdx_np'], inputs['pv_resp_np'], inputs['pv_bsdx_np'] ] elif item == 'aspn': copy_sources = [] if cfg.use_pvaspn: copy_sources.append(inputs['pv_aspn_np']) if cfg.enable_bspn: copy_sources.append(inputs[cfg.bspn_mode + '_np']) elif item == 'dspn': copy_sources = [inputs['pv_dspn_np']] elif item == 'resp': copy_sources = [inputs['usdx_np']] if cfg.enable_bspn: copy_sources.append(inputs[cfg.bspn_mode + '_np']) if cfg.enable_aspn: copy_sources.append(inputs['aspn_np']) else: return new_labels = np.copy(raw_labels) if copy_sources: bidx, tidx = np.where(raw_labels >= self.reader.vocab_size) copy_sources = np.concatenate(copy_sources, axis=1) for b in bidx: for t in tidx: oov_idx = raw_labels[b, t] if len(np.where(copy_sources[b, :] == oov_idx)[0]) == 0: new_labels[b, t] = 2 return cuda_(torch.from_numpy(new_labels).long())