def __getitem__(self, idx): """ :arg idx: int :returns input_token_list: list[int] output_token_list: list[int] token_mapping: binary array""" if ' >>><<< ' in self.data[idx]: input_token_list = self.data[idx].split(' >>><<< ')[0].split() output_token_list = self.data[idx].split(' >>><<< ')[1].split() else: input_token_list = self.data[idx].split() output_token_list = self.data[idx].split() #with open(self.data_path + self.files[idx], "r", encoding='utf-8') as pair_file: # input_token_list = pair_file.readline().split() # output_token_list = pair_file.readline().split() input_token_list = (['<SOS>'] + input_token_list + ['<EOS>'])[:self.maxlen] output_token_list = (['<SOS>'] + output_token_list + ['<EOS>'])[:self.maxlen] input_seq = tokens_to_seq(input_token_list, self.lang.tok_to_idx, self.maxlen, self.use_extended_vocab) output_seq = tokens_to_seq(output_token_list, self.lang.tok_to_idx, self.maxlen, self.use_extended_vocab, input_tokens=input_token_list) if self.use_cuda: input_seq = input_seq.cuda() output_seq = output_seq.cuda() return input_seq, output_seq, ' '.join(input_token_list), ' '.join(output_token_list)
def __getitem__(self, idx): """ :arg idx: int :returns input_token_list: list[int] output_token_list: list[int] token_mapping: binary array""" data_pair = self.data[idx] # Add in the start and end of sentence and chop at the max length input_token_list = (['<SOS>'] + data_pair[0] + ['<EOS>'])[:self.maxlen] output_token_list = (['<SOS>'] + data_pair[1] + ['<EOS>'])[:self.maxlen] # Turn the words to tokens input_seq = tokens_to_seq(input_token_list, self.lang.tok_to_idx, self.maxlen, self.use_extended_vocab) output_seq = tokens_to_seq(output_token_list, self.lang.tok_to_idx, self.maxlen, self.use_extended_vocab, input_tokens=input_token_list) return input_seq, output_seq, ' '.join(input_token_list), ' '.join( output_token_list)
def __getitem__(self, idx): """ :arg idx: int :returns input_token_list: list[int] output_token_list: list[int] token_mapping: binary array""" data_pair = self.data[idx] input_token_list = (['<SOS>'] + data_pair[0] + ['<EOS>'])[:self.maxlen] output_token_list = (['<SOS>'] + data_pair[1] + ['<EOS>'])[:self.maxlen] input_seq = tokens_to_seq(input_token_list, self.lang.tok_to_idx, self.maxlen, self.use_extended_vocab) output_seq = tokens_to_seq(output_token_list, self.lang.tok_to_idx, self.maxlen, self.use_extended_vocab, input_tokens=input_token_list) return input_seq, output_seq, ' '.join(input_token_list), ' '.join( output_token_list)
def get_response(self, input_string): use_extended_vocab = isinstance(self.decoder, CopyNetDecoder) if not hasattr(self, 'parser_'): self.parser_ = English() idx_to_tok = self.lang.idx_to_tok tok_to_idx = self.lang.tok_to_idx input_tokens = self.parser_(' '.join(input_string.split())) input_tokens = ['<SOS>' ] + [token.orth_.lower() for token in input_tokens] + ['<EOS>'] input_seq = tokens_to_seq(input_tokens, tok_to_idx, len(input_tokens), use_extended_vocab) input_variable = Variable(input_seq).view(1, -1) if next(self.parameters()).is_cuda: input_variable = input_variable.cuda() outputs, idxs = self.forward(input_variable, [len(input_seq)]) idxs = idxs.data.view(-1) eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs) output_string = seq_to_string(idxs[:eos_idx + 1], idx_to_tok, input_tokens=input_tokens) return output_string