Example #1
0
  def RunTraining(self, hyperparam_config):
    master_spec = self.LoadSpec('master_spec_link.textproto')

    self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint))
    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]
    tf.logging.info('Generating graph with config: %s', hyperparam_config)
    with tf.Graph().as_default():
      builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

      target = spec_pb2.TrainTarget()
      target.name = 'testTraining-all'
      train = builder.add_training_from_config(target)
      with self.test_session() as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        # Run one iteration of training and verify nothing crashes.
        logging.info('Training')
        sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
Example #2
0
def calculate_parse_metrics(gold_corpus, annotated_corpus):
  """Calculate POS/UAS/LAS accuracy based on gold and annotated sentences."""
  check.Eq(len(gold_corpus), len(annotated_corpus), 'Corpora are not aligned')
  num_tokens = 0
  num_correct_pos = 0
  num_correct_uas = 0
  num_correct_las = 0
  for gold_str, annotated_str in zip(gold_corpus, annotated_corpus):
    gold = sentence_pb2.Sentence()
    annotated = sentence_pb2.Sentence()
    gold.ParseFromString(gold_str)
    annotated.ParseFromString(annotated_str)
    check.Eq(gold.text, annotated.text, 'Text is not aligned')
    check.Eq(len(gold.token), len(annotated.token), 'Tokens are not aligned')
    tokens = zip(gold.token, annotated.token)
    num_tokens += len(tokens)
    num_correct_pos += sum(1 for x, y in tokens if x.tag == y.tag)
    num_correct_uas += sum(1 for x, y in tokens if x.head == y.head)
    num_correct_las += sum(1 for x, y in tokens
                           if x.head == y.head and x.label == y.label)

  tf.logging.info('Total num documents: %d', len(annotated_corpus))
  tf.logging.info('Total num tokens: %d', num_tokens)
  pos = num_correct_pos * 100.0 / num_tokens
  uas = num_correct_uas * 100.0 / num_tokens
  las = num_correct_las * 100.0 / num_tokens
  tf.logging.info('POS: %.2f%%', pos)
  tf.logging.info('UAS: %.2f%%', uas)
  tf.logging.info('LAS: %.2f%%', las)
  return pos, uas, las
Example #3
0
def calculate_segmentation_metrics(gold_corpus, annotated_corpus):
    """Calculate precision/recall/f1 based on gold and annotated sentences."""
    check.Eq(len(gold_corpus), len(annotated_corpus),
             'Corpora are not aligned')
    num_gold_tokens = 0
    num_test_tokens = 0
    num_correct_tokens = 0

    def token_span(token):
        check.Ge(token.end, token.start)
        return (token.start, token.end)

    def ratio(numerator, denominator):
        check.Ge(numerator, 0)
        check.Ge(denominator, 0)
        if denominator > 0:
            return numerator / denominator
        elif numerator == 0:
            return 0.0  # map 0/0 to 0
        else:
            return float('inf')  # map x/0 to inf

    for gold_str, annotated_str in zip(gold_corpus, annotated_corpus):
        gold = sentence_pb2.Sentence()
        annotated = sentence_pb2.Sentence()
        gold.ParseFromString(gold_str)
        annotated.ParseFromString(annotated_str)
        check.Eq(gold.text, annotated.text, 'Text is not aligned')
        gold_spans = set()
        test_spans = set()
        for token in gold.token:
            check.NotIn(token_span(token), gold_spans, 'Duplicate token')
            gold_spans.add(token_span(token))
        for token in annotated.token:
            check.NotIn(token_span(token), test_spans, 'Duplicate token')
            test_spans.add(token_span(token))
        num_gold_tokens += len(gold_spans)
        num_test_tokens += len(test_spans)
        num_correct_tokens += len(gold_spans.intersection(test_spans))

    tf.logging.info('Total num documents: %d', len(annotated_corpus))
    tf.logging.info('Total gold tokens: %d', num_gold_tokens)
    tf.logging.info('Total test tokens: %d', num_test_tokens)
    precision = 100 * ratio(num_correct_tokens, num_test_tokens)
    recall = 100 * ratio(num_correct_tokens, num_gold_tokens)
    f1 = ratio(2 * precision * recall, precision + recall)
    tf.logging.info('Precision: %.2f%%', precision)
    tf.logging.info('Recall: %.2f%%', recall)
    tf.logging.info('F1: %.2f%%', f1)

    return round(precision, 2), round(recall, 2), round(f1, 2)
Example #4
0
def pretty_print():
    _write_input(_read_output().strip())
    logging.set_verbosity(logging.INFO)
    with tf.Session() as sess:
        src = gen_parser_ops.document_source(
            batch_size=32,
            corpus_name='input-from-file-conll',
            task_context=task_context_path)
        sentence = sentence_pb2.Sentence()
        while True:
            documents, finished = sess.run(src)
            logging.info('Read %d documents', len(documents))
            # for d in documents:
            # 	sentence.ParseFromString(d)
            # 	as_asciitree(sentence)
            for d in documents:
                sentence.ParseFromString(d)
                tr = asciitree.LeftAligned()
                d = to_dict(sentence)
                print('Input: %s' % sentence.text)
                print('Parse:')
                tr_str = tr(d)
                pat = re.compile(r'\s*@\d+$')
                for tr_ln in tr_str.splitlines():
                    print(pat.sub('', tr_ln))
            if finished:
                break
def main(unused_argv):
  logging.set_verbosity(logging.INFO)
  with tf.Session() as sess:
    src = gen_parser_ops.document_source(batch_size=32,
                                         corpus_name=FLAGS.corpus_name,
                                         task_context=FLAGS.task_context)
    sentence = sentence_pb2.Sentence()
    
    while True:
      documents, finished = sess.run(src)
      logging.info('Read %d documents', len(documents))
      for d in documents:
        sentence.ParseFromString(d)
        #print '...Sentence string before serialization: ', d
        tr = asciitree.LeftAligned()
        d = to_dict(sentence)
        print 'Input: %s' % sentence.text
        serializedStr = sentence.SerializeToString()
	#print '...Sentence string protobuf: ', serializedStr
        file = open("/Users/yihed/Documents/workspace/Other/src/thmp/data/serializedSentence.txt", "wb")
        #file = open("serializedSentence.txt", "wb")
        file.write(serializedStr)
        file.close()
        print 'Parse:'
        print tr(d)

      if finished:
        break
Example #6
0
 def assertParseable(self, reader, expected_num, expected_last):
     sentences, last = reader.read()
     self.assertEqual(expected_num, len(sentences))
     self.assertEqual(expected_last, last)
     for s in sentences:
         pb = sentence_pb2.Sentence()
         pb.ParseFromString(s)
         self.assertGreater(len(pb.token), 0)
Example #7
0
 def ReadNextDocument(self, sess, sentence):
     sentence_str, = sess.run([sentence])
     if sentence_str:
         sentence_doc = sentence_pb2.Sentence()
         sentence_doc.ParseFromString(sentence_str[0])
     else:
         sentence_doc = None
     return sentence_doc
 def ReadNextDocument(self, sess, doc_source):
     doc_str, last = sess.run(doc_source)
     if doc_str:
         doc = sentence_pb2.Sentence()
         doc.ParseFromString(doc_str[0])
     else:
         doc = None
     return doc, last
Example #9
0
 def testReadFirstSentence(self):
     reader = sentence_io.ConllSentenceReader(self.filepath, 1)
     sentences, last = reader.read()
     self.assertEqual(1, len(sentences))
     pb = sentence_pb2.Sentence()
     pb.ParseFromString(sentences[0])
     self.assertFalse(last)
     self.assertEqual(
         u'I knew I could do it properly if given the right kind of support .',
         pb.text)
 def testGiveMeAName(self):
   document = sentence_pb2.Sentence()
   document.token.add(start=0, end=0, word='hi', head=1, label='something')
   document.token.add(start=1, end=1, word='there')
   contents = render_parse_tree_graphviz.parse_tree_graph(document)
   self.assertIn('<polygon', contents)
   self.assertIn('text/html;charset=utf-8;base64', contents)
   self.assertIn('something', contents)
   self.assertIn('hi', contents)
   self.assertIn('there', contents)
    def testModelExportProducesRunnableModel(self):
        # Get the master spec and params for this graph.
        master_spec = self.LoadSpec('ud-hungarian.master-spec')
        params_path = os.path.join(
            test_flags.source_root(), 'dragnn/python/testdata'
            '/ud-hungarian.params')

        # Export the graph via SavedModel. (Here, we maintain a handle to the graph
        # for comparison, but that's usually not necessary.)
        export_path = os.path.join(test_flags.temp_dir(), 'export')
        dragnn_model_saver_lib.clean_output_paths(export_path)
        saver_graph = tf.Graph()

        shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
            master_spec)

        dragnn_model_saver_lib.export_master_spec(master_spec, saver_graph)

        dragnn_model_saver_lib.export_to_graph(master_spec,
                                               params_path,
                                               export_path,
                                               saver_graph,
                                               export_moving_averages=False,
                                               build_runtime_graph=False)

        # Export the assets as well.
        dragnn_model_saver_lib.export_assets(master_spec,
                                             shortened_to_original,
                                             export_path)

        # Restore the graph from the checkpoint into a new Graph object.
        restored_graph = tf.Graph()
        restoration_config = tf.ConfigProto(log_device_placement=False,
                                            intra_op_parallelism_threads=10,
                                            inter_op_parallelism_threads=10)

        with tf.Session(graph=restored_graph,
                        config=restoration_config) as sess:
            tf.saved_model.loader.load(sess,
                                       [tf.saved_model.tag_constants.SERVING],
                                       export_path)

            test_doc = sentence_pb2.Sentence()
            text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
            test_reader_string = test_doc.SerializeToString()
            test_inputs = [test_reader_string]

            tf_out = sess.run('annotation/annotations:0',
                              feed_dict={
                                  'annotation/ComputeSession/InputBatch:0':
                                  test_inputs
                              })

            # We don't care about accuracy, only that the run sessions don't crash.
            del tf_out
 def _add_sentence(self, tags, heads, labels, corpus):
     """Adds a sentence to the corpus."""
     sentence = sentence_pb2.Sentence()
     for tag, head, label in zip(tags, heads, labels):
         sentence.token.add(word='x',
                            start=0,
                            end=0,
                            tag=tag,
                            head=head,
                            label=label)
     corpus.append(sentence.SerializeToString())
def _create_fake_corpus():
  """Returns a list of fake serialized sentences for tests."""
  num_docs = 4
  corpus = []
  for num_tokens in range(1, num_docs + 1):
    sentence = sentence_pb2.Sentence()
    sentence.text = 'x' * num_tokens
    for i in range(num_tokens):
      token = sentence.token.add()
      token.word = 'x'
      token.start = i
      token.end = i
    corpus.append(sentence.SerializeToString())
  return corpus
    def annotate_text(self, text):
        sentence = sentence_pb2.Sentence(
            text=text, token=[sentence_pb2.Token(word=text, start=-1, end=-1)])

        # preprocess
        with tf.Session(graph=tf.Graph()) as tmp_session:
            char_input = gen_parser_ops.char_token_generator(
                [sentence.SerializeToString()])
            preprocessed = tmp_session.run(char_input)[0]
        segmented, _ = self.segmenter_model(preprocessed)

        annotations, traces = self.parser_model(segmented[0])
        assert len(annotations) == 1
        assert len(traces) == 1
        return sentence_pb2.Sentence.FromString(annotations[0])
def main(argv):
    del argv  # unused
    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    lexicon.build_lexicon(lexicon_dir,
                          training_sentence,
                          training_corpus_format='sentence-prototext')

    # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
    # sequence tagger.
    tagger = spec_builder.ComponentSpecBuilder('tagger')
    tagger.set_network_unit(name='FeedForwardNetwork',
                            hidden_layer_sizes='256')
    tagger.set_transition_system(name='tagger')
    tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
    tagger.add_rnn_link(embedding_dim=-1)
    tagger.fill_from_resources(lexicon_dir)

    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend([tagger.spec])

    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

        target = spec_pb2.TrainTarget()
        target.name = 'all'
        target.unroll_using_oracle.extend([True])
        dry_run = builder.add_training_from_config(target, trace_only=True)

    # Read in serialized protos from training data.
    sentence = sentence_pb2.Sentence()
    text_format.Merge(open(training_sentence).read(), sentence)
    training_set = [sentence.SerializeToString()]

    with tf.Session(graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.initialize_all_variables())
        traces = sess.run(dry_run['traces'],
                          feed_dict={dry_run['input_batch']: training_set})

    with open('dragnn_tutorial_1.html', 'w') as f:
        f.write(
            visualization.trace_html(traces[0],
                                     height='300px').encode('utf-8'))
Example #16
0
def syntaxnet_tokenize(text):
    sentence = sentence_pb2.Sentence(
        text=text,
        token=[sentence_pb2.Token(word=text, start=-1, end=-1)]
    )

    # preprocess
    with tf.Session(graph=tf.Graph()) as tmp_session:
        char_input = gen_parser_ops.char_token_generator([sentence.SerializeToString()])
        preprocessed = tmp_session.run(char_input)[0]
    segmented, _ = segmenter_model(preprocessed)
    tokens = []

    for t in sentence_pb2.Sentence.FromString(segmented[0]).token:
        tokens.append(t.word)

    return tokens
Example #17
0
def syntaxnet_sentence(tokens):
    pb_tokens = []
    last_start = 0
    for token in tokens:
        token_bytes = token.encode("utf8")
        pb_tokens.append(sentence_pb2.Token(
            word=token_bytes, start=last_start, end=last_start + len(token_bytes) - 1)
        )
        last_start = last_start + len(token_bytes) + 1

    annotations, traces = parser_model(sentence_pb2.Sentence(
        text=u" ".join(tokens).encode("utf8"),
        token=pb_tokens
    ).SerializeToString())
    assert len(annotations) == 1
    assert len(traces) == 1
    return sentence_pb2.Sentence.FromString(annotations[0])
Example #18
0
def annotate_text(text):
    """
    Segment and parse input text using syntaxnet models.
    """
    sentence = sentence_pb2.Sentence(
        text=text, token=[sentence_pb2.Token(word=text, start=-1, end=-1)])

    # preprocess
    with tf.Session(graph=tf.Graph()) as tmp_session:
        char_input = gen_parser_ops.char_token_generator(
            [sentence.SerializeToString()])
        preprocessed = tmp_session.run(char_input)[0]
    segmented, _ = SEGMENTER_MODEL(preprocessed)

    annotations, traces = PARSER_MODEL(segmented[0])
    assert len(annotations) == 1
    assert len(traces) == 1
    return sentence_pb2.Sentence.FromString(annotations[0]), traces[0]
Example #19
0
def _get_sentence_dict():
    logging.set_verbosity(logging.INFO)
    with tf.Session() as sess:
        src = gen_parser_ops.document_source(
            batch_size=32,
            corpus_name='input-from-file-conll',
            task_context=task_context_path)
        sentence = sentence_pb2.Sentence()
        result_dict = None
        while True:
            documents, finished = sess.run(src)
            for d in documents:
                sentence.ParseFromString(d)
                d = to_dict(sentence)
                result_dict = d

            if finished:
                break

    return result_dict
Example #20
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    with tf.Session() as sess:
        src = gen_parser_ops.document_source(batch_size=32,
                                             corpus_name=FLAGS.corpus_name,
                                             task_context=FLAGS.task_context)
        sentence = sentence_pb2.Sentence()
        while True:
            documents, finished = sess.run(src)
            logging.info('Read %d documents', len(documents))
            for d in documents:
                sentence.ParseFromString(d)
                tr = asciitree.LeftAligned()
                d = to_dict(sentence)
                print('Input: %s' % sentence.text)
                print('Parse:')
                print(tr(d))

            if finished:
                break
Example #21
0
def print_output(output_file, use_text_format, use_gold_segmentation, output):
    """Writes a set of sentences in CoNLL format.

  Args:
    output_file: The file to write to.
    use_text_format: Whether this computation used text-format input.
    use_gold_segmentation: Whether this computation used gold segmentation.
    output: A list of sentences to write to the output file.
  """
    with gfile.GFile(output_file, 'w') as f:
        f.write('## tf:{}\n'.format(use_text_format))
        f.write('## gs:{}\n'.format(use_gold_segmentation))
        for serialized_sentence in output:
            sentence = sentence_pb2.Sentence()
            sentence.ParseFromString(serialized_sentence)
            f.write('# text = {}\n'.format(sentence.text.encode('utf-8')))
            for i, token in enumerate(sentence.token):
                head = token.head + 1
                f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
                        (i + 1, token.word.encode('utf-8'), head,
                         token.label.encode('utf-8')))
            f.write('\n')
Example #22
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    with tf.Session() as sess:
        src = gen_parser_ops.document_source(batch_size=32,
                                             corpus_name=FLAGS.corpus_name,
                                             task_context=FLAGS.task_context)
        sentence = sentence_pb2.Sentence()
        while True:
            documents, finished = sess.run(src)
            logging.info('Read %d documents', len(documents))
            for d in documents:
                sentence.ParseFromString(d)
                tr = asciitree.LeftAligned()
                d = to_dict(sentence)
                print 'Input: %s' % sentence.text
                print 'Parse:'
                tr_str = tr(d)
                pat = re.compile(r'\s*@\d+$')
                for tr_ln in tr_str.splitlines():
                    print pat.sub('', tr_ln)

            if finished:
                break
Example #23
0
def inference(sess, graph, builder, annotator, text, enable_tracing=False):
    tokens = [
        sentence_pb2.Token(word=word, start=-1, end=-1)
        for word in text.split()
    ]
    sentence = sentence_pb2.Sentence()
    sentence.token.extend(tokens)
    if enable_tracing:
        annotations, traces = sess.run(
            [annotator['annotations'], annotator['traces']],
            feed_dict={
                annotator['input_batch']: [sentence.SerializeToString()]
            })
        #HTML(visualization.trace_html(traces[0]))
    else:
        annotations = sess.run(annotator['annotations'],
                               feed_dict={
                                   annotator['input_batch']:
                                   [sentence.SerializeToString()]
                               })

    parsed_sentence = sentence_pb2.Sentence.FromString(annotations[0])
    #HTML(render_parse_tree_graphviz.parse_tree_graph(parsed_sentence))
    return parsed_sentence
Example #24
0
 def _parse(self):
     with self.graph.as_default():
         num_epochs = None
         num_docs = 0
         result = []
         while True:
             tf_epochs, _, tf_documents = self.sess.run([
                 self.parser.evaluation['epochs'],
                 self.parser.evaluation['eval_metrics'],
                 self.parser.evaluation['documents']
             ])
             #print len(tf_documents)
             # assert len(tf_documents) == 1
             # print type(tf_documents[len(tf_documents)-1])
             if len(tf_documents) > 0:
                 for item in tf_documents:
                     doc = sentence_pb2.Sentence()
                     doc.ParseFromString(item)
                     result.append(ConvertToString(doc))
             if num_epochs is None:
                 num_epochs = tf_epochs
             elif num_epochs < tf_epochs:
                 break
         return result
Example #25
0
  def RunFullTrainingAndInference(self,
                                  test_name,
                                  master_spec_path=None,
                                  master_spec=None,
                                  component_weights=None,
                                  unroll_using_oracle=None,
                                  num_evaluated_components=1,
                                  expected_num_actions=None,
                                  expected=None,
                                  batch_size_limit=None):
    if not master_spec:
      master_spec = self.LoadSpec(master_spec_path)

    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    gold_reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]

    test_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
    test_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2)
    test_reader_strings = [
        test_doc.SerializeToString(), test_doc.SerializeToString(),
        test_doc_2.SerializeToString(), test_doc.SerializeToString()
    ]

    if batch_size_limit is not None:
      gold_reader_strings = gold_reader_strings[:batch_size_limit]
      test_reader_strings = test_reader_strings[:batch_size_limit]

    with tf.Graph().as_default():
      tf.set_random_seed(1)
      hyperparam_config = spec_pb2.GridPoint()
      builder = graph_builder.MasterBuilder(
          master_spec, hyperparam_config, pool_scope=test_name)
      target = spec_pb2.TrainTarget()
      target.name = 'testFullInference-train-%s' % test_name
      if component_weights:
        target.component_weights.extend(component_weights)
      else:
        target.component_weights.extend([0] * len(master_spec.component))
        target.component_weights[-1] = 1.0
      if unroll_using_oracle:
        target.unroll_using_oracle.extend(unroll_using_oracle)
      else:
        target.unroll_using_oracle.extend([False] * len(master_spec.component))
        target.unroll_using_oracle[-1] = True
      train = builder.add_training_from_config(target)
      oracle_trace = builder.add_training_from_config(
          target, prefix='train_traced-', trace_only=True)
      builder.add_saver()

      anno = builder.add_annotation(test_name)
      trace = builder.add_annotation(test_name + '-traced', enable_tracing=True)

      # Verifies that the summaries can be built.
      for component in builder.components:
        component.get_summaries()

      config = tf.ConfigProto(
          intra_op_parallelism_threads=0, inter_op_parallelism_threads=0)
      with self.test_session(config=config) as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        logging.info('Dry run oracle trace...')
        traces = sess.run(
            oracle_trace['traces'],
            feed_dict={oracle_trace['input_batch']: gold_reader_strings})

        # Check that the oracle traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        logging.info('Simulating training...')
        break_iter = 400
        is_resolved = False
        for i in range(0,
                       400):  # needs ~100 iterations, but is not deterministic
          cost, eval_res_val = sess.run(
              [train['cost'], train['metrics']],
              feed_dict={train['input_batch']: gold_reader_strings})
          logging.info('cost = %s', cost)
          self.assertFalse(np.isnan(cost))
          total_val = eval_res_val.reshape((-1, 2))[:, 0].sum()
          correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum()
          if correct_val == total_val and not is_resolved:
            logging.info('... converged on iteration %d with (correct, total) '
                         '= (%d, %d)', i, correct_val, total_val)
            is_resolved = True
            # Run for slightly longer than convergence to help with quantized
            # weight tiebreakers.
            break_iter = i + 50

          if i == break_iter:
            break

        # If training failed, report total/correct actions for each component.
        if not expected_num_actions:
          expected_num_actions = 4 * num_evaluated_components
        if (correct_val != total_val or correct_val != expected_num_actions or
            total_val != expected_num_actions):
          for c in xrange(len(master_spec.component)):
            logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c,
                          master_spec.component[c].name, eval_res_val[2 * c],
                          eval_res_val[2 * c + 1])

        assert correct_val == total_val, 'Did not converge! %d vs %d.' % (
            correct_val, total_val)

        self.assertEqual(expected_num_actions, correct_val)
        self.assertEqual(expected_num_actions, total_val)

        builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))

        logging.info('Running test.')
        logging.info('Printing annotations')
        annotations = sess.run(
            anno['annotations'],
            feed_dict={anno['input_batch']: test_reader_strings})
        logging.info('Put %d inputs in, got %d annotations out.',
                     len(test_reader_strings), len(annotations))

        # Also run the annotation graph with tracing enabled.
        annotations_with_trace, traces = sess.run(
            [trace['annotations'], trace['traces']],
            feed_dict={trace['input_batch']: test_reader_strings})

        # The result of the two annotation graphs should be identical.
        self.assertItemsEqual(annotations, annotations_with_trace)

        # Check that the inference traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        self.assertEqual(len(test_reader_strings), len(annotations))
        pred_sentences = []
        for annotation in annotations:
          pred_sentences.append(sentence_pb2.Sentence())
          pred_sentences[-1].ParseFromString(annotation)

        if expected is None:
          expected = _TAGGER_EXPECTED_SENTENCES

        expected_sentences = [expected[i] for i in [0, 0, 1, 0]]

        for i, pred_sentence in enumerate(pred_sentences):
          self.assertProtoEquals(expected_sentences[i], pred_sentence)
Example #26
0
def main(argv):
  del argv  # unused
  # Constructs lexical resources for SyntaxNet in the given resource path, from
  # the training data.
  lexicon.build_lexicon(
      lexicon_dir,
      training_sentence,
      training_corpus_format='sentence-prototext')

  # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  # sequence tagger.
  tagger = spec_builder.ComponentSpecBuilder('tagger')
  tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  tagger.set_transition_system(name='tagger')
  tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  tagger.add_rnn_link(embedding_dim=-1)
  tagger.fill_from_resources(lexicon_dir)

  # Construct the ComponentSpec for parsing.
  parser = spec_builder.ComponentSpecBuilder('parser')
  parser.set_network_unit(
      name='FeedForwardNetwork',
      hidden_layer_sizes='256',
      layer_norm_hidden='True')
  parser.set_transition_system(name='arc-standard')
  parser.add_token_link(
      source=tagger,
      fml='input.focus stack.focus stack(1).focus',
      embedding_dim=32,
      source_layer='logits')

  # Recurrent connection for the arc-standard parser. For both tokens on the
  # stack, we connect to the last time step to either SHIFT or REDUCE that
  # token. This allows the parser to build up compositional representations of
  # phrases.
  parser.add_link(
      source=parser,  # recurrent connection
      name='rnn-stack',  # unique identifier
      fml='stack.focus stack(1).focus',  # look for both stack tokens
      source_translator='shift-reduce-step',  # maps token indices -> step
      embedding_dim=32)  # project down to 32 dims
  parser.fill_from_resources(lexicon_dir)

  master_spec = spec_pb2.MasterSpec()
  master_spec.component.extend([tagger.spec, parser.spec])

  hyperparam_config = spec_pb2.GridPoint()

  # Build the TensorFlow graph.
  graph = tf.Graph()
  with graph.as_default():
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

    target = spec_pb2.TrainTarget()
    target.name = 'all'
    target.unroll_using_oracle.extend([True, True])
    dry_run = builder.add_training_from_config(target, trace_only=True)

  # Read in serialized protos from training data.
  sentence = sentence_pb2.Sentence()
  text_format.Merge(open(training_sentence).read(), sentence)
  training_set = [sentence.SerializeToString()]

  with tf.Session(graph=graph) as sess:
    # Make sure to re-initialize all underlying state.
    sess.run(tf.initialize_all_variables())
    traces = sess.run(
        dry_run['traces'], feed_dict={dry_run['input_batch']: training_set})

  with open('dragnn_tutorial_2.html', 'w') as f:
    f.write(
        visualization.trace_html(
            traces[0], height='400px', master_spec=master_spec).encode('utf-8'))
 def add_sentence_for_segment_eval(starts, ends, corpus):
     """Adds a sentence to the corpus."""
     sentence = sentence_pb2.Sentence()
     for start, end in zip(starts, ends):
         sentence.token.add(word='x', start=start, end=end)
     corpus.append(sentence.SerializeToString())
Example #28
0
def main(unused_argv):

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    ## SEGMENTATION ##

    if not FLAGS.use_gold_segmentation:

        # Reads master spec.
        master_spec = spec_pb2.MasterSpec()
        with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
            text_format.Parse(fin.read(), master_spec)

        if FLAGS.complete_master_spec:
            spec_builder.complete_master_spec(master_spec, None,
                                              FLAGS.segmenter_resource_dir)

        # Graph building.
        tf.logging.info('Building the graph')
        g = tf.Graph()
        with g.as_default(), tf.device('/device:CPU:0'):
            hyperparam_config = spec_pb2.GridPoint()
            hyperparam_config.use_moving_average = True
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            annotator = builder.add_annotation()
            builder.add_saver()

        tf.logging.info('Reading documents...')
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()
        with tf.Session(graph=tf.Graph()) as tmp_session:
            char_input = gen_parser_ops.char_token_generator(input_corpus)
            char_corpus = tmp_session.run(char_input)
        check.Eq(len(input_corpus), len(char_corpus))

        session_config = tf.ConfigProto(
            log_device_placement=False,
            intra_op_parallelism_threads=FLAGS.threads,
            inter_op_parallelism_threads=FLAGS.threads)

        with tf.Session(graph=g, config=session_config) as sess:
            tf.logging.info('Initializing variables...')
            sess.run(tf.global_variables_initializer())
            tf.logging.info('Loading from checkpoint...')
            sess.run('save/restore_all',
                     {'save/Const:0': FLAGS.segmenter_checkpoint_file})

            tf.logging.info('Processing sentences...')

            processed = []
            start_time = time.time()
            run_metadata = tf.RunMetadata()
            for start in range(0, len(char_corpus), FLAGS.max_batch_size):
                end = min(start + FLAGS.max_batch_size, len(char_corpus))
                feed_dict = {annotator['input_batch']: char_corpus[start:end]}
                if FLAGS.timeline_output_file and end == len(char_corpus):
                    serialized_annotations = sess.run(
                        annotator['annotations'],
                        feed_dict=feed_dict,
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                    trace = timeline.Timeline(
                        step_stats=run_metadata.step_stats)
                    with open(FLAGS.timeline_output_file, 'w') as trace_file:
                        trace_file.write(trace.generate_chrome_trace_format())
                else:
                    serialized_annotations = sess.run(annotator['annotations'],
                                                      feed_dict=feed_dict)
                processed.extend(serialized_annotations)

            tf.logging.info('Processed %d documents in %.2f seconds.',
                            len(char_corpus),
                            time.time() - start_time)

        input_corpus = processed
    else:
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()

    ## PARSING

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.parser_resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all',
                 {'save/Const:0': FLAGS.parser_checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write('#' + sentence.text.encode('utf-8') + '\n')
                    for i, token in enumerate(sentence.token):
                        head = token.head + 1
                        f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
                                (i + 1, token.word.encode('utf-8'), head,
                                 token.label.encode('utf-8')))
                    f.write('\n\n')
Example #29
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    # Rewrite resource locations.
    if FLAGS.resource_dir:
        for component in master_spec.component:
            for resource in component.resource:
                for part in resource.part:
                    part.file_pattern = os.path.join(FLAGS.resource_dir,
                                                     part.file_pattern)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')
    input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)
        pos, uas, las = evaluation.calculate_parse_metrics(
            input_corpus, processed)
        if FLAGS.log_file:
            with gfile.GFile(FLAGS.log_file, 'w') as f:
                f.write('%s\t%f\t%f\t%f\n' %
                        (FLAGS.language_name, pos, uas, las))

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write(text_format.MessageToString(sentence) + '\n\n')