Beispiel #1
0
def batcher(params, batch):
    new_batch = []
    for p in batch:
        if params.tokenize:
            tok = params.entok.tokenize(p, escape=False)
            p = " ".join(tok)
        if params.lower_case:
            p = p.lower()
        p = params.sp.EncodeAsPieces(p)
        p = " ".join(p)
        p = Example(p, params.lower_case)
        p.populate_embeddings(params.model.vocab, params.model.zero_unk, params.model.ngrams)
        new_batch.append(p)
    x, l = params.model.torchify_batch(new_batch)
    vecs = params.model.encode(x, l)
    return vecs.detach().cpu().numpy()
def get_data(params):
    examples = []

    finished = set([])  #check for duplicates
    with io.open(params.data_file, 'r', encoding='utf-8') as f:
        for i in f:
            if i in finished:
                continue
            else:
                finished.add(i)

            i = i.split('\t')
            if len(i[0].strip()) == 0 or len(i[1].strip()) == 0:
                continue

            e = (Example(i[0]), Example(i[1]))
            examples.append(e)

    return examples
Beispiel #3
0
def lm_score(
        candidates: List[List[str]],
        log_prob_ctc: np.ndarray,
        lm: Type[BaseLanguageModel],
        ctc_weight: float
) -> List[str]:
    """
    Псевдокод:
    [[foo1, bar1], [foo2, bar2], [foo3, bar3]] ->
    [[foo1, foo2, foo3], [bar1, bar2, bar3]] ->
    [foo1, foo2, foo3, bar1, bar2, bar3] ->
    lm.predict ->
    [p(foo1), p(foo2), p(foo3), p(bar1), p(bar2), p(bar3)] ->
    [[p(foo1), p(foo2), p(foo3)], [p(bar1), p(bar2), p(bar3)]] (log_prob_lm) ->
    scores = log_prob_ctc * k_ctc + log_prob_lm * (1 - k_ctc)

    * Т.к. кандидаты, полученные от ctc могут быть произвольными, возможен случай одного пробельного символа.
    * Т.к. мы удаляем лишние пробелы, данные последовательности препащаются в пустые, то приводит к падению
    языковой модели.
    * Поэтому для получения вероятностей последовательности от языковой модели заменим пустые строки
    на какую-то маловероятную, чтоб модель выдавала низкие вероятности на таких последовательностях
    """
    top_paths = len(candidates)
    improbable_seq = '000000'

    def process_input_text(text):
        if text == "":
            return improbable_seq
        return text

    def process_output_text(text):
        if text == improbable_seq:
            return ""
        return text

    candidates_flat = list(map(process_input_text, chain(*zip(*candidates))))
    examples = [Example(text=text) for text in candidates_flat]
    log_prob_lm = lm.predict(examples)  # [num_examples * top_paths]
    log_prob_lm = log_prob_lm.reshape((-1, top_paths))
    scores = log_prob_ctc * ctc_weight + log_prob_lm * (1 - ctc_weight)  # [num_examples, top_paths]
    indices = scores.argmax(1)  # [num_examples]
    best_candidates = list(map(process_output_text, (
        candidates_flat[top_paths * id_example + id_path] for id_example, id_path in enumerate(indices)
    )))
    return best_candidates
Beispiel #4
0
    def summary():
        if request.method == "POST":
            payload = request.get_json()
            example = [Example(
                source=payload["code"],
                target=None,
            )]

            t0 = time.time()
            message, length = inference(data=get_features(example))
            t1 = time.time()
            result = {
                'message': message,
                'time': (t1 - t0),
                'device': args.device_name,
                'length': length
            }
            logger.info(json.dumps(result, indent=4))
            return jsonify(**result)
Beispiel #5
0
    def srun(self, example: Example, **kwargs) -> None:
        tokens = jieba.lcut(example.text)

        example.set("tokens", tokens)
Beispiel #6
0
    def srun(self, example: Example, **kwargs) -> None:
        tokens = self.lac.run(example.text)

        example.set("tokens", tokens)
def get_sequences(p1, p2, model, params, fr0=0, fr1=0):
    wp1 = Example(p1)
    wp2 = Example(p2)

    if fr0 == 1 and fr1 == 1 and not model.share_vocab:
        wp1.populate_embeddings(model.vocab_fr, model.zero_unk, params.ngrams)
        wp2.populate_embeddings(model.vocab_fr, model.zero_unk, params.ngrams)
        if len(wp1.embeddings) == 0:
            wp1.embeddings.append(model.vocab_fr[unk_string])
        if len(wp2.embeddings) == 0:
            wp2.embeddings.append(model.vocab_fr[unk_string])
    elif fr0 == 0 and fr1 == 1 and not model.share_vocab:
        wp1.populate_embeddings(model.vocab, model.zero_unk, params.ngrams)
        wp2.populate_embeddings(model.vocab_fr, model.zero_unk, params.ngrams)
        if len(wp1.embeddings) == 0:
            wp1.embeddings.append(model.vocab[unk_string])
        if len(wp2.embeddings) == 0:
            wp2.embeddings.append(model.vocab_fr[unk_string])
    else:
        wp1.populate_embeddings(model.vocab, model.zero_unk, params.ngrams)
        wp2.populate_embeddings(model.vocab, model.zero_unk, params.ngrams)
        if len(wp1.embeddings) == 0:
            wp1.embeddings.append(model.vocab[unk_string])
        if len(wp2.embeddings) == 0:
            wp2.embeddings.append(model.vocab[unk_string])

    return wp1, wp2
Beispiel #8
0
def main(args):
    print('load images')
    images = glob.glob(os.path.join(args.input_dir, "*"))
    examples = [Example(img=img) for img in images]

    print('inference of the first htr model')
    htr_model = build_and_restore(model_cls=models.HTRModel,
                                  model_dir=infer_moder_dir(
                                      'htr', val_mode=args.val_mode))
    char_prob = htr_model.get_ctc_prob(examples, batch_size=128)

    print("decoding...")
    lm = LanguageModelKen(lm=kenlm.Model(
        infer_moder_dir('lm/kneser-ney/model.arpa', val_mode=args.val_mode)))
    id2char = {v: k for k, v in htr_model.char2id.items()}
    states = get_beam_states(char_prob,
                             classes=id2char,
                             lm=lm,
                             beam_width=args.beam_width,
                             alpha=args.alpha,
                             beta=args.beta,
                             min_char_prob=0.001)
    candidates_str, log_prob_ctc = get_ctc_prob_and_candidates(
        states, beam_width=args.beam_width, id2char=id2char)

    print("flatten...")
    candidates = []
    img2id = {}
    for i, candidates_i in enumerate(candidates_str):
        img = images[i]
        img2id[img] = i
        for text in candidates_i:
            text_clean = process_input_text(text)
            x = Example(img=img, text=text_clean)
            candidates.append(x)

    print("attn scoring...")
    joint_model = build_and_restore(model_cls=models.JointModel,
                                    model_dir=infer_moder_dir(
                                        'joint', val_mode=args.val_mode),
                                    training=False)
    log_prob_joint = joint_model.predict(examples=examples,
                                         candidates=candidates,
                                         img2id=img2id,
                                         batch_size_enc=128,
                                         batch_size_dec=512)
    log_prob_joint = log_prob_joint.reshape((-1, args.beam_width))

    print("birnn scoring...")
    lm_birnn = build_and_restore(model_cls=models.BiRNNLanguageModel,
                                 model_dir=infer_moder_dir(
                                     'lm/birnn', val_mode=args.val_mode))
    log_prob_birnn = lm_birnn.predict(
        candidates, batch_size=256)  # [num_examples * beam_width]
    log_prob_birnn = log_prob_birnn.reshape((-1, args.beam_width))

    print("transformer scoring...")
    lm_transformer = build_and_restore(
        model_cls=models.TransformerLanguageModel,
        model_dir=infer_moder_dir('lm/transformer', val_mode=args.val_mode))
    log_prob_transformer = lm_transformer.predict(
        candidates, batch_size=256)  # [num_examples * beam_width]
    log_prob_transformer = log_prob_transformer.reshape((-1, args.beam_width))

    # weighting
    print("weighting...")
    scores = log_prob_ctc * args.w_ctc \
        + log_prob_birnn * args.w_birnn \
        + log_prob_joint * args.w_joint \
        + log_prob_transformer * args.w_transformer  # [num_examples, beam_width]
    indices = scores.argmax(1)  # [num_examples]
    texts = list(
        map(process_output_text,
            (candidates[args.beam_width * id_example + id_path].text
             for id_example, id_path in enumerate(indices))))

    # saving predictions
    print("saving...")
    save_predictions(output_dir=args.output_dir, images=images, texts=texts)