コード例 #1
0
ファイル: evaluator.py プロジェクト: dhanya1/full_cyclist
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    # Rewrite resource locations.
    if FLAGS.resource_dir:
        for component in master_spec.component:
            for resource in component.resource:
                for part in resource.part:
                    part.file_pattern = os.path.join(FLAGS.resource_dir,
                                                     part.file_pattern)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')
    input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)
        pos, uas, las = evaluation.calculate_parse_metrics(
            input_corpus, processed)
        if FLAGS.log_file:
            with gfile.GFile(FLAGS.log_file, 'w') as f:
                f.write('%s\t%f\t%f\t%f\n' %
                        (FLAGS.language_name, pos, uas, las))

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write(text_format.MessageToString(sentence) + '\n\n')
コード例 #2
0
 def testCalculateParseMetrics(self):
     pos, uas, las = evaluation.calculate_parse_metrics(
         self._gold_corpus, self._test_corpus)
     self.assertEqual(75, pos)
     self.assertEqual(50, uas)
     self.assertEqual(25, las)
コード例 #3
0
def main(unused_argv):

  # Parse the flags containing lists, using regular expressions.
  # This matches and extracts key=value pairs.
  component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                    FLAGS.inference_beam_size)
  # This matches strings separated by a comma. Does not return any empty
  # strings.
  components_to_locally_normalize = re.findall(r'[^,]+',
                                               FLAGS.locally_normalize)

  ## SEGMENTATION ##

  if not FLAGS.use_gold_segmentation:

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
      text_format.Parse(fin.read(), master_spec)

    if FLAGS.complete_master_spec:
      spec_builder.complete_master_spec(
          master_spec, None, FLAGS.segmenter_resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
      hyperparam_config = spec_pb2.GridPoint()
      hyperparam_config.use_moving_average = True
      builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
      annotator = builder.add_annotation()
      builder.add_saver()

    tf.logging.info('Reading documents...')
    if FLAGS.text_format:
      char_corpus = sentence_io.FormatSentenceReader(
          FLAGS.input_file, 'untokenized-text').corpus()
    else:
      input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
      with tf.Session(graph=tf.Graph()) as tmp_session:
        char_input = gen_parser_ops.char_token_generator(input_corpus)
        char_corpus = tmp_session.run(char_input)
      check.Eq(len(input_corpus), len(char_corpus))

    session_config = tf.ConfigProto(
        log_device_placement=False,
        intra_op_parallelism_threads=FLAGS.threads,
        inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
      tf.logging.info('Initializing variables...')
      sess.run(tf.global_variables_initializer())
      tf.logging.info('Loading from checkpoint...')
      sess.run('save/restore_all',
               {'save/Const:0': FLAGS.segmenter_checkpoint_file})

      tf.logging.info('Processing sentences...')

      processed = []
      start_time = time.time()
      run_metadata = tf.RunMetadata()
      for start in range(0, len(char_corpus), FLAGS.max_batch_size):
        end = min(start + FLAGS.max_batch_size, len(char_corpus))
        feed_dict = {annotator['input_batch']: char_corpus[start:end]}
        if FLAGS.timeline_output_file and end == len(char_corpus):
          serialized_annotations = sess.run(
              annotator['annotations'], feed_dict=feed_dict,
              options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
              run_metadata=run_metadata)
          trace = timeline.Timeline(step_stats=run_metadata.step_stats)
          with open(FLAGS.timeline_output_file, 'w') as trace_file:
            trace_file.write(trace.generate_chrome_trace_format())
        else:
          serialized_annotations = sess.run(
              annotator['annotations'], feed_dict=feed_dict)
        processed.extend(serialized_annotations)

      tf.logging.info('Processed %d documents in %.2f seconds.',
                      len(char_corpus), time.time() - start_time)

    input_corpus = processed
  else:
    input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()

  ## PARSING

  # Reads master spec.
  master_spec = spec_pb2.MasterSpec()
  with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
    text_format.Parse(fin.read(), master_spec)

  if FLAGS.complete_master_spec:
    spec_builder.complete_master_spec(
        master_spec, None, FLAGS.parser_resource_dir)

  # Graph building.
  tf.logging.info('Building the graph')
  g = tf.Graph()
  with g.as_default(), tf.device('/device:CPU:0'):
    hyperparam_config = spec_pb2.GridPoint()
    hyperparam_config.use_moving_average = True
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
    annotator = builder.add_annotation()
    builder.add_saver()

  tf.logging.info('Reading documents...')

  session_config = tf.ConfigProto(
      log_device_placement=False,
      intra_op_parallelism_threads=FLAGS.threads,
      inter_op_parallelism_threads=FLAGS.threads)

  with tf.Session(graph=g, config=session_config) as sess:
    tf.logging.info('Initializing variables...')
    sess.run(tf.global_variables_initializer())

    tf.logging.info('Loading from checkpoint...')
    sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file})

    tf.logging.info('Processing sentences...')

    processed = []
    start_time = time.time()
    run_metadata = tf.RunMetadata()
    for start in range(0, len(input_corpus), FLAGS.max_batch_size):
      end = min(start + FLAGS.max_batch_size, len(input_corpus))
      feed_dict = {annotator['input_batch']: input_corpus[start:end]}
      for comp, beam_size in component_beam_sizes:
        feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
      for comp in components_to_locally_normalize:
        feed_dict['%s/LocallyNormalize:0' % comp] = True
      if FLAGS.timeline_output_file and end == len(input_corpus):
        serialized_annotations = sess.run(
            annotator['annotations'], feed_dict=feed_dict,
            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
            run_metadata=run_metadata)
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        with open(FLAGS.timeline_output_file, 'w') as trace_file:
          trace_file.write(trace.generate_chrome_trace_format())
      else:
        serialized_annotations = sess.run(
            annotator['annotations'], feed_dict=feed_dict)
      processed.extend(serialized_annotations)

    tf.logging.info('Processed %d documents in %.2f seconds.',
                    len(input_corpus), time.time() - start_time)
    _, uas, las = evaluation.calculate_parse_metrics(input_corpus, processed)
    tf.logging.info('UAS: %.2f', uas)
    tf.logging.info('LAS: %.2f', las)

    if FLAGS.output_file:
      with gfile.GFile(FLAGS.output_file, 'w') as f:
        f.write('## tf:{}\n'.format(FLAGS.text_format))
        f.write('## gs:{}\n'.format(FLAGS.use_gold_segmentation))
        for serialized_sentence in processed:
          sentence = sentence_pb2.Sentence()
          sentence.ParseFromString(serialized_sentence)
          f.write('# text = {}\n'.format(sentence.text.encode('utf-8')))
          for i, token in enumerate(sentence.token):
            head = token.head + 1
            f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n'%(
                i + 1,
                token.word.encode('utf-8'), head,
                token.label.encode('utf-8')))
          f.write('\n')
コード例 #4
0
def run_parser(input_data, parser_model, session_config, beam_sizes,
               locally_normalized_components, max_batch_size,
               timeline_output_file):
    """Runs the provided segmenter model on the provided character corpus.

  Args:
    input_data: Input corpus to parse.
    parser_model: Path to a SavedModel file containing the parser graph.
    session_config: A session configuration object.
    beam_sizes: A dict of component names : beam sizes (optional).
    locally_normalized_components: A list of components to normalize (optional).
    max_batch_size: The maximum batch size to use.
    timeline_output_file: Filepath for timeline export. Does not export if None.

  Returns:
    A list of parsed sentences.
  """
    parser_graph = tf.Graph()
    with tf.Session(graph=parser_graph, config=session_config) as sess:
        tf.logging.info('Initializing parser model...')
        tf.saved_model.loader.load(sess,
                                   [tf.saved_model.tag_constants.SERVING],
                                   parser_model)

        tf.logging.info('Parsing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        tf.logging.info('Corpus length is %d' % len(input_data))
        for start in range(0, len(input_data), max_batch_size):
            # Set up the input and output.
            end = min(start + max_batch_size, len(input_data))
            feed_dict = {
                'annotation/ComputeSession/InputBatch:0': input_data[start:end]
            }
            for comp, beam_size in beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in locally_normalized_components:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            output_node = 'annotation/annotations:0'

            # Process.
            tf.logging.info('Processing examples %d to %d' % (start, end))
            if timeline_output_file and end == len(input_data):
                serialized_annotations = sess.run(
                    output_node,
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(output_node,
                                                  feed_dict=feed_dict)

            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_data),
                        time.time() - start_time)
        _, uas, las = evaluation.calculate_parse_metrics(input_data, processed)
        tf.logging.info('UAS: %.2f', uas)
        tf.logging.info('LAS: %.2f', las)

    return processed
コード例 #5
0
ファイル: parse_to_conll.py プロジェクト: NoPointExc/models
def run_parser(input_data, parser_model, session_config, beam_sizes,
               locally_normalized_components, max_batch_size,
               timeline_output_file):
  """Runs the provided segmenter model on the provided character corpus.

  Args:
    input_data: Input corpus to parse.
    parser_model: Path to a SavedModel file containing the parser graph.
    session_config: A session configuration object.
    beam_sizes: A dict of component names : beam sizes (optional).
    locally_normalized_components: A list of components to normalize (optional).
    max_batch_size: The maximum batch size to use.
    timeline_output_file: Filepath for timeline export. Does not export if None.

  Returns:
    A list of parsed sentences.
  """
  parser_graph = tf.Graph()
  with tf.Session(graph=parser_graph, config=session_config) as sess:
    tf.logging.info('Initializing parser model...')
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
                               parser_model)

    tf.logging.info('Parsing sentences...')

    processed = []
    start_time = time.time()
    run_metadata = tf.RunMetadata()
    tf.logging.info('Corpus length is %d' % len(input_data))
    for start in range(0, len(input_data), max_batch_size):
      # Set up the input and output.
      end = min(start + max_batch_size, len(input_data))
      feed_dict = {
          'annotation/ComputeSession/InputBatch:0': input_data[start:end]
      }
      for comp, beam_size in beam_sizes:
        feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
      for comp in locally_normalized_components:
        feed_dict['%s/LocallyNormalize:0' % comp] = True
      output_node = 'annotation/annotations:0'

      # Process.
      tf.logging.info('Processing examples %d to %d' % (start, end))
      if timeline_output_file and end == len(input_data):
        serialized_annotations = sess.run(
            output_node,
            feed_dict=feed_dict,
            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
            run_metadata=run_metadata)
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        with open(timeline_output_file, 'w') as trace_file:
          trace_file.write(trace.generate_chrome_trace_format())
      else:
        serialized_annotations = sess.run(output_node, feed_dict=feed_dict)

      processed.extend(serialized_annotations)

    tf.logging.info('Processed %d documents in %.2f seconds.',
                    len(input_data), time.time() - start_time)
    _, uas, las = evaluation.calculate_parse_metrics(input_data, processed)
    tf.logging.info('UAS: %.2f', uas)
    tf.logging.info('LAS: %.2f', las)

  return processed
コード例 #6
0
ファイル: evaluator.py プロジェクト: knathanieltucker/models
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  # Parse the flags containint lists, using regular expressions.
  # This matches and extracts key=value pairs.
  component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                    FLAGS.inference_beam_size)
  # This matches strings separated by a comma. Does not return any empty
  # strings.
  components_to_locally_normalize = re.findall(r'[^,]+',
                                               FLAGS.locally_normalize)

  # Reads master spec.
  master_spec = spec_pb2.MasterSpec()
  with gfile.FastGFile(FLAGS.master_spec) as fin:
    text_format.Parse(fin.read(), master_spec)

  # Rewrite resource locations.
  if FLAGS.resource_dir:
    for component in master_spec.component:
      for resource in component.resource:
        for part in resource.part:
          part.file_pattern = os.path.join(FLAGS.resource_dir,
                                           part.file_pattern)

  if FLAGS.complete_master_spec:
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir)

  # Graph building.
  tf.logging.info('Building the graph')
  g = tf.Graph()
  with g.as_default(), tf.device('/device:CPU:0'):
    hyperparam_config = spec_pb2.GridPoint()
    hyperparam_config.use_moving_average = True
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
    annotator = builder.add_annotation()
    builder.add_saver()

  tf.logging.info('Reading documents...')
  input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()

  session_config = tf.ConfigProto(
      log_device_placement=False,
      intra_op_parallelism_threads=FLAGS.threads,
      inter_op_parallelism_threads=FLAGS.threads)

  with tf.Session(graph=g, config=session_config) as sess:
    tf.logging.info('Initializing variables...')
    sess.run(tf.global_variables_initializer())

    tf.logging.info('Loading from checkpoint...')
    sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})

    tf.logging.info('Processing sentences...')

    processed = []
    start_time = time.time()
    run_metadata = tf.RunMetadata()
    for start in range(0, len(input_corpus), FLAGS.max_batch_size):
      end = min(start + FLAGS.max_batch_size, len(input_corpus))
      feed_dict = {annotator['input_batch']: input_corpus[start:end]}
      for comp, beam_size in component_beam_sizes:
        feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
      for comp in components_to_locally_normalize:
        feed_dict['%s/LocallyNormalize:0' % comp] = True
      if FLAGS.timeline_output_file and end == len(input_corpus):
        serialized_annotations = sess.run(
            annotator['annotations'], feed_dict=feed_dict,
            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
            run_metadata=run_metadata)
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        with open(FLAGS.timeline_output_file, 'w') as trace_file:
          trace_file.write(trace.generate_chrome_trace_format())
      else:
        serialized_annotations = sess.run(
            annotator['annotations'], feed_dict=feed_dict)
      processed.extend(serialized_annotations)

    tf.logging.info('Processed %d documents in %.2f seconds.',
                    len(input_corpus), time.time() - start_time)
    pos, uas, las = evaluation.calculate_parse_metrics(input_corpus, processed)

    if FLAGS.output_file:
      with gfile.GFile(FLAGS.output_file, 'w') as f:
        for serialized_sentence in processed:
          sentence = sentence_pb2.Sentence()
          sentence.ParseFromString(serialized_sentence)
          f.write(text_format.MessageToString(sentence) + '\n\n')
コード例 #7
0
def main(unused_argv):

    # Parse the flags containing lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    ## SEGMENTATION ##

    if not FLAGS.use_gold_segmentation:

        # Reads master spec.
        master_spec = spec_pb2.MasterSpec()
        with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
            text_format.Parse(fin.read(), master_spec)

        if FLAGS.complete_master_spec:
            spec_builder.complete_master_spec(master_spec, None,
                                              FLAGS.segmenter_resource_dir)

        # Graph building.
        tf.logging.info('Building the graph')
        g = tf.Graph()
        with g.as_default(), tf.device('/device:CPU:0'):
            hyperparam_config = spec_pb2.GridPoint()
            hyperparam_config.use_moving_average = True
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            annotator = builder.add_annotation()
            builder.add_saver()

        tf.logging.info('Reading documents...')
        if FLAGS.text_format:
            char_corpus = sentence_io.FormatSentenceReader(
                FLAGS.input_file, 'untokenized-text').corpus()
        else:
            input_corpus = sentence_io.ConllSentenceReader(
                FLAGS.input_file).corpus()
            with tf.Session(graph=tf.Graph()) as tmp_session:
                char_input = gen_parser_ops.char_token_generator(input_corpus)
                char_corpus = tmp_session.run(char_input)
            check.Eq(len(input_corpus), len(char_corpus))

        session_config = tf.ConfigProto(
            log_device_placement=False,
            intra_op_parallelism_threads=FLAGS.threads,
            inter_op_parallelism_threads=FLAGS.threads)

        with tf.Session(graph=g, config=session_config) as sess:
            tf.logging.info('Initializing variables...')
            sess.run(tf.global_variables_initializer())
            tf.logging.info('Loading from checkpoint...')
            sess.run('save/restore_all',
                     {'save/Const:0': FLAGS.segmenter_checkpoint_file})

            tf.logging.info('Processing sentences...')

            processed = []
            start_time = time.time()
            run_metadata = tf.RunMetadata()
            for start in range(0, len(char_corpus), FLAGS.max_batch_size):
                end = min(start + FLAGS.max_batch_size, len(char_corpus))
                feed_dict = {annotator['input_batch']: char_corpus[start:end]}
                if FLAGS.timeline_output_file and end == len(char_corpus):
                    serialized_annotations = sess.run(
                        annotator['annotations'],
                        feed_dict=feed_dict,
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                    trace = timeline.Timeline(
                        step_stats=run_metadata.step_stats)
                    with open(FLAGS.timeline_output_file, 'w') as trace_file:
                        trace_file.write(trace.generate_chrome_trace_format())
                else:
                    serialized_annotations = sess.run(annotator['annotations'],
                                                      feed_dict=feed_dict)
                processed.extend(serialized_annotations)

            tf.logging.info('Processed %d documents in %.2f seconds.',
                            len(char_corpus),
                            time.time() - start_time)

        input_corpus = processed
    else:
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()

    ## PARSING

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.parser_resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all',
                 {'save/Const:0': FLAGS.parser_checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)
        _, uas, las = evaluation.calculate_parse_metrics(
            input_corpus, processed)
        tf.logging.info('UAS: %.2f', uas)
        tf.logging.info('LAS: %.2f', las)

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                f.write('## tf:{}\n'.format(FLAGS.text_format))
                f.write('## gs:{}\n'.format(FLAGS.use_gold_segmentation))
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write('# text = {}\n'.format(
                        sentence.text.encode('utf-8')))
                    for i, token in enumerate(sentence.token):
                        head = token.head + 1
                        f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
                                (i + 1, token.word.encode('utf-8'), head,
                                 token.label.encode('utf-8')))
                    f.write('\n')