def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) features_shards = [] features_names = [] for feat_name, feat_path in opt.src_feats.items(): features_shards.append(split_corpus(feat_path, opt.shard_size)) features_names.append(feat_name) shard_pairs = zip(src_shards, tgt_shards, *features_shards) for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs): features_shard_ = defaultdict(list) for j, x in enumerate(features_shard): features_shard_[features_names[j]] = x logger.info("Translating shard %d." % i) translator.translate(src=src_shard, src_feats=features_shard_, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tree: tree_shards = split_corpus(opt.tree, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards, tree_shards) for i, (src_shard, tgt_shard, tree_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, tree=tree_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug) else: shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, tree=None, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) # wei 20200809 nfr_tag_shards = split_corpus(opt.src_nfr_tag, opt.shard_size) flat_tag_shards = split_corpus(opt.src_flat_tag, opt.shard_size) # end wei tgt_shards = split_corpus(opt.tgt, opt.shard_size) # shard_pairs = zip(src_shards, tgt_shards) # 20200809 wei shard_pairs = zip(src_shards, nfr_tag_shards, flat_tag_shards, tgt_shards) # for i, (src_shard, tgt_shard) in enumerate(shard_pairs): # 20200809 for i, (src_shard, nfr_tag_shard, flat_tag_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, # wei 20200809 nfr_tag=nfr_tag_shard, flat_tag=flat_tag_shard, # end wei tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) # make virtual files to collect the predicted output (not actually needed but opennmt still requires this) f_output=io.StringIO() translator = build_translator(opt, report_score=False, out_file=f_output) for i, (comments, game) in enumerate(yield_games(sys.stdin)): logger.info("Translating shard %d." % i) events = [] for event_group in game: scores, predictions = translator.translate(src=event_group, batch_size=opt.batch_size) # PRED SCORE = cumulated log likelihood of the generated sequence f_output.truncate(0) # clear this to prevent eating memory text_output=[p[0] for p in predictions] normalized_scores = [s[0].item()/len(t.split()) for s, t in zip(scores, text_output)] max_index = normalized_scores.index(max(normalized_scores)) events.append( (normalized_scores[max_index], detokenize(text_output[max_index])) ) for comm in comments: print(comm) print(" ".join([t for s, t in events])) print()
def translate(opt): ArgumentParser.validate_translate_opts(opt) ArgumentParser._get_all_transform_translate(opt) ArgumentParser._validate_transforms_opts(opt) ArgumentParser.validate_translate_opts_dynamic(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats) # Build transforms transforms_cls = get_transforms_cls(opt._all_transform) transforms = make_transforms(opt, transforms_cls, translator.fields) data_transform = [ transforms[name] for name in opt.transforms if name in transforms ] transform = TransformPipe.build_from(data_transform) for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader): logger.info("Translating shard %d." % i) translator.translate_dynamic(src=src_shard, transform=transform, src_feats=feats_shard, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): ''' src_shard type = list len(src_shard) = 2507 src_shard[0].decode("utf-8") 'आपकी कार में ब्लैक बॉक्स\n' ''' logger.info("Translating shard %d." % i) print("in translate") import os print("in translate.py pwd = ", os.getcwd()) translator.translate( src= src_shard, #src_shard:type=list,len=2507,src_shard[0]='आपकी कार में ब्लैक बॉक्स\n' tgt=tgt_shard, #tgt_shard[0]='a black box in your car\n' tgt_path=opt.tgt, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) lemma_shards = split_corpus(opt.lemma, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards, lemma_shards) lemma_align = open(opt.lemma_align, 'rb').readlines() if opt.gpu >= 0: topic_matrix = torch.load(opt.topic_matrix, map_location=torch.device(opt.gpu)) else: topic_matrix = torch.load(opt.topic_matrix) if not opt.fp32: topic_matrix = topic_matrix.half() for i, (src_shard, tgt_shard, lemma_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(lemma_align=lemma_align, topic_matrix=topic_matrix, src=src_shard, tgt=tgt_shard, lemma=lemma_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def parse_opt(opt, model_root): argv = [] parser = ArgumentParser() onmt.opts.model_opts(parser) onmt.opts.translate_opts(parser) models = opt["models"] if not isinstance(models, (list, tuple)): models = [models] opt["models"] = [os.path.join(model_root, model) for model in models] opt["src"] = "dummy_src" for (k, v) in opt.items(): if k == "models": argv += ["-model"] argv += [str(model) for model in v] elif type(v) == bool: argv += ["-%s" % k] else: argv += ["-%s" % k, str(v)] opt = parser.parse_args(argv) ArgumentParser.validate_translate_opts(opt) opt.cuda = opt.gpu > -1 return opt
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) # MultimodalTranslator 继承 Translator,主要覆盖了translate函数 # 有些函数没有覆盖,被我原封不动地拷贝进了MultimodalTranslator, # 其实可以直接在子类里面调用父类的方法, # 用super(MultimodalTranslator, self).method() # 这个以后再改吧。 test_img_feats = np.load(opt.path_to_test_img_feats) test_img_feats = test_img_feats.astype(np.float32) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug, test_img_feats=test_img_feats, multimodal_model_type=opt.multimodal_model_type)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) # yida translate tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \ if opt.tag_src is not None else repeat(None) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) # yida translate shard_pairs = zip(src_shards, tgt_shards, tag_src_shards) # yida translate for i, (src_shard, tgt_shard, tag_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, # yida translate tag_src=tag_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug)
def __init__(self, language='en', method='bert', extract=False): ''' Arguments: Language: "en" or "de" Method: "bert" or "conv" Extract: True or False ''' self.method = method self.opt = self._get_opt(language, self.method) if torch.cuda.is_available(): self.opt.gpu = 0 ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) self.language = language if self.language == 'en': self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') elif self.language == 'de': self.tokenizer = BertTokenizer.from_pretrained( 'bert-base-multilingual-cased') self.extract = extract if self.extract: self.extractor = TFIDF()
def eval_impl(self, processed_data_dir: Path, model_dir: Path, beam_search_size: int, k: int ) -> List[List[Tuple[str, float]]]: from roosterize.ml.onmt.CustomTranslator import CustomTranslator from onmt.utils.misc import split_corpus from onmt.utils.parse import ArgumentParser from translate import _get_parser as translate_get_parser src_path = processed_data_dir/"src.txt" tgt_path = processed_data_dir/"tgt.txt" best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json) self.logger.info(f"Taking best step at {best_step}") candidates_logprobs: List[List[Tuple[List[str], float]]] = list() with IOUtils.cd(self.open_nmt_path): parser = translate_get_parser() opt = parser.parse_args( f" -model {model_dir}/models/ckpt_step_{best_step}.pt" f" -src {src_path}" f" -tgt {tgt_path}" ) opt.output = f"{model_dir}/last-pred.txt" opt.beam_size = beam_search_size opt.gpu = 0 if torch.cuda.is_available() else -1 opt.n_best = k opt.block_ngram_repeat = 1 opt.ignore_when_blocking = ["_"] # translate.main ArgumentParser.validate_translate_opts(opt) translator = CustomTranslator.build_translator(opt, report_score=False) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): self.logger.info("Translating shard %d." % i) _, _, candidates_logprobs_shard = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug ) candidates_logprobs.extend(candidates_logprobs_shard) # end for # end with # Reformat candidates candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs] return candidates_logprobs
def __init__(self): parser = ArgumentParser() opts.config_opts(parser) self.opt = parser.parse_args() ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) self.mecab = Mecab.Tagger("-Owakati") self.mecab.parce("")
def evaluate_translation_on_datasets(opt): opt.log_file = logging_file_path opt.models = [model_file_path] opt.n_best = 1 opt.beam_size = 5 opt.report_bleu = False opt.report_rouge = False logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) for dataset_file in dataset_files: if '_jp_spaced.txt' in dataset_file: src_file_path = os.path.join(dataset_root_path, dataset_file) tgt_file_path = os.path.join(dataset_root_path, dataset_file[:-len('_jp_spaced.txt')] + '_en_spaced.txt') num_lines = sum(1 for line in open(tgt_file_path)) if num_lines > sentences_per_dataset_max: src_tmp_file_path = src_file_path[:-4] + '_tmp.txt' tgt_tmp_file_path = tgt_file_path[:-4] + '_tmp.txt' with open(src_file_path, 'r') as src_file, open(tgt_file_path, 'r') as tgt_file: src_lines = src_file.read().splitlines() tgt_lines = tgt_file.read().splitlines() pairs = list(zip(src_lines, tgt_lines)) random.shuffle(pairs) pairs = pairs[:sentences_per_dataset_max] with open(src_tmp_file_path, 'w') as src_tmp_file, open(tgt_tmp_file_path, 'w') as tgt_tmp_file: for pair in pairs: src_tmp_file.write(pair[0]+'\n') tgt_tmp_file.write(pair[1]+'\n') src_file_path = src_tmp_file_path tgt_file_path = tgt_tmp_file_path opt.src = src_file_path opt.tgt = tgt_file_path ArgumentParser.validate_translate_opts(opt) average_pred_score = evaluate_translation(translator, opt, src_file_path, tgt_file_path) logger.info('{}: {}'.format(dataset_file, average_pred_score)) if num_lines > sentences_per_dataset_max: os.remove(src_tmp_file_path) os.remove(tgt_tmp_file_path)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) logger.info("Translating{}".format(opt.data)) translator.translate(data=opt.data, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) if opt.constraint_file: tag_shards = split_corpus(opt.constraint_file, opt.shard_size, iter_func=constraint_iter_func, binary=False) translator = build_translator(opt, report_score=True, logger=logger) def create_src_shards(path, opt, binary=True): if opt.data_type == 'imgvec': assert opt.shard_size <= 0 return [path] else: if opt.data_type == 'none': return [None] * 99999 else: return split_corpus(path, opt.shard_size, binary=binary) src_shards = create_src_shards(opt.src, opt) if opt.agenda: agenda_shards = create_src_shards(opt.agenda, opt, False) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) if not opt.agenda: shards = zip(src_shards, tgt_shards) else: shards = zip(src_shards, agenda_shards, tgt_shards) for i, flat_shard in enumerate(shards): if not opt.agenda: src_shard, tgt_shard = flat_shard agenda_shard = None else: src_shard, agenda_shard, tgt_shard = flat_shard logger.info("Translating shard %d." % i) tag_shard = None if opt.constraint_file: tag_shard = next(tag_shards) translator.translate(src=src_shard, tgt=tgt_shard, agenda=agenda_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, tag_shard=tag_shard)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) # shard_pairs = zip(src_shards, tgt_shards) # print("number of shards: ", len(src_shards), len(tgt_shards)) # load emotions tgt_emotion_shards = [None]*100 if opt.target_emotions_path != "": print("Loading target emotions...") tgt_emotions = read_emotion_file(opt.target_emotions_path) tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size) # print("number of shards: ", len(tgt_emotion_shards)) tgt_concept_embedding_shards = [None]*100 if opt.target_concept_embedding != "": print("Loading target_concept_embedding...") tgt_concept_embedding = load_pickle(opt.target_concept_embedding) tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size) # print("number of shards: ", len(tgt_concept_embedding_shards)) tgt_concept_words_shards = [None]*100 if opt.target_concept_words != "": print("Loading target_concept_words...") tgt_concept_words = load_pickle(opt.target_concept_words) # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size) tgt_concept_words_shards = [tgt_concept_words] # print("number of shards: ", len(tgt_concept_words_shards)) shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards) for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs): # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard)) logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, tgt_emotion_shard=tgt_emotion_shard, rerank=opt.rerank, emotion_lexicon=opt.emotion_lexicon, tgt_concept_embedding_shard=tgt_concept_embedding_shard, tgt_concept_words_shard=tgt_concept_words_shard )
def main(opt): ArgumentParser.validate_translate_opts(opt) if not os.path.exists(opt.output_dir): os.makedirs(opt.output_dir) if 'n_latent' not in vars(opt): vars(opt)['n_latent'] = vars(opt)['n_translate_latent'] logger = init_logger(opt.log_file) if 'use_segments' not in vars(opt): vars(opt)['use_segments'] = opt.n_translate_segments != 0 vars(opt)['max_segments'] = opt.n_translate_segments translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) n_latent = opt.n_latent if n_latent > 1: for latent_idx in range(n_latent): output_path = opt.output_dir + '/output_%d' % (latent_idx) out_file = codecs.open(output_path, 'w+', 'utf-8') translator.out_file = out_file for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, latent_idx=latent_idx) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) else: output_path = opt.output_dir + '/output' out_file = codecs.open(output_path, 'w+', 'utf-8') translator.out_file = out_file for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def __init__(self): # コマンドラインで指定したオプションをもとにモデルを読み込む parser = ArgumentParser() opts.config_opts(parser) opts.translate_opts(parser) self.opt = parser.parse_args() ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) # 分かち書きのためにMeCabを使用 self.mecab = MeCab.Tagger("-Owakati") self.mecab.parse("") # 前回の応答を保存しておく辞書 self.prev_uttr_dict = {}
def __init__(self): # おまじない parser = ArgumentParser() opts.config_opts(parser) opts.translate_opts(parser) self.opt = parser.parse_args(args=[ "-model", "../models/model.pt", "-src", "None", "-replace_unk", "--beam_size", "10", "--min_length", "7", "--block_ngram_repeat", "2" ]) ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) # 単語分割用にMeCabを使用 self.mecab = MeCab.Tagger("-Owakati") self.mecab.parse("")
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) if opt.constraint_file: tag_shards = split_corpus(opt.constraint_file, opt.shard_size, iter_func=constraint_iter_func, binary=False) with open("opt.pkl", 'wb') as f1: pickle.dump(opt, f1) with open("opt.pkl", 'rb') as f1: opt1 = pickle.load(f1) translator = build_translator(opt, report_score=True) if opt.data_type == 'imgvec': assert opt.shard_size <= 0 src_shards = [opt.src] else: if opt.data_type == 'none': src_shards = [None] * 99999 else: src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) tag_shard = None if opt.constraint_file: tag_shard = next(tag_shards) all_scores, all_predictions = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, tag_shard=tag_shard) with open("result_{}.pickle".format(i), 'wb') as f1: pickle.dump(all_predictions, f1)
def _build_translator(args): """ Initializes a seq2seq translator model """ from onmt.utils.parse import ArgumentParser parser = ArgumentParser() import onmt.opts as opts opts.config_opts(parser) opts.translate_opts(parser) opt = parser.parse_args(args=args) ArgumentParser.validate_translate_opts(opt) from onmt.translate.translator import build_translator translator = build_translator(opt, report_score=False) return translator, opt
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) history_shards = split_corpus(opt.history, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else [None] * opt.shard_size shard_pairs = zip(src_shards, history_shards, tgt_shards) for i, (src_shard, history_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, history=history_shard, tgt=opt.tgt, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def get_onmt_opt(translation_model: Iterable[str], src_file: Optional[str] = None, output_file: Optional[str] = None, n_best: int = 1, log_probs: bool = False): src = src_file if src_file is not None else '(unused)' output = output_file if output_file is not None else '(unused)' args_str = f'--model {" ".join(translation_model)} --src {src} --output {output}' if log_probs: args_str += ' --log_probs' if n_best != 1: args_str += f' --n_best {n_best}' args = args_str.split() parser = onmt_parser() opt = parser.parse_args(args) ArgumentParser.validate_translate_opts(opt) return opt
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True, logger=logger) translator.out_file = codecs.open(opt.output, 'w+', 'utf-8') src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, opt=opt)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) if opt.data_type == 'graph': node1_shards = split_corpus(opt.src_node1, opt.shard_size) node2_shards = split_corpus(opt.src_node2, opt.shard_size) shard_pairs = zip(src_shards, node1_shards, node2_shards, tgt_shards) for i, (src_shard, node1_shard, node2_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_node1=node1_shard, src_node2=node2_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug ) else: shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug )
def prepare_translators(langspecf): global translatorbest, translatorbigram, langspec with open(os.path.join(dir_path, 'opt_data'), 'rb') as f: opt = pickle.load(f) if not langspec or langspec != langspecf: opt.models = [os.path.join(dir_path, 'model', langspecf['model'])] opt.n_best = 1 ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translatorbest = build_translator(opt, report_score=True) opt.models = [os.path.join(dir_path, 'model', langspecf['model'])] opt.n_best = 5 opt.max_length = 2 ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translatorbigram = build_translator(opt, report_score=True) langspec = langspecf
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) abs_path = os.path.dirname(opt.src) src_mode = opt.data_mode candidates_logprobs: List[List[Tuple[List[str], float]]] = list() if "patype0" in opt.src_types: translator = MultiSourceAPTypeAppendedTranslator.build_translator( opt.src_types, opt, report_score=True) else: translator = MultiSourceAPTranslator.build_translator( opt.src_types, opt, report_score=True) raw_data_keys = ["src.{}".format(src_type) for src_type in opt.src_types] + (["tgt"]) raw_data_paths: Dict[str, str] = { k: "{0}/{1}.{2}.txt".format(abs_path, k, src_mode) for k in raw_data_keys } raw_data_shards: Dict[str, list] = { k: list(split_corpus(p, opt.shard_size)) for k, p in raw_data_paths.items() } for i in range(len(list(raw_data_shards.values())[0])): logger.info("Translating shard %d." % i) _, _, candidates_logprobs_shard = translator.translate( {k: v[i] for k, v in raw_data_shards.items()}, True, src_dir=None, batch_size=opt.batch_size, attn_debug=opt.attn_debug) candidates_logprobs.extend(candidates_logprobs_shard) # Reformat candidates candidates_logprobs: List[List[Tuple[str, float]]] = [[ ("".join(c), l) for c, l in cl ] for cl in candidates_logprobs] return candidates_logprobs
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) #translator = build_translator(opt, report_score=True, out_file=opt.output) Changed to None to fix a bug ?????? translator = build_translator(opt, report_score=True, out_file=None) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug, generate_hidden_states=opt.gen_hidden_states, out_dir=opt.output)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) reps_tensor = translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug) torch.save(reps_tensor, opt.out_reps) print(reps_tensor.size())
def parse_opt(self, opt): """Parse the option set passed by the user using `onmt.opts` Args: opt (dict): Options passed by the user Returns: opt (argparse.Namespace): full set of options for the Translator """ prec_argv = sys.argv sys.argv = sys.argv[:1] parser = ArgumentParser() onmt.opts.translate_opts(parser) models = opt['models'] if not isinstance(models, (list, tuple)): models = [models] opt['models'] = [os.path.join(self.model_root, model) for model in models] opt['src'] = "dummy_src" for (k, v) in opt.items(): if k == 'models': sys.argv += ['-model'] sys.argv += [str(model) for model in v] elif type(v) == bool: sys.argv += ['-%s' % k] else: sys.argv += ['-%s' % k, str(v)] opt = parser.parse_args() ArgumentParser.validate_translate_opts(opt) opt.cuda = opt.gpu > -1 sys.argv = prec_argv return opt