def generate_with_greedy_search(texts):
     tokens = tokenizer.tokenize(texts).to_tensor(default_value=args.pad_id)
     decoded_tokens, perplexity = greedy_search(model, tokens, bos_id,
                                                eos_id,
                                                args.max_sequence_length,
                                                args.pad_id)
     sentences = tokenizer.detokenize(decoded_tokens)
     return {"sentences": sentences, "perplexity": perplexity}
Exemple #2
0
def test_search():
    model = RNNSeq2Seq(
        cell_type="SimpleRNN",
        vocab_size=100,
        hidden_dim=32,
        num_encoder_layers=1,
        num_decoder_layers=1,
        dropout=0.0,
        use_bidirectional=True,
    )

    batch_size = 8
    encoder_sequence = 10
    decoder_sequence = 15
    bos_id = 2
    eos_id = 3
    max_sequence_length = 17
    output = model((
        tf.random.uniform((batch_size, encoder_sequence),
                          maxval=100,
                          dtype=tf.int32),
        tf.random.uniform((batch_size, decoder_sequence),
                          maxval=100,
                          dtype=tf.int32),
    ))

    encoder_input = tf.random.uniform((batch_size, encoder_sequence),
                                      maxval=100,
                                      dtype=tf.int32)
    decoder_sequence = tf.random.uniform((batch_size, decoder_sequence),
                                         maxval=100,
                                         dtype=tf.int32)
    beam_result, beam_ppl = beam_search(model, encoder_input, 1, bos_id,
                                        eos_id, max_sequence_length)
    greedy_result, greedy_ppl = greedy_search(model, encoder_input, bos_id,
                                              eos_id, max_sequence_length)

    tf.debugging.assert_equal(beam_result[:, 0, :], greedy_result)
    tf.debugging.assert_near(tf.squeeze(beam_ppl), greedy_ppl)
                src_id, tgt_id = data_id
                src_mask, tgt_mask = data_mask
                src_lengths, tgt_lengths = data_lengths
                if USE_GPU:
                    src_id = Variable(src_id).cuda()
                    tgt_id = Variable(tgt_id).cuda()
                    src_mask = Variable(src_mask).cuda()
                    tgt_mask = Variable(tgt_mask).cuda()
                else:
                    src_id = Variable(src_id)
                    tgt_id = Variable(tgt_id)
                    src_mask = Variable(src_mask)
                    tgt_mask = Variable(tgt_mask)

                pred_ids, _, _ = greedy_search(
                    model, src_id, src_mask, src_lengths, tgt_dict,
                    config['evaluation']['max_decode_len'],
                    USE_GPU)  # N x max_decode_len
                pred_ids_lst = pred_ids.data.tolist()
                pred_batch_lst = tgt_dict.convert_id_lst_to_symbol_lst(
                    pred_ids_lst)
                cand_lst.extend(pred_batch_lst)
                # single ref.
                gold_batch_lst = [tup[1] for tup in data_symbol]
                gold_lst.extend(gold_batch_lst)

            ngram_bleus, bleu, bp, hyp_ref_len, ratio = bleu_calulator.calc_bleu(
                cand_lst, [gold_lst])
            print(
                'BLEU: %2.2f (%2.2f, %2.2f, %2.2f, %2.2f) BP: %.5f ratio: %.5f (%d/%d)'
                % (bleu * 100, ngram_bleus[0] * 100, ngram_bleus[1] * 100,
                   ngram_bleus[2] * 100, ngram_bleus[3] * 100, bp, ratio,
Exemple #4
0
            model = MODEL_MAP[args.model_name](**json.load(f))
        model((tf.keras.Input([None], dtype=tf.int32), tf.keras.Input([None], dtype=tf.int32)))
        model.load_weights(args.model_path)
        logger.info("Loaded weights of model")

    # Evaluate
    bleu_sum = 0.0
    perplexity_sum = 0.0
    total = 0
    bos_id, eos_id = tokenizer.tokenize("").numpy().tolist()
    dataset_tqdm = tqdm(dataset)
    for batch_input, batch_true_answer in dataset_tqdm:
        num_batch = len(batch_true_answer)
        if args.beam_size > 0:
            batch_pred_answer, perplexity = beam_search(
                model, batch_input, args.beam_size, bos_id, eos_id, args.max_sequence_length
            )
            batch_pred_answer = batch_pred_answer[:, 0, :]
        else:
            batch_pred_answer, perplexity = greedy_search(model, batch_input, bos_id, eos_id, args.max_sequence_length)
        perplexity_sum += tf.math.reduce_sum(perplexity).numpy()

        for true_answer, pred_answer in zip(batch_true_answer, batch_pred_answer):
            bleu_sum += calculat_bleu_score(true_answer.numpy().tolist(), pred_answer.numpy().tolist())

        total += num_batch
        dataset_tqdm.set_description(f"Perplexity: {perplexity_sum / total}, BLEU: {bleu_sum / total}")

    logger.info("Finished evalaution!")
    logger.info(f"Perplexity: {perplexity_sum / total}, BLEU: {bleu_sum / total}")
Exemple #5
0
        with tf.io.gfile.GFile(args.model_config_path) as f:
            model = MODEL_MAP[args.model_name](**json.load(f))
        model.load_weights(args.model_path)
        logger.info("Loaded weights of model")

    # Inference
    logger.info("Start Inference")
    outputs = []
    bos_id, eos_id = tokenizer.tokenize("").numpy().tolist()

    for batch_input in dataset:
        if args.beam_size > 0:
            batch_output = beam_search(model, batch_input, args.beam_size, bos_id, eos_id, args.max_sequence_length)
            batch_output = batch_output[0][:, 0, :].numpy()
        else:
            batch_output = greedy_search(model, batch_input, bos_id, eos_id, args.max_sequence_length)[0].numpy()
        outputs.extend(batch_output)
    outputs = [tokenizer.detokenize(output).numpy().decode("UTF8") for output in outputs]
    logger.info("Ended Inference, Start to save...")

    # Save file
    if args.save_pair:
        with open(args.dataset_path) as f, open(args.output_path, "w") as fout:
            wtr = csv.writer(fout, delimiter="\t")
            wtr.writerow(["EncodedSentence", "DecodedSentence"])

            for input_sentence, decoded_sentence in zip(f.read().split("\n"), outputs):
                wtr.writerow((input_sentence, decoded_sentence))
        logger.info(f"Saved (original sentence,decoded sentence) pairs to {args.output_path}")

    else: