Exemplo n.º 1
0
from argparse import ArgumentParser
import json

parser = ArgumentParser(description='Predict translation')
parser.add_argument('--source', type=str)
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--checkpoint', type=str)
parser.add_argument('--num_candidates', type=int, default=1)

args = parser.parse_args()
with open(args.config) as f:
    config = json.load(f)

print('Constructing dictionaries...')
source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size'])
target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size'])

print('Building model...')
model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size)

predictor = Predictor(
    preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary),
    postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '<EndSent>']),
    model=model,
    checkpoint_filepath=args.checkpoint
)

for index, candidate in enumerate(predictor.predict_one(args.source, num_candidates=args.num_candidates)):
    print(f'Candidate {index} : {candidate}')
Exemplo n.º 2
0
def predict(dn, rn):
    dir_name_format = "../data/{dn}-{rn}-raw"
    dir_name = dir_name_format.format(dn=dn, rn=rn)
    input_path = os.path.join(dir_name, "src-test.txt")
    if not os.path.isfile(input_path):
        print(f"File: {input_path} not exist.")
        return

    output_filename = f"prediction-{dn}-{rn}.txt"
    output_path = os.path.join(outputDir, output_filename)
    if os.path.isfile(output_path):
        print(f"File {output_path} already exists.")
        return

    # 作用:将src进行index
    preprocess = IndexedInputTargetTranslationDataset.preprocess(source_dictionary)
    # 作用:将输出逆index为句子
    postprocess = lambda x: ''.join(
        [token for token in target_dictionary.tokenize_indexes(x) if token != END_TOKEN and token != START_TOKEN and token != PAD_TOKEN])
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

    print('Building model...')
    model = TransformerModel(source_dictionary.vocabulary_size, target_dictionary.vocabulary_size,
                             config['d_model'],
                             config['nhead'],
                             config['nhid'],
                             config['nlayers'])
    model.eval()
    checkpoint_filepath = checkpoint_path
    checkpoint = torch.load(checkpoint_filepath, map_location='cpu')
    model.load_state_dict(checkpoint)
    translator = Translator(
        model=model,
        beam_size=args.beam_size,
        max_seq_len=args.max_seq_len,
        trg_bos_idx=target_dictionary.token_to_index(START_TOKEN),
        trg_eos_idx=target_dictionary.token_to_index(END_TOKEN)
    ).to(device)

    from utils.pipe import PAD_INDEX
    def pad_src(batch):
        sources_lengths = [len(sources) for sources in batch]
        sources_max_length = max(sources_lengths)
        sources_padded = [sources + [PAD_INDEX] * (sources_max_length - len(sources)) for sources in batch]
        sources_tensor = torch.tensor(sources_padded)
        return sources_tensor
    def process(seq):
        seq = seq.strip()
        def is_proof(name):
            return name.count("balance") > 0 or name.count("one") > 0
        if is_proof(data_name) and not is_proof(dn):
            seq += ",$,1"
            global is_proof_process
            if is_proof_process:
                print("processing")
                is_proof_process = False
        return seq

    batch_size = args.bs
    print(f"Output to {output_path}:")
    with open(output_path, 'w', encoding='utf-8') as outFile:
        with open(input_path, 'r', encoding='utf-8') as inFile:
            seqs = []
            for seq in tqdm(inFile):
                seq = process(seq)
                src_seq = preprocess(seq)
                seqs.append(src_seq)
                if len(seqs) >= batch_size:
                    pred_seq = translator.translate_sentence(pad_src(seqs).to(device))
                    pred_line = [postprocess(pred) for pred in pred_seq]
                    # print(pred_line)
                    outFile.writelines([p.strip() + '\n' for p in pred_line])
                    seqs.clear()
                # endif
            # endfor
            if seqs:    # last batch
                pred_seq = translator.translate_sentence(pad_src(seqs).to(device))
                pred_line = [postprocess(pred).replace(START_TOKEN, '').replace(END_TOKEN, '') for pred in pred_seq]
                # print(pred_line)
                outFile.writelines([p.strip() + '\n' for p in pred_line])
                seqs.clear()
        # endwith
    # endwith
    print(f'[Info] {input_path} Finished.')