Exemplo n.º 1
0
    def score(dataset, beam_size) -> (float, float):
        hyp_file = os.path.join(_base, f"hyps_{dataset}_beam-{beam_size}.txt")
        src_file = cp["config"]["data"]["src"][f"{dataset}_path"]
        ref_file = cp["config"]["data"]["trg"][f"{dataset}_path"]

        src_file = fix_paths(src_file, "datasets")
        ref_file = fix_paths(ref_file, "datasets")

        fusion = cp["config"]["model"]["decoding"].get("fusion")
        batch_tokens = max(10000 // beam_size, 1000)

        if fusion is None and lm is not None and fusion_a is not None:
            fusion = "shallow"

        seq2seq_translate(checkpoint=cp,
                          src_file=src_file,
                          out_file=hyp_file,
                          beam_size=beam_size,
                          length_penalty=1,
                          lm=lm,
                          fusion=fusion,
                          fusion_a=fusion_a,
                          batch_tokens=batch_tokens,
                          device=device)
        _mixed = compute_bleu_score(hyp_file, ref_file)
        _lower = compute_bleu_score(hyp_file, ref_file, True)
        return _mixed, _lower
def backtranslate(trainer: NmtPriorTrainer):
    cp = load_checkpoint(trainer.best_checkpoint)
    fusion = trainer.config["model"]["decoding"].get("fusion")
    src_file = fix_paths(trainer.config["data"]["backtranslate_path"],
                         "datasets")

    _base, _file = os.path.split(src_file)
    out_file = os.path.join(_base, f"{trainer.config['name']}.synthetic")

    # if trainer.config['resume_state_id'] is not None:
    #     out_file += "__" + trainer.config['resume_state_id']

    seq2seq_translate(checkpoint=cp,
                      src_file=src_file,
                      out_file=out_file,
                      beam_size=1,
                      length_penalty=1,
                      lm=None,
                      fusion=fusion,
                      fusion_a=None,
                      batch_tokens=trainer.config["batch_tokens"],
                      device=trainer.device)
Exemplo n.º 3
0
    def __init__(self,
                 input,
                 tokenize=None,
                 vocab=None,
                 vocab_size=None,
                 subword_path=None,
                 seq_len=0,
                 sos=False,
                 oovs=0,
                 lang="en",
                 subsample=0,
                 **kwargs):
        """
        Base Dataset for Language Modeling.

        Args:
            tokenize (callable): tokenization callable, which takes as input
                a string and returns a list of tokens
            input (str, list): the path to the data file, or a list of samples.
            vocab (Vocab): a vocab instance. If None, then build a new one
                from the Datasets data.
            vocab_size(int): if given, then trim the vocab to the given number.
        """
        self.input = input
        self.seq_len = seq_len
        self.subword_path = subword_path
        self.sos = sos
        self.oovs = oovs
        self.subsample = subsample

        # > define tokenization to be used -------------------------------
        if tokenize is not None:
            self.tokenize = tokenize
        else:
            self.tokenize = self.space_tok

        if self.subword_path is not None:
            subword = spm.SentencePieceProcessor()
            subword_path = fix_paths(subword_path, "datasets")
            subword.Load(subword_path + ".model")
            self.tokenize = lambda x: subword.EncodeAsPieces(x.rstrip())
        else:
            self.tokenize = MosesTokenizer(lang=lang).tokenize

        # > Build Vocabulary --------------------------------------------
        self.vocab, is_vocab_built = self.init_vocab(vocab, subword_path, oovs)

        # > Cache text file ---------------------------------------------
        self.lengths = []
        _is_cached = False

        def _line_callback(x):
            _tokens = self.tokenize(x)
            self.lengths.append(len(self.add_special_tokens(_tokens)))

            if is_vocab_built is False:
                self.vocab.read_sequence(_tokens)

        # -------------------------------------------------------------
        # If there is a (vocab, lengths) tuple associated with the given input
        # file, then load them from cache and skip the recalculation
        # -------------------------------------------------------------
        _ckey = self._get_cache_key(input, vocab, self.tokenize,
                                    subword_path, vocab_size, self.subsample)
        _cfile = os.path.join(os.path.dirname(input), f".cache_{_ckey}")
        if os.path.isfile(_cfile):
            print("Loading data from cache...", end=" ")
            with open(_cfile, "rb") as f:
                _vocab, self.lengths = pickle.load(f)
                self.vocab = Vocab().from_vocab_instance(_vocab)
            print("done!")
            _is_cached = True

        # > Preprocessing ---------------------------------------------
        print("Preprocessing...")
        self.data = DatasetCache(input,
                                 callback=_line_callback,
                                 subsample=subsample)

        # if the text file has already been cached,
        # but lengths and vocab are not cached (i.e., new for this input file)
        if _is_cached is False and len(self.lengths) == 0:
            for i in range(len(self.data)):
                _line_callback(self.data[i])

        # trim down the size of a newly created vocab
        if subword_path is None and vocab_size is not None:
            self.vocab.build_lookup(vocab_size)

        # -------------------------------------------------------------
        # save to cache if not already saved
        # -------------------------------------------------------------
        if _is_cached is False:
            print("Writing data to cache...")
            with open(_cfile, "wb") as f:
                pickle.dump((self.vocab, self.lengths), f)

        self.lengths = numpy.array(self.lengths)
Exemplo n.º 4
0
def seq2seq_translate(checkpoint, src_file, out_file, beam_size,
                      length_penalty, lm, fusion, fusion_a, batch_tokens,
                      device):
    # --------------------------------------
    # load checkpoint
    # --------------------------------------
    if isinstance(checkpoint, str):
        cp = load_checkpoint(checkpoint)
    else:
        cp = checkpoint
    src_vocab, trg_vocab = cp["vocab"]

    # --------------------------------------
    # load model
    # --------------------------------------
    model_type = cp["config"]["model"].get("type", "rnn")
    src_ntokens = len(src_vocab)
    trg_ntokens = len(trg_vocab)

    if model_type == "rnn":
        model = Seq2SeqRNN(src_ntokens, trg_ntokens, **cp["config"]["model"])
    elif model_type == "transformer":
        model = Seq2SeqTransformer(src_ntokens, trg_ntokens,
                                   **cp["config"]["model"])
    else:
        raise NotImplementedError

    model.load_state_dict(cp["model"])
    model.to(device)
    model.eval()

    # --------------------------------------
    # load prior
    # --------------------------------------
    if lm is not None:
        lm_cp = load_checkpoint(lm)
    elif fusion:
        lm_cp = load_checkpoint(
            fix_paths(cp["config"]["data"]["prior_path"], "checkpoints"))
    else:
        lm_cp = None

    if lm_cp is not None:
        lm = prior_model_from_checkpoint(lm_cp)
        lm.to(device)
        lm.eval()
    else:
        lm = None

    test_set = SequenceDataset(src_file,
                               vocab=src_vocab,
                               **{
                                   **cp["config"]["data"],
                                   **{
                                       "subsample": 0
                                   },
                                   **cp["config"]["data"]["src"]
                               })
    print(test_set)

    if batch_tokens is None:
        batch_tokens = cp["config"]["batch_tokens"]

    sampler = BucketTokensSampler(test_set.lengths * 2, batch_tokens)
    data_loader = DataLoader(
        test_set,
        # num_workers=cp["config"].get("cores",
        #                              min(4, multiprocessing.cpu_count())),
        # pin_memory=cp["config"].get("pin_memory", True),
        num_workers=cp["config"].get("cores", 4),
        pin_memory=True,
        batch_sampler=sampler,
        collate_fn=LMCollate())

    # translate the data
    output_ids = seq2seq_translate_ids(model,
                                       data_loader,
                                       trg_vocab,
                                       beam_size=beam_size,
                                       length_penalty=length_penalty,
                                       lm=lm,
                                       fusion=fusion,
                                       fusion_a=fusion_a)

    output_ids = output_ids[data_loader.batch_sampler.reverse_ids]
    seq2seq_output_ids_to_file(output_ids, trg_vocab, out_file)