コード例 #1
0
def do_predict(args):
    if args.use_cuda:
        place = fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()

    # define the data generator
    # old reader
    processor = reader.DataProcessor(fpattern=args.predict_file,
                                     src_vocab_fpath=args.src_vocab_fpath,
                                     trg_vocab_fpath=args.trg_vocab_fpath,
                                     token_delimiter=args.token_delimiter,
                                     use_token_batch=False,
                                     batch_size=args.batch_size,
                                     device_count=1,
                                     pool_size=args.pool_size,
                                     sort_type=reader.SortType.NONE,
                                     shuffle=False,
                                     shuffle_batch=False,
                                     start_mark=args.special_token[0],
                                     end_mark=args.special_token[1],
                                     unk_mark=args.special_token[2],
                                     max_length=args.max_length,
                                     n_head=args.n_head)
    '''
    processor = reader.DataProcessor(fpattern=args.predict_file,
                                     src_vocab_fpath=args.src_vocab_fpath,
                                     trg_vocab_fpath=args.trg_vocab_fpath,
                                     token_delimiter=args.token_delimiter,
                                     use_token_batch=False,
                                     batch_size=args.batch_size,
                                     device_count=1,
                                     pool_size=args.pool_size,
                                     sort_type=reader.SortType.NONE,
                                     shuffle=False,
                                     shuffle_batch=False,
                                     only_src=True,
                                     start_mark=args.special_token[0],
                                     end_mark=args.special_token[1],
                                     unk_mark=args.special_token[2],
                                     max_length=args.max_length,
                                     n_head=args.n_head,
                                     stream=args.stream,
                                     src_bpe_dict=args.src_bpe_dict)
    '''
    batch_generator = processor.data_generator(phase="predict", place=place)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()
    trg_idx2word = reader.DataProcessor.load_dict(
        dict_path=args.trg_vocab_fpath, reverse=True)

    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = processor.get_vocab_summary()

    with fluid.dygraph.guard(place):
        # define data loader
        test_loader = fluid.io.DataLoader.from_generator(capacity=10)
        test_loader.set_batch_generator(batch_generator, places=place)

        # define model
        transformer = Transformer(
            args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
            args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
            args.d_inner_hid, args.prepostprocess_dropout,
            args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
            args.postprocess_cmd, args.weight_sharing, args.bos_idx,
            args.eos_idx)

        # load the trained model
        assert args.init_from_params, (
            "Please set init_from_params to load the infer model.")
        model_dict, _ = fluid.load_dygraph(
            os.path.join(args.init_from_params, "transformer"))
        # to avoid a longer length than training, reset the size of position
        # encoding to max_length
        model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
            args.max_length + 1, args.d_model)
        model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
            args.max_length + 1, args.d_model)
        transformer.load_dict(model_dict)

        # set evaluate mode
        transformer.eval()

        f = open(args.output_file, "wb")

        detok = MosesDetokenizer(lang='en')
        detc = MosesDetruecaser()

        for input_data in test_loader():
            if args.stream:
                (src_word, src_pos, src_slf_attn_bias, trg_word,
                 trg_src_attn_bias, real_read) = input_data
            else:
                (src_word, src_pos, src_slf_attn_bias, trg_word,
                 trg_src_attn_bias) = input_data

            finished_seq, finished_scores = transformer.beam_search(
                src_word,
                src_pos,
                src_slf_attn_bias,
                trg_word,
                trg_src_attn_bias,
                bos_id=args.bos_idx,
                eos_id=args.eos_idx,
                beam_size=args.beam_size,
                max_len=args.max_out_len,
                waitk=args.waitk,
                stream=args.stream)
            # embed()
            finished_seq = finished_seq.numpy()
            finished_scores = finished_scores.numpy()
            for idx, ins in enumerate(finished_seq):
                for beam_idx, beam in enumerate(ins):
                    if beam_idx >= args.n_best: break
                    id_list = post_process_seq(beam, args.bos_idx,
                                               args.eos_idx)
                    word_list = [trg_idx2word[id] for id in id_list]

                    if args.stream:
                        if args.waitk > 0:
                            # for wait-k models, wait k words in the beginning
                            word_list = [b''] * (args.waitk - 1) + word_list
                        else:
                            # for full sentence model, wait until the end
                            word_list = [b''] * (len(real_read[idx].numpy()) -
                                                 1) + word_list

                        final_output = []
                        real_output = []
                        _read = real_read[idx].numpy()
                        sent = ''
                        bpe_flag = False

                        for j in range(max(len(_read), len(word_list))):
                            # append number of reads at step j
                            r = _read[j] if j < len(_read) else 0
                            if r > 0:
                                final_output += [b''] * (r - 1)

                            # append number of writes at step j
                            w = word_list[j] if j < len(word_list) else b''
                            w = w.decode('utf-8')
                            real_output.append(w)

                            # if bpe_flag:
                            #     _sent = ('%s@@ %s'%(sent, w)).strip()
                            # else:
                            #     _sent = ('%s %s'%(sent, w)).strip()

                            _sent = ' '.join(real_output)

                            if len(_sent) > 0:

                                _sent += ' a'
                                _sent = ' '.join(_sent.split())

                                # if _sent.endswith('@@ a'):
                                #     bpe_flag = True
                                # else:
                                #     bpe_flag = False

                                _sent = _sent.replace('@@ ', '')
                                _sent = detok.detokenize(_sent.split())
                                _sent = detc.detruecase(_sent)
                                _sent = ' '.join(_sent)
                                _sent = _sent[:-1].strip()

                            incre = _sent[len(sent):]
                            #print('_sent0:', _sent)
                            sent = _sent
                            #print('sent:', sent)

                            if r > 0:
                                # if there is read, append a word to write
                                # final_output.append(w)
                                final_output.append(str.encode(incre))
                            else:
                                # if there is no read, append word to the final write
                                if j >= len(word_list):
                                    break
                                # final_output[-1] += b' '+w
                                final_output[-1] += str.encode(incre)

                            #print(final_output)
                            #print('incre:', incre)
                            #print('_sent1:', _sent)
                            # f.write(bytes('part:'+_sent+'\n'))

                        sequence = b"\n".join(final_output) + b" \n"
                        f.write(sequence)
                        # embed()
                    else:
                        sequence = b" ".join(word_list) + b"\n"
                        f.write(sequence)
                    f.flush()