示例#1
0
def generateResults(encoder_decoder: EncoderDecoder, data_loader,
                    resultFilename, input_tokens_list):
    idx_to_tok = encoder_decoder.lang.idx_to_tok
    all_output_seqs = []
    all_target_seqs = []

    for batch_idx, (input_idxs, target_idxs, _,
                    _) in enumerate(tqdm(data_loader)):
        input_lengths = (input_idxs != 0).long().sum(dim=1)

        sorted_lengths, order = torch.sort(input_lengths, descending=True)
        input_variable = Variable(input_idxs[order, :][:, :max(input_lengths)])
        target_variable = Variable(target_idxs[order, :])

        output_log_probs, output_seqs = encoder_decoder(
            input_variable, list(sorted_lengths))
        print(output_seqs.size())
        all_output_seqs.extend(trim_seqs(output_seqs))
        all_target_seqs.extend([list(seq[seq > 0])]
                               for seq in to_np(target_variable))

    with open(resultFilename, 'w') as fo:
        for seq, input_tokens in zip(all_output_seqs, input_tokens_list):
            print(type(seq))
            #seq = seq.data.view(-1)
            eos_idx = seq.index(2) if 2 in seq else seq
            string = seq_to_string(seq[:eos_idx + 1],
                                   idx_to_tok,
                                   input_tokens=None)
            fo.write(string + '\n')

    return None
    def get_response(self, input_string):
        use_extended_vocab = isinstance(self.decoder, CopyNetDecoder)

        if not hasattr(self, 'parser_'):
            self.parser_ = English()

        idx_to_tok = self.lang.idx_to_tok
        tok_to_idx = self.lang.tok_to_idx

        input_tokens = self.parser_(' '.join(input_string.split()))
        input_tokens = ['<SOS>'
                        ] + [token.orth_.lower()
                             for token in input_tokens] + ['<EOS>']
        input_seq = tokens_to_seq(input_tokens, tok_to_idx, len(input_tokens),
                                  use_extended_vocab)
        input_variable = Variable(input_seq).view(1, -1)

        if next(self.parameters()).is_cuda:
            input_variable = input_variable.cuda()

        outputs, idxs = self.forward(input_variable, [len(input_seq)])
        idxs = idxs.data.view(-1)
        eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs)
        output_string = seq_to_string(idxs[:eos_idx + 1],
                                      idx_to_tok,
                                      input_tokens=input_tokens)

        return output_string
示例#3
0
def singleOutput(input_seq,
                 encoder_decoder: EncoderDecoder,
                 input_tokens=None):
    idx_to_tok = encoder_decoder.lang.idx_to_tok
    #if input_tokens is not None:
    #    input_string = ' '.join(input_tokens)
    #else:
    #    input_string = seq_to_string(input_seq, idx_to_tok)
    lengths = ((input_seq != 0).long().sum(dim=0)).unsqueeze(0)
    input_variable = Variable(input_seq).view(1, -1)

    outputs, idxs = encoder_decoder(input_variable, lengths)
    idxs = idxs.data.view(-1)
    eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs)

    string = seq_to_string(idxs[:eos_idx + 1],
                           idx_to_tok,
                           input_tokens=input_tokens)

    #print('>', input_string, flush=True)
    #print('<', string, '\n', flush=True)

    return string.strip()
示例#4
0
def print_output(input_seq,
                 encoder_decoder: EncoderDecoder,
                 input_tokens=None,
                 target_tokens=None,
                 target_seq=None):
    idx_to_tok = encoder_decoder.lang.idx_to_tok
    if input_tokens is not None:
        input_string = ' '.join(input_tokens)
    else:
        input_string = seq_to_string(input_seq, idx_to_tok)

    lengths = list((input_seq != 0).long().sum(dim=0))
    input_variable = Variable(input_seq).view(1, -1)
    target_variable = Variable(target_seq).view(1, -1)

    if target_tokens is not None:
        target_string = ' '.join(target_tokens)
    elif target_seq is not None:
        target_string = seq_to_string(target_seq,
                                      idx_to_tok,
                                      input_tokens=input_tokens)
    else:
        target_string = ''

    if target_seq is not None:
        target_eos_idx = list(target_seq).index(2) if 2 in list(
            target_seq) else len(target_seq)
        target_outputs, _ = encoder_decoder(input_variable,
                                            lengths,
                                            targets=target_variable,
                                            teacher_forcing=1.0)
        target_log_prob = sum([
            target_outputs[0, step_idx, target_idx]
            for step_idx, target_idx in enumerate(target_seq[:target_eos_idx +
                                                             1])
        ])

    outputs, idxs = encoder_decoder(input_variable, lengths)
    idxs = idxs.data.view(-1)
    eos_idx = list(idxs).index(2) if 2 in list(idxs) else len(idxs)
    string = seq_to_string(idxs[:eos_idx + 1],
                           idx_to_tok,
                           input_tokens=input_tokens)
    log_prob = sum([
        outputs[0, step_idx, idx]
        for step_idx, idx in enumerate(idxs[:eos_idx + 1])
    ])

    print('>', input_string, '\n', flush=True)

    if target_seq is not None:
        print('=', target_string, flush=True)
    print('<', string, flush=True)

    print('\n')

    if target_seq is not None:
        print('target log prob:', float(target_log_prob))
    print('output log prob:', float(log_prob))

    print('-' * 100, '\n')

    return idxs