예제 #1
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    features_shards = []
    features_names = []
    for feat_name, feat_path in opt.src_feats.items():
        features_shards.append(split_corpus(feat_path, opt.shard_size))
        features_names.append(feat_name)
    shard_pairs = zip(src_shards, tgt_shards, *features_shards)

    for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs):
        features_shard_ = defaultdict(list)
        for j, x in enumerate(features_shard):
            features_shard_[features_names[j]] = x
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             src_feats=features_shard_,
                             tgt=tgt_shard,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug)
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    if opt.tree:
        tree_shards = split_corpus(opt.tree, opt.shard_size)
        shard_pairs = zip(src_shards, tgt_shards, tree_shards)
        for i, (src_shard, tgt_shard, tree_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 tree=tree_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug)
    else:
        shard_pairs = zip(src_shards, tgt_shards)
        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 tree=None,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug)
예제 #3
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    # wei 20200809
    nfr_tag_shards = split_corpus(opt.src_nfr_tag, opt.shard_size)
    flat_tag_shards = split_corpus(opt.src_flat_tag, opt.shard_size)
    # end wei
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    # shard_pairs = zip(src_shards, tgt_shards)    # 20200809 wei
    shard_pairs = zip(src_shards, nfr_tag_shards, flat_tag_shards, tgt_shards)

    # for i, (src_shard, tgt_shard) in enumerate(shard_pairs):    # 20200809
    for i, (src_shard, nfr_tag_shard, flat_tag_shard,
            tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            # wei 20200809
            nfr_tag=nfr_tag_shard,
            flat_tag=flat_tag_shard,
            # end wei
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug,
            align_debug=opt.align_debug)
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    # make virtual files to collect the predicted output (not actually needed but opennmt still requires this)
    f_output=io.StringIO()

    translator = build_translator(opt, report_score=False, out_file=f_output)

    for i, (comments, game) in enumerate(yield_games(sys.stdin)):
        logger.info("Translating shard %d." % i)

        events = []
        for event_group in game:

            scores, predictions = translator.translate(src=event_group, batch_size=opt.batch_size)

            # PRED SCORE = cumulated log likelihood of the generated sequence

            f_output.truncate(0) # clear this to prevent eating memory
            text_output=[p[0] for p in predictions]
            normalized_scores = [s[0].item()/len(t.split()) for s, t in zip(scores, text_output)]

            max_index = normalized_scores.index(max(normalized_scores))

            events.append( (normalized_scores[max_index], detokenize(text_output[max_index])) )

        for comm in comments:
            print(comm)
        print(" ".join([t for s, t in events]))
        print()
예제 #5
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    ArgumentParser._get_all_transform_translate(opt)
    ArgumentParser._validate_transforms_opts(opt)
    ArgumentParser.validate_translate_opts_dynamic(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)

    data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats)

    # Build transforms
    transforms_cls = get_transforms_cls(opt._all_transform)
    transforms = make_transforms(opt, transforms_cls, translator.fields)
    data_transform = [
        transforms[name] for name in opt.transforms if name in transforms
    ]
    transform = TransformPipe.build_from(data_transform)

    for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader):
        logger.info("Translating shard %d." % i)
        translator.translate_dynamic(src=src_shard,
                                     transform=transform,
                                     src_feats=feats_shard,
                                     tgt=tgt_shard,
                                     batch_size=opt.batch_size,
                                     batch_type=opt.batch_type,
                                     attn_debug=opt.attn_debug,
                                     align_debug=opt.align_debug)
예제 #6
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        '''
        src_shard type = list
        len(src_shard) = 2507
        src_shard[0].decode("utf-8")
        'आपकी कार में ब्लैक बॉक्स\n'
        '''
        logger.info("Translating shard %d." % i)
        print("in translate")
        import os
        print("in translate.py pwd = ", os.getcwd())
        translator.translate(
            src=
            src_shard,  #src_shard:type=list,len=2507,src_shard[0]='आपकी कार में ब्लैक बॉक्स\n'
            tgt=tgt_shard,  #tgt_shard[0]='a black box in your car\n'
            tgt_path=opt.tgt,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
예제 #7
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    lemma_shards = split_corpus(opt.lemma, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards, lemma_shards)
    lemma_align = open(opt.lemma_align, 'rb').readlines()
    if opt.gpu >= 0:
        topic_matrix = torch.load(opt.topic_matrix,
                                  map_location=torch.device(opt.gpu))
    else:
        topic_matrix = torch.load(opt.topic_matrix)
    if not opt.fp32:
        topic_matrix = topic_matrix.half()
    for i, (src_shard, tgt_shard, lemma_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(lemma_align=lemma_align,
                             topic_matrix=topic_matrix,
                             src=src_shard,
                             tgt=tgt_shard,
                             lemma=lemma_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug)
예제 #8
0
def parse_opt(opt, model_root):

    argv = []
    parser = ArgumentParser()
    onmt.opts.model_opts(parser)
    onmt.opts.translate_opts(parser)

    models = opt["models"]
    if not isinstance(models, (list, tuple)):
        models = [models]
    opt["models"] = [os.path.join(model_root, model) for model in models]
    opt["src"] = "dummy_src"

    for (k, v) in opt.items():
        if k == "models":
            argv += ["-model"]
            argv += [str(model) for model in v]
        elif type(v) == bool:
            argv += ["-%s" % k]
        else:
            argv += ["-%s" % k, str(v)]

    opt = parser.parse_args(argv)
    ArgumentParser.validate_translate_opts(opt)
    opt.cuda = opt.gpu > -1

    return opt
예제 #9
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    # MultimodalTranslator 继承 Translator,主要覆盖了translate函数
    # 有些函数没有覆盖,被我原封不动地拷贝进了MultimodalTranslator,
    # 其实可以直接在子类里面调用父类的方法,
    # 用super(MultimodalTranslator, self).method()
    # 这个以后再改吧。

    test_img_feats = np.load(opt.path_to_test_img_feats)
    test_img_feats = test_img_feats.astype(np.float32)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug,
                             test_img_feats=test_img_feats,
                             multimodal_model_type=opt.multimodal_model_type)
예제 #10
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    # yida translate
    tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \
        if opt.tag_src is not None else repeat(None)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    # yida translate
    shard_pairs = zip(src_shards, tgt_shards, tag_src_shards)

    # yida translate
    for i, (src_shard, tgt_shard, tag_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            # yida translate
            tag_src=tag_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug)
예제 #11
0
    def __init__(self, language='en', method='bert', extract=False):
        '''
		Arguments:
			Language: "en" or "de"
			Method: "bert" or "conv"
			Extract: True or False
		'''
        self.method = method

        self.opt = self._get_opt(language, self.method)

        if torch.cuda.is_available():
            self.opt.gpu = 0

        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        self.language = language
        if self.language == 'en':
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        elif self.language == 'de':
            self.tokenizer = BertTokenizer.from_pretrained(
                'bert-base-multilingual-cased')

        self.extract = extract
        if self.extract:
            self.extractor = TFIDF()
    def eval_impl(self,
            processed_data_dir: Path,
            model_dir: Path,
            beam_search_size: int,
            k: int
    ) -> List[List[Tuple[str, float]]]:
        from roosterize.ml.onmt.CustomTranslator import CustomTranslator
        from onmt.utils.misc import split_corpus
        from onmt.utils.parse import ArgumentParser
        from translate import _get_parser as translate_get_parser

        src_path = processed_data_dir/"src.txt"
        tgt_path = processed_data_dir/"tgt.txt"

        best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json)
        self.logger.info(f"Taking best step at {best_step}")

        candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

        with IOUtils.cd(self.open_nmt_path):
            parser = translate_get_parser()
            opt = parser.parse_args(
                f" -model {model_dir}/models/ckpt_step_{best_step}.pt"
                f" -src {src_path}"
                f" -tgt {tgt_path}"
            )
            opt.output = f"{model_dir}/last-pred.txt"
            opt.beam_size = beam_search_size
            opt.gpu = 0 if torch.cuda.is_available() else -1
            opt.n_best = k
            opt.block_ngram_repeat = 1
            opt.ignore_when_blocking = ["_"]

            # translate.main
            ArgumentParser.validate_translate_opts(opt)

            translator = CustomTranslator.build_translator(opt, report_score=False)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                self.logger.info("Translating shard %d." % i)
                _, _, candidates_logprobs_shard = translator.translate(
                    src=src_shard,
                    tgt=tgt_shard,
                    src_dir=opt.src_dir,
                    batch_size=opt.batch_size,
                    attn_debug=opt.attn_debug
                )
                candidates_logprobs.extend(candidates_logprobs_shard)
            # end for
        # end with

        # Reformat candidates
        candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs]

        return candidates_logprobs
예제 #13
0
    def __init__(self):
        parser = ArgumentParser()
        opts.config_opts(parser)
        self.opt = parser.parse_args()
        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        self.mecab = Mecab.Tagger("-Owakati")
        self.mecab.parce("")
예제 #14
0
def evaluate_translation_on_datasets(opt):

    opt.log_file = logging_file_path
    opt.models = [model_file_path]
    opt.n_best = 1
    opt.beam_size = 5
    opt.report_bleu = False
    opt.report_rouge = False

    logger = init_logger(opt.log_file)
    translator = build_translator(opt, report_score=True)

    for dataset_file in dataset_files:

        if '_jp_spaced.txt' in dataset_file:

            src_file_path = os.path.join(dataset_root_path, dataset_file)
            tgt_file_path = os.path.join(dataset_root_path, dataset_file[:-len('_jp_spaced.txt')] + '_en_spaced.txt')

            num_lines = sum(1 for line in open(tgt_file_path))

            if num_lines > sentences_per_dataset_max:

                src_tmp_file_path = src_file_path[:-4] + '_tmp.txt'
                tgt_tmp_file_path = tgt_file_path[:-4] + '_tmp.txt'

                with open(src_file_path, 'r') as src_file, open(tgt_file_path, 'r') as tgt_file:

                    src_lines = src_file.read().splitlines()
                    tgt_lines = tgt_file.read().splitlines()

                    pairs = list(zip(src_lines, tgt_lines))
                    random.shuffle(pairs)

                    pairs = pairs[:sentences_per_dataset_max]

                    with open(src_tmp_file_path, 'w') as src_tmp_file, open(tgt_tmp_file_path, 'w') as tgt_tmp_file:
                        for pair in pairs:
                            src_tmp_file.write(pair[0]+'\n')
                            tgt_tmp_file.write(pair[1]+'\n')

                src_file_path = src_tmp_file_path
                tgt_file_path = tgt_tmp_file_path

            opt.src = src_file_path
            opt.tgt = tgt_file_path

            ArgumentParser.validate_translate_opts(opt)

            average_pred_score = evaluate_translation(translator, opt, src_file_path, tgt_file_path)

            logger.info('{}: {}'.format(dataset_file, average_pred_score))

            if num_lines > sentences_per_dataset_max:
                os.remove(src_tmp_file_path)
                os.remove(tgt_tmp_file_path)
예제 #15
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)

    logger.info("Translating{}".format(opt.data))
    translator.translate(data=opt.data,
                         batch_size=opt.batch_size,
                         attn_debug=opt.attn_debug)
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    translator = build_translator(opt, report_score=True, logger=logger)

    def create_src_shards(path, opt, binary=True):
        if opt.data_type == 'imgvec':
            assert opt.shard_size <= 0
            return [path]
        else:
            if opt.data_type == 'none':
                return [None] * 99999
            else:
                return split_corpus(path, opt.shard_size, binary=binary)

    src_shards = create_src_shards(opt.src, opt)
    if opt.agenda:
        agenda_shards = create_src_shards(opt.agenda, opt, False)

    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)

    if not opt.agenda:
        shards = zip(src_shards, tgt_shards)
    else:
        shards = zip(src_shards, agenda_shards, tgt_shards)

    for i, flat_shard in enumerate(shards):
        if not opt.agenda:
            src_shard, tgt_shard = flat_shard
            agenda_shard = None
        else:
            src_shard, agenda_shard, tgt_shard = flat_shard
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             agenda=agenda_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             tag_shard=tag_shard)
예제 #17
0
def translate(opt): 
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    # shard_pairs = zip(src_shards, tgt_shards)
    # print("number of shards: ", len(src_shards), len(tgt_shards))

    # load emotions
    tgt_emotion_shards = [None]*100
    if opt.target_emotions_path != "":
        print("Loading target emotions...")
        tgt_emotions = read_emotion_file(opt.target_emotions_path)
        tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size)
        # print("number of shards: ", len(tgt_emotion_shards))
    
    tgt_concept_embedding_shards = [None]*100
    if opt.target_concept_embedding != "":
        print("Loading target_concept_embedding...")
        tgt_concept_embedding = load_pickle(opt.target_concept_embedding)
        tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size)
        # print("number of shards: ", len(tgt_concept_embedding_shards))
    
    tgt_concept_words_shards = [None]*100
    if opt.target_concept_words != "":
        print("Loading target_concept_words...")
        tgt_concept_words = load_pickle(opt.target_concept_words)
        # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size)
        tgt_concept_words_shards = [tgt_concept_words]
        # print("number of shards: ", len(tgt_concept_words_shards))
    
    shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards)

    for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs):
        # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard))
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug,
            tgt_emotion_shard=tgt_emotion_shard,
            rerank=opt.rerank,
            emotion_lexicon=opt.emotion_lexicon,
            tgt_concept_embedding_shard=tgt_concept_embedding_shard,
            tgt_concept_words_shard=tgt_concept_words_shard
            )
예제 #18
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if 'n_latent' not in vars(opt):
        vars(opt)['n_latent'] = vars(opt)['n_translate_latent']
    logger = init_logger(opt.log_file)

    if 'use_segments' not in vars(opt):
        vars(opt)['use_segments'] = opt.n_translate_segments != 0
        vars(opt)['max_segments'] = opt.n_translate_segments

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    n_latent = opt.n_latent

    if n_latent > 1:
        for latent_idx in range(n_latent):
            output_path = opt.output_dir + '/output_%d' % (latent_idx)
            out_file = codecs.open(output_path, 'w+', 'utf-8')
            translator.out_file = out_file

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                logger.info("Translating shard %d." % i)
                translator.translate(src=src_shard,
                                     tgt=tgt_shard,
                                     src_dir=opt.src_dir,
                                     batch_size=opt.batch_size,
                                     attn_debug=opt.attn_debug,
                                     latent_idx=latent_idx)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
                if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)
    else:
        output_path = opt.output_dir + '/output'
        out_file = codecs.open(output_path, 'w+', 'utf-8')
        translator.out_file = out_file

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 attn_debug=opt.attn_debug)
    def __init__(self):
        # コマンドラインで指定したオプションをもとにモデルを読み込む
        parser = ArgumentParser()
        opts.config_opts(parser)
        opts.translate_opts(parser)
        self.opt = parser.parse_args()
        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        # 分かち書きのためにMeCabを使用
        self.mecab = MeCab.Tagger("-Owakati")
        self.mecab.parse("")

        # 前回の応答を保存しておく辞書
        self.prev_uttr_dict = {}
    def __init__(self):
        # おまじない
        parser = ArgumentParser()
        opts.config_opts(parser)
        opts.translate_opts(parser)
        self.opt = parser.parse_args(args=[
            "-model", "../models/model.pt", "-src", "None", "-replace_unk",
            "--beam_size", "10", "--min_length", "7", "--block_ngram_repeat",
            "2"
        ])
        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        # 単語分割用にMeCabを使用
        self.mecab = MeCab.Tagger("-Owakati")
        self.mecab.parse("")
예제 #21
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    with open("opt.pkl", 'wb') as f1:
        pickle.dump(opt, f1)
    with open("opt.pkl", 'rb') as f1:
        opt1 = pickle.load(f1)
    translator = build_translator(opt, report_score=True)

    if opt.data_type == 'imgvec':
        assert opt.shard_size <= 0
        src_shards = [opt.src]
    else:
        if opt.data_type == 'none':
            src_shards = [None] * 99999
        else:
            src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        all_scores, all_predictions = translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug,
            tag_shard=tag_shard)
        with open("result_{}.pickle".format(i), 'wb') as f1:
            pickle.dump(all_predictions, f1)
예제 #22
0
def _build_translator(args):
    """
    Initializes a seq2seq translator model
    """

    from onmt.utils.parse import ArgumentParser
    parser = ArgumentParser()

    import onmt.opts as opts
    opts.config_opts(parser)
    opts.translate_opts(parser)

    opt = parser.parse_args(args=args)
    ArgumentParser.validate_translate_opts(opt)

    from onmt.translate.translator import build_translator
    translator = build_translator(opt, report_score=False)

    return translator, opt
예제 #23
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    history_shards = split_corpus(opt.history, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else [None] * opt.shard_size
    shard_pairs = zip(src_shards, history_shards, tgt_shards)

    for i, (src_shard, history_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             history=history_shard,
                             tgt=opt.tgt,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug)
def get_onmt_opt(translation_model: Iterable[str],
                 src_file: Optional[str] = None,
                 output_file: Optional[str] = None,
                 n_best: int = 1,
                 log_probs: bool = False):
    src = src_file if src_file is not None else '(unused)'
    output = output_file if output_file is not None else '(unused)'
    args_str = f'--model {" ".join(translation_model)} --src {src} --output {output}'
    if log_probs:
        args_str += ' --log_probs'
    if n_best != 1:
        args_str += f' --n_best {n_best}'
    args = args_str.split()

    parser = onmt_parser()
    opt = parser.parse_args(args)
    ArgumentParser.validate_translate_opts(opt)

    return opt
예제 #25
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True, logger=logger)
    translator.out_file = codecs.open(opt.output, 'w+', 'utf-8')
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             opt=opt)
예제 #26
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)

    if opt.data_type == 'graph':

        node1_shards = split_corpus(opt.src_node1, opt.shard_size)
        node2_shards = split_corpus(opt.src_node2, opt.shard_size)

        shard_pairs = zip(src_shards, node1_shards, node2_shards, tgt_shards)

        for i, (src_shard, node1_shard, node2_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(
                src=src_shard,
                tgt=tgt_shard,
                src_node1=node1_shard,
                src_node2=node2_shard,
                src_dir=opt.src_dir,
                batch_size=opt.batch_size,
                attn_debug=opt.attn_debug
                )

    else:

        shard_pairs = zip(src_shards, tgt_shards)

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(
                src=src_shard,
                tgt=tgt_shard,
                src_dir=opt.src_dir,
                batch_size=opt.batch_size,
                attn_debug=opt.attn_debug
                )
예제 #27
0
def prepare_translators(langspecf):
    global translatorbest, translatorbigram, langspec
    with open(os.path.join(dir_path, 'opt_data'), 'rb') as f:
        opt = pickle.load(f)

    if not langspec or langspec != langspecf:
        opt.models = [os.path.join(dir_path, 'model', langspecf['model'])]
        opt.n_best = 1
        ArgumentParser.validate_translate_opts(opt)
        logger = init_logger(opt.log_file)
        translatorbest = build_translator(opt, report_score=True)

        opt.models = [os.path.join(dir_path, 'model', langspecf['model'])]
        opt.n_best = 5
        opt.max_length = 2
        ArgumentParser.validate_translate_opts(opt)
        logger = init_logger(opt.log_file)
        translatorbigram = build_translator(opt, report_score=True)

        langspec = langspecf
예제 #28
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)

    logger = init_logger(opt.log_file)
    abs_path = os.path.dirname(opt.src)
    src_mode = opt.data_mode
    candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

    if "patype0" in opt.src_types:
        translator = MultiSourceAPTypeAppendedTranslator.build_translator(
            opt.src_types, opt, report_score=True)
    else:
        translator = MultiSourceAPTranslator.build_translator(
            opt.src_types, opt, report_score=True)
    raw_data_keys = ["src.{}".format(src_type)
                     for src_type in opt.src_types] + (["tgt"])
    raw_data_paths: Dict[str, str] = {
        k: "{0}/{1}.{2}.txt".format(abs_path, k, src_mode)
        for k in raw_data_keys
    }
    raw_data_shards: Dict[str, list] = {
        k: list(split_corpus(p, opt.shard_size))
        for k, p in raw_data_paths.items()
    }

    for i in range(len(list(raw_data_shards.values())[0])):
        logger.info("Translating shard %d." % i)
        _, _, candidates_logprobs_shard = translator.translate(
            {k: v[i]
             for k, v in raw_data_shards.items()},
            True,
            src_dir=None,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
        candidates_logprobs.extend(candidates_logprobs_shard)

    # Reformat candidates
    candidates_logprobs: List[List[Tuple[str, float]]] = [[
        ("".join(c), l) for c, l in cl
    ] for cl in candidates_logprobs]
    return candidates_logprobs
예제 #29
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    #translator = build_translator(opt, report_score=True, out_file=opt.output)  Changed to None to fix a bug ??????
    translator = build_translator(opt, report_score=True, out_file=None)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug,
                             generate_hidden_states=opt.gen_hidden_states,
                             out_dir=opt.output)
예제 #30
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        reps_tensor = translator.translate(src=src_shard,
                                           tgt=tgt_shard,
                                           src_dir=opt.src_dir,
                                           batch_size=opt.batch_size,
                                           batch_type=opt.batch_type,
                                           attn_debug=opt.attn_debug,
                                           align_debug=opt.align_debug)

    torch.save(reps_tensor, opt.out_reps)
    print(reps_tensor.size())
예제 #31
0
    def parse_opt(self, opt):
        """Parse the option set passed by the user using `onmt.opts`

       Args:
           opt (dict): Options passed by the user

       Returns:
           opt (argparse.Namespace): full set of options for the Translator
        """

        prec_argv = sys.argv
        sys.argv = sys.argv[:1]
        parser = ArgumentParser()
        onmt.opts.translate_opts(parser)

        models = opt['models']
        if not isinstance(models, (list, tuple)):
            models = [models]
        opt['models'] = [os.path.join(self.model_root, model)
                         for model in models]
        opt['src'] = "dummy_src"

        for (k, v) in opt.items():
            if k == 'models':
                sys.argv += ['-model']
                sys.argv += [str(model) for model in v]
            elif type(v) == bool:
                sys.argv += ['-%s' % k]
            else:
                sys.argv += ['-%s' % k, str(v)]

        opt = parser.parse_args()
        ArgumentParser.validate_translate_opts(opt)
        opt.cuda = opt.gpu > -1

        sys.argv = prec_argv
        return opt