Exemplo n.º 1
0
def main(_):
    assert FLAGS.checkpoint_dir, "--checkpoint_dir is required."
    assert FLAGS.source_test_path, "--source_test_path is required."
    assert FLAGS.target_test_path, "--target_test_path is required."
    assert FLAGS.reference_test_path, "--reference_test_path is required."
    assert FLAGS.source_vocab_path, "--souce_vocab_path is required."
    assert FLAGS.target_vocab_path, "--target_vocab_path is required."

    # Read vocabularies.
    source_vocab, _ = utils.initialize_vocabulary(FLAGS.source_vocab_path)
    target_vocab, _ = utils.initialize_vocabulary(FLAGS.target_vocab_path)

    # Read test set.
    source_sentences, target_sentences, references = utils.read_data_with_ref(
        FLAGS.source_test_path, FLAGS.target_test_path,
        FLAGS.reference_test_path)

    # Convert sentences to token ids sequences.
    source_sentences_ids = [
        utils.sentence_to_token_ids(sent, source_vocab, FLAGS.max_seq_length)
        for sent in source_sentences
    ]
    target_sentences_ids = [
        utils.sentence_to_token_ids(sent, target_vocab, FLAGS.max_seq_length)
        for sent in target_sentences
    ]

    utils.reset_graph()
    with tf.Session() as sess:
        # Restore saved model.
        utils.restore_model(sess, FLAGS.checkpoint_dir)

        # Recover placeholders and ops for evaluation.
        x_source = sess.graph.get_tensor_by_name("x_source:0")
        source_seq_length = sess.graph.get_tensor_by_name(
            "source_seq_length:0")

        x_target = sess.graph.get_tensor_by_name("x_target:0")
        target_seq_length = sess.graph.get_tensor_by_name(
            "target_seq_length:0")

        labels = sess.graph.get_tensor_by_name("labels:0")

        placeholders = [
            x_source, source_seq_length, x_target, target_seq_length, labels
        ]

        probs = sess.graph.get_tensor_by_name("feed_forward/output/probs:0")

        # Run evaluation.
        evaluate(sess, source_sentences, target_sentences, references,
                 source_sentences_ids, target_sentences_ids, probs,
                 placeholders)
 def read_vocab(self):
     # don't try reading vocabulary for encoders that take pre-computed features
     self.vocabs = [
         None if binary else utils.initialize_vocabulary(vocab_path)
         for vocab_path, binary in zip(self.filenames.vocab, self.binary)
         ]
     self.src_vocab, self.trg_vocab = self.vocabs[:len(self.src_ext)], self.vocabs[len(self.src_ext):]
def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):

    anno_file_dirs = [config.val_annotation_file]
    images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.val_data_root,
                                                                 anno_file_dirs=anno_file_dirs)
    vocab, _ = initialize_vocabulary(config.vocab_path)

    data_schema = {"image": {"type": "bytes"},
                   "decoder_input": {"type": "int32", "shape": [-1]},
                   "decoder_target": {"type": "int32", "shape": [-1]},
                   "annotation": {"type": "string"}}

    mindrecord_path = os.path.join(mindrecord_dir, prefix)

    writer = FileWriter(mindrecord_path, file_num)
    writer.add_schema(data_schema, "ocr")

    for img_id in images:

        image_path = image_path_dict[img_id]
        annotation = image_anno_dict[img_id]

        label_max_len = config.max_length
        text_max_len = config.max_length - 2

        if len(annotation) > text_max_len:
            continue
        label = serialize_annotation(image_path, annotation, vocab)

        if label is None:
            continue

        label_len = len(label)
        decoder_input_len = label_max_len

        if label_len <= decoder_input_len:
            label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
        else:
            continue

        decoder_input = (np.array(label).T).astype(np.int32)

        target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
        target = (np.array(target)).astype(np.int32)


        with open(image_path, 'rb') as f:
            img = f.read()

        row = {"image": img,
               "decoder_input": decoder_input,
               "decoder_target": target,
               "annotation": str(annotation)}

        writer.write_raw_data([row])
    writer.commit()
def decode():
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        vocab_path = os.path.join(FLAGS.data_dir, "vocab%d" % FLAGS.vocab_size)
        vocab, rev_vocab = utils.initialize_vocabulary(vocab_path)

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            # Get token-ids for the input sentence.
            sentence_tokens = utils.basic_tokenizer(
                tf.compat.as_bytes(sentence))
            token_ids = utils.sentence_to_token_ids(sentence_tokens, vocab)
            # Which bucket does it belong to?
            bucket_id = min([
                b for b in xrange(len(_buckets))
                if _buckets[b][0] > len(token_ids)
            ])
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)
            # Get output logits for the sentence.
            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, target_weights,
                                             bucket_id, True)
            # This is a greedy decoder - outputs are just argmaxes of output_logits.
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            # If there is an EOS symbol in outputs, cut them at that point.
            if utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(utils.EOS_ID)]
            # Print out French sentence corresponding to outputs.
            print(" ".join(
                [tf.compat.as_str(rev_vocab[output]) for output in outputs]))
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Exemplo n.º 5
0
def main(_):
    assert FLAGS.source_train_path, ("--source_train_path is required.")
    assert FLAGS.target_train_path, ("--target_train_path is required.")

    # Create vocabularies.
    source_vocab_path = os.path.join(os.path.dirname(FLAGS.source_train_path),
                                     "vocabulary.source")
    target_vocab_path = os.path.join(os.path.dirname(FLAGS.source_train_path),
                                     "vocabulary.target")
    utils.create_vocabulary(source_vocab_path, FLAGS.source_train_path, FLAGS.source_vocab_size)
    utils.create_vocabulary(target_vocab_path, FLAGS.target_train_path, FLAGS.target_vocab_size)

    # Read vocabularies.
    source_vocab, rev_source_vocab = utils.initialize_vocabulary(source_vocab_path)
    target_vocab, rev_target_vocab = utils.initialize_vocabulary(target_vocab_path)

    # Read parallel sentences.
    parallel_data = utils.read_data(FLAGS.source_train_path, FLAGS.target_train_path,
                                    source_vocab, target_vocab)

    # Read validation data set.
    if FLAGS.source_valid_path and FLAGS.target_valid_path:
        valid_data = utils.read_data(FLAGS.source_valid_path, FLAGS.target_valid_path,
                                    source_vocab, target_vocab)

    # Initialize BiRNN.
    config = Config(len(source_vocab),
                    len(target_vocab),
                    FLAGS.embedding_size,
                    FLAGS.state_size,
                    FLAGS.hidden_size,
                    FLAGS.num_layers,
                    FLAGS.learning_rate,
                    FLAGS.max_gradient_norm,
                    FLAGS.use_lstm,
                    FLAGS.use_mean_pooling,
                    FLAGS.use_max_pooling,
                    FLAGS.source_embeddings_path,
                    FLAGS.target_embeddings_path,
                    FLAGS.fix_pretrained)

    model = BiRNN(config)

    # Build graph.
    model.build_graph()

    # Train  model.
    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        train_iterator = utils.TrainingIteratorRandom(parallel_data, FLAGS.num_negative)
        train_summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "train"), sess.graph)

        if FLAGS.source_valid_path and FLAGS.target_valid_path:
            valid_iterator = utils.EvalIterator(valid_data)
            valid_summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "valid"), sess.graph)

        epoch_loss = 0
        epoch_completed = 0
        batch_completed = 0

        num_iter = int(np.ceil(train_iterator.size / FLAGS.batch_size * FLAGS.num_epochs))
        start_time = time.time()
        print("Training model on {} sentence pairs per epoch:".
            format(train_iterator.size, valid_iterator.size))

        for step in xrange(num_iter):
            source, target, label = train_iterator.next_batch(FLAGS.batch_size)
            source_len = utils.sequence_length(source)
            target_len = utils.sequence_length(target)
            feed_dict = {model.x_source: source,
                         model.x_target: target,
                         model.labels: label,
                         model.source_seq_length: source_len,
                         model.target_seq_length: target_len,
                         model.input_dropout: FLAGS.keep_prob_input,
                         model.output_dropout: FLAGS.keep_prob_output,
                         model.decision_threshold: FLAGS.decision_threshold}

            _, loss_value, epoch_accuracy,\
            epoch_precision, epoch_recall = sess.run([model.train_op,
                                                      model.mean_loss,
                                                      model.accuracy[1],
                                                      model.precision[1],
                                                      model.recall[1]],
                                                      feed_dict=feed_dict)
            epoch_loss += loss_value
            batch_completed += 1
            # Write the model's training summaries.
            if step % FLAGS.steps_per_checkpoint == 0:
                summary = sess.run(model.summaries, feed_dict=feed_dict)
                train_summary_writer.add_summary(summary, global_step=step)
            # End of current epoch.
            if train_iterator.epoch_completed > epoch_completed:
                epoch_time = time.time() - start_time
                epoch_loss /= batch_completed
                epoch_f1 = utils.f1_score(epoch_precision, epoch_recall)
                epoch_completed += 1
                print("Epoch {} in {:.0f} sec\n"
                      "  Training: Loss = {:.6f}, Accuracy = {:.4f}, "
                      "Precision = {:.4f}, Recall = {:.4f}, F1 = {:.4f}"
                      .format(epoch_completed, epoch_time,
                              epoch_loss, epoch_accuracy,
                              epoch_precision, epoch_recall, epoch_f1))
                # Save a model checkpoint.
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir, "model.ckpt")
                model.saver.save(sess, checkpoint_path, global_step=step)
                # Evaluate model on the validation set.
                if FLAGS.source_valid_path and FLAGS.target_valid_path:
                    eval_epoch(sess, model, valid_iterator, valid_summary_writer)
                # Initialize local variables for new epoch.
                batch_completed = 0
                epoch_loss = 0
                sess.run(tf.local_variables_initializer())
                start_time = time.time()

        print("Training done with {} steps.".format(num_iter))
        train_summary_writer.close()
        valid_summary_writer.close()
Exemplo n.º 6
0
def main(_):
    assert FLAGS.checkpoint_dir, "--checkpoint_dir is required."
    assert FLAGS.extract_dir, "--extract_dir is required."
    assert FLAGS.source_vocab_path, "--source_vocab_path is required."
    assert FLAGS.target_vocab_path, "--target_vocab_path is required."
    assert FLAGS.source_output_path, "--source_output_path is required."
    assert FLAGS.target_output_path, "--target_output_path is required."
    assert FLAGS.score_output_path, "--score_output_path is required."
    assert FLAGS.source_language, "--source_language is required."
    assert FLAGS.target_language, "--target_language is required."

    # Read vocabularies.
    source_vocab, _ = utils.initialize_vocabulary(FLAGS.source_vocab_path)
    target_vocab, _ = utils.initialize_vocabulary(FLAGS.target_vocab_path)

    source_vocab_words = read_vocabulary(FLAGS.source_vocab_path)
    target_vocab_words = read_vocabulary(FLAGS.target_vocab_path)

    utils.reset_graph()
    with tf.Session() as sess:
        # Restore saved model.
        utils.restore_model(sess, FLAGS.checkpoint_dir)

        # Recover placeholders and ops for extraction.
        x_source = sess.graph.get_tensor_by_name("x_source:0")
        source_seq_length = sess.graph.get_tensor_by_name(
            "source_seq_length:0")

        x_target = sess.graph.get_tensor_by_name("x_target:0")
        target_seq_length = sess.graph.get_tensor_by_name(
            "target_seq_length:0")

        labels = sess.graph.get_tensor_by_name("labels:0")

        placeholders = [
            x_source, source_seq_length, x_target, target_seq_length, labels
        ]

        probs = sess.graph.get_tensor_by_name("feed_forward/output/probs:0")

        with open(FLAGS.source_output_path, mode="w", encoding="utf-8") as source_output_file, \
                open(FLAGS.target_output_path, mode="w", encoding="utf-8") as target_output_file, \
                open(FLAGS.score_output_path, mode="w", encoding="utf-8") as score_output_file:

            source_docs, target_docs = read_docs(FLAGS.extract_dir,
                                                 source_vocab, target_vocab)
            pairs = extract_pairs(sess, source_docs, target_docs,
                                  source_sentences_ids, target_sentences_ids,
                                  probs, placeholders)

            #for source_path, target_path in zip(source_paths, target_paths):
            for source_path, target_path in itertools.product(
                    source_paths, target_paths):
                #print("paths", source_path, target_path)
                # Read sentences from articles.
                source_sentences, target_sentences = read_articles(
                    source_path, target_path)

                # Convert sentences to token ids sequences.
                source_sentences_ids = [
                    utils.sentence_to_token_ids(sent, source_vocab,
                                                FLAGS.max_seq_length)
                    for sent in source_sentences
                ]
                target_sentences_ids = [
                    utils.sentence_to_token_ids(sent, target_vocab,
                                                FLAGS.max_seq_length)
                    for sent in target_sentences
                ]

                # Extract sentence pairs.
                pairs = extract_pairs(sess, source_sentences, target_sentences,
                                      source_sentences_ids,
                                      target_sentences_ids, probs,
                                      placeholders)
                if not pairs:
                    continue
                for source_sentence, target_sentence, score in pairs:
                    source_output_file.write(source_sentence)
                    target_output_file.write(target_sentence)
                    score_output_file.write(str(score) + "\n")
Exemplo n.º 7
0
def main(_):
    assert FLAGS.checkpoint_dir, "--checkpoint_dir is required."
    assert FLAGS.extract_dir, "--extract_dir is required."
    assert FLAGS.source_vocab_path, "--source_vocab_path is required."
    assert FLAGS.target_vocab_path, "--target_vocab_path is required."
    assert FLAGS.source_output_path, "--source_output_path is required."
    assert FLAGS.target_output_path, "--target_output_path is required."
    assert FLAGS.score_output_path, "--score_output_path is required."
    assert FLAGS.source_language, "--source_language is required."
    assert FLAGS.target_language, "--target_language is required."

    # Read vocabularies.
    source_vocab, _ = utils.initialize_vocabulary(FLAGS.source_vocab_path)
    target_vocab, _ = utils.initialize_vocabulary(FLAGS.target_vocab_path)

    # Read source and target paths for sentence extraction.
    source_paths = []
    target_paths = []
    for file in os.listdir(FLAGS.extract_dir):
        if file.endswith(FLAGS.source_language):
            source_paths.append(os.path.join(FLAGS.extract_dir, file))
        elif file.endswith(FLAGS.target_language):
            target_paths.append(os.path.join(FLAGS.extract_dir, file))
    source_paths.sort()
    target_paths.sort()

    utils.reset_graph()
    with tf.Session() as sess:
        # Restore saved model.
        utils.restore_model(sess, FLAGS.checkpoint_dir)

        # Recover placeholders and ops for extraction.
        x_source = sess.graph.get_tensor_by_name("x_source:0")
        source_seq_length = sess.graph.get_tensor_by_name("source_seq_length:0")

        x_target = sess.graph.get_tensor_by_name("x_target:0")
        target_seq_length = sess.graph.get_tensor_by_name("target_seq_length:0")

        labels = sess.graph.get_tensor_by_name("labels:0")

        placeholders = [x_source, source_seq_length, x_target, target_seq_length, labels]

        probs = sess.graph.get_tensor_by_name("feed_forward/output/probs:0")

        source_final_state_ph = sess.graph.get_tensor_by_name("birnn/source_final_state_ph:0")

        with open(FLAGS.source_output_path, mode="w", encoding="utf-8") as source_output_file,\
             open(FLAGS.target_output_path, mode="w", encoding="utf-8") as target_output_file,\
             open(FLAGS.score_output_path, mode="w", encoding="utf-8") as score_output_file:

            for source_path, target_path in zip(source_paths, target_paths):
                # Read sentences from articles.
                source_sentences, target_sentences = read_articles(source_path, target_path)

                # Convert sentences to token ids sequences.
                source_sentences_ids = [utils.sentence_to_token_ids(sent, source_vocab, FLAGS.max_seq_length)
                                        for sent in source_sentences]
                target_sentences_ids = [utils.sentence_to_token_ids(sent, target_vocab, FLAGS.max_seq_length)
                                        for sent in target_sentences]

                # Extract sentence pairs.
                pairs = extract_pairs(sess, source_sentences, target_sentences,
                                      source_sentences_ids, target_sentences_ids,
                                      probs, placeholders, source_final_state_ph)
                if not pairs:
                    continue
                for source_sentence, target_sentence, score in pairs:
                    source_output_file.write(source_sentence)
                    target_output_file.write(target_sentence)
                    score_output_file.write(str(score) + "\n")