def export(master_spec_path, params_path, export_path, export_moving_averages,
           build_runtime_graph):
    """Restores a model and exports it in SavedModel form.

  This method loads a graph specified by the spec at master_spec_path and the
  params in params_path. It then saves the model in SavedModel format to the
  location specified in export_path.

  Args:
    master_spec_path: Path to a proto-text master spec.
    params_path: Path to the parameters file to export.
    export_path: Path to export the SavedModel to.
    export_moving_averages: Whether to export the moving average parameters.
    build_runtime_graph: Whether to build a graph for use by the runtime.
  """

    graph = tf.Graph()
    master_spec = spec_pb2.MasterSpec()
    with tf.gfile.FastGFile(master_spec_path) as fin:
        text_format.Parse(fin.read(), master_spec)

    # Remove '/' if it exists at the end of the export path, ensuring that
    # path utils work correctly.
    stripped_path = export_path.rstrip('/')
    saver_lib.clean_output_paths(stripped_path)

    short_to_original = saver_lib.shorten_resource_paths(master_spec)
    saver_lib.export_master_spec(master_spec, graph)
    saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph,
                              export_moving_averages, build_runtime_graph)
    saver_lib.export_assets(master_spec, short_to_original, stripped_path)
    def load_model(self, base_dir, master_spec_name, checkpoint_name="checkpoint", rename=True):
        try:
            master_spec = spec_pb2.MasterSpec()
            with open(os.path.join(base_dir, master_spec_name)) as f:
                text_format.Merge(f.read(), master_spec)
            spec_builder.complete_master_spec(master_spec, None, base_dir)

            graph = tf.Graph()
            with graph.as_default():
                hyperparam_config = spec_pb2.GridPoint()
                builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
                annotator = builder.add_annotation(enable_tracing=True)
                builder.add_saver()

            sess = tf.Session(graph=graph)
            with graph.as_default():
                builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name))

            def annotate_sentence(sentence):
                with graph.as_default():
                    return sess.run([annotator['annotations'], annotator['traces']],
                                    feed_dict={annotator['input_batch']: [sentence]})
        except:
            if rename:
                self.rename_vars(base_dir, checkpoint_name)
                return self.load_model(base_dir, master_spec_name, checkpoint_name, False)
            raise Exception('Cannot load model: spec expects references to */kernel tensors instead of */weights.\
            Try running with rename=True or run rename_vars() to convert existing checkpoint files into supported format')

        return annotate_sentence
    def _load_model(self, base_dir, master_spec_name):
        master_spec = spec_pb2.MasterSpec()
        with open(os.path.join(base_dir, master_spec_name)) as f:
            text_format.Merge(f.read(), master_spec)
        spec_builder.complete_master_spec(master_spec, None, base_dir)

        graph = tf.Graph()
        with graph.as_default():
            hyperparam_config = spec_pb2.GridPoint()
            builder = graph_builder.MasterBuilder(
                master_spec,
                hyperparam_config
            )
            annotator = builder.add_annotation(enable_tracing=True)
            builder.add_saver()

        sess = tf.Session(graph=graph)
        with graph.as_default():
            builder.saver.restore(sess, os.path.join(base_dir, "checkpoint"))

        def annotate_sentence(sentence):
            with graph.as_default():
                return sess.run(
                    [annotator['annotations'], annotator['traces']],
                    feed_dict={annotator['input_batch']: [sentence]}
                )
        return annotate_sentence
Beispiel #4
0
def _make_basic_master_spec():
    """Constructs a simple spec.

  Modified version of nlp/saft/opensource/dragnn/tools/parser_trainer.py

  Returns:
    spec_pb2.MasterSpec instance.
  """
    # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN
    # sequence model, which encodes the context to the right of each token. It has
    # no loss except for the downstream components.
    lookahead = spec_builder.ComponentSpecBuilder('lookahead')
    lookahead.set_network_unit(name='FeedForwardNetwork',
                               hidden_layer_sizes='256')
    lookahead.set_transition_system(name='shift-only', left_to_right='true')
    lookahead.add_fixed_feature(name='words',
                                fml='input.word',
                                embedding_dim=64)
    lookahead.add_rnn_link(embedding_dim=-1)

    # Construct the ComponentSpec for parsing.
    parser = spec_builder.ComponentSpecBuilder('parser')
    parser.set_network_unit(name='FeedForwardNetwork',
                            hidden_layer_sizes='256')
    parser.set_transition_system(name='arc-standard')
    parser.add_token_link(source=lookahead,
                          fml='input.focus',
                          embedding_dim=32)

    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend([lookahead.spec, parser.spec])
    return master_spec
Beispiel #5
0
 def LoadSpec(self, spec_path):
   master_spec = spec_pb2.MasterSpec()
   testdata = os.path.join(test_flags.source_root(),
                           'dragnn/core/testdata')
   with open(os.path.join(testdata, spec_path), 'r') as fin:
     text_format.Parse(fin.read().replace('TESTDATA', testdata), master_spec)
     return master_spec
def main(argv):
    del argv  # unused
    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    lexicon.build_lexicon(lexicon_dir,
                          training_sentence,
                          training_corpus_format='sentence-prototext')

    # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
    # sequence tagger.
    tagger = spec_builder.ComponentSpecBuilder('tagger')
    tagger.set_network_unit(name='FeedForwardNetwork',
                            hidden_layer_sizes='256')
    tagger.set_transition_system(name='tagger')
    tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
    tagger.add_rnn_link(embedding_dim=-1)
    tagger.fill_from_resources(lexicon_dir)

    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend([tagger.spec])

    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

        target = spec_pb2.TrainTarget()
        target.name = 'all'
        target.unroll_using_oracle.extend([True])
        dry_run = builder.add_training_from_config(target, trace_only=True)

    # Read in serialized protos from training data.
    sentence = sentence_pb2.Sentence()
    text_format.Merge(open(training_sentence).read(), sentence)
    training_set = [sentence.SerializeToString()]

    with tf.Session(graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.initialize_all_variables())
        traces = sess.run(dry_run['traces'],
                          feed_dict={dry_run['input_batch']: training_set})

    with open('dragnn_tutorial_1.html', 'w') as f:
        f.write(
            visualization.trace_html(traces[0],
                                     height='300px').encode('utf-8'))
 def __init__(self):
   self.spec = spec_pb2.MasterSpec()
   self.hyperparams = spec_pb2.GridPoint()
   self.lookup_component = {
       'previous': MockComponent(self, spec_pb2.ComponentSpec())
   }
Beispiel #8
0
def main(unused_argv):

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    ## SEGMENTATION ##

    if not FLAGS.use_gold_segmentation:

        # Reads master spec.
        master_spec = spec_pb2.MasterSpec()
        with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
            text_format.Parse(fin.read(), master_spec)

        if FLAGS.complete_master_spec:
            spec_builder.complete_master_spec(master_spec, None,
                                              FLAGS.segmenter_resource_dir)

        # Graph building.
        tf.logging.info('Building the graph')
        g = tf.Graph()
        with g.as_default(), tf.device('/device:CPU:0'):
            hyperparam_config = spec_pb2.GridPoint()
            hyperparam_config.use_moving_average = True
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            annotator = builder.add_annotation()
            builder.add_saver()

        tf.logging.info('Reading documents...')
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()
        with tf.Session(graph=tf.Graph()) as tmp_session:
            char_input = gen_parser_ops.char_token_generator(input_corpus)
            char_corpus = tmp_session.run(char_input)
        check.Eq(len(input_corpus), len(char_corpus))

        session_config = tf.ConfigProto(
            log_device_placement=False,
            intra_op_parallelism_threads=FLAGS.threads,
            inter_op_parallelism_threads=FLAGS.threads)

        with tf.Session(graph=g, config=session_config) as sess:
            tf.logging.info('Initializing variables...')
            sess.run(tf.global_variables_initializer())
            tf.logging.info('Loading from checkpoint...')
            sess.run('save/restore_all',
                     {'save/Const:0': FLAGS.segmenter_checkpoint_file})

            tf.logging.info('Processing sentences...')

            processed = []
            start_time = time.time()
            run_metadata = tf.RunMetadata()
            for start in range(0, len(char_corpus), FLAGS.max_batch_size):
                end = min(start + FLAGS.max_batch_size, len(char_corpus))
                feed_dict = {annotator['input_batch']: char_corpus[start:end]}
                if FLAGS.timeline_output_file and end == len(char_corpus):
                    serialized_annotations = sess.run(
                        annotator['annotations'],
                        feed_dict=feed_dict,
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                    trace = timeline.Timeline(
                        step_stats=run_metadata.step_stats)
                    with open(FLAGS.timeline_output_file, 'w') as trace_file:
                        trace_file.write(trace.generate_chrome_trace_format())
                else:
                    serialized_annotations = sess.run(annotator['annotations'],
                                                      feed_dict=feed_dict)
                processed.extend(serialized_annotations)

            tf.logging.info('Processed %d documents in %.2f seconds.',
                            len(char_corpus),
                            time.time() - start_time)

        input_corpus = processed
    else:
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()

    ## PARSING

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.parser_resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all',
                 {'save/Const:0': FLAGS.parser_checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write('#' + sentence.text.encode('utf-8') + '\n')
                    for i, token in enumerate(sentence.token):
                        head = token.head + 1
                        f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
                                (i + 1, token.word.encode('utf-8'), head,
                                 token.label.encode('utf-8')))
                    f.write('\n\n')
 def __init__(self):
     self.spec = spec_pb2.MasterSpec()
     self.hyperparams = spec_pb2.GridPoint()
     self.lookup_component = {'mock': MockComponent()}
     self.build_runtime_graph = False
Beispiel #10
0
def main(argv):
  del argv  # unused
  # Constructs lexical resources for SyntaxNet in the given resource path, from
  # the training data.
  lexicon.build_lexicon(
      lexicon_dir,
      training_sentence,
      training_corpus_format='sentence-prototext')

  # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  # sequence tagger.
  tagger = spec_builder.ComponentSpecBuilder('tagger')
  tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  tagger.set_transition_system(name='tagger')
  tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  tagger.add_rnn_link(embedding_dim=-1)
  tagger.fill_from_resources(lexicon_dir)

  # Construct the ComponentSpec for parsing.
  parser = spec_builder.ComponentSpecBuilder('parser')
  parser.set_network_unit(
      name='FeedForwardNetwork',
      hidden_layer_sizes='256',
      layer_norm_hidden='True')
  parser.set_transition_system(name='arc-standard')
  parser.add_token_link(
      source=tagger,
      fml='input.focus stack.focus stack(1).focus',
      embedding_dim=32,
      source_layer='logits')

  # Recurrent connection for the arc-standard parser. For both tokens on the
  # stack, we connect to the last time step to either SHIFT or REDUCE that
  # token. This allows the parser to build up compositional representations of
  # phrases.
  parser.add_link(
      source=parser,  # recurrent connection
      name='rnn-stack',  # unique identifier
      fml='stack.focus stack(1).focus',  # look for both stack tokens
      source_translator='shift-reduce-step',  # maps token indices -> step
      embedding_dim=32)  # project down to 32 dims
  parser.fill_from_resources(lexicon_dir)

  master_spec = spec_pb2.MasterSpec()
  master_spec.component.extend([tagger.spec, parser.spec])

  hyperparam_config = spec_pb2.GridPoint()

  # Build the TensorFlow graph.
  graph = tf.Graph()
  with graph.as_default():
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

    target = spec_pb2.TrainTarget()
    target.name = 'all'
    target.unroll_using_oracle.extend([True, True])
    dry_run = builder.add_training_from_config(target, trace_only=True)

  # Read in serialized protos from training data.
  sentence = sentence_pb2.Sentence()
  text_format.Merge(open(training_sentence).read(), sentence)
  training_set = [sentence.SerializeToString()]

  with tf.Session(graph=graph) as sess:
    # Make sure to re-initialize all underlying state.
    sess.run(tf.initialize_all_variables())
    traces = sess.run(
        dry_run['traces'], feed_dict={dry_run['input_batch']: training_set})

  with open('dragnn_tutorial_2.html', 'w') as f:
    f.write(
        visualization.trace_html(
            traces[0], height='400px', master_spec=master_spec).encode('utf-8'))