def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) features_shards = [] features_names = [] for feat_name, feat_path in opt.src_feats.items(): features_shards.append(split_corpus(feat_path, opt.shard_size)) features_names.append(feat_name) shard_pairs = zip(src_shards, tgt_shards, *features_shards) for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs): features_shard_ = defaultdict(list) for j, x in enumerate(features_shard): features_shard_[features_names[j]] = x logger.info("Translating shard %d." % i) translator.translate(src=src_shard, src_feats=features_shard_, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) # MultimodalTranslator 继承 Translator,主要覆盖了translate函数 # 有些函数没有覆盖,被我原封不动地拷贝进了MultimodalTranslator, # 其实可以直接在子类里面调用父类的方法, # 用super(MultimodalTranslator, self).method() # 这个以后再改吧。 test_img_feats = np.load(opt.path_to_test_img_feats) test_img_feats = test_img_feats.astype(np.float32) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug, test_img_feats=test_img_feats, multimodal_model_type=opt.multimodal_model_type)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) # wei 20200809 nfr_tag_shards = split_corpus(opt.src_nfr_tag, opt.shard_size) flat_tag_shards = split_corpus(opt.src_flat_tag, opt.shard_size) # end wei tgt_shards = split_corpus(opt.tgt, opt.shard_size) # shard_pairs = zip(src_shards, tgt_shards) # 20200809 wei shard_pairs = zip(src_shards, nfr_tag_shards, flat_tag_shards, tgt_shards) # for i, (src_shard, tgt_shard) in enumerate(shard_pairs): # 20200809 for i, (src_shard, nfr_tag_shard, flat_tag_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, # wei 20200809 nfr_tag=nfr_tag_shard, flat_tag=flat_tag_shard, # end wei tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tree: tree_shards = split_corpus(opt.tree, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards, tree_shards) for i, (src_shard, tgt_shard, tree_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, tree=tree_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug) else: shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, tree=None, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): ''' src_shard type = list len(src_shard) = 2507 src_shard[0].decode("utf-8") 'आपकी कार में ब्लैक बॉक्स\n' ''' logger.info("Translating shard %d." % i) print("in translate") import os print("in translate.py pwd = ", os.getcwd()) translator.translate( src= src_shard, #src_shard:type=list,len=2507,src_shard[0]='आपकी कार में ब्लैक बॉक्स\n' tgt=tgt_shard, #tgt_shard[0]='a black box in your car\n' tgt_path=opt.tgt, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def shard_iterator(srcs, tgts, ids, aligns, existing_shards, existing_fields, corpus_type, opt): """ Builds a single iterator yielding every shard of every corpus. """ for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns): if maybe_id in existing_shards: if opt.overwrite: logger.warning( "Overwrite shards for corpus {}".format(maybe_id)) else: if corpus_type == "train": assert existing_fields is not None,\ ("A 'vocab.pt' file should be passed to " "`-src_vocab` when adding a corpus to " "a set of already existing shards.") logger.warning("Ignore corpus {} because " "shards already exist".format(maybe_id)) continue if ((corpus_type == "train" or opt.filter_valid) and tgt is not None): filter_pred = partial(inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) align_shards = split_corpus(maybe_align, opt.shard_size) for i, (ss, ts, a_s) in enumerate(zip(src_shards, tgt_shards, align_shards)): yield (i, (ss, ts, a_s, maybe_id, filter_pred))
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) # yida translate tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \ if opt.tag_src is not None else repeat(None) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) # yida translate shard_pairs = zip(src_shards, tgt_shards, tag_src_shards) # yida translate for i, (src_shard, tgt_shard, tag_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, # yida translate tag_src=tag_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) lemma_shards = split_corpus(opt.lemma, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards, lemma_shards) lemma_align = open(opt.lemma_align, 'rb').readlines() if opt.gpu >= 0: topic_matrix = torch.load(opt.topic_matrix, map_location=torch.device(opt.gpu)) else: topic_matrix = torch.load(opt.topic_matrix) if not opt.fp32: topic_matrix = topic_matrix.half() for i, (src_shard, tgt_shard, lemma_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(lemma_align=lemma_align, topic_matrix=topic_matrix, src=src_shard, tgt=tgt_shard, lemma=lemma_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def eval_impl(self, processed_data_dir: Path, model_dir: Path, beam_search_size: int, k: int ) -> List[List[Tuple[str, float]]]: from roosterize.ml.onmt.CustomTranslator import CustomTranslator from onmt.utils.misc import split_corpus from onmt.utils.parse import ArgumentParser from translate import _get_parser as translate_get_parser src_path = processed_data_dir/"src.txt" tgt_path = processed_data_dir/"tgt.txt" best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json) self.logger.info(f"Taking best step at {best_step}") candidates_logprobs: List[List[Tuple[List[str], float]]] = list() with IOUtils.cd(self.open_nmt_path): parser = translate_get_parser() opt = parser.parse_args( f" -model {model_dir}/models/ckpt_step_{best_step}.pt" f" -src {src_path}" f" -tgt {tgt_path}" ) opt.output = f"{model_dir}/last-pred.txt" opt.beam_size = beam_search_size opt.gpu = 0 if torch.cuda.is_available() else -1 opt.n_best = k opt.block_ngram_repeat = 1 opt.ignore_when_blocking = ["_"] # translate.main ArgumentParser.validate_translate_opts(opt) translator = CustomTranslator.build_translator(opt, report_score=False) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): self.logger.info("Translating shard %d." % i) _, _, candidates_logprobs_shard = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug ) candidates_logprobs.extend(candidates_logprobs_shard) # end for # end with # Reformat candidates candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs] return candidates_logprobs
def get_shard_pairs(src, tgt, shard_size): if not tgt or tgt == '-': for pair_shard in split_corpus(src, shard_size): pairs = [pair.split(b'\t') for pair in pair_shard] yield zip(*pairs) else: src_shards = split_corpus(src, shard_size) tgt_shards = split_corpus(tgt, shard_size) return zip(src_shards, tgt_shards)
def build_save_dataset(corpus_type, fields, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src = opt.train_src tgt = opt.train_tgt ans = opt.train_ans else: src = opt.valid_src tgt = opt.valid_tgt ans = opt.valid_ans logger.info("Reading source answer and target files: %s %s %s." % (src, ans, tgt)) src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) ans_shards = split_corpus(ans, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards, ans_shards) dataset_paths = [] for i, (src_shard, tgt_shard, ans_shard) in enumerate(shard_pairs): assert len(src_shard) == len(tgt_shard) == len(ans_shard) logger.info("Building shard %d." % i) dataset = inputters.build_dataset( fields, opt.data_type, src=src_shard, tgt=tgt_shard, ans=ans_shard, src_dir=opt.src_dir, src_seq_len=opt.src_seq_length, tgt_seq_len=opt.tgt_seq_length, ans_seq_len=opt.ans_seq_length, sample_rate=opt.sample_rate, window_size=opt.window_size, window_stride=opt.window_stride, window=opt.window, use_filter_pred=corpus_type == 'train' or opt.filter_valid) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, data_path)) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() return dataset_paths
def build_save_dataset(corpus_type, fields, src_reader, history_reader, ans_reader, tgt_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src = opt.train_src history = opt.train_history ans = opt.train_ans tgt = opt.train_tgt else: src = opt.valid_src history = opt.valid_history ans = opt.valid_ans tgt = opt.valid_tgt logger.info("Reading source and target files: %s %s %s %s." % (src, history, ans, tgt)) src_shards = split_corpus(src, opt.shard_size) history_shards = split_corpus(history, opt.shard_size) ans_shards = split_corpus(ans, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) shard_pairs = zip(src_shards, history_shards, ans_shards, tgt_shards) dataset_paths = [] if (corpus_type == "train" or opt.filter_valid) and tgt is not None: filter_pred = partial( inputters.filter_example, use_src_len=opt.data_type == "text", use_history_len=False, max_src_len=opt.src_seq_length, max_history_len=-1, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None logger.info("filter_pred is not used:{}".format(filter_pred)) for i, (src_shard, history_shard, ans_shard, tgt_shard) in enumerate(shard_pairs): assert len(src_shard) == len(tgt_shard) and len(src_shard) == len(history_shard) and len(src_shard) == len(ans_shard) logger.info("Building shard %d." % i) dataset = inputters.Dataset( fields, readers=[src_reader, history_reader, ans_reader, tgt_reader] if tgt_reader else [src_reader, history_reader, ans_reader], data=([("src", src_shard), ("history", history_shard), ("ans", ans_shard), ("tgt", tgt_shard)] if tgt_reader else [("src", src_shard), ("history", history_shard), ("ans", ans_shard)]), dirs=[opt.src_dir, opt.src_dir, opt.src_dir, None] if tgt_reader else [opt.src_dir, opt.src_dir, opt.src_dir], sort_key=inputters.str2sortkey[opt.data_type], filter_pred=None ) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, data_path)) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() return dataset_paths
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt, index, train_src, train_tgt, valid_src, valid_tgt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src = train_src tgt = train_tgt else: src = valid_src tgt = valid_tgt logger.info("Reading source and target files: %s %s." % (src, tgt)) src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) dataset_paths = [] if (corpus_type == "train" or opt.filter_valid) and tgt is not None: filter_pred = partial(inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None for i, (src_shard, tgt_shard) in enumerate(shard_pairs): assert i == 0, "only allow shard size is 1" assert len(src_shard) == len(tgt_shard) logger.info("Building shard %d." % i) dataset = inputters.Dataset( fields, readers=[src_reader, tgt_reader] if tgt_reader else [src_reader], data=([("src", src_shard), ("tgt", tgt_shard)] if tgt_reader else [("src", src_shard)]), dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir], sort_key=inputters.str2sortkey[opt.data_type], filter_pred=filter_pred) #data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, index) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, data_path)) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() return dataset_paths
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) if opt.constraint_file: tag_shards = split_corpus(opt.constraint_file, opt.shard_size, iter_func=constraint_iter_func, binary=False) translator = build_translator(opt, report_score=True, logger=logger) def create_src_shards(path, opt, binary=True): if opt.data_type == 'imgvec': assert opt.shard_size <= 0 return [path] else: if opt.data_type == 'none': return [None] * 99999 else: return split_corpus(path, opt.shard_size, binary=binary) src_shards = create_src_shards(opt.src, opt) if opt.agenda: agenda_shards = create_src_shards(opt.agenda, opt, False) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) if not opt.agenda: shards = zip(src_shards, tgt_shards) else: shards = zip(src_shards, agenda_shards, tgt_shards) for i, flat_shard in enumerate(shards): if not opt.agenda: src_shard, tgt_shard = flat_shard agenda_shard = None else: src_shard, agenda_shard, tgt_shard = flat_shard logger.info("Translating shard %d." % i) tag_shard = None if opt.constraint_file: tag_shard = next(tag_shards) translator.translate(src=src_shard, tgt=tgt_shard, agenda=agenda_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, tag_shard=tag_shard)
def main(opt): ArgumentParser.validate_translate_opts(opt) if not os.path.exists(opt.output_dir): os.makedirs(opt.output_dir) if 'n_latent' not in vars(opt): vars(opt)['n_latent'] = vars(opt)['n_translate_latent'] logger = init_logger(opt.log_file) if 'use_segments' not in vars(opt): vars(opt)['use_segments'] = opt.n_translate_segments != 0 vars(opt)['max_segments'] = opt.n_translate_segments translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) n_latent = opt.n_latent if n_latent > 1: for latent_idx in range(n_latent): output_path = opt.output_dir + '/output_%d' % (latent_idx) out_file = codecs.open(output_path, 'w+', 'utf-8') translator.out_file = out_file for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, latent_idx=latent_idx) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) else: output_path = opt.output_dir + '/output' out_file = codecs.open(output_path, 'w+', 'utf-8') translator.out_file = out_file for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) # shard_pairs = zip(src_shards, tgt_shards) # print("number of shards: ", len(src_shards), len(tgt_shards)) # load emotions tgt_emotion_shards = [None]*100 if opt.target_emotions_path != "": print("Loading target emotions...") tgt_emotions = read_emotion_file(opt.target_emotions_path) tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size) # print("number of shards: ", len(tgt_emotion_shards)) tgt_concept_embedding_shards = [None]*100 if opt.target_concept_embedding != "": print("Loading target_concept_embedding...") tgt_concept_embedding = load_pickle(opt.target_concept_embedding) tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size) # print("number of shards: ", len(tgt_concept_embedding_shards)) tgt_concept_words_shards = [None]*100 if opt.target_concept_words != "": print("Loading target_concept_words...") tgt_concept_words = load_pickle(opt.target_concept_words) # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size) tgt_concept_words_shards = [tgt_concept_words] # print("number of shards: ", len(tgt_concept_words_shards)) shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards) for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs): # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard)) logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, tgt_emotion_shard=tgt_emotion_shard, rerank=opt.rerank, emotion_lexicon=opt.emotion_lexicon, tgt_concept_embedding_shard=tgt_concept_embedding_shard, tgt_concept_words_shard=tgt_concept_words_shard )
def main(opt): translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else [None]*opt.shard_size shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): #logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': src = opt.train_src tgt = opt.train_tgt else: src = opt.valid_src tgt = opt.valid_tgt logger.info("Reading source and target files: %s %s." % (src, tgt)) src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) dataset_paths = [] if (corpus_type == "train" or opt.filter_valid) and tgt is not None: filter_pred = partial( inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None for i, (src_shard, tgt_shard) in enumerate(shard_pairs): assert len(src_shard) == len(tgt_shard) logger.info("Building shard %d." % i) dataset = inputters.Dataset( fields, readers=[src_reader, tgt_reader] if tgt_reader else [src_reader], data=([("src", src_shard), ("tgt", tgt_shard)] if tgt_reader else [("src", src_shard)]), dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir], sort_key=inputters.str2sortkey[opt.data_type], filter_pred=filter_pred ) data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) dataset_paths.append(data_path) logger.info(" * saving %sth %s data shard to %s." % (i, corpus_type, data_path)) dataset.save(data_path) del dataset.examples gc.collect() del dataset gc.collect() return dataset_paths
def translate_file(input_filename, output_filename): parser = ArgumentParser(description='translation') opts.config_opts(parser) opts.translate_opts(parser) # print(opts) args = f'''-model m16_step_44000.pt -src source_products_16.txt -output op_16x_4400_50_10.txt -batch_size 128 -replace_unk -max_length 200 -verbose -beam_size 50 -n_best 10 -min_length 5''' opt = parser.parse_args(args) # print(opt.model) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): scores, predictions = translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug) return scores, predictions
def translate_file(input_filename, output_filename): parser = ArgumentParser(description='translation') opts.config_opts(parser) opts.translate_opts(parser) args = f'''-model Experiments/Checkpoints/retrosynthesis_augmented_medium/retrosynthesis_aug_medium_model_step_100000.pt -src MCTS_data/{input_filename}.txt -output MCTS_data/{output_filename}.txt -batch_size 128 -replace_unk -max_length 200 -verbose -beam_size 10 -n_best 10 -min_length 5 -gpu 0''' opt = parser.parse_args(args) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): scores, predictions = translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug) return scores, predictions
def summarize(self, texts): ''' Arguments: texts: list(str) ''' if self.extract: summaries = self.extractor.summarize(texts) else: summaries = texts with open('src.txt', "w", encoding='utf-8') as src_f: for summary in summaries: src_f.write( '[CLS] ' + ' '.join(self.tokenizer.tokenize(strip_accents(summary))) + '\n') src_shards = split_corpus(self.opt.src, self.opt.shard_size) for i, src_shard in enumerate(src_shards): self.translator.translate(src=src_shard, batch_size=self.opt.batch_size, attn_debug=self.opt.attn_debug) output = [] for line in fileinput.FileInput("output.out", inplace=1): line = line.replace(" ##", "").replace(" .", ".").replace( " ,", ",").replace(" !", "!").replace(" ?", "?").replace("\n", "") print(line) output.append(line) os.remove('src.txt') os.remove("output.out") return output
def main(): parser = _get_parser() opt = parser.parse_args() fields, model, model_opts = load_test_model(opt) src_shards = split_corpus(opt.src, opt.shard_size) for src in src_shards: store_encoder_attn(model.encoder, src, {'src': fields['src']}, opt.batch_size, opt.gpu, opt.out_file)
def create_src_shards(src, opt): if opt.data_type == 'imgvec': assert opt.shard_size <= 0 return [src] elif opt.data_type == 'none': return [None]*99999 else: return split_corpus(src, opt.shard_size)
def create_src_shards(path, opt, binary=True): if opt.data_type == 'imgvec': assert opt.shard_size <= 0 return [path] else: if opt.data_type == 'none': return [None] * 99999 else: return split_corpus(path, opt.shard_size, binary=binary)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) if opt.constraint_file: tag_shards = split_corpus(opt.constraint_file, opt.shard_size, iter_func=constraint_iter_func, binary=False) with open("opt.pkl", 'wb') as f1: pickle.dump(opt, f1) with open("opt.pkl", 'rb') as f1: opt1 = pickle.load(f1) translator = build_translator(opt, report_score=True) if opt.data_type == 'imgvec': assert opt.shard_size <= 0 src_shards = [opt.src] else: if opt.data_type == 'none': src_shards = [None] * 99999 else: src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) tag_shard = None if opt.constraint_file: tag_shard = next(tag_shards) all_scores, all_predictions = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, tag_shard=tag_shard) with open("result_{}.pickle".format(i), 'wb') as f1: pickle.dump(all_predictions, f1)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) history_shards = split_corpus(opt.history, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else [None] * opt.shard_size shard_pairs = zip(src_shards, history_shards, tgt_shards) for i, (src_shard, history_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, history=history_shard, tgt=opt.tgt, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True, logger=logger) translator.out_file = codecs.open(opt.output, 'w+', 'utf-8') src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, opt=opt)
def __call__(self, src_embed=None, hidden_state=None): translator = build_translator(self.opt, report_score=True) src_shards = split_corpus(self.opt.src, self.opt.shard_size) tgt_shards = split_corpus(self.opt.tgt, self.opt.shard_size) tgt2_shards = split_corpus(self.opt.tgt2, self.opt.shard_size) shard_trips = zip(src_shards, tgt_shards, tgt2_shards) for i, (src_shard, tgt_shard, tgt2_shard) in enumerate(shard_trips): return translator.translate_gold_diff( src=src_shard, tgt=tgt_shard, tgt2=tgt2_shard, src_dir=self.opt.src_dir, batch_size=self.opt.batch_size, batch_type=self.opt.batch_type, attn_debug=self.opt.attn_debug, align_debug=self.opt.align_debug, src_embed=src_embed, hidden_state=hidden_state)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) if opt.data_type == 'graph': node1_shards = split_corpus(opt.src_node1, opt.shard_size) node2_shards = split_corpus(opt.src_node2, opt.shard_size) shard_pairs = zip(src_shards, node1_shards, node2_shards, tgt_shards) for i, (src_shard, node1_shard, node2_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_node1=node1_shard, src_node2=node2_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug ) else: shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug )
def nmt_filter_dataset(opt): opt.src = os.path.join(dataset_root_path, src_file) opt.tgt = os.path.join(dataset_root_path, tgt_file) opt.shard_size = 1 opt.log_file = logging_file_path opt.models = [model_file_path] opt.n_best = 1 opt.beam_size = 1 opt.report_bleu = False opt.report_rouge = False logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_file_path = os.path.join(dataset_root_path, src_file) tgt_file_path = os.path.join(dataset_root_path, tgt_file) src_shards = split_corpus(src_file_path, opt.shard_size) tgt_shards = split_corpus(tgt_file_path, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) pred_scores = [] for i, (src_shard, tgt_shard) in enumerate(shard_pairs): start_time = time.time() shard_pred_scores, shard_pred_sentences = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug) print("--- %s seconds ---" % (time.time() - start_time)) pred_scores += [scores[0] for scores in shard_pred_scores] average_pred_score = torch.mean(torch.stack(pred_scores)).detach() return average_pred_score
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) reps_tensor = translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug) torch.save(reps_tensor, opt.out_reps) print(reps_tensor.size())