def load_test_model(opt, model_path=None):
    if model_path is None:
        model_path = opt.models[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    vocab = checkpoint['vocab']
    if inputters.old_style_vocab(vocab):
        fields = inputters.load_old_vocab(
            vocab, opt.data_type, dynamic_dict=model_opt.copy_attn
        )
    else:
        fields = vocab

    model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint,
                             opt.gpu)
    if opt.fp32:
        model.float()

    model.eval()
    model.generator.eval()
    return fields, model, model_opt
Beispiel #2
0
 def test_method(self):
     opt = copy.deepcopy(self.opt)
     if param_setting:
         for param, setting in param_setting:
             setattr(opt, param, setting)
     ArgumentParser.update_model_opts(opt)
     getattr(self, methodname)(opt)
Beispiel #3
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        '''
        src_shard type = list
        len(src_shard) = 2507
        src_shard[0].decode("utf-8")
        'आपकी कार में ब्लैक बॉक्स\n'
        '''
        logger.info("Translating shard %d." % i)
        print("in translate")
        import os
        print("in translate.py pwd = ", os.getcwd())
        translator.translate(
            src=
            src_shard,  #src_shard:type=list,len=2507,src_shard[0]='आपकी कार में ब्लैक बॉक्स\n'
            tgt=tgt_shard,  #tgt_shard[0]='a black box in your car\n'
            tgt_path=opt.tgt,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
def translate_file(input_filename, output_filename):
    parser = ArgumentParser(description='translation')
    opts.config_opts(parser)
    opts.translate_opts(parser)

    args = f'''-model Experiments/Checkpoints/retrosynthesis_augmented_medium/retrosynthesis_aug_medium_model_step_100000.pt 
                -src MCTS_data/{input_filename}.txt 
                -output MCTS_data/{output_filename}.txt 
                -batch_size 128 
                -replace_unk
                -max_length 200 
                -verbose 
                -beam_size 10 
                -n_best 10 
                -min_length 5 
                -gpu 0'''

    opt = parser.parse_args(args)
    translator = build_translator(opt, report_score=True)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        scores, predictions = translator.translate(src=src_shard,
                                                   tgt=tgt_shard,
                                                   src_dir=opt.src_dir,
                                                   batch_size=opt.batch_size,
                                                   attn_debug=opt.attn_debug)

    return scores, predictions
Beispiel #5
0
def main():
    parser = ArgumentParser()
    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.global_opts(parser)
    opt = parser.parse_args()
    with open(os.path.join(dir_path, 'opt_data'), 'wb') as f:
        pickle.dump(opt, f)
    def eval_impl(self,
            processed_data_dir: Path,
            model_dir: Path,
            beam_search_size: int,
            k: int
    ) -> List[List[Tuple[str, float]]]:
        from roosterize.ml.onmt.CustomTranslator import CustomTranslator
        from onmt.utils.misc import split_corpus
        from onmt.utils.parse import ArgumentParser
        from translate import _get_parser as translate_get_parser

        src_path = processed_data_dir/"src.txt"
        tgt_path = processed_data_dir/"tgt.txt"

        best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json)
        self.logger.info(f"Taking best step at {best_step}")

        candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

        with IOUtils.cd(self.open_nmt_path):
            parser = translate_get_parser()
            opt = parser.parse_args(
                f" -model {model_dir}/models/ckpt_step_{best_step}.pt"
                f" -src {src_path}"
                f" -tgt {tgt_path}"
            )
            opt.output = f"{model_dir}/last-pred.txt"
            opt.beam_size = beam_search_size
            opt.gpu = 0 if torch.cuda.is_available() else -1
            opt.n_best = k
            opt.block_ngram_repeat = 1
            opt.ignore_when_blocking = ["_"]

            # translate.main
            ArgumentParser.validate_translate_opts(opt)

            translator = CustomTranslator.build_translator(opt, report_score=False)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                self.logger.info("Translating shard %d." % i)
                _, _, candidates_logprobs_shard = translator.translate(
                    src=src_shard,
                    tgt=tgt_shard,
                    src_dir=opt.src_dir,
                    batch_size=opt.batch_size,
                    attn_debug=opt.attn_debug
                )
                candidates_logprobs.extend(candidates_logprobs_shard)
            # end for
        # end with

        # Reformat candidates
        candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs]

        return candidates_logprobs
Beispiel #7
0
    def __init__(self):
        parser = ArgumentParser()
        opts.config_opts(parser)
        self.opt = parser.parse_args()
        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        self.mecab = Mecab.Tagger("-Owakati")
        self.mecab.parce("")
def evaluate_translation_on_datasets(opt):

    opt.log_file = logging_file_path
    opt.models = [model_file_path]
    opt.n_best = 1
    opt.beam_size = 5
    opt.report_bleu = False
    opt.report_rouge = False

    logger = init_logger(opt.log_file)
    translator = build_translator(opt, report_score=True)

    for dataset_file in dataset_files:

        if '_jp_spaced.txt' in dataset_file:

            src_file_path = os.path.join(dataset_root_path, dataset_file)
            tgt_file_path = os.path.join(dataset_root_path, dataset_file[:-len('_jp_spaced.txt')] + '_en_spaced.txt')

            num_lines = sum(1 for line in open(tgt_file_path))

            if num_lines > sentences_per_dataset_max:

                src_tmp_file_path = src_file_path[:-4] + '_tmp.txt'
                tgt_tmp_file_path = tgt_file_path[:-4] + '_tmp.txt'

                with open(src_file_path, 'r') as src_file, open(tgt_file_path, 'r') as tgt_file:

                    src_lines = src_file.read().splitlines()
                    tgt_lines = tgt_file.read().splitlines()

                    pairs = list(zip(src_lines, tgt_lines))
                    random.shuffle(pairs)

                    pairs = pairs[:sentences_per_dataset_max]

                    with open(src_tmp_file_path, 'w') as src_tmp_file, open(tgt_tmp_file_path, 'w') as tgt_tmp_file:
                        for pair in pairs:
                            src_tmp_file.write(pair[0]+'\n')
                            tgt_tmp_file.write(pair[1]+'\n')

                src_file_path = src_tmp_file_path
                tgt_file_path = tgt_tmp_file_path

            opt.src = src_file_path
            opt.tgt = tgt_file_path

            ArgumentParser.validate_translate_opts(opt)

            average_pred_score = evaluate_translation(translator, opt, src_file_path, tgt_file_path)

            logger.info('{}: {}'.format(dataset_file, average_pred_score))

            if num_lines > sentences_per_dataset_max:
                os.remove(src_tmp_file_path)
                os.remove(tgt_tmp_file_path)
Beispiel #9
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)

    logger.info("Translating{}".format(opt.data))
    translator.translate(data=opt.data,
                         batch_size=opt.batch_size,
                         attn_debug=opt.attn_debug)
Beispiel #10
0
def build_vocabulary(ds: Dataset) -> None:
    base_args = ([
        "-config", f"{path.join(ds.path, 'config.yaml')}", "-n_sample", "10000"
    ])

    parser = ArgumentParser(description='vocab.py')
    dynamic_prepare_opts(parser, build_vocab_only=True)

    options, unknown = parser.parse_known_args(base_args)
    build_vocab_main(options)

    return options, unknown
Beispiel #11
0
def load_pre_train(path):
    logger.info('Loading pre-train model from %s' % path)
    checkpoint = torch.load(path, map_location=lambda storage, loc: storage)

    opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
    model_opt = opt
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
    fields = checkpoint['vocab']
    model = build_model(model_opt, opt, fields, checkpoint)
    return model
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    translator = build_translator(opt, report_score=True, logger=logger)

    def create_src_shards(path, opt, binary=True):
        if opt.data_type == 'imgvec':
            assert opt.shard_size <= 0
            return [path]
        else:
            if opt.data_type == 'none':
                return [None] * 99999
            else:
                return split_corpus(path, opt.shard_size, binary=binary)

    src_shards = create_src_shards(opt.src, opt)
    if opt.agenda:
        agenda_shards = create_src_shards(opt.agenda, opt, False)

    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)

    if not opt.agenda:
        shards = zip(src_shards, tgt_shards)
    else:
        shards = zip(src_shards, agenda_shards, tgt_shards)

    for i, flat_shard in enumerate(shards):
        if not opt.agenda:
            src_shard, tgt_shard = flat_shard
            agenda_shard = None
        else:
            src_shard, agenda_shard, tgt_shard = flat_shard
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             agenda=agenda_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             tag_shard=tag_shard)
Beispiel #13
0
def translate(opt): 
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    # shard_pairs = zip(src_shards, tgt_shards)
    # print("number of shards: ", len(src_shards), len(tgt_shards))

    # load emotions
    tgt_emotion_shards = [None]*100
    if opt.target_emotions_path != "":
        print("Loading target emotions...")
        tgt_emotions = read_emotion_file(opt.target_emotions_path)
        tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size)
        # print("number of shards: ", len(tgt_emotion_shards))
    
    tgt_concept_embedding_shards = [None]*100
    if opt.target_concept_embedding != "":
        print("Loading target_concept_embedding...")
        tgt_concept_embedding = load_pickle(opt.target_concept_embedding)
        tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size)
        # print("number of shards: ", len(tgt_concept_embedding_shards))
    
    tgt_concept_words_shards = [None]*100
    if opt.target_concept_words != "":
        print("Loading target_concept_words...")
        tgt_concept_words = load_pickle(opt.target_concept_words)
        # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size)
        tgt_concept_words_shards = [tgt_concept_words]
        # print("number of shards: ", len(tgt_concept_words_shards))
    
    shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards)

    for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs):
        # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard))
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug,
            tgt_emotion_shard=tgt_emotion_shard,
            rerank=opt.rerank,
            emotion_lexicon=opt.emotion_lexicon,
            tgt_concept_embedding_shard=tgt_concept_embedding_shard,
            tgt_concept_words_shard=tgt_concept_words_shard
            )
Beispiel #14
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt[0])  # tgt always text so far
    if len(opt.train_src) > 1 and opt.data_type == 'text':
        for src, tgt in zip(opt.train_src[1:], opt.train_tgt[1:]):
            assert src_nfeats == count_features(src),\
                "%s seems to mismatch features of "\
                "the other source datasets" % src
            assert tgt_nfeats == count_features(tgt),\
                "%s seems to mismatch features of "\
                "the other target datasets" % tgt
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    if opt.disable_eos_sampling:
        eos_token = "<blank>"
        logger.info("Using NO eos token")
    else:
        eos_token = "</s>"
        logger.info("Using standard eos token")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc,
                                  eos=eos_token)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, align_reader,
                       opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, opt)
Beispiel #15
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if 'n_latent' not in vars(opt):
        vars(opt)['n_latent'] = vars(opt)['n_translate_latent']
    logger = init_logger(opt.log_file)

    if 'use_segments' not in vars(opt):
        vars(opt)['use_segments'] = opt.n_translate_segments != 0
        vars(opt)['max_segments'] = opt.n_translate_segments

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    n_latent = opt.n_latent

    if n_latent > 1:
        for latent_idx in range(n_latent):
            output_path = opt.output_dir + '/output_%d' % (latent_idx)
            out_file = codecs.open(output_path, 'w+', 'utf-8')
            translator.out_file = out_file

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                logger.info("Translating shard %d." % i)
                translator.translate(src=src_shard,
                                     tgt=tgt_shard,
                                     src_dir=opt.src_dir,
                                     batch_size=opt.batch_size,
                                     attn_debug=opt.attn_debug,
                                     latent_idx=latent_idx)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
                if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)
    else:
        output_path = opt.output_dir + '/output'
        out_file = codecs.open(output_path, 'w+', 'utf-8')
        translator.out_file = out_file

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 attn_debug=opt.attn_debug)
    def load_model(self, path=None):

        if path is None or not os.path.exists(
                os.path.abspath(os.path.join(os.getcwd(), path))):
            print("No model present at the specified path : {}".format(path))

        parser = ArgumentParser(description='translate.py')
        opts.config_opts(parser)

        opts.translate_opts(parser)
        opt = parser.parse_args(["--model", path])
        self.model = build_translator(opt, report_score=True)
        return
Beispiel #17
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    #check_existing_pt_files(opt)

    init_logger(opt.log_file)
    #if not os.path.exists(os.path.dirname(opt.save_data)):
    #    os.makedirs(os.path.dirname(opt.save_data))
    #    logger.info("Creating dirs..."+os.path.dirname(opt.save_data))
    logger.info("Extracting features...")
    fields = get_fields(opt)
    logger.info("Building & saving training data...")
    build_save_dataset(opt, fields)
Beispiel #18
0
def _get_parser():
    parser = ArgumentParser(description='translate.py')

    opts.config_opts(parser)
    opts.translate_opts(parser)
    parser.add(
        '--model',
        '-model',
        dest='models',
        metavar='MODEL',
        nargs='+',
        type=str,
        default=[
            "F:/Project/Python/selfProject/translate_NMT/transflate_NMT/data_step_100.pt"
        ],
        required=False,
        help="模型使用得训练文件")

    parser.add(
        '--src',
        '-src',
        required=False,
        default=
        "F:/Project/Python/selfProject/translate_NMT/transflate_NMT/data/src-test.txt",
        help="自己写的测试文件在哪里????")

    parser.add(
        '--output',
        '-output',
        default=
        'F:/Project/Python/selfProject/translate_NMT/transflate_NMT/data/pred.txt',
        help="测试文件输出未知  改成自己得")
    return parser
Beispiel #19
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    ArgumentParser._get_all_transform_translate(opt)
    ArgumentParser._validate_transforms_opts(opt)
    ArgumentParser.validate_translate_opts_dynamic(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)

    data_reader = InferenceDataReader(opt.src, opt.tgt, opt.src_feats)

    # Build transforms
    transforms_cls = get_transforms_cls(opt._all_transform)
    transforms = make_transforms(opt, transforms_cls, translator.fields)
    data_transform = [
        transforms[name] for name in opt.transforms if name in transforms
    ]
    transform = TransformPipe.build_from(data_transform)

    for i, (src_shard, tgt_shard, feats_shard) in enumerate(data_reader):
        logger.info("Translating shard %d." % i)
        translator.translate_dynamic(src=src_shard,
                                     transform=transform,
                                     src_feats=feats_shard,
                                     tgt=tgt_shard,
                                     batch_size=opt.batch_size,
                                     batch_type=opt.batch_type,
                                     attn_debug=opt.attn_debug,
                                     align_debug=opt.align_debug)
def get_default_opts():
    parser = ArgumentParser(description='data sample prepare')
    dynamic_prepare_opts(parser)

    default_opts = [
        '-config', 'data/data.yaml', '-src_vocab', 'data/vocab-train.src',
        '-tgt_vocab', 'data/vocab-train.tgt'
    ]

    opt = parser.parse_known_args(default_opts)[0]
    # Inject some dummy training options that may needed when build fields
    opt.copy_attn = False
    ArgumentParser.validate_prepare_opts(opt)
    return opt
Beispiel #21
0
def _get_model_opts(opt, checkpoint=None):
    """Get `model_opt` to build model, may load from `checkpoint` if any."""
    if checkpoint is not None:
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        if (opt.tensorboard_log_dir == model_opt.tensorboard_log_dir
                and hasattr(model_opt, 'tensorboard_log_dir_dated')):
            # ensure tensorboard output is written in the directory
            # of previous checkpoints
            opt.tensorboard_log_dir_dated = model_opt.tensorboard_log_dir_dated
    else:
        model_opt = opt
    return model_opt
Beispiel #22
0
def build_vocab_main(opts):
    """Apply transforms to samples of specified data and build vocab from it.

    Transforms that need vocab will be disabled in this.
    Built vocab is saved in plain text format as following and can be pass as
    `-src_vocab` (and `-tgt_vocab`) when training:
    ```
    <tok_0>\t<count_0>
    <tok_1>\t<count_1>
    ```
    """

    ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
    assert opts.n_sample == -1 or opts.n_sample > 1, \
        f"Illegal argument n_sample={opts.n_sample}."

    logger = init_logger()
    set_random_seed(opts.seed, False)
    transforms_cls = get_transforms_cls(opts._all_transform)
    fields = None

    transforms = make_transforms(opts, transforms_cls, fields)

    logger.info(f"Counter vocab from {opts.n_sample} samples.")
    src_counter, tgt_counter, src_feats_counter = build_vocab(
        opts, transforms, n_sample=opts.n_sample)

    logger.info(f"Counters src:{len(src_counter)}")
    logger.info(f"Counters tgt:{len(tgt_counter)}")
    for feat_name, feat_counter in src_feats_counter.items():
        logger.info(f"Counters {feat_name}:{len(feat_counter)}")

    def save_counter(counter, save_path):
        check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
        with open(save_path, "w", encoding="utf8") as fo:
            for tok, count in counter.most_common():
                fo.write(tok + "\t" + str(count) + "\n")

    if opts.share_vocab:
        src_counter += tgt_counter
        tgt_counter = src_counter
        logger.info(f"Counters after share:{len(src_counter)}")
        save_counter(src_counter, opts.src_vocab)
    else:
        save_counter(src_counter, opts.src_vocab)
        save_counter(tgt_counter, opts.tgt_vocab)

    for k, v in src_feats_counter.items():
        save_counter(v, opts.src_feats_vocab[k])
Beispiel #23
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
        # print(src_nfeats)
    # exit()
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    # exit()
    ##################=======================================
    tt_nfeats = tgt_nfeats
    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tt_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    ########===============================================================
    tt_reader = inputters.str2reader["text"].from_opt(opt)

    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    # for k,v in fields.items():
    #     if(k in ['src','tgt','tt']):
    #         print(("preprocess .py preprocess() fields_item",k,v.fields[0][1].include_lengths))
    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tt_reader, tgt_reader,
                       align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tt_reader, tgt_reader,
                           align_reader, opt)
    def __init__(self):
        # コマンドラインで指定したオプションをもとにモデルを読み込む
        parser = ArgumentParser()
        opts.config_opts(parser)
        opts.translate_opts(parser)
        self.opt = parser.parse_args()
        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        # 分かち書きのためにMeCabを使用
        self.mecab = MeCab.Tagger("-Owakati")
        self.mecab.parse("")

        # 前回の応答を保存しておく辞書
        self.prev_uttr_dict = {}
Beispiel #25
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    confnet_nfeats = 0
    for src, tgt, cnet in zip(opt.train_src, opt.train_tgt, opt.train_confnet):
        src_nfeats += count_features(src) if opt.data_type == 'text' or opt.data_type == 'lattice' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
        #confnet_nfeats += count_features(cnet) if opt.data_type == 'lattice' \
        #    else 0
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info(" * number of confnet features: %d." % confnet_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  confnet_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  ans_truncate=opt.src_seq_length_trunc,
                                  ques_truncate=opt.confnet_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)
    #print('fields done')
    ans_reader = inputters.str2reader["text"].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)
    ques_reader = inputters.str2reader["lattice"].from_opt(opt)
    #print('src_reader', ques_reader)
    #print('tgt_reader', tgt_reader)
    #print('aglign_reader', align_reader)
    #print('confnet_reader', ans_reader)
    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, ques_reader, ans_reader, tgt_reader,
                       align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, ques_reader, ans_reader,
                           tgt_reader, align_reader, opt)
Beispiel #26
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(
        opt.train_src) if opt.data_type == 'text' else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far

    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    if len(opt.src_vocab) > 0:
        assert len(opt.src_vocab) == len(
            opt.train_src
        ), "you should provide src vocab for each dataset if you want to use your own vocab"
    for i, (train_src,
            train_tgt) in enumerate(zip(opt.train_src, opt.train_tgt)):
        valid_src = opt.valid_src[i]
        valid_tgt = opt.valid_tgt[i]
        logger.info("Working on %d dataset..." % i)
        logger.info("Building `Fields` object...")
        fields = inputters.get_fields(opt.data_type,
                                      src_nfeats,
                                      tgt_nfeats,
                                      dynamic_dict=opt.dynamic_dict,
                                      src_truncate=opt.src_seq_length_trunc,
                                      tgt_truncate=opt.tgt_seq_length_trunc)

        src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
        tgt_reader = inputters.str2reader["text"].from_opt(opt)

        logger.info("Building & saving training data...")
        train_dataset_files = build_save_dataset('train', fields, src_reader,
                                                 tgt_reader, opt, i, train_src,
                                                 train_tgt, valid_src,
                                                 valid_tgt)

        if opt.valid_src and opt.valid_tgt:
            logger.info("Building & saving validation data...")
            build_save_dataset('valid', fields, src_reader, tgt_reader, opt, i,
                               train_src, train_tgt, valid_src, valid_tgt)

        logger.info("Building & saving vocabulary...")
        build_save_vocab(train_dataset_files, fields, opt, i)
def _get_parser():
    parser = ArgumentParser(description='run_kp_eval.py')

    opts.config_opts(parser)
    opts.translate_opts(parser)

    return parser
    def __init__(self):
        # おまじない
        parser = ArgumentParser()
        opts.config_opts(parser)
        opts.translate_opts(parser)
        self.opt = parser.parse_args(args=[
            "-model", "../models/model.pt", "-src", "None", "-replace_unk",
            "--beam_size", "10", "--min_length", "7", "--block_ngram_repeat",
            "2"
        ])
        ArgumentParser.validate_translate_opts(self.opt)
        self.translator = build_translator(self.opt, report_score=True)

        # 単語分割用にMeCabを使用
        self.mecab = MeCab.Tagger("-Owakati")
        self.mecab.parse("")
Beispiel #29
0
def _get_parser():
    parser = ArgumentParser(description='translate.py')

    opts.config_opts(parser)
    opts.translate_opts(parser)
    opts.mmod_finetune_translate_opts(parser)
    return parser
Beispiel #30
0
def _get_parser():
    parser = ArgumentParser(description='build_copy_transformer.py')

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
Beispiel #31
0
def _get_parser():
    parser = ArgumentParser(description='train.py')

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
    def parse_opt(self, opt):
        """Parse the option set passed by the user using `onmt.opts`

       Args:
           opt (dict): Options passed by the user

       Returns:
           opt (argparse.Namespace): full set of options for the Translator
        """

        prec_argv = sys.argv
        sys.argv = sys.argv[:1]
        parser = ArgumentParser()
        onmt.opts.translate_opts(parser)

        models = opt['models']
        if not isinstance(models, (list, tuple)):
            models = [models]
        opt['models'] = [os.path.join(self.model_root, model)
                         for model in models]
        opt['src'] = "dummy_src"

        for (k, v) in opt.items():
            if k == 'models':
                sys.argv += ['-model']
                sys.argv += [str(model) for model in v]
            elif type(v) == bool:
                sys.argv += ['-%s' % k]
            else:
                sys.argv += ['-%s' % k, str(v)]

        opt = parser.parse_args()
        ArgumentParser.validate_translate_opts(opt)
        opt.cuda = opt.gpu > -1

        sys.argv = prec_argv
        return opt
Beispiel #33
0
def load_test_model(opt, model_path=None):
    if model_path is None:
        model_path = opt.models[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    vocab = checkpoint['vocab']
    if inputters.old_style_vocab(vocab):
        fields = inputters.load_old_vocab(
            vocab, opt.data_type, dynamic_dict=model_opt.copy_attn
        )
    else:
        fields = vocab

    model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint,
                             opt.gpu)
    if opt.fp32:
        model.float()
    model.eval()
    model.generator.eval()
    return fields, model, model_opt
Beispiel #34
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
Beispiel #35
0
import copy
import unittest
import math

import torch

import onmt
import onmt.inputters
import onmt.opts
from onmt.model_builder import build_embeddings, \
    build_encoder, build_decoder
from onmt.encoders.image_encoder import ImageEncoder
from onmt.encoders.audio_encoder import AudioEncoder
from onmt.utils.parse import ArgumentParser

parser = ArgumentParser(description='train.py')
onmt.opts.model_opts(parser)
onmt.opts.train_opts(parser)

# -data option is required, but not used in this test, so dummy.
opt = parser.parse_known_args(['-data', 'dummy'])[0]


class TestModel(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        super(TestModel, self).__init__(*args, **kwargs)
        self.opt = opt

    def get_field(self):
        src = onmt.inputters.get_fields("text", 0, 0)["src"]
Beispiel #36
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()