def decode_value(model, utterance, class_string, memory, cuda): sent_lis = process_sent(utterance) if len(sent_lis) == 0: return [] data, lengths, extra_zeros, enc_batch_extend_vocab_idx, oov_list = \ ValueDataset.data_info(utterance, memory, cuda) act_inputs, act_slot_pairs, values_inp, values_out = \ ValueDataset.label_info(class_string, memory, oov_list, cuda) # Model processing ## encoder outputs, hiddens = model.encoder(data, lengths) h_T = hiddens[0].transpose(0, 1).contiguous().view(-1, model.enc_hid_all_dim) ## value decoder s_decoder = model.enc_to_dec(hiddens) s_t_1 = s_decoder act_slot_ids = act_slot_pairs[0] y_t = torch.tensor([Constants.BOS]).view(1, 1) if cuda: y_t = y_t.cuda() value_ids = beam_search(model.value_decoder, act_slot_ids, extra_zeros,enc_batch_extend_vocab_idx, s_decoder, outputs, lengths, len(memory['dec2idx']), cuda )[1:-1] value_lis = [] for vid in value_ids: if vid < len(memory['idx2dec']): value_lis.append(memory['idx2dec'][vid]) else: value_lis.append(oov_list[vid - len(memory['idx2dec'])]) values = [' '.join(value_lis)] slot = memory['idx2slot'][act_slot_pairs[0][0,1].item()] value = correct_value(slot, values[0]) if value is None: return [] values = [value] return values
def decode_value(model, cnet, class_string, memory, cuda): result = process_cn_example(cnet, memory['enc2idx']) if result is None: return [] data, lengths, extra_zeros, enc_batch_extend_vocab_idx, oov_list = \ ValueDataset.data_info(cnet, memory, cuda) act_inputs, act_slot_pairs, values_inp, values_out = \ ValueDataset.label_info(class_string, memory, oov_list, cuda) # Model processing ## encoder outputs, hiddens = model.encoder(data, lengths) h_T = hiddens[0] ## value decoder s_decoder = model.enc_to_dec(hiddens) s_t_1 = s_decoder act_slot_ids = act_slot_pairs[0] y_t = torch.tensor([Constants.BOS]).view(1, 1) if cuda: y_t = y_t.cuda() value_ids = beam_search(model.value_decoder, act_slot_ids, extra_zeros, enc_batch_extend_vocab_idx, s_decoder, outputs, lengths, len(memory['dec2idx']), cuda)[1:-1] value_lis = [] for vid in value_ids: if vid < len(memory['idx2dec']): value_lis.append(memory['idx2dec'][vid]) else: value_lis.append(oov_list[vid - len(memory['idx2dec'])]) values = [' '.join(value_lis)] return values