def _load_model(self, base_dir, master_spec_name):
        master_spec = spec_pb2.MasterSpec()
        with open(os.path.join(base_dir, master_spec_name)) as f:
            text_format.Merge(f.read(), master_spec)
        spec_builder.complete_master_spec(master_spec, None, base_dir)

        graph = tf.Graph()
        with graph.as_default():
            hyperparam_config = spec_pb2.GridPoint()
            builder = graph_builder.MasterBuilder(
                master_spec,
                hyperparam_config
            )
            annotator = builder.add_annotation(enable_tracing=True)
            builder.add_saver()

        sess = tf.Session(graph=graph)
        with graph.as_default():
            builder.saver.restore(sess, os.path.join(base_dir, "checkpoint"))

        def annotate_sentence(sentence):
            with graph.as_default():
                return sess.run(
                    [annotator['annotations'], annotator['traces']],
                    feed_dict={annotator['input_batch']: [sentence]}
                )
        return annotate_sentence
    def load_model(self, base_dir, master_spec_name, checkpoint_name):
        # Read the master spec
        master_spec = spec_pb2.MasterSpec()
        with open(os.path.join(base_dir, master_spec_name), "r") as f:
            text_format.Merge(f.read(), master_spec)
        spec_builder.complete_master_spec(master_spec, None, base_dir)
        logging.set_verbosity(logging.WARN)  # Turn off TensorFlow spam.

        # Initialize a graph
        graph = tf.Graph()
        with graph.as_default():
            hyperparam_config = spec_pb2.GridPoint()
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            # This is the component that will annotate test sentences.
            annotator = builder.add_annotation(enable_tracing=True)
            builder.add_saver(
            )  # "Savers" can save and load models; here, we're only going to load.

        sess = tf.Session(graph=graph)
        with graph.as_default():
            # sess.run(tf.global_variables_initializer())
            # sess.run('save/restore_all', {'save/Const:0': os.path.join(base_dir, checkpoint_name)})
            builder.saver.restore(sess, os.path.join(base_dir,
                                                     checkpoint_name))

        def annotate_sentence(sentence):
            with graph.as_default():
                return sess.run(
                    [annotator['annotations'], annotator['traces']],
                    feed_dict={annotator['input_batch']: [sentence]})

        return annotate_sentence
示例#3
0
    def load_model(self, base_dir, master_spec_name, checkpoint_name="checkpoint", rename=True):
        try:
            master_spec = spec_pb2.MasterSpec()
            with open(os.path.join(base_dir, master_spec_name)) as f:
                text_format.Merge(f.read(), master_spec)
            spec_builder.complete_master_spec(master_spec, None, base_dir)

            graph = tf.Graph()
            with graph.as_default():
                hyperparam_config = spec_pb2.GridPoint()
                builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
                annotator = builder.add_annotation(enable_tracing=True)
                builder.add_saver()

            sess = tf.Session(graph=graph)
            with graph.as_default():
                builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name))

            def annotate_sentence(sentence):
                with graph.as_default():
                    return sess.run([annotator['annotations'], annotator['traces']],
                                    feed_dict={annotator['input_batch']: [sentence]})
        except:
            if rename:
                self.rename_vars(base_dir, checkpoint_name)
                return self.load_model(base_dir, master_spec_name, checkpoint_name, False)
            raise Exception('Cannot load model: spec expects references to */kernel tensors instead of */weights.\
            Try running with rename=True or run rename_vars() to convert existing checkpoint files into supported format')

        return annotate_sentence
示例#4
0
def load_master_spec(spec_file, resource_path):
    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(spec_file, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    return master_spec
示例#5
0
def load_master_spec(spec_file, resource_path) :
    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(spec_file, 'r') as fin :
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    return master_spec
def export(master_spec_path, params_path, resource_path, export_path,
           export_moving_averages):
    """Restores a model and exports it in SavedModel form.

  This method loads a graph specified by the spec at master_spec_path and the
  params in params_path. It then saves the model in SavedModel format to the
  location specified in export_path.

  Args:
    master_spec_path: Path to a proto-text master spec.
    params_path: Path to the parameters file to export.
    resource_path: Path to resources in the master spec.
    export_path: Path to export the SavedModel to.
    export_moving_averages: Whether to export the moving average parameters.
  """
    # Old CoNLL checkpoints did not need a known-word-map. Create a temporary if
    # that file is missing.
    if not tf.gfile.Exists(os.path.join(resource_path, 'known-word-map')):
        with tf.gfile.FastGFile(os.path.join(resource_path, 'known-word-map'),
                                'w') as out_file:
            out_file.write('This file intentionally left blank.')

    graph = tf.Graph()
    master_spec = spec_pb2.MasterSpec()
    with tf.gfile.FastGFile(master_spec_path) as fin:
        text_format.Parse(fin.read(), master_spec)

    # This is a workaround for an issue where the segmenter master-spec had a
    # spurious resource in it; this resource was not respected in the spec-builder
    # and ended up crashing the saver (since it didn't really exist).
    for component in master_spec.component:
        del component.resource[:]

    spec_builder.complete_master_spec(master_spec, None, resource_path)

    # Remove '/' if it exists at the end of the export path, ensuring that
    # path utils work correctly.
    stripped_path = export_path.rstrip('/')
    saver_lib.clean_output_paths(stripped_path)

    short_to_original = saver_lib.shorten_resource_paths(master_spec)
    saver_lib.export_master_spec(master_spec, graph)
    saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph,
                              export_moving_averages)
    saver_lib.export_assets(master_spec, short_to_original, stripped_path)
def export(master_spec_path, params_path, resource_path, export_path,
           export_moving_averages):
  """Restores a model and exports it in SavedModel form.

  This method loads a graph specified by the spec at master_spec_path and the
  params in params_path. It then saves the model in SavedModel format to the
  location specified in export_path.

  Args:
    master_spec_path: Path to a proto-text master spec.
    params_path: Path to the parameters file to export.
    resource_path: Path to resources in the master spec.
    export_path: Path to export the SavedModel to.
    export_moving_averages: Whether to export the moving average parameters.
  """
  # Old CoNLL checkpoints did not need a known-word-map. Create a temporary if
  # that file is missing.
  if not tf.gfile.Exists(os.path.join(resource_path, 'known-word-map')):
    with tf.gfile.FastGFile(os.path.join(resource_path, 'known-word-map'),
                            'w') as out_file:
      out_file.write('This file intentionally left blank.')

  graph = tf.Graph()
  master_spec = spec_pb2.MasterSpec()
  with tf.gfile.FastGFile(master_spec_path) as fin:
    text_format.Parse(fin.read(), master_spec)

  # This is a workaround for an issue where the segmenter master-spec had a
  # spurious resource in it; this resource was not respected in the spec-builder
  # and ended up crashing the saver (since it didn't really exist).
  for component in master_spec.component:
    del component.resource[:]

  spec_builder.complete_master_spec(master_spec, None, resource_path)

  # Remove '/' if it exists at the end of the export path, ensuring that
  # path utils work correctly.
  stripped_path = export_path.rstrip('/')
  saver_lib.clean_output_paths(stripped_path)

  short_to_original = saver_lib.shorten_resource_paths(master_spec)
  saver_lib.export_master_spec(master_spec, graph)
  saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph,
                            export_moving_averages)
  saver_lib.export_assets(master_spec, short_to_original, stripped_path)
示例#8
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    check.IsTrue(FLAGS.checkpoint_filename)
    check.IsTrue(FLAGS.tensorboard_dir)
    check.IsTrue(FLAGS.resource_path)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
    tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

    # SummaryWriter for TensorBoard
    tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
    tf.logging.info('Deleting prior data if exists...')

    stats_file = '%s.stats' % FLAGS.checkpoint_filename
    try:
        stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
        stats = [int(x) for x in stats]
    except errors.OpError:
        stats = [-1, 0, 0]

    tf.logging.info('Read ckpt stats: %s', str(stats))
    do_restore = True
    if stats[0] < FLAGS.job_id:
        do_restore = False
        tf.logging.info('Deleting last job: %d', stats[0])
        try:
            gfile.DeleteRecursively(FLAGS.tensorboard_dir)
            gfile.Remove(FLAGS.checkpoint_filename)
        except errors.OpError as err:
            tf.logging.error('Unable to delete prior files: %s', err)
        stats = [FLAGS.job_id, 0, 0]

    tf.logging.info('Creating the directory again...')
    gfile.MakeDirs(FLAGS.tensorboard_dir)
    tf.logging.info('Created! Instatiating SummaryWriter...')
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path,
                              training_corpus_path,
                              morph_to_pos=True)

    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    tf.logging.info('Building Graph...')
    hyperparam_config = spec_pb2.GridPoint()
    try:
        text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    except text_format.ParseError:
        text_format.Parse(base64.b64decode(FLAGS.hyperparams),
                          hyperparam_config)
    g = tf.Graph()
    with g.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = ConllSentenceReader(
        training_corpus_path,
        projectivize=FLAGS.projectivize_training_set,
        morph_to_pos=True).corpus()
    tune_set = ConllSentenceReader(tune_corpus_path,
                                   projectivize=False,
                                   morph_to_pos=True).corpus()

    # Ready to train_bkp!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(tune_set))

    pretrain_steps = [10000, 0]
    tagger_steps = 100000
    train_steps = [tagger_steps, 8 * tagger_steps]

    with tf.Session(FLAGS.tf_master, graph=g) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        if do_restore:
            tf.logging.info('Restoring from checkpoint...')
            builder.saver.restore(sess, FLAGS.checkpoint_filename)

            prev_tagger_steps = stats[1]
            prev_parser_steps = stats[2]
            tf.logging.info('adjusting schedule from steps: %d, %d',
                            prev_tagger_steps, prev_parser_steps)
            pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0)
            tf.logging.info('new pretrain steps: %d', pretrain_steps[0])

        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, training_set, tune_set, tune_set,
                                 FLAGS.batch_size, summary_writer,
                                 FLAGS.report_every, builder.saver,
                                 FLAGS.checkpoint_filename, stats)
示例#9
0
def build_complete_master_spec(resource_path):
    tf.logging.info('Building MasterSpec...')
    master_spec = build_master_spec()
    spec_builder.complete_master_spec(master_spec, None, resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    return master_spec
示例#10
0
def main(unused_argv):

  # Parse the flags containint lists, using regular expressions.
  # This matches and extracts key=value pairs.
  component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                    FLAGS.inference_beam_size)
  # This matches strings separated by a comma. Does not return any empty
  # strings.
  components_to_locally_normalize = re.findall(r'[^,]+',
                                               FLAGS.locally_normalize)

  # Reads master spec.
  master_spec = spec_pb2.MasterSpec()
  with gfile.FastGFile(FLAGS.master_spec) as fin:
    text_format.Parse(fin.read(), master_spec)

  # Rewrite resource locations.
  if FLAGS.resource_dir:
    for component in master_spec.component:
      for resource in component.resource:
        for part in resource.part:
          part.file_pattern = os.path.join(FLAGS.resource_dir,
                                           part.file_pattern)

  if FLAGS.complete_master_spec:
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir)

  # Graph building.
  tf.logging.info('Building the graph')
  g = tf.Graph()
  with g.as_default(), tf.device('/device:CPU:0'):
    hyperparam_config = spec_pb2.GridPoint()
    hyperparam_config.use_moving_average = True
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
    annotator = builder.add_annotation()
    builder.add_saver()

  tf.logging.info('Reading documents...')
  input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
  with tf.Session(graph=tf.Graph()) as tmp_session:
    char_input = gen_parser_ops.char_token_generator(input_corpus)
    char_corpus = tmp_session.run(char_input)
  check.Eq(len(input_corpus), len(char_corpus))

  session_config = tf.ConfigProto(
      log_device_placement=False,
      intra_op_parallelism_threads=FLAGS.threads,
      inter_op_parallelism_threads=FLAGS.threads)

  with tf.Session(graph=g, config=session_config) as sess:
    tf.logging.info('Initializing variables...')
    sess.run(tf.global_variables_initializer())

    tf.logging.info('Loading from checkpoint...')
    sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})

    tf.logging.info('Processing sentences...')

    processed = []
    start_time = time.time()
    run_metadata = tf.RunMetadata()
    for start in range(0, len(char_corpus), FLAGS.max_batch_size):
      end = min(start + FLAGS.max_batch_size, len(char_corpus))
      feed_dict = {annotator['input_batch']: char_corpus[start:end]}
      for comp, beam_size in component_beam_sizes:
        feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
      for comp in components_to_locally_normalize:
        feed_dict['%s/LocallyNormalize:0' % comp] = True
      if FLAGS.timeline_output_file and end == len(char_corpus):
        serialized_annotations = sess.run(
            annotator['annotations'], feed_dict=feed_dict,
            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
            run_metadata=run_metadata)
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        with open(FLAGS.timeline_output_file, 'w') as trace_file:
          trace_file.write(trace.generate_chrome_trace_format())
      else:
        serialized_annotations = sess.run(
            annotator['annotations'], feed_dict=feed_dict)
      processed.extend(serialized_annotations)

    tf.logging.info('Processed %d documents in %.2f seconds.',
                    len(char_corpus), time.time() - start_time)
    evaluation.calculate_segmentation_metrics(input_corpus, processed)

    if FLAGS.output_file:
      with gfile.GFile(FLAGS.output_file, 'w') as f:
        for serialized_sentence in processed:
          sentence = sentence_pb2.Sentence()
          sentence.ParseFromString(serialized_sentence)
          f.write(text_format.MessageToString(sentence) + '\n\n')
示例#11
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    # Rewrite resource locations.
    if FLAGS.resource_dir:
        for component in master_spec.component:
            for resource in component.resource:
                for part in resource.part:
                    part.file_pattern = os.path.join(FLAGS.resource_dir,
                                                     part.file_pattern)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')
    input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)
        pos, uas, las = evaluation.calculate_parse_metrics(
            input_corpus, processed)
        if FLAGS.log_file:
            with gfile.GFile(FLAGS.log_file, 'w') as f:
                f.write('%s\t%f\t%f\t%f\n' %
                        (FLAGS.language_name, pos, uas, las))

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write(text_format.MessageToString(sentence) + '\n\n')
示例#12
0
文件: trainer.py 项目: JiweiHe/models
def main(unused_argv):
  logging.set_verbosity(logging.INFO)
  check.IsTrue(FLAGS.checkpoint_filename)
  check.IsTrue(FLAGS.tensorboard_dir)
  check.IsTrue(FLAGS.resource_path)

  if not gfile.IsDirectory(FLAGS.resource_path):
    gfile.MakeDirs(FLAGS.resource_path)

  training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
  tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

  # SummaryWriter for TensorBoard
  tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
  tf.logging.info('Deleting prior data if exists...')

  stats_file = '%s.stats' % FLAGS.checkpoint_filename
  try:
    stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
    stats = [int(x) for x in stats]
  except errors.OpError:
    stats = [-1, 0, 0]

  tf.logging.info('Read ckpt stats: %s', str(stats))
  do_restore = True
  if stats[0] < FLAGS.job_id:
    do_restore = False
    tf.logging.info('Deleting last job: %d', stats[0])
    try:
      gfile.DeleteRecursively(FLAGS.tensorboard_dir)
      gfile.Remove(FLAGS.checkpoint_filename)
    except errors.OpError as err:
      tf.logging.error('Unable to delete prior files: %s', err)
    stats = [FLAGS.job_id, 0, 0]

  tf.logging.info('Creating the directory again...')
  gfile.MakeDirs(FLAGS.tensorboard_dir)
  tf.logging.info('Created! Instatiating SummaryWriter...')
  summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
  tf.logging.info('Creating TensorFlow checkpoint dir...')
  gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

  # Constructs lexical resources for SyntaxNet in the given resource path, from
  # the training data.
  if FLAGS.compute_lexicon:
    logging.info('Computing lexicon...')
    lexicon.build_lexicon(
        FLAGS.resource_path, training_corpus_path, morph_to_pos=True)

  tf.logging.info('Loading MasterSpec...')
  master_spec = spec_pb2.MasterSpec()
  with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin:
    text_format.Parse(fin.read(), master_spec)
  spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path)
  logging.info('Constructed master spec: %s', str(master_spec))
  hyperparam_config = spec_pb2.GridPoint()

  # Build the TensorFlow graph.
  tf.logging.info('Building Graph...')
  hyperparam_config = spec_pb2.GridPoint()
  try:
    text_format.Parse(FLAGS.hyperparams, hyperparam_config)
  except text_format.ParseError:
    text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config)
  g = tf.Graph()
  with g.as_default():
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
    component_targets = [
        spec_pb2.TrainTarget(
            name=component.name,
            max_index=idx + 1,
            unroll_using_oracle=[False] * idx + [True])
        for idx, component in enumerate(master_spec.component)
        if 'shift-only' not in component.transition_system.registered_name
    ]
    trainers = [
        builder.add_training_from_config(target) for target in component_targets
    ]
    annotator = builder.add_annotation()
    builder.add_saver()

  # Read in serialized protos from training data.
  training_set = ConllSentenceReader(
      training_corpus_path,
      projectivize=FLAGS.projectivize_training_set,
      morph_to_pos=True).corpus()
  tune_set = ConllSentenceReader(
      tune_corpus_path, projectivize=False, morph_to_pos=True).corpus()

  # Ready to train!
  logging.info('Training on %d sentences.', len(training_set))
  logging.info('Tuning on %d sentences.', len(tune_set))

  pretrain_steps = [10000, 0]
  tagger_steps = 100000
  train_steps = [tagger_steps, 8 * tagger_steps]

  with tf.Session(FLAGS.tf_master, graph=g) as sess:
    # Make sure to re-initialize all underlying state.
    sess.run(tf.global_variables_initializer())

    if do_restore:
      tf.logging.info('Restoring from checkpoint...')
      builder.saver.restore(sess, FLAGS.checkpoint_filename)

      prev_tagger_steps = stats[1]
      prev_parser_steps = stats[2]
      tf.logging.info('adjusting schedule from steps: %d, %d',
                      prev_tagger_steps, prev_parser_steps)
      pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0)
      tf.logging.info('new pretrain steps: %d', pretrain_steps[0])

    trainer_lib.run_training(
        sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps,
        train_steps, training_set, tune_set, tune_set, FLAGS.batch_size,
        summary_writer, FLAGS.report_every, builder.saver,
        FLAGS.checkpoint_filename, stats)
示例#13
0
def build_complete_master_spec(resource_path) :
    tf.logging.info('Building MasterSpec...')
    master_spec = build_master_spec()
    spec_builder.complete_master_spec(master_spec, None, resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    return master_spec
示例#14
0
def main(unused_argv):

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    ## SEGMENTATION ##

    if not FLAGS.use_gold_segmentation:

        # Reads master spec.
        master_spec = spec_pb2.MasterSpec()
        with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
            text_format.Parse(fin.read(), master_spec)

        if FLAGS.complete_master_spec:
            spec_builder.complete_master_spec(master_spec, None,
                                              FLAGS.segmenter_resource_dir)

        # Graph building.
        tf.logging.info('Building the graph')
        g = tf.Graph()
        with g.as_default(), tf.device('/device:CPU:0'):
            hyperparam_config = spec_pb2.GridPoint()
            hyperparam_config.use_moving_average = True
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            annotator = builder.add_annotation()
            builder.add_saver()

        tf.logging.info('Reading documents...')
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()
        with tf.Session(graph=tf.Graph()) as tmp_session:
            char_input = gen_parser_ops.char_token_generator(input_corpus)
            char_corpus = tmp_session.run(char_input)
        check.Eq(len(input_corpus), len(char_corpus))

        session_config = tf.ConfigProto(
            log_device_placement=False,
            intra_op_parallelism_threads=FLAGS.threads,
            inter_op_parallelism_threads=FLAGS.threads)

        with tf.Session(graph=g, config=session_config) as sess:
            tf.logging.info('Initializing variables...')
            sess.run(tf.global_variables_initializer())
            tf.logging.info('Loading from checkpoint...')
            sess.run('save/restore_all',
                     {'save/Const:0': FLAGS.segmenter_checkpoint_file})

            tf.logging.info('Processing sentences...')

            processed = []
            start_time = time.time()
            run_metadata = tf.RunMetadata()
            for start in range(0, len(char_corpus), FLAGS.max_batch_size):
                end = min(start + FLAGS.max_batch_size, len(char_corpus))
                feed_dict = {annotator['input_batch']: char_corpus[start:end]}
                if FLAGS.timeline_output_file and end == len(char_corpus):
                    serialized_annotations = sess.run(
                        annotator['annotations'],
                        feed_dict=feed_dict,
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                    trace = timeline.Timeline(
                        step_stats=run_metadata.step_stats)
                    with open(FLAGS.timeline_output_file, 'w') as trace_file:
                        trace_file.write(trace.generate_chrome_trace_format())
                else:
                    serialized_annotations = sess.run(annotator['annotations'],
                                                      feed_dict=feed_dict)
                processed.extend(serialized_annotations)

            tf.logging.info('Processed %d documents in %.2f seconds.',
                            len(char_corpus),
                            time.time() - start_time)

        input_corpus = processed
    else:
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()

    ## PARSING

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.parser_resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all',
                 {'save/Const:0': FLAGS.parser_checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write('#' + sentence.text.encode('utf-8') + '\n')
                    for i, token in enumerate(sentence.token):
                        head = token.head + 1
                        f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
                                (i + 1, token.word.encode('utf-8'), head,
                                 token.label.encode('utf-8')))
                    f.write('\n\n')