def load_test_model(opt, model_path=None): if model_path is None: model_path = opt.models[0] checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt']) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) vocab = checkpoint['vocab'] if inputters.old_style_vocab(vocab): fields = inputters.load_old_vocab( vocab, opt.data_type, dynamic_dict=model_opt.copy_attn ) else: fields = vocab model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint, opt.gpu) if opt.fp32: model.float() model.eval() model.generator.eval() return fields, model, model_opt
def test_method(self): opt = copy.deepcopy(self.opt) if param_setting: for param, setting in param_setting: setattr(opt, param, setting) ArgumentParser.update_model_opts(opt) getattr(self, methodname)(opt)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): ''' src_shard type = list len(src_shard) = 2507 src_shard[0].decode("utf-8") 'आपकी कार में ब्लैक बॉक्स\n' ''' logger.info("Translating shard %d." % i) print("in translate") import os print("in translate.py pwd = ", os.getcwd()) translator.translate( src= src_shard, #src_shard:type=list,len=2507,src_shard[0]='आपकी कार में ब्लैक बॉक्स\n' tgt=tgt_shard, #tgt_shard[0]='a black box in your car\n' tgt_path=opt.tgt, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def translate_file(input_filename, output_filename): parser = ArgumentParser(description='translation') opts.config_opts(parser) opts.translate_opts(parser) args = f'''-model Experiments/Checkpoints/retrosynthesis_augmented_medium/retrosynthesis_aug_medium_model_step_100000.pt -src MCTS_data/{input_filename}.txt -output MCTS_data/{output_filename}.txt -batch_size 128 -replace_unk -max_length 200 -verbose -beam_size 10 -n_best 10 -min_length 5 -gpu 0''' opt = parser.parse_args(args) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): scores, predictions = translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug) return scores, predictions
def main(): parser = ArgumentParser() opts.config_opts(parser) opts.model_opts(parser) opts.global_opts(parser) opt = parser.parse_args() with open(os.path.join(dir_path, 'opt_data'), 'wb') as f: pickle.dump(opt, f)
def eval_impl(self, processed_data_dir: Path, model_dir: Path, beam_search_size: int, k: int ) -> List[List[Tuple[str, float]]]: from roosterize.ml.onmt.CustomTranslator import CustomTranslator from onmt.utils.misc import split_corpus from onmt.utils.parse import ArgumentParser from translate import _get_parser as translate_get_parser src_path = processed_data_dir/"src.txt" tgt_path = processed_data_dir/"tgt.txt" best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json) self.logger.info(f"Taking best step at {best_step}") candidates_logprobs: List[List[Tuple[List[str], float]]] = list() with IOUtils.cd(self.open_nmt_path): parser = translate_get_parser() opt = parser.parse_args( f" -model {model_dir}/models/ckpt_step_{best_step}.pt" f" -src {src_path}" f" -tgt {tgt_path}" ) opt.output = f"{model_dir}/last-pred.txt" opt.beam_size = beam_search_size opt.gpu = 0 if torch.cuda.is_available() else -1 opt.n_best = k opt.block_ngram_repeat = 1 opt.ignore_when_blocking = ["_"] # translate.main ArgumentParser.validate_translate_opts(opt) translator = CustomTranslator.build_translator(opt, report_score=False) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) for i, (src_shard, tgt_shard) in enumerate(shard_pairs): self.logger.info("Translating shard %d." % i) _, _, candidates_logprobs_shard = translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug ) candidates_logprobs.extend(candidates_logprobs_shard) # end for # end with # Reformat candidates candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs] return candidates_logprobs
def __init__(self): parser = ArgumentParser() opts.config_opts(parser) self.opt = parser.parse_args() ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) self.mecab = Mecab.Tagger("-Owakati") self.mecab.parce("")
def evaluate_translation_on_datasets(opt): opt.log_file = logging_file_path opt.models = [model_file_path] opt.n_best = 1 opt.beam_size = 5 opt.report_bleu = False opt.report_rouge = False logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) for dataset_file in dataset_files: if '_jp_spaced.txt' in dataset_file: src_file_path = os.path.join(dataset_root_path, dataset_file) tgt_file_path = os.path.join(dataset_root_path, dataset_file[:-len('_jp_spaced.txt')] + '_en_spaced.txt') num_lines = sum(1 for line in open(tgt_file_path)) if num_lines > sentences_per_dataset_max: src_tmp_file_path = src_file_path[:-4] + '_tmp.txt' tgt_tmp_file_path = tgt_file_path[:-4] + '_tmp.txt' with open(src_file_path, 'r') as src_file, open(tgt_file_path, 'r') as tgt_file: src_lines = src_file.read().splitlines() tgt_lines = tgt_file.read().splitlines() pairs = list(zip(src_lines, tgt_lines)) random.shuffle(pairs) pairs = pairs[:sentences_per_dataset_max] with open(src_tmp_file_path, 'w') as src_tmp_file, open(tgt_tmp_file_path, 'w') as tgt_tmp_file: for pair in pairs: src_tmp_file.write(pair[0]+'\n') tgt_tmp_file.write(pair[1]+'\n') src_file_path = src_tmp_file_path tgt_file_path = tgt_tmp_file_path opt.src = src_file_path opt.tgt = tgt_file_path ArgumentParser.validate_translate_opts(opt) average_pred_score = evaluate_translation(translator, opt, src_file_path, tgt_file_path) logger.info('{}: {}'.format(dataset_file, average_pred_score)) if num_lines > sentences_per_dataset_max: os.remove(src_tmp_file_path) os.remove(tgt_tmp_file_path)
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) logger.info("Translating{}".format(opt.data)) translator.translate(data=opt.data, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def build_vocabulary(ds: Dataset) -> None: base_args = ([ "-config", f"{path.join(ds.path, 'config.yaml')}", "-n_sample", "10000" ]) parser = ArgumentParser(description='vocab.py') dynamic_prepare_opts(parser, build_vocab_only=True) options, unknown = parser.parse_known_args(base_args) build_vocab_main(options) return options, unknown
def load_pre_train(path): logger.info('Loading pre-train model from %s' % path) checkpoint = torch.load(path, map_location=lambda storage, loc: storage) opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) model_opt = opt ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) fields = checkpoint['vocab'] model = build_model(model_opt, opt, fields, checkpoint) return model
def main(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) if opt.constraint_file: tag_shards = split_corpus(opt.constraint_file, opt.shard_size, iter_func=constraint_iter_func, binary=False) translator = build_translator(opt, report_score=True, logger=logger) def create_src_shards(path, opt, binary=True): if opt.data_type == 'imgvec': assert opt.shard_size <= 0 return [path] else: if opt.data_type == 'none': return [None] * 99999 else: return split_corpus(path, opt.shard_size, binary=binary) src_shards = create_src_shards(opt.src, opt) if opt.agenda: agenda_shards = create_src_shards(opt.agenda, opt, False) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) if not opt.agenda: shards = zip(src_shards, tgt_shards) else: shards = zip(src_shards, agenda_shards, tgt_shards) for i, flat_shard in enumerate(shards): if not opt.agenda: src_shard, tgt_shard = flat_shard agenda_shard = None else: src_shard, agenda_shard, tgt_shard = flat_shard logger.info("Translating shard %d." % i) tag_shard = None if opt.constraint_file: tag_shard = next(tag_shards) translator.translate(src=src_shard, tgt=tgt_shard, agenda=agenda_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, tag_shard=tag_shard)
def translate(opt): ArgumentParser.validate_translate_opts(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) # shard_pairs = zip(src_shards, tgt_shards) # print("number of shards: ", len(src_shards), len(tgt_shards)) # load emotions tgt_emotion_shards = [None]*100 if opt.target_emotions_path != "": print("Loading target emotions...") tgt_emotions = read_emotion_file(opt.target_emotions_path) tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size) # print("number of shards: ", len(tgt_emotion_shards)) tgt_concept_embedding_shards = [None]*100 if opt.target_concept_embedding != "": print("Loading target_concept_embedding...") tgt_concept_embedding = load_pickle(opt.target_concept_embedding) tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size) # print("number of shards: ", len(tgt_concept_embedding_shards)) tgt_concept_words_shards = [None]*100 if opt.target_concept_words != "": print("Loading target_concept_words...") tgt_concept_words = load_pickle(opt.target_concept_words) # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size) tgt_concept_words_shards = [tgt_concept_words] # print("number of shards: ", len(tgt_concept_words_shards)) shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards) for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs): # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard)) logger.info("Translating shard %d." % i) translator.translate( src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, tgt_emotion_shard=tgt_emotion_shard, rerank=opt.rerank, emotion_lexicon=opt.emotion_lexicon, tgt_concept_embedding_shard=tgt_concept_embedding_shard, tgt_concept_words_shard=tgt_concept_words_shard )
def preprocess(opt): ArgumentParser.validate_preprocess_args(opt) torch.manual_seed(opt.seed) init_logger(opt.log_file) logger.info("Extracting features...") src_nfeats = 0 tgt_nfeats = 0 src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \ else 0 tgt_nfeats = count_features(opt.train_tgt[0]) # tgt always text so far if len(opt.train_src) > 1 and opt.data_type == 'text': for src, tgt in zip(opt.train_src[1:], opt.train_tgt[1:]): assert src_nfeats == count_features(src),\ "%s seems to mismatch features of "\ "the other source datasets" % src assert tgt_nfeats == count_features(tgt),\ "%s seems to mismatch features of "\ "the other target datasets" % tgt logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) logger.info("Building `Fields` object...") if opt.disable_eos_sampling: eos_token = "<blank>" logger.info("Using NO eos token") else: eos_token = "</s>" logger.info("Using standard eos token") fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats, dynamic_dict=opt.dynamic_dict, with_align=opt.train_align[0] is not None, src_truncate=opt.src_seq_length_trunc, tgt_truncate=opt.tgt_seq_length_trunc, eos=eos_token) src_reader = inputters.str2reader[opt.data_type].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) align_reader = inputters.str2reader["text"].from_opt(opt) logger.info("Building & saving training data...") build_save_dataset('train', fields, src_reader, tgt_reader, align_reader, opt) if opt.valid_src and opt.valid_tgt: logger.info("Building & saving validation data...") build_save_dataset('valid', fields, src_reader, tgt_reader, align_reader, opt)
def main(opt): ArgumentParser.validate_translate_opts(opt) if not os.path.exists(opt.output_dir): os.makedirs(opt.output_dir) if 'n_latent' not in vars(opt): vars(opt)['n_latent'] = vars(opt)['n_translate_latent'] logger = init_logger(opt.log_file) if 'use_segments' not in vars(opt): vars(opt)['use_segments'] = opt.n_translate_segments != 0 vars(opt)['max_segments'] = opt.n_translate_segments translator = build_translator(opt, report_score=True) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) n_latent = opt.n_latent if n_latent > 1: for latent_idx in range(n_latent): output_path = opt.output_dir + '/output_%d' % (latent_idx) out_file = codecs.open(output_path, 'w+', 'utf-8') translator.out_file = out_file for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug, latent_idx=latent_idx) src_shards = split_corpus(opt.src, opt.shard_size) tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ if opt.tgt is not None else repeat(None) shard_pairs = zip(src_shards, tgt_shards) else: output_path = opt.output_dir + '/output' out_file = codecs.open(output_path, 'w+', 'utf-8') translator.out_file = out_file for i, (src_shard, tgt_shard) in enumerate(shard_pairs): logger.info("Translating shard %d." % i) translator.translate(src=src_shard, tgt=tgt_shard, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug)
def load_model(self, path=None): if path is None or not os.path.exists( os.path.abspath(os.path.join(os.getcwd(), path))): print("No model present at the specified path : {}".format(path)) parser = ArgumentParser(description='translate.py') opts.config_opts(parser) opts.translate_opts(parser) opt = parser.parse_args(["--model", path]) self.model = build_translator(opt, report_score=True) return
def main(opt): ArgumentParser.validate_preprocess_args(opt) torch.manual_seed(opt.seed) #check_existing_pt_files(opt) init_logger(opt.log_file) #if not os.path.exists(os.path.dirname(opt.save_data)): # os.makedirs(os.path.dirname(opt.save_data)) # logger.info("Creating dirs..."+os.path.dirname(opt.save_data)) logger.info("Extracting features...") fields = get_fields(opt) logger.info("Building & saving training data...") build_save_dataset(opt, fields)
def _get_parser(): parser = ArgumentParser(description='translate.py') opts.config_opts(parser) opts.translate_opts(parser) parser.add( '--model', '-model', dest='models', metavar='MODEL', nargs='+', type=str, default=[ "F:/Project/Python/selfProject/translate_NMT/transflate_NMT/data_step_100.pt" ], required=False, help="模型使用得训练文件") parser.add( '--src', '-src', required=False, default= "F:/Project/Python/selfProject/translate_NMT/transflate_NMT/data/src-test.txt", help="自己写的测试文件在哪里????") parser.add( '--output', '-output', default= 'F:/Project/Python/selfProject/translate_NMT/transflate_NMT/data/pred.txt', help="测试文件输出未知 改成自己得") return parser
def translate(opt): ArgumentParser.validate_translate_opts(opt) ArgumentParser._get_all_transform_translate(opt) ArgumentParser._validate_transforms_opts(opt) ArgumentParser.validate_translate_opts_dynamic(opt) logger = init_logger(opt.log_file) translator = build_translator(opt, logger=logger, report_score=True) data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats) # Build transforms transforms_cls = get_transforms_cls(opt._all_transform) transforms = make_transforms(opt, transforms_cls, translator.fields) data_transform = [ transforms[name] for name in opt.transforms if name in transforms ] transform = TransformPipe.build_from(data_transform) for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader): logger.info("Translating shard %d." % i) translator.translate_dynamic(src=src_shard, transform=transform, src_feats=feats_shard, tgt=tgt_shard, batch_size=opt.batch_size, batch_type=opt.batch_type, attn_debug=opt.attn_debug, align_debug=opt.align_debug)
def get_default_opts(): parser = ArgumentParser(description='data sample prepare') dynamic_prepare_opts(parser) default_opts = [ '-config', 'data/data.yaml', '-src_vocab', 'data/vocab-train.src', '-tgt_vocab', 'data/vocab-train.tgt' ] opt = parser.parse_known_args(default_opts)[0] # Inject some dummy training options that may needed when build fields opt.copy_attn = False ArgumentParser.validate_prepare_opts(opt) return opt
def _get_model_opts(opt, checkpoint=None): """Get `model_opt` to build model, may load from `checkpoint` if any.""" if checkpoint is not None: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) if (opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr(model_opt, 'tensorboard_log_dir_dated')): # ensure tensorboard output is written in the directory # of previous checkpoints opt.tensorboard_log_dir_dated = model_opt.tensorboard_log_dir_dated else: model_opt = opt return model_opt
def build_vocab_main(opts): """Apply transforms to samples of specified data and build vocab from it. Transforms that need vocab will be disabled in this. Built vocab is saved in plain text format as following and can be pass as `-src_vocab` (and `-tgt_vocab`) when training: ``` <tok_0>\t<count_0> <tok_1>\t<count_1> ``` """ ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True) assert opts.n_sample == -1 or opts.n_sample > 1, \ f"Illegal argument n_sample={opts.n_sample}." logger = init_logger() set_random_seed(opts.seed, False) transforms_cls = get_transforms_cls(opts._all_transform) fields = None transforms = make_transforms(opts, transforms_cls, fields) logger.info(f"Counter vocab from {opts.n_sample} samples.") src_counter, tgt_counter, src_feats_counter = build_vocab( opts, transforms, n_sample=opts.n_sample) logger.info(f"Counters src:{len(src_counter)}") logger.info(f"Counters tgt:{len(tgt_counter)}") for feat_name, feat_counter in src_feats_counter.items(): logger.info(f"Counters {feat_name}:{len(feat_counter)}") def save_counter(counter, save_path): check_path(save_path, exist_ok=opts.overwrite, log=logger.warning) with open(save_path, "w", encoding="utf8") as fo: for tok, count in counter.most_common(): fo.write(tok + "\t" + str(count) + "\n") if opts.share_vocab: src_counter += tgt_counter tgt_counter = src_counter logger.info(f"Counters after share:{len(src_counter)}") save_counter(src_counter, opts.src_vocab) else: save_counter(src_counter, opts.src_vocab) save_counter(tgt_counter, opts.tgt_vocab) for k, v in src_feats_counter.items(): save_counter(v, opts.src_feats_vocab[k])
def preprocess(opt): ArgumentParser.validate_preprocess_args(opt) torch.manual_seed(opt.seed) init_logger(opt.log_file) logger.info("Extracting features...") src_nfeats = 0 tgt_nfeats = 0 for src, tgt in zip(opt.train_src, opt.train_tgt): src_nfeats += count_features(src) if opt.data_type == 'text' \ else 0 tgt_nfeats += count_features(tgt) # tgt always text so far # print(src_nfeats) # exit() logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) # exit() ##################======================================= tt_nfeats = tgt_nfeats logger.info("Building `Fields` object...") fields = inputters.get_fields(opt.data_type, src_nfeats, tt_nfeats, tgt_nfeats, dynamic_dict=opt.dynamic_dict, with_align=opt.train_align[0] is not None, src_truncate=opt.src_seq_length_trunc, tgt_truncate=opt.tgt_seq_length_trunc) src_reader = inputters.str2reader[opt.data_type].from_opt(opt) ########=============================================================== tt_reader = inputters.str2reader["text"].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) align_reader = inputters.str2reader["text"].from_opt(opt) # for k,v in fields.items(): # if(k in ['src','tgt','tt']): # print(("preprocess .py preprocess() fields_item",k,v.fields[0][1].include_lengths)) logger.info("Building & saving training data...") build_save_dataset('train', fields, src_reader, tt_reader, tgt_reader, align_reader, opt) if opt.valid_src and opt.valid_tgt: logger.info("Building & saving validation data...") build_save_dataset('valid', fields, src_reader, tt_reader, tgt_reader, align_reader, opt)
def __init__(self): # コマンドラインで指定したオプションをもとにモデルを読み込む parser = ArgumentParser() opts.config_opts(parser) opts.translate_opts(parser) self.opt = parser.parse_args() ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) # 分かち書きのためにMeCabを使用 self.mecab = MeCab.Tagger("-Owakati") self.mecab.parse("") # 前回の応答を保存しておく辞書 self.prev_uttr_dict = {}
def preprocess(opt): ArgumentParser.validate_preprocess_args(opt) torch.manual_seed(opt.seed) init_logger(opt.log_file) logger.info("Extracting features...") src_nfeats = 0 tgt_nfeats = 0 confnet_nfeats = 0 for src, tgt, cnet in zip(opt.train_src, opt.train_tgt, opt.train_confnet): src_nfeats += count_features(src) if opt.data_type == 'text' or opt.data_type == 'lattice' \ else 0 tgt_nfeats += count_features(tgt) # tgt always text so far #confnet_nfeats += count_features(cnet) if opt.data_type == 'lattice' \ # else 0 logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) logger.info(" * number of confnet features: %d." % confnet_nfeats) logger.info("Building `Fields` object...") fields = inputters.get_fields(opt.data_type, src_nfeats, confnet_nfeats, tgt_nfeats, dynamic_dict=opt.dynamic_dict, with_align=opt.train_align[0] is not None, ans_truncate=opt.src_seq_length_trunc, ques_truncate=opt.confnet_seq_length_trunc, tgt_truncate=opt.tgt_seq_length_trunc) #print('fields done') ans_reader = inputters.str2reader["text"].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) align_reader = inputters.str2reader["text"].from_opt(opt) ques_reader = inputters.str2reader["lattice"].from_opt(opt) #print('src_reader', ques_reader) #print('tgt_reader', tgt_reader) #print('aglign_reader', align_reader) #print('confnet_reader', ans_reader) logger.info("Building & saving training data...") build_save_dataset('train', fields, ques_reader, ans_reader, tgt_reader, align_reader, opt) if opt.valid_src and opt.valid_tgt: logger.info("Building & saving validation data...") build_save_dataset('valid', fields, ques_reader, ans_reader, tgt_reader, align_reader, opt)
def main(opt): ArgumentParser.validate_preprocess_args(opt) torch.manual_seed(opt.seed) check_existing_pt_files(opt) init_logger(opt.log_file) logger.info("Extracting features...") src_nfeats = count_features( opt.train_src) if opt.data_type == 'text' else 0 tgt_nfeats = count_features(opt.train_tgt) # tgt always text so far logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) if len(opt.src_vocab) > 0: assert len(opt.src_vocab) == len( opt.train_src ), "you should provide src vocab for each dataset if you want to use your own vocab" for i, (train_src, train_tgt) in enumerate(zip(opt.train_src, opt.train_tgt)): valid_src = opt.valid_src[i] valid_tgt = opt.valid_tgt[i] logger.info("Working on %d dataset..." % i) logger.info("Building `Fields` object...") fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats, dynamic_dict=opt.dynamic_dict, src_truncate=opt.src_seq_length_trunc, tgt_truncate=opt.tgt_seq_length_trunc) src_reader = inputters.str2reader[opt.data_type].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) logger.info("Building & saving training data...") train_dataset_files = build_save_dataset('train', fields, src_reader, tgt_reader, opt, i, train_src, train_tgt, valid_src, valid_tgt) if opt.valid_src and opt.valid_tgt: logger.info("Building & saving validation data...") build_save_dataset('valid', fields, src_reader, tgt_reader, opt, i, train_src, train_tgt, valid_src, valid_tgt) logger.info("Building & saving vocabulary...") build_save_vocab(train_dataset_files, fields, opt, i)
def _get_parser(): parser = ArgumentParser(description='run_kp_eval.py') opts.config_opts(parser) opts.translate_opts(parser) return parser
def __init__(self): # おまじない parser = ArgumentParser() opts.config_opts(parser) opts.translate_opts(parser) self.opt = parser.parse_args(args=[ "-model", "../models/model.pt", "-src", "None", "-replace_unk", "--beam_size", "10", "--min_length", "7", "--block_ngram_repeat", "2" ]) ArgumentParser.validate_translate_opts(self.opt) self.translator = build_translator(self.opt, report_score=True) # 単語分割用にMeCabを使用 self.mecab = MeCab.Tagger("-Owakati") self.mecab.parse("")
def _get_parser(): parser = ArgumentParser(description='translate.py') opts.config_opts(parser) opts.translate_opts(parser) opts.mmod_finetune_translate_opts(parser) return parser
def _get_parser(): parser = ArgumentParser(description='build_copy_transformer.py') opts.config_opts(parser) opts.model_opts(parser) opts.train_opts(parser) return parser
def _get_parser(): parser = ArgumentParser(description='train.py') opts.config_opts(parser) opts.model_opts(parser) opts.train_opts(parser) return parser
def parse_opt(self, opt): """Parse the option set passed by the user using `onmt.opts` Args: opt (dict): Options passed by the user Returns: opt (argparse.Namespace): full set of options for the Translator """ prec_argv = sys.argv sys.argv = sys.argv[:1] parser = ArgumentParser() onmt.opts.translate_opts(parser) models = opt['models'] if not isinstance(models, (list, tuple)): models = [models] opt['models'] = [os.path.join(self.model_root, model) for model in models] opt['src'] = "dummy_src" for (k, v) in opt.items(): if k == 'models': sys.argv += ['-model'] sys.argv += [str(model) for model in v] elif type(v) == bool: sys.argv += ['-%s' % k] else: sys.argv += ['-%s' % k, str(v)] opt = parser.parse_args() ArgumentParser.validate_translate_opts(opt) opt.cuda = opt.gpu > -1 sys.argv = prec_argv return opt
def main(opt): ArgumentParser.validate_preprocess_args(opt) torch.manual_seed(opt.seed) check_existing_pt_files(opt) init_logger(opt.log_file) logger.info("Extracting features...") src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \ else 0 tgt_nfeats = count_features(opt.train_tgt) # tgt always text so far logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) logger.info("Building `Fields` object...") fields = inputters.get_fields( opt.data_type, src_nfeats, tgt_nfeats, dynamic_dict=opt.dynamic_dict, src_truncate=opt.src_seq_length_trunc, tgt_truncate=opt.tgt_seq_length_trunc) src_reader = inputters.str2reader[opt.data_type].from_opt(opt) tgt_reader = inputters.str2reader["text"].from_opt(opt) logger.info("Building & saving training data...") train_dataset_files = build_save_dataset( 'train', fields, src_reader, tgt_reader, opt) if opt.valid_src and opt.valid_tgt: logger.info("Building & saving validation data...") build_save_dataset('valid', fields, src_reader, tgt_reader, opt) logger.info("Building & saving vocabulary...") build_save_vocab(train_dataset_files, fields, opt)
import copy import unittest import math import torch import onmt import onmt.inputters import onmt.opts from onmt.model_builder import build_embeddings, \ build_encoder, build_decoder from onmt.encoders.image_encoder import ImageEncoder from onmt.encoders.audio_encoder import AudioEncoder from onmt.utils.parse import ArgumentParser parser = ArgumentParser(description='train.py') onmt.opts.model_opts(parser) onmt.opts.train_opts(parser) # -data option is required, but not used in this test, so dummy. opt = parser.parse_known_args(['-data', 'dummy'])[0] class TestModel(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestModel, self).__init__(*args, **kwargs) self.opt = opt def get_field(self): src = onmt.inputters.get_fields("text", 0, 0)["src"]
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()