Exemplo n.º 1
0
def get_XLM_clozes_for_questions(clozes):


    with tempfile.TemporaryDirectory() as tempdir:
        raw_cloze_file = os.path.join(tempdir, 'dev.nq.en')
        _dump_questions_for_translation(clozes, raw_cloze_file, wh_heuristic=True)

        tok_cloze_file = os.path.join(tempdir, 'dev.nq.tok')
        cmd = f'cat {raw_cloze_file} | {SRC_PREPROCESSING} > {tok_cloze_file}'
        subprocess.check_call(cmd, shell=True)

        bpe_cloze_file = os.path.join(tempdir, 'dev.nq.tok.bpe')
        cmd = f'{XLM_FASTBPE} applybpe {bpe_cloze_file} {tok_cloze_file} {BPE_CODES} {SRC_VOCAB}'
        subprocess.check_call(cmd, shell=True)

        cmd = f'rm -r {XLM_DATA_DIR}/data/translate_nq_cloze_temp'
        if os.path.exists(os.path.join(XLM_DATA_DIR,'data/translate_nq_cloze_temp')):
            subprocess.check_call(cmd, shell=True)
        translation_output_path = os.path.join(tempdir, 'dev.cloze.en')

        cmd = f'cat {bpe_cloze_file} | python {TRANSLATOR} --exp_name translate_nq_cloze_temp ' \
            f'--src_lang nq.en --tgt_lang cloze.en --model_path {XLM_MODEL} --output_path {translation_output_path}'
        subprocess.check_call(cmd, shell=True)
        restore_segmentation(translation_output_path)

        clozes_with_questions = _associate_clozes_to_clozes(clozes, translation_output_path, wh_heuristic=False)

    return clozes_with_questions
Exemplo n.º 2
0
def perform_translation(input_file_path, translation_directory,
                        cloze_train_path, question_train_path,
                        fasttext_vectors_path, checkpoint_path):
    params = get_params(
        exp_name='translation',
        dump_path=translation_directory,
        cloze_train_path=cloze_train_path,
        question_train_path=question_train_path,
        cloze_test_path=input_file_path,
        fasttext_vectors_path=fasttext_vectors_path,
        checkpoint_path=checkpoint_path,
    )

    # check parameters
    assert params.exp_name
    check_all_data_params(params)
    check_mt_model_params(params)
    data = load_data(params, mono_only=True)
    encoder, decoder, discriminator, lm = build_mt_model(params, data)
    # initialize trainer / reload checkpoint / initialize evaluator
    trainer = TrainerMT(encoder, decoder, discriminator, lm, data, params)
    trainer.reload_checkpoint()
    trainer.test_sharing()  # check parameters sharing
    evaluator = EvaluatorMT(trainer, data, params)

    with torch.no_grad():
        lang1, lang2 = 'cloze', 'question'

        evaluator.encoder.eval()
        evaluator.decoder.eval()
        lang1_id = evaluator.params.lang2id[lang1]
        lang2_id = evaluator.params.lang2id[lang2]

        translations = []
        dataset = evaluator.data['mono'][lang1]['test']
        dataset.batch_size = params.batch_size

        for i, (sent1, len1) in enumerate(
                dataset.get_iterator(shuffle=False, group_by_size=False)()):
            encoded = evaluator.encoder(sent1.cuda(), len1, lang1_id)
            sent2_, len2_, _ = evaluator.decoder.generate(encoded, lang2_id)
            lang1_text = convert_to_text(sent1, len1, evaluator.dico[lang1],
                                         lang1_id, evaluator.params)
            lang2_text = convert_to_text(sent2_, len2_, evaluator.dico[lang2],
                                         lang2_id, evaluator.params)
            translations += zip(lang1_text, lang2_text)

        # export sentences to hypothesis file and restore BPE segmentation
        out_name = os.path.join(translation_directory,
                                'output_translations.txt')
        with open(out_name, 'w', encoding='utf-8') as f:
            f.write('\n'.join(['\t'.join(st) for st in translations]) + '\n')
        restore_segmentation(out_name)

    return out_name
    def create_reference_files():
        """
        Create reference files for BLEU evaluation.
        """
        params.ref_paths = {}

        for (lang1, lang2), v in data['para'].items():

            assert lang1 < lang2

            for data_set in ['valid', 'test']:

                # define data paths
                lang1_path = os.path.join(
                    params.hyp_path,
                    'ref.{0}-{1}.{2}.txt'.format(lang2, lang1, data_set))
                lang2_path = os.path.join(
                    params.hyp_path,
                    'ref.{0}-{1}.{2}.txt'.format(lang1, lang2, data_set))

                # store data paths
                params.ref_paths[(lang2, lang1, data_set)] = lang1_path
                params.ref_paths[(lang1, lang2, data_set)] = lang2_path

                # text sentences
                lang1_txt = []
                lang2_txt = []

                # convert to text
                for (sent1, len1), (sent2, len2) in get_iterator(
                        params, data, data_set, lang1, lang2):
                    lang1_txt.extend(
                        convert_to_text(sent1, len1, data['dico'], params))
                    lang2_txt.extend(
                        convert_to_text(sent2, len2, data['dico'], params))

                # replace <unk> by <<unk>> as these tokens cannot be counted in BLEU
                lang1_txt = [x.replace('<unk>', '<<unk>>') for x in lang1_txt]
                lang2_txt = [x.replace('<unk>', '<<unk>>') for x in lang2_txt]

                # export hypothesis
                with open(lang1_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(lang1_txt) + '\n')
                with open(lang2_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(lang2_txt) + '\n')

                # restore original segmentation
                restore_segmentation(lang1_path)
                restore_segmentation(lang2_path)
Exemplo n.º 4
0
    def end_eval(self, text_z_prime, references, hypothesis, hypothesis2):
        output_file, i = "output", 1
        while os.path.isfile(
                os.path.join(self.params.dump_path,
                             output_file + str(i) + '.txt')):
            i += 1
        output_file = os.path.join(self.params.dump_path,
                                   output_file + str(i) + '.txt')
        write_text_z_in_file(output_file, text_z_prime)
        restore_segmentation(output_file)

        # compute BLEU
        eval_bleu = True
        if eval_bleu and references:
            if False:
                bleu = multi_list_bleu(references, hypothesis)
                self.bleu = sum(bleu) / len(bleu)
                self.logger.info("average BLEU %s %s : %f" %
                                 ("input", "gen", self.bleu))
            else:
                # hypothesis / reference paths
                hyp_name, ref_name, i = "hyp", "ref", 1
                while os.path.isfile(
                        os.path.join(self.params.dump_path,
                                     hyp_name + str(i) + '.txt')):
                    i += 1
                ref_path = os.path.join(self.params.dump_path,
                                        ref_name + str(i) + '.txt')
                hyp_path = os.path.join(self.params.dump_path,
                                        hyp_name + str(i) + '.txt')
                hyp_path2 = os.path.join(self.params.dump_path,
                                         hyp_name + '_deb' + str(i) + '.txt')

                # export sentences to reference and hypothesis file
                with open(ref_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(references) + '\n')
                with open(hyp_path, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(hypothesis) + '\n')
                with open(hyp_path2, 'w', encoding='utf-8') as f:
                    f.write('\n'.join(hypothesis2) + '\n')

                restore_segmentation(ref_path)
                restore_segmentation(hyp_path)
                restore_segmentation(hyp_path2)

                # evaluate BLEU score
                self.bleu = eval_moses_bleu(ref_path, hyp_path)
                self.logger.info("BLEU input-gen : %f (%s, %s)" %
                                 (self.bleu, hyp_path, ref_path))
                bleu = eval_moses_bleu(ref_path, hyp_path2)
                self.logger.info("BLEU input-deb : %f (%s, %s)" %
                                 (bleu, hyp_path2, ref_path))
                bleu = eval_moses_bleu(hyp_path, hyp_path2)
                self.logger.info("BLEU gen-deb : %f (%s, %s)" %
                                 (bleu, hyp_path, hyp_path2))
Exemplo n.º 5
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    torch.manual_seed(
        params.seed
    )  # Set random seed. NB: Multi-GPU also needs torch.cuda.manual_seed_all(params.seed)
    assert (params.sample_temperature
            == 0) or (params.beam_size == 1), 'Cannot sample with beam search.'
    assert params.amp <= 1, f'params.amp == {params.amp} not yet supported.'
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
    encoder.load_state_dict(reloaded['encoder'])
    if all([k.startswith('module.') for k in reloaded['decoder'].keys()]):
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }
    decoder.load_state_dict(reloaded['decoder'])

    if params.amp != 0:
        models = apex.amp.initialize([encoder, decoder],
                                     opt_level=('O%i' % params.amp))
        encoder, decoder = models

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    # f = io.open(params.output_path, 'w', encoding='utf-8')

    hypothesis = [[] for _ in range(params.beam_size)]
    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        max_len = int(1.5 * lengths.max().item() + 10)
        if params.beam_size == 1:
            decoded, dec_lengths = decoder.generate(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                max_len=max_len,
                sample_temperature=(None if params.sample_temperature == 0 else
                                    params.sample_temperature))
        else:
            decoded, dec_lengths, all_hyp_strs = decoder.generate_beam(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                beam_size=params.beam_size,
                length_penalty=params.length_penalty,
                early_stopping=params.early_stopping,
                max_len=max_len,
                output_all_hyps=True)
        # hypothesis.extend(convert_to_text(decoded, dec_lengths, dico, params))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip().replace('<unk>', '<<unk>>')
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))
                               ]).replace('<unk>', '<<unk>>')
            if params.beam_size == 1:
                hypothesis[0].append(target)
            else:
                for hyp_rank in range(params.beam_size):
                    print(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])
                    hypothesis[hyp_rank].append(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])

            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source.replace(
                                 '@@ ', ''), target.replace('@@ ', '')))
            # f.write(target + "\n")

    # f.close()

    # export sentences to reference and hypothesis files / restore BPE segmentation
    save_dir, split = params.output_path.rsplit('/', 1)
    for hyp_rank in range(len(hypothesis)):
        hyp_name = f'hyp.st={params.sample_temperature}.bs={params.beam_size}.lp={params.length_penalty}.es={params.early_stopping}.seed={params.seed if (len(hypothesis) == 1) else str(hyp_rank)}.{params.src_lang}-{params.tgt_lang}.{split}.txt'
        hyp_path = os.path.join(save_dir, hyp_name)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(hypothesis[hyp_rank]) + '\n')
        restore_segmentation(hyp_path)

        # evaluate BLEU score
        if params.ref_path:
            bleu = eval_moses_bleu(params.ref_path, hyp_path)
            logger.info("BLEU %s %s : %f" % (hyp_path, params.ref_path, bleu))
def run(model,
        params,
        dico,
        data,
        split,
        src_lang,
        trg_lang,
        gen_type="src2trg",
        alpha=1.,
        beta=1.,
        gamma=0.,
        uniform=False,
        iter_mult=1,
        mask_schedule="constant",
        constant_k=1,
        batch_size=8,
        gpu_id=0):
    #n_batches = math.ceil(len(srcs) / batch_size)
    if gen_type == "src2trg":
        ref_path = params.ref_paths[(src_lang, trg_lang, split)]
    elif gen_type == "trg2src":
        ref_path = params.ref_paths[(trg_lang, src_lang, split)]

    refs = [s.strip() for s in open(ref_path, encoding="utf-8").readlines()]
    hypothesis = []
    #hypothesis_selected_pos = []
    for batch_n, batch in enumerate(
            get_iterator(params, data, split, "de", "en")):
        (src_x, src_lens), (trg_x, trg_lens) = batch

        batches, batches_src_lens, batches_trg_lens, total_scores = [], [], [], []
        #batches_selected_pos = []
        for i_topk_length in range(params.num_topk_lengths):

            # overwrite source/target lengths according to dataset stats if necessary
            if params.de2en_lengths != None and params.en2de_lengths != None:
                src_lens_item = src_lens[0].item() - 2  # remove BOS, EOS
                trg_lens_item = trg_lens[0].item() - 2  # remove BOS, EOS
                if gen_type == "src2trg":
                    if len(params.de2en_lengths[src_lens_item].keys()
                           ) < i_topk_length + 1:
                        break
                    data_trg_lens = sorted(
                        params.de2en_lengths[src_lens_item].items(),
                        key=operator.itemgetter(1))
                    data_trg_lens_item = data_trg_lens[-1 -
                                                       i_topk_length][0] + 2
                    # overwrite trg_lens
                    trg_lens = torch.ones_like(trg_lens) * data_trg_lens_item
                elif gen_type == "trg2src":
                    if len(params.en2de_lengths[trg_lens_item].keys()
                           ) < i_topk_length + 1:
                        break
                    data_src_lens = sorted(
                        params.en2de_lengths[trg_lens_item].items(),
                        key=operator.itemgetter(1))
                    # take i_topk_length most likely length and add BOS, EOS
                    data_src_lens_item = data_src_lens[-1 -
                                                       i_topk_length][0] + 2
                    # overwrite src_lens
                    src_lens = torch.ones_like(src_lens) * data_src_lens_item

            if gen_type == "src2trg":
                sent1_input = src_x
                sent2_input = create_masked_batch(trg_lens, params, dico)
                dec_len = torch.max(trg_lens).item() - 2  # cut BOS, EOS
            elif gen_type == "trg2src":
                sent1_input = create_masked_batch(src_lens, params, dico)
                sent2_input = trg_x
                dec_len = torch.max(src_lens).item() - 2  # cut BOS, EOS

            batch, lengths, positions, langs = concat_batches(sent1_input, src_lens, params.lang2id[src_lang], \
                                                              sent2_input, trg_lens, params.lang2id[trg_lang], \
                                                              params.pad_index, params.eos_index, \
                                                              reset_positions=True,
                                                              assert_eos=True) # not sure about it

            if gpu_id >= 0:
                batch, lengths, positions, langs, src_lens, trg_lens = \
                    to_cuda(batch, lengths, positions, langs, src_lens, trg_lens)

            with torch.no_grad():
                batch, total_score_argmax_toks = \
                    _evaluate_batch(model, params, dico, batch,
                                    lengths, positions, langs, src_lens, trg_lens,
                                    gen_type, alpha, beta, gamma, uniform,
                                    dec_len, iter_mult, mask_schedule, constant_k)
            batches.append(batch.clone())
            batches_src_lens.append(src_lens.clone())
            batches_trg_lens.append(trg_lens.clone())
            total_scores.append(total_score_argmax_toks)
            #batches_selected_pos.append(selected_pos)

        best_score_idx = np.array(total_scores).argmax()
        batch, src_lens, trg_lens = batches[best_score_idx], batches_src_lens[
            best_score_idx], batches_trg_lens[best_score_idx]
        #selected_pos = batches_selected_pos[best_score_idx]

        #if gen_type == "src2trg":
        #    hypothesis_selected_pos.append([selected_pos, trg_lens.item()-2])
        #elif gen_type == "trg2src":
        #    hypothesis_selected_pos.append([selected_pos, src_lens.item()-2])

        for batch_idx in range(batch_size):
            src_len = src_lens[batch_idx].item()
            tgt_len = trg_lens[batch_idx].item()
            if gen_type == "src2trg":
                generated = batch[src_len:src_len + tgt_len, batch_idx]
            else:
                generated = batch[:src_len, batch_idx]
            # extra <eos>
            eos_pos = (generated == params.eos_index).nonzero()
            if eos_pos.shape[0] > 2:
                generated = generated[:(eos_pos[1, 0].item() + 1)]
            hypothesis.extend(convert_to_text(generated.unsqueeze(1), \
                                torch.Tensor([generated.shape[0]]).int(), \
                                dico, params))

        print("Ex {0}\nRef: {1}\nHyp: {2}\n".format(
            batch_n, refs[batch_n].encode("utf-8"),
            hypothesis[-1].encode("utf-8")))

    hyp_path = os.path.join(params.hyp_path, 'decoding.txt')
    hyp_path_tok = os.path.join(params.hyp_path, 'decoding.tok.txt')
    #hyp_selected_pos_path = os.path.join(params.hyp_path, "selected_pos.pkl")

    # export sentences to hypothesis file / restore BPE segmentation
    with open(hyp_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(hypothesis) + '\n')
    with open(hyp_path_tok, 'w', encoding='utf-8') as f:
        f.write('\n'.join(hypothesis) + '\n')
    #with open(hyp_selected_pos_path, 'wb') as f:
    #    pkl.dump(hypothesis_selected_pos, f)
    restore_segmentation(hyp_path)

    # evaluate BLEU score
    bleu = eval_moses_bleu(ref_path, hyp_path)
    print("BLEU %s-%s; %s %s : %f" %
          (src_lang, trg_lang, hyp_path, ref_path, bleu))
    # write BLEU score result to file
    result_path = os.path.join(params.hyp_path, "result.txt")
    with open(result_path, 'w', encoding='utf-8') as f:
        f.write("BLEU %s-%s; %s %s : %f\n" %
                (src_lang, trg_lang, hyp_path, ref_path, bleu))
Exemplo n.º 7
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    enc_reload = reloaded['encoder']
    dec_reload = reloaded['decoder']
    if all([k.startswith('module.') for k in enc_reload.keys()]):
        enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
        dec_reload = {k[len('module.'):]: v for k, v in dec_reload.items()}
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=True).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    encoder.load_state_dict(enc_reload)
    decoder.load_state_dict(dec_reload)
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        max_len = min(params.max_length, lengths.max().item())
        batch = torch.LongTensor(max_len,
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                if lengths[j] > max_len:
                    lengths[j] = max_len
                    s = s[:max_len - 2]
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        decoded, dec_lengths = decoder.generate(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
    restore_segmentation(params.output_path)
import os
import sys
import argparse

from src.utils import restore_segmentation

if __name__ == '__main__':

    # generate parser / parse parameters
    parser = argparse.ArgumentParser(description="Generate reference file")
    parser.add_argument("--output_path",
                        type=str,
                        default="",
                        help="Output reference path")
    params = parser.parse_args()
    assert params.output_path and not os.path.isfile(params.output_path)

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line.strip().replace('<unk>', '<<unk>>'))
    print("Read %i sentences from stdin." % len(src_sent))

    # export sentences to file / restore BPE segmentation
    with open(params.output_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(src_sent) + '\n')
    restore_segmentation(params.output_path)
    print("Restored segmentation")