Ejemplo n.º 1
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    features_shards = []
    features_names = []
    for feat_name, feat_path in opt.src_feats.items():
        features_shards.append(split_corpus(feat_path, opt.shard_size))
        features_names.append(feat_name)
    shard_pairs = zip(src_shards, tgt_shards, *features_shards)

    for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs):
        features_shard_ = defaultdict(list)
        for j, x in enumerate(features_shard):
            features_shard_[features_names[j]] = x
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             src_feats=features_shard_,
                             tgt=tgt_shard,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug)
Ejemplo n.º 2
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    # MultimodalTranslator 继承 Translator,主要覆盖了translate函数
    # 有些函数没有覆盖,被我原封不动地拷贝进了MultimodalTranslator,
    # 其实可以直接在子类里面调用父类的方法,
    # 用super(MultimodalTranslator, self).method()
    # 这个以后再改吧。

    test_img_feats = np.load(opt.path_to_test_img_feats)
    test_img_feats = test_img_feats.astype(np.float32)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug,
                             test_img_feats=test_img_feats,
                             multimodal_model_type=opt.multimodal_model_type)
Ejemplo n.º 3
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    # wei 20200809
    nfr_tag_shards = split_corpus(opt.src_nfr_tag, opt.shard_size)
    flat_tag_shards = split_corpus(opt.src_flat_tag, opt.shard_size)
    # end wei
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    # shard_pairs = zip(src_shards, tgt_shards)    # 20200809 wei
    shard_pairs = zip(src_shards, nfr_tag_shards, flat_tag_shards, tgt_shards)

    # for i, (src_shard, tgt_shard) in enumerate(shard_pairs):    # 20200809
    for i, (src_shard, nfr_tag_shard, flat_tag_shard,
            tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            # wei 20200809
            nfr_tag=nfr_tag_shard,
            flat_tag=flat_tag_shard,
            # end wei
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug,
            align_debug=opt.align_debug)
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    if opt.tree:
        tree_shards = split_corpus(opt.tree, opt.shard_size)
        shard_pairs = zip(src_shards, tgt_shards, tree_shards)
        for i, (src_shard, tgt_shard, tree_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 tree=tree_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug)
    else:
        shard_pairs = zip(src_shards, tgt_shards)
        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 tree=None,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug)
Ejemplo n.º 5
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        '''
        src_shard type = list
        len(src_shard) = 2507
        src_shard[0].decode("utf-8")
        'आपकी कार में ब्लैक बॉक्स\n'
        '''
        logger.info("Translating shard %d." % i)
        print("in translate")
        import os
        print("in translate.py pwd = ", os.getcwd())
        translator.translate(
            src=
            src_shard,  #src_shard:type=list,len=2507,src_shard[0]='आपकी कार में ब्लैक बॉक्स\n'
            tgt=tgt_shard,  #tgt_shard[0]='a black box in your car\n'
            tgt_path=opt.tgt,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
Ejemplo n.º 6
0
 def shard_iterator(srcs, tgts, ids, aligns, existing_shards,
                    existing_fields, corpus_type, opt):
     """
     Builds a single iterator yielding every shard of every corpus.
     """
     for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns):
         if maybe_id in existing_shards:
             if opt.overwrite:
                 logger.warning(
                     "Overwrite shards for corpus {}".format(maybe_id))
             else:
                 if corpus_type == "train":
                     assert existing_fields is not None,\
                         ("A 'vocab.pt' file should be passed to "
                          "`-src_vocab` when adding a corpus to "
                          "a set of already existing shards.")
                 logger.warning("Ignore corpus {} because "
                                "shards already exist".format(maybe_id))
                 continue
         if ((corpus_type == "train" or opt.filter_valid)
                 and tgt is not None):
             filter_pred = partial(inputters.filter_example,
                                   use_src_len=opt.data_type == "text",
                                   max_src_len=opt.src_seq_length,
                                   max_tgt_len=opt.tgt_seq_length)
         else:
             filter_pred = None
         src_shards = split_corpus(src, opt.shard_size)
         tgt_shards = split_corpus(tgt, opt.shard_size)
         align_shards = split_corpus(maybe_align, opt.shard_size)
         for i, (ss, ts,
                 a_s) in enumerate(zip(src_shards, tgt_shards,
                                       align_shards)):
             yield (i, (ss, ts, a_s, maybe_id, filter_pred))
Ejemplo n.º 7
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    # yida translate
    tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \
        if opt.tag_src is not None else repeat(None)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    # yida translate
    shard_pairs = zip(src_shards, tgt_shards, tag_src_shards)

    # yida translate
    for i, (src_shard, tgt_shard, tag_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            # yida translate
            tag_src=tag_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug)
Ejemplo n.º 8
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    lemma_shards = split_corpus(opt.lemma, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards, lemma_shards)
    lemma_align = open(opt.lemma_align, 'rb').readlines()
    if opt.gpu >= 0:
        topic_matrix = torch.load(opt.topic_matrix,
                                  map_location=torch.device(opt.gpu))
    else:
        topic_matrix = torch.load(opt.topic_matrix)
    if not opt.fp32:
        topic_matrix = topic_matrix.half()
    for i, (src_shard, tgt_shard, lemma_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(lemma_align=lemma_align,
                             topic_matrix=topic_matrix,
                             src=src_shard,
                             tgt=tgt_shard,
                             lemma=lemma_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug)
    def eval_impl(self,
            processed_data_dir: Path,
            model_dir: Path,
            beam_search_size: int,
            k: int
    ) -> List[List[Tuple[str, float]]]:
        from roosterize.ml.onmt.CustomTranslator import CustomTranslator
        from onmt.utils.misc import split_corpus
        from onmt.utils.parse import ArgumentParser
        from translate import _get_parser as translate_get_parser

        src_path = processed_data_dir/"src.txt"
        tgt_path = processed_data_dir/"tgt.txt"

        best_step = IOUtils.load(model_dir/"best-step.json", IOUtils.Format.json)
        self.logger.info(f"Taking best step at {best_step}")

        candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

        with IOUtils.cd(self.open_nmt_path):
            parser = translate_get_parser()
            opt = parser.parse_args(
                f" -model {model_dir}/models/ckpt_step_{best_step}.pt"
                f" -src {src_path}"
                f" -tgt {tgt_path}"
            )
            opt.output = f"{model_dir}/last-pred.txt"
            opt.beam_size = beam_search_size
            opt.gpu = 0 if torch.cuda.is_available() else -1
            opt.n_best = k
            opt.block_ngram_repeat = 1
            opt.ignore_when_blocking = ["_"]

            # translate.main
            ArgumentParser.validate_translate_opts(opt)

            translator = CustomTranslator.build_translator(opt, report_score=False)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                self.logger.info("Translating shard %d." % i)
                _, _, candidates_logprobs_shard = translator.translate(
                    src=src_shard,
                    tgt=tgt_shard,
                    src_dir=opt.src_dir,
                    batch_size=opt.batch_size,
                    attn_debug=opt.attn_debug
                )
                candidates_logprobs.extend(candidates_logprobs_shard)
            # end for
        # end with

        # Reformat candidates
        candidates_logprobs: List[List[Tuple[str, float]]] = [[("".join(c), l) for c, l in cl] for cl in candidates_logprobs]

        return candidates_logprobs
Ejemplo n.º 10
0
def get_shard_pairs(src, tgt, shard_size):
    if not tgt or tgt == '-':
        for pair_shard in split_corpus(src, shard_size):
            pairs = [pair.split(b'\t') for pair in pair_shard]
            yield zip(*pairs)
    else:
        src_shards = split_corpus(src, shard_size)
        tgt_shards = split_corpus(tgt, shard_size)
        return zip(src_shards, tgt_shards)
Ejemplo n.º 11
0
def build_save_dataset(corpus_type, fields, opt):
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src = opt.train_src
        tgt = opt.train_tgt
        ans = opt.train_ans
    else:
        src = opt.valid_src
        tgt = opt.valid_tgt
        ans = opt.valid_ans

    logger.info("Reading source answer and target files: %s %s %s." %
                (src, ans, tgt))

    src_shards = split_corpus(src, opt.shard_size)
    tgt_shards = split_corpus(tgt, opt.shard_size)
    ans_shards = split_corpus(ans, opt.shard_size)

    shard_pairs = zip(src_shards, tgt_shards, ans_shards)
    dataset_paths = []

    for i, (src_shard, tgt_shard, ans_shard) in enumerate(shard_pairs):
        assert len(src_shard) == len(tgt_shard) == len(ans_shard)
        logger.info("Building shard %d." % i)
        dataset = inputters.build_dataset(
            fields,
            opt.data_type,
            src=src_shard,
            tgt=tgt_shard,
            ans=ans_shard,
            src_dir=opt.src_dir,
            src_seq_len=opt.src_seq_length,
            tgt_seq_len=opt.tgt_seq_length,
            ans_seq_len=opt.ans_seq_length,
            sample_rate=opt.sample_rate,
            window_size=opt.window_size,
            window_stride=opt.window_stride,
            window=opt.window,
            use_filter_pred=corpus_type == 'train' or opt.filter_valid)

        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s." %
                    (i, corpus_type, data_path))

        dataset.save(data_path)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return dataset_paths
Ejemplo n.º 12
0
def build_save_dataset(corpus_type, fields, src_reader, history_reader, ans_reader, tgt_reader, opt):
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src = opt.train_src
        history = opt.train_history
        ans = opt.train_ans
        tgt = opt.train_tgt
    else:
        src = opt.valid_src
        history = opt.valid_history
        ans = opt.valid_ans
        tgt = opt.valid_tgt

    logger.info("Reading source and target files: %s %s %s %s." % (src, history, ans, tgt))

    src_shards = split_corpus(src, opt.shard_size)
    history_shards = split_corpus(history, opt.shard_size)
    ans_shards = split_corpus(ans, opt.shard_size)
    tgt_shards = split_corpus(tgt, opt.shard_size)
    shard_pairs = zip(src_shards, history_shards, ans_shards, tgt_shards)
    dataset_paths = []
    if (corpus_type == "train" or opt.filter_valid) and tgt is not None:
        filter_pred = partial(
            inputters.filter_example, use_src_len=opt.data_type == "text", use_history_len=False,
            max_src_len=opt.src_seq_length, max_history_len=-1, max_tgt_len=opt.tgt_seq_length)
    else:
        filter_pred = None
    logger.info("filter_pred is not used:{}".format(filter_pred))
    for i, (src_shard, history_shard, ans_shard, tgt_shard) in enumerate(shard_pairs):
        assert len(src_shard) == len(tgt_shard) and len(src_shard) == len(history_shard) and len(src_shard) == len(ans_shard)
        logger.info("Building shard %d." % i)
        dataset = inputters.Dataset(
            fields,
            readers=[src_reader, history_reader, ans_reader, tgt_reader] if tgt_reader else [src_reader, history_reader, ans_reader],
            data=([("src", src_shard), ("history", history_shard), ("ans", ans_shard), ("tgt", tgt_shard)]
                  if tgt_reader else [("src", src_shard), ("history", history_shard), ("ans", ans_shard)]),
            dirs=[opt.src_dir, opt.src_dir, opt.src_dir, None] if tgt_reader else [opt.src_dir, opt.src_dir, opt.src_dir],
            sort_key=inputters.str2sortkey[opt.data_type],
            filter_pred=None
        )

        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s."
                    % (i, corpus_type, data_path))
        dataset.save(data_path)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return dataset_paths
Ejemplo n.º 13
0
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt, index,
                       train_src, train_tgt, valid_src, valid_tgt):
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src = train_src
        tgt = train_tgt
    else:
        src = valid_src
        tgt = valid_tgt

    logger.info("Reading source and target files: %s %s." % (src, tgt))

    src_shards = split_corpus(src, opt.shard_size)
    tgt_shards = split_corpus(tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)
    dataset_paths = []
    if (corpus_type == "train" or opt.filter_valid) and tgt is not None:
        filter_pred = partial(inputters.filter_example,
                              use_src_len=opt.data_type == "text",
                              max_src_len=opt.src_seq_length,
                              max_tgt_len=opt.tgt_seq_length)
    else:
        filter_pred = None
    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        assert i == 0, "only allow shard size is 1"
        assert len(src_shard) == len(tgt_shard)
        logger.info("Building shard %d." % i)
        dataset = inputters.Dataset(
            fields,
            readers=[src_reader, tgt_reader] if tgt_reader else [src_reader],
            data=([("src", src_shard),
                   ("tgt", tgt_shard)] if tgt_reader else [("src",
                                                            src_shard)]),
            dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir],
            sort_key=inputters.str2sortkey[opt.data_type],
            filter_pred=filter_pred)

        #data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type,
                                               index)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s." %
                    (i, corpus_type, data_path))

        dataset.save(data_path)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return dataset_paths
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    translator = build_translator(opt, report_score=True, logger=logger)

    def create_src_shards(path, opt, binary=True):
        if opt.data_type == 'imgvec':
            assert opt.shard_size <= 0
            return [path]
        else:
            if opt.data_type == 'none':
                return [None] * 99999
            else:
                return split_corpus(path, opt.shard_size, binary=binary)

    src_shards = create_src_shards(opt.src, opt)
    if opt.agenda:
        agenda_shards = create_src_shards(opt.agenda, opt, False)

    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)

    if not opt.agenda:
        shards = zip(src_shards, tgt_shards)
    else:
        shards = zip(src_shards, agenda_shards, tgt_shards)

    for i, flat_shard in enumerate(shards):
        if not opt.agenda:
            src_shard, tgt_shard = flat_shard
            agenda_shard = None
        else:
            src_shard, agenda_shard, tgt_shard = flat_shard
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             agenda=agenda_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             tag_shard=tag_shard)
Ejemplo n.º 15
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if 'n_latent' not in vars(opt):
        vars(opt)['n_latent'] = vars(opt)['n_translate_latent']
    logger = init_logger(opt.log_file)

    if 'use_segments' not in vars(opt):
        vars(opt)['use_segments'] = opt.n_translate_segments != 0
        vars(opt)['max_segments'] = opt.n_translate_segments

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    n_latent = opt.n_latent

    if n_latent > 1:
        for latent_idx in range(n_latent):
            output_path = opt.output_dir + '/output_%d' % (latent_idx)
            out_file = codecs.open(output_path, 'w+', 'utf-8')
            translator.out_file = out_file

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                logger.info("Translating shard %d." % i)
                translator.translate(src=src_shard,
                                     tgt=tgt_shard,
                                     src_dir=opt.src_dir,
                                     batch_size=opt.batch_size,
                                     attn_debug=opt.attn_debug,
                                     latent_idx=latent_idx)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
                if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)
    else:
        output_path = opt.output_dir + '/output'
        out_file = codecs.open(output_path, 'w+', 'utf-8')
        translator.out_file = out_file

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 attn_debug=opt.attn_debug)
Ejemplo n.º 16
0
def translate(opt): 
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    # shard_pairs = zip(src_shards, tgt_shards)
    # print("number of shards: ", len(src_shards), len(tgt_shards))

    # load emotions
    tgt_emotion_shards = [None]*100
    if opt.target_emotions_path != "":
        print("Loading target emotions...")
        tgt_emotions = read_emotion_file(opt.target_emotions_path)
        tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size)
        # print("number of shards: ", len(tgt_emotion_shards))
    
    tgt_concept_embedding_shards = [None]*100
    if opt.target_concept_embedding != "":
        print("Loading target_concept_embedding...")
        tgt_concept_embedding = load_pickle(opt.target_concept_embedding)
        tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size)
        # print("number of shards: ", len(tgt_concept_embedding_shards))
    
    tgt_concept_words_shards = [None]*100
    if opt.target_concept_words != "":
        print("Loading target_concept_words...")
        tgt_concept_words = load_pickle(opt.target_concept_words)
        # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size)
        tgt_concept_words_shards = [tgt_concept_words]
        # print("number of shards: ", len(tgt_concept_words_shards))
    
    shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards)

    for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs):
        # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard))
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug,
            tgt_emotion_shard=tgt_emotion_shard,
            rerank=opt.rerank,
            emotion_lexicon=opt.emotion_lexicon,
            tgt_concept_embedding_shard=tgt_concept_embedding_shard,
            tgt_concept_words_shard=tgt_concept_words_shard
            )
Ejemplo n.º 17
0
def main(opt):
    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else [None]*opt.shard_size
    shard_pairs = zip(src_shards, tgt_shards)
    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        #logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug)
Ejemplo n.º 18
0
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
    assert corpus_type in ['train', 'valid']

    if corpus_type == 'train':
        src = opt.train_src
        tgt = opt.train_tgt
    else:
        src = opt.valid_src
        tgt = opt.valid_tgt

    logger.info("Reading source and target files: %s %s." % (src, tgt))

    src_shards = split_corpus(src, opt.shard_size)
    tgt_shards = split_corpus(tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)
    dataset_paths = []
    if (corpus_type == "train" or opt.filter_valid) and tgt is not None:
        filter_pred = partial(
            inputters.filter_example, use_src_len=opt.data_type == "text",
            max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length)
    else:
        filter_pred = None
    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        assert len(src_shard) == len(tgt_shard)
        logger.info("Building shard %d." % i)
        dataset = inputters.Dataset(
            fields,
            readers=[src_reader, tgt_reader] if tgt_reader else [src_reader],
            data=([("src", src_shard), ("tgt", tgt_shard)]
                  if tgt_reader else [("src", src_shard)]),
            dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir],
            sort_key=inputters.str2sortkey[opt.data_type],
            filter_pred=filter_pred
        )

        data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i)
        dataset_paths.append(data_path)

        logger.info(" * saving %sth %s data shard to %s."
                    % (i, corpus_type, data_path))

        dataset.save(data_path)

        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return dataset_paths
Ejemplo n.º 19
0
def translate_file(input_filename, output_filename):
    parser = ArgumentParser(description='translation')
    opts.config_opts(parser)
    opts.translate_opts(parser)
    # print(opts)
    args = f'''-model m16_step_44000.pt
                -src source_products_16.txt
                -output op_16x_4400_50_10.txt
                -batch_size 128
                -replace_unk
                -max_length 200
                -verbose
                -beam_size 50
                -n_best 10
                -min_length 5'''
    opt = parser.parse_args(args)
    # print(opt.model)
    translator = build_translator(opt, report_score=True)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        scores, predictions = translator.translate(src=src_shard,
                                                   tgt=tgt_shard,
                                                   src_dir=opt.src_dir,
                                                   batch_size=opt.batch_size,
                                                   attn_debug=opt.attn_debug)

    return scores, predictions
Ejemplo n.º 20
0
def translate_file(input_filename, output_filename):
    parser = ArgumentParser(description='translation')
    opts.config_opts(parser)
    opts.translate_opts(parser)

    args = f'''-model Experiments/Checkpoints/retrosynthesis_augmented_medium/retrosynthesis_aug_medium_model_step_100000.pt 
                -src MCTS_data/{input_filename}.txt 
                -output MCTS_data/{output_filename}.txt 
                -batch_size 128 
                -replace_unk
                -max_length 200 
                -verbose 
                -beam_size 10 
                -n_best 10 
                -min_length 5 
                -gpu 0'''

    opt = parser.parse_args(args)
    translator = build_translator(opt, report_score=True)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        scores, predictions = translator.translate(src=src_shard,
                                                   tgt=tgt_shard,
                                                   src_dir=opt.src_dir,
                                                   batch_size=opt.batch_size,
                                                   attn_debug=opt.attn_debug)

    return scores, predictions
Ejemplo n.º 21
0
    def summarize(self, texts):
        '''
		Arguments:
			texts: list(str)
		'''
        if self.extract:
            summaries = self.extractor.summarize(texts)
        else:
            summaries = texts

        with open('src.txt', "w", encoding='utf-8') as src_f:
            for summary in summaries:
                src_f.write(
                    '[CLS] ' +
                    ' '.join(self.tokenizer.tokenize(strip_accents(summary))) +
                    '\n')

        src_shards = split_corpus(self.opt.src, self.opt.shard_size)
        for i, src_shard in enumerate(src_shards):
            self.translator.translate(src=src_shard,
                                      batch_size=self.opt.batch_size,
                                      attn_debug=self.opt.attn_debug)

        output = []
        for line in fileinput.FileInput("output.out", inplace=1):
            line = line.replace(" ##", "").replace(" .", ".").replace(
                " ,", ",").replace(" !", "!").replace(" ?",
                                                      "?").replace("\n", "")
            print(line)
            output.append(line)

        os.remove('src.txt')
        os.remove("output.out")

        return output
Ejemplo n.º 22
0
def main():
    parser = _get_parser()
    opt = parser.parse_args()
    fields, model, model_opts = load_test_model(opt)
    src_shards = split_corpus(opt.src, opt.shard_size)
    for src in src_shards:
        store_encoder_attn(model.encoder, src, {'src': fields['src']},
                           opt.batch_size, opt.gpu, opt.out_file)
 def create_src_shards(src, opt):
     if opt.data_type == 'imgvec':
         assert opt.shard_size <= 0
         return [src]
     elif opt.data_type == 'none':
         return [None]*99999
     else:
         return split_corpus(src, opt.shard_size)
 def create_src_shards(path, opt, binary=True):
     if opt.data_type == 'imgvec':
         assert opt.shard_size <= 0
         return [path]
     else:
         if opt.data_type == 'none':
             return [None] * 99999
         else:
             return split_corpus(path, opt.shard_size, binary=binary)
Ejemplo n.º 25
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    with open("opt.pkl", 'wb') as f1:
        pickle.dump(opt, f1)
    with open("opt.pkl", 'rb') as f1:
        opt1 = pickle.load(f1)
    translator = build_translator(opt, report_score=True)

    if opt.data_type == 'imgvec':
        assert opt.shard_size <= 0
        src_shards = [opt.src]
    else:
        if opt.data_type == 'none':
            src_shards = [None] * 99999
        else:
            src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        all_scores, all_predictions = translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug,
            tag_shard=tag_shard)
        with open("result_{}.pickle".format(i), 'wb') as f1:
            pickle.dump(all_predictions, f1)
Ejemplo n.º 26
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    history_shards = split_corpus(opt.history, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else [None] * opt.shard_size
    shard_pairs = zip(src_shards, history_shards, tgt_shards)

    for i, (src_shard, history_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             history=history_shard,
                             tgt=opt.tgt,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug)
Ejemplo n.º 27
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True, logger=logger)
    translator.out_file = codecs.open(opt.output, 'w+', 'utf-8')
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             opt=opt)
Ejemplo n.º 28
0
    def __call__(self, src_embed=None, hidden_state=None):
        translator = build_translator(self.opt, report_score=True)
        src_shards = split_corpus(self.opt.src, self.opt.shard_size)
        tgt_shards = split_corpus(self.opt.tgt, self.opt.shard_size)
        tgt2_shards = split_corpus(self.opt.tgt2, self.opt.shard_size)
        shard_trips = zip(src_shards, tgt_shards, tgt2_shards)

        for i, (src_shard, tgt_shard, tgt2_shard) in enumerate(shard_trips):
            return translator.translate_gold_diff(
                src=src_shard,
                tgt=tgt_shard,
                tgt2=tgt2_shard,
                src_dir=self.opt.src_dir,
                batch_size=self.opt.batch_size,
                batch_type=self.opt.batch_type,
                attn_debug=self.opt.attn_debug,
                align_debug=self.opt.align_debug,
                src_embed=src_embed,
                hidden_state=hidden_state)
Ejemplo n.º 29
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)

    if opt.data_type == 'graph':

        node1_shards = split_corpus(opt.src_node1, opt.shard_size)
        node2_shards = split_corpus(opt.src_node2, opt.shard_size)

        shard_pairs = zip(src_shards, node1_shards, node2_shards, tgt_shards)

        for i, (src_shard, node1_shard, node2_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(
                src=src_shard,
                tgt=tgt_shard,
                src_node1=node1_shard,
                src_node2=node2_shard,
                src_dir=opt.src_dir,
                batch_size=opt.batch_size,
                attn_debug=opt.attn_debug
                )

    else:

        shard_pairs = zip(src_shards, tgt_shards)

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(
                src=src_shard,
                tgt=tgt_shard,
                src_dir=opt.src_dir,
                batch_size=opt.batch_size,
                attn_debug=opt.attn_debug
                )
Ejemplo n.º 30
0
def nmt_filter_dataset(opt):

    opt.src = os.path.join(dataset_root_path, src_file)
    opt.tgt = os.path.join(dataset_root_path, tgt_file)
    opt.shard_size = 1

    opt.log_file = logging_file_path
    opt.models = [model_file_path]
    opt.n_best = 1
    opt.beam_size = 1
    opt.report_bleu = False
    opt.report_rouge = False

    logger = init_logger(opt.log_file)
    translator = build_translator(opt, report_score=True)

    src_file_path = os.path.join(dataset_root_path, src_file)
    tgt_file_path = os.path.join(dataset_root_path, tgt_file)

    src_shards = split_corpus(src_file_path, opt.shard_size)
    tgt_shards = split_corpus(tgt_file_path, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    pred_scores = []

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):

        start_time = time.time()
        shard_pred_scores, shard_pred_sentences = translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
        print("--- %s seconds ---" % (time.time() - start_time))

        pred_scores += [scores[0] for scores in shard_pred_scores]

    average_pred_score = torch.mean(torch.stack(pred_scores)).detach()

    return average_pred_score
Ejemplo n.º 31
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        reps_tensor = translator.translate(src=src_shard,
                                           tgt=tgt_shard,
                                           src_dir=opt.src_dir,
                                           batch_size=opt.batch_size,
                                           batch_type=opt.batch_type,
                                           attn_debug=opt.attn_debug,
                                           align_debug=opt.align_debug)

    torch.save(reps_tensor, opt.out_reps)
    print(reps_tensor.size())