def generateResults(encoder_decoder: EncoderDecoder, data_loader, resultFilename, input_tokens_list): idx_to_tok = encoder_decoder.lang.idx_to_tok all_output_seqs = [] all_target_seqs = [] for batch_idx, (input_idxs, target_idxs, _, _) in enumerate(tqdm(data_loader)): input_lengths = (input_idxs != 0).long().sum(dim=1) sorted_lengths, order = torch.sort(input_lengths, descending=True) input_variable = Variable(input_idxs[order, :][:, :max(input_lengths)]) target_variable = Variable(target_idxs[order, :]) output_log_probs, output_seqs = encoder_decoder( input_variable, list(sorted_lengths)) print(output_seqs.size()) all_output_seqs.extend(trim_seqs(output_seqs)) all_target_seqs.extend([list(seq[seq > 0])] for seq in to_np(target_variable)) with open(resultFilename, 'w') as fo: for seq, input_tokens in zip(all_output_seqs, input_tokens_list): print(type(seq)) #seq = seq.data.view(-1) eos_idx = seq.index(2) if 2 in seq else seq string = seq_to_string(seq[:eos_idx + 1], idx_to_tok, input_tokens=None) fo.write(string + '\n') return None
def get_response(self, input_string): use_extended_vocab = isinstance(self.decoder, CopyNetDecoder) if not hasattr(self, 'parser_'): self.parser_ = English() idx_to_tok = self.lang.idx_to_tok tok_to_idx = self.lang.tok_to_idx input_tokens = self.parser_(' '.join(input_string.split())) input_tokens = ['<SOS>' ] + [token.orth_.lower() for token in input_tokens] + ['<EOS>'] input_seq = tokens_to_seq(input_tokens, tok_to_idx, len(input_tokens), use_extended_vocab) input_variable = Variable(input_seq).view(1, -1) if next(self.parameters()).is_cuda: input_variable = input_variable.cuda() outputs, idxs = self.forward(input_variable, [len(input_seq)]) idxs = idxs.data.view(-1) eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs) output_string = seq_to_string(idxs[:eos_idx + 1], idx_to_tok, input_tokens=input_tokens) return output_string
def singleOutput(input_seq, encoder_decoder: EncoderDecoder, input_tokens=None): idx_to_tok = encoder_decoder.lang.idx_to_tok #if input_tokens is not None: # input_string = ' '.join(input_tokens) #else: # input_string = seq_to_string(input_seq, idx_to_tok) lengths = ((input_seq != 0).long().sum(dim=0)).unsqueeze(0) input_variable = Variable(input_seq).view(1, -1) outputs, idxs = encoder_decoder(input_variable, lengths) idxs = idxs.data.view(-1) eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs) string = seq_to_string(idxs[:eos_idx + 1], idx_to_tok, input_tokens=input_tokens) #print('>', input_string, flush=True) #print('<', string, '\n', flush=True) return string.strip()
def print_output(input_seq, encoder_decoder: EncoderDecoder, input_tokens=None, target_tokens=None, target_seq=None): idx_to_tok = encoder_decoder.lang.idx_to_tok if input_tokens is not None: input_string = ' '.join(input_tokens) else: input_string = seq_to_string(input_seq, idx_to_tok) lengths = list((input_seq != 0).long().sum(dim=0)) input_variable = Variable(input_seq).view(1, -1) target_variable = Variable(target_seq).view(1, -1) if target_tokens is not None: target_string = ' '.join(target_tokens) elif target_seq is not None: target_string = seq_to_string(target_seq, idx_to_tok, input_tokens=input_tokens) else: target_string = '' if target_seq is not None: target_eos_idx = list(target_seq).index(2) if 2 in list( target_seq) else len(target_seq) target_outputs, _ = encoder_decoder(input_variable, lengths, targets=target_variable, teacher_forcing=1.0) target_log_prob = sum([ target_outputs[0, step_idx, target_idx] for step_idx, target_idx in enumerate(target_seq[:target_eos_idx + 1]) ]) outputs, idxs = encoder_decoder(input_variable, lengths) idxs = idxs.data.view(-1) eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs) string = seq_to_string(idxs[:eos_idx + 1], idx_to_tok, input_tokens=input_tokens) log_prob = sum([ outputs[0, step_idx, idx] for step_idx, idx in enumerate(idxs[:eos_idx + 1]) ]) print('>', input_string, '\n', flush=True) if target_seq is not None: print('=', target_string, flush=True) print('<', string, flush=True) print('\n') if target_seq is not None: print('target log prob:', float(target_log_prob)) print('output log prob:', float(log_prob)) print('-' * 100, '\n') return idxs